diff --git a/CMakeLists.txt b/CMakeLists.txt index 86fc2f0..e261d51 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ project( SamplinSafari DESCRIPTION "A research tool to visualize and interactively inspect high-dimensional (quasi) Monte Carlo samplers." VERSION ${VERSION} - LANGUAGES C CXX + LANGUAGES C CXX OBJC ) message(STATUS "C++ compiler is: ${CMAKE_CXX_COMPILER_ID}") @@ -178,8 +178,9 @@ if(portable-file-dialogs_ADDED) target_include_directories(portable-file-dialogs INTERFACE "${portable-file-dialogs_SOURCE_DIR}") endif() -set(HELLOIMGUI_WITH_GLFW ON) -CPMAddPackage("gh:wkjarosz/hello_imgui#745ced9d52be601097e4b7d0837ad67a0b93db32") +# set(HELLOIMGUI_USE_GLFW_METAL ON) +set(HELLOIMGUI_USE_GLFW_OPENGL3 ON) +CPMAddPackage("gh:pthom/hello_imgui#0c8e5e848738dd7160485d96da3aa9e9033c62ad") if(hello_imgui_ADDED) message(STATUS "hello_imgui library added") endif() @@ -334,9 +335,13 @@ hello_imgui_add_app( SamplinSafari src/app.cpp ${CMAKE_CURRENT_BINARY_DIR}/src/common.cpp + src/opengl_check.cpp src/shader.cpp src/shader_gl.cpp + src/shader_metal.mm src/export_to_file.cpp + src/renderpass_gl.cpp + src/renderpass_metal.mm ASSETS_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/assets ) diff --git a/include/app.h b/include/app.h index 49c831d..2b8e10e 100644 --- a/include/app.h +++ b/include/app.h @@ -37,6 +37,7 @@ using namespace linalg::aliases; #include "arcball.h" #include "hello_imgui/hello_imgui.h" #include "misc/cpp/imgui_stdlib.h" +#include "renderpass.h" #include "shader.h" #include #include @@ -111,7 +112,7 @@ class SampleViewer void draw_text(const int2 &pos, const std::string &text, const float4 &col, ImFont *font = nullptr, int align = TextAlign_RIGHT | TextAlign_BOTTOM) const; void draw_points(const float4x4 &mvp, const float4x4 &smash, const float3 &color); - void draw_grid(const float4x4 &mat, int2 size, float alpha) const; + void draw_grid(const float4x4 &mat, int2 size, float alpha); void draw_trigrid(Shader *shader, const float4x4 &mvp, float alpha, const int2x3 &count); void draw_2D_points_and_grid(const float4x4 &mvp, int2 dims, int plotIndex); int2 get_draw_range() const; @@ -139,9 +140,9 @@ class SampleViewer bool m_show_1d_projections = false, m_show_point_nums = false, m_show_point_coords = false, m_show_coarse_grid = false, m_show_fine_grid = false, m_show_custom_grid = false, m_show_bbox = false; - Shader *m_3d_point_shader = nullptr, *m_2d_point_shader = nullptr, *m_grid_shader = nullptr; + RenderPass m_render_pass; + Shader *m_3d_point_shader = nullptr, *m_2d_point_shader = nullptr, *m_grid_shader = nullptr; - int2 m_viewport_pos, m_viewport_pos_GL, m_viewport_size; float m_animate_start_time = 0.0f; bool m_subset_by_index = false; diff --git a/include/opengl_check.h b/include/opengl_check.h new file mode 100644 index 0000000..19c27c7 --- /dev/null +++ b/include/opengl_check.h @@ -0,0 +1,19 @@ +/** + \file opengl_check.h +*/ +#pragma once + +bool check_glerror(const char *cmd); + +#if defined(NDEBUG) +#define CHK(cmd) cmd +#else +#define CHK(cmd) \ + do \ + { \ + cmd; \ + while (check_glerror(#cmd)) \ + { \ + } \ + } while (false) +#endif \ No newline at end of file diff --git a/include/renderpass.h b/include/renderpass.h new file mode 100644 index 0000000..29efeff --- /dev/null +++ b/include/renderpass.h @@ -0,0 +1,154 @@ +/** + \file renderpass.h +*/ +#pragma once + +#include "linalg.h" +#include + +using namespace linalg::aliases; + +/// An abstraction for rendering passes that work with OpenGL, OpenGL ES, and Metal. +/* + This is a greatly simplified version of NanoGUI's RenderPass class. Original copyright follows. + ---------- + NanoGUI was developed by Wenzel Jakob . + The widget drawing code is based on the NanoVG demo application + by Mikko Mononen. + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE.txt file. +*/ +class RenderPass +{ +public: + /// Depth test + enum class DepthTest + { + Never, + Less, + Equal, + LessEqual, + Greater, + NotEqual, + GreaterEqual, + Always + }; + + /// Culling mode + enum class CullMode + { + Disabled, + Front, + Back + }; + + /** + * Create a new render pass for rendering to the main color and (optionally) depth buffer. + * + * \param write_depth + * Should we write to the depth buffer? + * + * \param clear + * Should \ref enter() begin by clearing all buffers? + */ + RenderPass(bool write_depth = true, bool clear = true); + + ~RenderPass(); + + /** + * Begin the render pass + * + * The specified drawing state (e.g. depth tests, culling mode, blending mode) are automatically set up at this + * point. Later changes between \ref begin() and \ref end() are possible but cause additional OpenGL/GLES/Metal API + * calls. + */ + void begin(); + + /// Finish the render pass + void end(); + + /// Return the clear color for a given color attachment + const float4 &clear_color() const + { + return m_clear_color; + } + + /// Set the clear color for a given color attachment + void set_clear_color(const float4 &color); + + /// Return the clear depth for the depth attachment + float clear_depth() const + { + return m_clear_depth; + } + + /// Set the clear depth for the depth attachment + void set_clear_depth(float depth); + + /// Specify the depth test and depth write mask of this render pass + void set_depth_test(DepthTest depth_test, bool depth_write); + + /// Return the depth test and depth write mask of this render pass + std::pair depth_test() const + { + return {m_depth_test, m_depth_write}; + } + + /// Set the pixel offset and size of the viewport region + void set_viewport(const int2 &offset, const int2 &size); + + /// Return the pixel offset and size of the viewport region + std::pair viewport() + { + return {m_viewport_offset, m_viewport_size}; + } + + /// Specify the culling mode associated with the render pass + void set_cull_mode(CullMode mode); + + /// Return the culling mode associated with the render pass + CullMode cull_mode() const + { + return m_cull_mode; + } + + /// Resize all texture targets attached to the render pass + void resize(const int2 &size); + +#if defined(HELLOIMGUI_HAS_METAL) + void *command_encoder() const + { + return m_command_encoder; + } + void *command_buffer() const + { + return m_command_buffer; + } +#endif + +protected: + bool m_clear; + float4 m_clear_color; + float m_clear_depth; + int2 m_viewport_offset; + int2 m_viewport_size; + int2 m_framebuffer_size; + DepthTest m_depth_test; + bool m_depth_write; + CullMode m_cull_mode; + bool m_active; +#if defined(HELLOIMGUI_HAS_OPENGL) + int4 m_viewport_backup, m_scissor_backup; + bool m_depth_test_backup; + bool m_depth_write_backup; + bool m_scissor_test_backup; + bool m_cull_face_backup; + bool m_blend_backup; +#elif defined(HELLOIMGUI_HAS_METAL) + void *m_command_buffer; + void *m_command_encoder; + void *m_pass_descriptor; +// ref m_clear_shader; +#endif +}; diff --git a/include/shader.h b/include/shader.h index 2edd54e..ee5e8e8 100644 --- a/include/shader.h +++ b/include/shader.h @@ -8,6 +8,8 @@ #include #include +class RenderPass; + /// An abstraction for shaders that work with OpenGL, OpenGL ES, and (at some point down the road, hopefully) Metal. /* This is adapted from NanoGUI's Shader class. Copyright follows. @@ -44,6 +46,9 @@ class Shader /** Initialize the shader using the source files (read from the assets directory). + \param render_pass + RenderPass object encoding targets to which color and depth information will be rendered. + \param name A name identifying this shader @@ -53,11 +58,14 @@ class Shader \param fs_filename Filename of the fragment shader source code. */ - Shader(const std::string &name, const std::string &vs_filename, const std::string &fs_filename, - BlendMode blend_mode = BlendMode::None); + Shader(RenderPass *render_pass, const std::string &name, const std::string &vs_filename, + const std::string &fs_filename, BlendMode blend_mode = BlendMode::None); /// Return the render pass associated with this shader - // RenderPass *render_pass() { return m_render_pass; } + RenderPass *render_pass() + { + return m_render_pass; + } /// Return the name of this shader const std::string &name() const @@ -256,7 +264,7 @@ class Shader virtual ~Shader(); protected: - // RenderPass* m_render_pass; + RenderPass *m_render_pass; std::string m_name; std::unordered_map m_buffers; BlendMode m_blend_mode; @@ -271,14 +279,3 @@ class Shader void *m_pipeline_state; #endif }; - -bool check_glerror(const char *cmd); - -#define CHK(cmd) \ - do \ - { \ - cmd; \ - while (check_glerror(#cmd)) \ - { \ - } \ - } while (false) diff --git a/src/app.cpp b/src/app.cpp index 16d04b0..ece0a57 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -9,6 +9,8 @@ #include "imgui_ext.h" #include "imgui_internal.h" +#include "opengl_check.h" + #include #include #include @@ -374,22 +376,22 @@ SampleViewer::SampleViewer() try { auto quad_verts = - vector{{-0.5f, -0.5f, 0.f}, {-0.5f, 0.5f, 0.0f}, {0.5f, 0.5f, 0.0f}, {0.5f, -0.5f, 0.0f}}; - m_2d_point_shader = new Shader("2D point shader", "shaders/point.vert", "shaders/point.frag", - Shader::BlendMode::AlphaBlend); + vector{{-0.5f, -0.5f, 0.f}, {0.5f, -0.5f, 0.0f}, {0.5f, 0.5f, 0.0f}, {-0.5f, 0.5f, 0.0f}}; + m_2d_point_shader = new Shader(&m_render_pass, "2D point shader", "shaders/point.vert", + "shaders/point.frag", Shader::BlendMode::AlphaBlend); m_2d_point_shader->set_buffer("vertices", quad_verts); m_2d_point_shader->set_buffer_divisor("vertices", 0); - m_3d_point_shader = new Shader("3D point shader", "shaders/point.vert", "shaders/point.frag", - Shader::BlendMode::AlphaBlend); + m_3d_point_shader = new Shader(&m_render_pass, "3D point shader", "shaders/point.vert", + "shaders/point.frag", Shader::BlendMode::AlphaBlend); m_3d_point_shader->set_buffer("vertices", quad_verts); m_3d_point_shader->set_buffer_divisor("vertices", 0); - m_grid_shader = - new Shader("Grid shader", "shaders/grid.vert", "shaders/grid.frag", Shader::BlendMode::AlphaBlend); + m_grid_shader = new Shader(&m_render_pass, "Grid shader", "shaders/grid.vert", "shaders/grid.frag", + Shader::BlendMode::AlphaBlend); m_grid_shader->set_buffer( "position", - vector{{-0.5f, -0.5f, 0.5f}, {-0.5f, 1.5f, 0.5f}, {1.5f, 1.5f, 0.5f}, {1.5f, -0.5f, 0.5f}}); + vector{{-0.5f, -0.5f, 0.5f}, {1.5f, -0.5f, 0.5f}, {1.5f, 1.5f, 0.5f}, {-0.5f, 1.5f, 0.5f}}); HelloImGui::Log(HelloImGui::LogLevel::Info, "Successfully initialized GL!"); } @@ -1249,44 +1251,24 @@ void SampleViewer::draw_background() update_points(m_cpu_points_dirty); // - // clear the scene and set up viewports - // - // calculate the viewport sizes - m_viewport_pos_GL = m_viewport_pos = {0, 0}; - m_viewport_size = io.DisplaySize; + // fbsize is the size of the window in pixels while accounting for dpi factor on retina screens. + // for retina displays, io.DisplaySize is the size of the window in points (logical pixels) + // but we need the size in pixels. So we scale io.DisplaySize by io.DisplayFramebufferScale + int2 fbscale = io.DisplayFramebufferScale; + auto fbsize = int2{io.DisplaySize} * fbscale; + int2 viewport_offset = {0, 0}; + int2 viewport_size = io.DisplaySize; if (auto id = m_params.dockingParams.dockSpaceIdFromName("MainDockSpace")) { auto central_node = ImGui::DockBuilderGetCentralNode(*id); - m_viewport_size = int2{int(central_node->Size.x), int(central_node->Size.y)}; - m_viewport_pos = int2{int(central_node->Pos.x), int(central_node->Pos.y)}; - // flip y coordinates between ImGui and OpenGL screen coordinates - m_viewport_pos_GL = - int2{int(central_node->Pos.x), int(io.DisplaySize.y - (central_node->Pos.y + central_node->Size.y))}; + viewport_size = int2{int(central_node->Size.x), int(central_node->Size.y)}; + viewport_offset = int2{int(central_node->Pos.x), int(central_node->Pos.y)}; } - // first clear the entire window with the background color - // display_size is the size of the window in pixels while accounting for dpi factor on retina screens. - // for retina displays, io.DisplaySize is the size of the window in points (logical pixels) - // but we need the size in pixels. So we scale io.DisplaySize by io.DisplayFramebufferScale - auto display_size = int2{io.DisplaySize} * int2{io.DisplayFramebufferScale}; - CHK(glViewport(0, 0, display_size.x, display_size.y)); - CHK(glClearColor(m_bg_color[0], m_bg_color[1], m_bg_color[2], 1.f)); - CHK(glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)); - - // now set up a new viewport for the rest of the drawing - CHK(glViewport( - m_viewport_pos_GL.x * io.DisplayFramebufferScale.x, m_viewport_pos_GL.y * io.DisplayFramebufferScale.y, - m_viewport_size.x * io.DisplayFramebufferScale.x, m_viewport_size.y * io.DisplayFramebufferScale.y)); - // inform the arcballs of the viewport size for (int i = 0; i < NUM_CAMERA_TYPES; ++i) - m_camera[i].arcball.set_size(m_viewport_size); - - // enable depth testing - CHK(glEnable(GL_DEPTH_TEST)); - CHK(glDepthFunc(GL_LESS)); - CHK(glDepthMask(GL_TRUE)); + m_camera[i].arcball.set_size(viewport_size); // // process camera movement @@ -1299,7 +1281,7 @@ void SampleViewer::draw_background() // on mouse down we start switching to a perspective camera // and start recording the arcball rotation in CAMERA_NEXT set_view(CAMERA_CURRENT); - m_camera[CAMERA_NEXT].arcball.button(int2{io.MousePos} - m_viewport_pos, io.MouseDown[0]); + m_camera[CAMERA_NEXT].arcball.button(int2{io.MousePos} - viewport_offset, io.MouseDown[0]); m_camera[CAMERA_NEXT].camera_type = CAMERA_CURRENT; } @@ -1307,7 +1289,7 @@ void SampleViewer::draw_background() // background (e.g. when dismissing a popup with a mouse click) if (io.MouseReleased[0] && !io.MouseDownOwned[0]) { - m_camera[CAMERA_NEXT].arcball.button(int2{io.MousePos} - m_viewport_pos, io.MouseDown[0]); + m_camera[CAMERA_NEXT].arcball.button(int2{io.MousePos} - viewport_offset, io.MouseDown[0]); // since the time between mouse down and up could be shorter // than the animation duration, we override the previous // camera's arcball on mouse up to complete the animation @@ -1316,7 +1298,7 @@ void SampleViewer::draw_background() } if (io.MouseDown[0]) - m_camera[CAMERA_NEXT].arcball.motion(int2{io.MousePos} - m_viewport_pos); + m_camera[CAMERA_NEXT].arcball.motion(int2{io.MousePos} - viewport_offset); } // @@ -1353,7 +1335,19 @@ void SampleViewer::draw_background() camera.arcball.set_state(qslerp(camera0.arcball.state(), camera1.arcball.state(), t)); } - float4x4 mvp = m_camera[CAMERA_CURRENT].matrix(float(m_viewport_size.x) / m_viewport_size.y); + // + // clear the framebuffer and set up the viewport + // + + m_render_pass.resize(fbsize); + m_render_pass.set_viewport(viewport_offset * fbscale, viewport_size * fbscale); + m_render_pass.set_clear_color(float4{m_bg_color, 1.f}); + m_render_pass.set_cull_mode(RenderPass::CullMode::Disabled); + m_render_pass.set_depth_test(RenderPass::DepthTest::Less, true); + + m_render_pass.begin(); + + float4x4 mvp = m_camera[CAMERA_CURRENT].matrix(float(viewport_size.x) / viewport_size.y); // // Now render the points and grids @@ -1371,16 +1365,16 @@ void SampleViewer::draw_background() float4x4 pos = layout_2d_matrix(m_num_dimensions, int2{i, m_num_dimensions - 1}); float4 text_pos = mul(mvp, mul(pos, float4{0.f, -0.5f, -1.0f, 1.0f})); float2 text_2d_pos((text_pos.x / text_pos.w + 1) / 2, (text_pos.y / text_pos.w + 1) / 2); - draw_text(m_viewport_pos + int2(int((text_2d_pos.x) * m_viewport_size.x), - int((1.f - text_2d_pos.y) * m_viewport_size.y) + 16), + draw_text(viewport_offset + int2(int((text_2d_pos.x) * viewport_size.x), + int((1.f - text_2d_pos.y) * viewport_size.y) + 16), to_string(i), float4(1.0f, 1.0f, 1.0f, 0.75f), m_regular[16], TextAlign_CENTER | TextAlign_BOTTOM); pos = layout_2d_matrix(m_num_dimensions, int2{0, i + 1}); text_pos = mul(mvp, mul(pos, float4{-0.5f, 0.f, -1.0f, 1.0f})); text_2d_pos = float2((text_pos.x / text_pos.w + 1) / 2, (text_pos.y / text_pos.w + 1) / 2); - draw_text(m_viewport_pos + int2(int((text_2d_pos.x) * m_viewport_size.x) - 4, - int((1.f - text_2d_pos.y) * m_viewport_size.y)), + draw_text(viewport_offset + int2(int((text_2d_pos.x) * viewport_size.x) - 4, + int((1.f - text_2d_pos.y) * viewport_size.y)), to_string(i + 1), float4(1.0f, 1.0f, 1.0f, 0.75f), m_regular[16], TextAlign_RIGHT | TextAlign_MIDDLE); } @@ -1435,8 +1429,8 @@ void SampleViewer::draw_background() { float4 text_pos = mul(mvp, float4{m_3d_points[p] - 0.5f, 1.f}); float2 text_2d_pos((text_pos.x / text_pos.w + 1) / 2, (text_pos.y / text_pos.w + 1) / 2); - int2 draw_pos = m_viewport_pos + int2{int((text_2d_pos.x) * m_viewport_size.x), - int((1.f - text_2d_pos.y) * m_viewport_size.y)}; + int2 draw_pos = viewport_offset + int2{int((text_2d_pos.x) * viewport_size.x), + int((1.f - text_2d_pos.y) * viewport_size.y)}; if (m_show_point_nums) draw_text(draw_pos - int2{0, int(radius / 4)}, fmt::format("{:d}", p), float4(1.0f, 1.0f, 1.0f, 0.75f), m_regular[12], TextAlign_CENTER | TextAlign_BOTTOM); @@ -1447,6 +1441,8 @@ void SampleViewer::draw_background() float4(1.0f, 1.0f, 1.0f, 0.75f), m_regular[11], TextAlign_CENTER | TextAlign_TOP); } } + + m_render_pass.end(); } catch (const std::exception &e) { @@ -1490,16 +1486,20 @@ void SampleViewer::draw_points(const float4x4 &mvp, const float4x4 &smash, const m_3d_point_shader->end(); } -void SampleViewer::draw_grid(const float4x4 &mat, int2 size, float alpha) const +void SampleViewer::draw_grid(const float4x4 &mat, int2 size, float alpha) { m_grid_shader->set_uniform("mvp", mat); m_grid_shader->set_uniform("size", size); m_grid_shader->set_uniform("alpha", alpha); + + auto backup = m_render_pass.depth_test(); + m_render_pass.set_depth_test(RenderPass::DepthTest::Always, false); + m_grid_shader->begin(); - CHK(glDepthMask(GL_FALSE)); m_grid_shader->draw_array(Shader::PrimitiveType::TriangleFan, 0, 4); - CHK(glDepthMask(GL_TRUE)); m_grid_shader->end(); + + m_render_pass.set_depth_test(backup.first, backup.second); } /*! diff --git a/src/opengl_check.cpp b/src/opengl_check.cpp new file mode 100644 index 0000000..c3383fc --- /dev/null +++ b/src/opengl_check.cpp @@ -0,0 +1,32 @@ +#if defined(HELLOIMGUI_HAS_OPENGL) + +#include "hello_imgui/hello_imgui.h" +#include "hello_imgui/hello_imgui_include_opengl.h" // cross-platform way to include OpenGL headers +#include + +bool check_glerror(const char *cmd) +{ + GLenum err = glGetError(); + const char *msg = nullptr; + + switch (err) + { + case GL_NO_ERROR: return false; + case GL_INVALID_ENUM: msg = "invalid enumeration"; break; + case GL_INVALID_VALUE: msg = "invalid value"; break; + case GL_INVALID_OPERATION: msg = "invalid operation"; break; + case GL_INVALID_FRAMEBUFFER_OPERATION: msg = "invalid framebuffer operation"; break; + case GL_OUT_OF_MEMORY: msg = "out of memory"; break; +#ifndef __EMSCRIPTEN__ + case GL_STACK_UNDERFLOW: msg = "stack underflow"; break; + case GL_STACK_OVERFLOW: msg = "stack overflow"; break; +#endif + default: msg = "unknown error"; break; + } + + fmt::print(stderr, "OpenGL error ({}) during operation \"{}\"!\n", msg, cmd); + HelloImGui::Log(HelloImGui::LogLevel::Error, "OpenGL error (%s) during operation \"%s\"!\n", msg, cmd); + return true; +} + +#endif // defined(HELLOIMGUI_HAS_OPENGL) diff --git a/src/renderpass_gl.cpp b/src/renderpass_gl.cpp new file mode 100644 index 0000000..670de89 --- /dev/null +++ b/src/renderpass_gl.cpp @@ -0,0 +1,192 @@ +#if defined(HELLOIMGUI_HAS_OPENGL) + +#include "hello_imgui/hello_imgui_include_opengl.h" // cross-platform way to include OpenGL headers +#include "opengl_check.h" +#include "renderpass.h" + +#include + +RenderPass::RenderPass(bool write_depth, bool clear) : + m_clear(clear), m_clear_color(0, 0, 0, 0), m_clear_depth(1.f), m_viewport_offset(0), m_viewport_size(0), + m_framebuffer_size(0), m_depth_test(write_depth ? DepthTest::Less : DepthTest::Always), m_depth_write(write_depth), + m_cull_mode(CullMode::Back), m_active(false) +{ + m_viewport_size = m_framebuffer_size = int2{0, 0}; +} + +RenderPass::~RenderPass() +{ +} + +void RenderPass::begin() +{ +#if !defined(NDEBUG) + if (m_active) + throw std::runtime_error("RenderPass::begin(): render pass is already active!"); +#endif + m_active = true; + + CHK(glGetIntegerv(GL_VIEWPORT, &m_viewport_backup[0])); + CHK(glGetIntegerv(GL_SCISSOR_BOX, &m_scissor_backup[0])); + GLboolean depth_write; + CHK(glGetBooleanv(GL_DEPTH_WRITEMASK, &depth_write)); + m_depth_write_backup = depth_write; + + m_depth_test_backup = glIsEnabled(GL_DEPTH_TEST); + m_scissor_test_backup = glIsEnabled(GL_SCISSOR_TEST); + m_cull_face_backup = glIsEnabled(GL_CULL_FACE); + m_blend_backup = glIsEnabled(GL_BLEND); + + set_viewport(m_viewport_offset, m_viewport_size); + + // if (m_clear) + { + GLenum what = 0; + if (m_depth_write) + { + CHK(glClearDepthf(m_clear_depth)); + what |= GL_DEPTH_BUFFER_BIT; + } + + CHK(glClearColor(m_clear_color.x, m_clear_color.y, m_clear_color.z, m_clear_color.w)); + what |= GL_COLOR_BUFFER_BIT; + + CHK(glClear(what)); + } + + set_depth_test(m_depth_test, m_depth_write); + set_cull_mode(m_cull_mode); + + if (m_blend_backup) + CHK(glDisable(GL_BLEND)); +} + +void RenderPass::end() +{ +#if !defined(NDEBUG) + if (!m_active) + throw std::runtime_error("RenderPass::end(): render pass is not active!"); +#endif + + CHK(glViewport(m_viewport_backup[0], m_viewport_backup[1], m_viewport_backup[2], m_viewport_backup[3])); + CHK(glScissor(m_scissor_backup[0], m_scissor_backup[1], m_scissor_backup[2], m_scissor_backup[3])); + + if (m_depth_test_backup) + CHK(glEnable(GL_DEPTH_TEST)); + else + CHK(glDisable(GL_DEPTH_TEST)); + + CHK(glDepthMask(m_depth_write_backup)); + + if (m_scissor_test_backup) + CHK(glEnable(GL_SCISSOR_TEST)); + else + CHK(glDisable(GL_SCISSOR_TEST)); + + if (m_cull_face_backup) + CHK(glEnable(GL_CULL_FACE)); + else + CHK(glDisable(GL_CULL_FACE)); + + if (m_blend_backup) + CHK(glEnable(GL_BLEND)); + else + CHK(glDisable(GL_BLEND)); + + m_active = false; +} + +void RenderPass::resize(const int2 &size) +{ + m_framebuffer_size = size; + m_viewport_offset = int2(0, 0); + m_viewport_size = size; +} + +void RenderPass::set_clear_color(const float4 &color) +{ + m_clear_color = color; +} + +void RenderPass::set_clear_depth(float depth) +{ + m_clear_depth = depth; +} + +void RenderPass::set_viewport(const int2 &offset, const int2 &size) +{ + m_viewport_offset = offset; + m_viewport_size = size; + + if (m_active) + { + int ypos = m_framebuffer_size.y - m_viewport_size.y - m_viewport_offset.y; + CHK(glViewport(m_viewport_offset.x, ypos, m_viewport_size.x, m_viewport_size.y)); + // fmt::print("RenderPass::viewport({}, {}, {}, {})\n", m_viewport_offset.x, ypos, m_viewport_size.x, + // m_viewport_size.y); + CHK(glScissor(m_viewport_offset.x, ypos, m_viewport_size.x, m_viewport_size.y)); + + if (m_viewport_offset == int2(0, 0) && m_viewport_size == m_framebuffer_size) + CHK(glDisable(GL_SCISSOR_TEST)); + else + CHK(glEnable(GL_SCISSOR_TEST)); + } +} + +void RenderPass::set_depth_test(DepthTest depth_test, bool depth_write) +{ + m_depth_test = depth_test; + m_depth_write = depth_write; + + if (m_active) + { + if (depth_test != DepthTest::Always) + { + GLenum func; + switch (depth_test) + { + case DepthTest::Never: func = GL_NEVER; break; + case DepthTest::Less: func = GL_LESS; break; + case DepthTest::Equal: func = GL_EQUAL; break; + case DepthTest::LessEqual: func = GL_LEQUAL; break; + case DepthTest::Greater: func = GL_GREATER; break; + case DepthTest::NotEqual: func = GL_NOTEQUAL; break; + case DepthTest::GreaterEqual: func = GL_GEQUAL; break; + default: throw std::runtime_error("Shader::set_depth_test(): invalid depth test mode!"); + } + CHK(glEnable(GL_DEPTH_TEST)); + CHK(glDepthFunc(func)); + } + else + { + CHK(glDisable(GL_DEPTH_TEST)); + } + CHK(glDepthMask(depth_write ? GL_TRUE : GL_FALSE)); + // fmt::print("RenderPass::set_depth_test({}, {})\n", (int)depth_test, depth_write); + } +} + +void RenderPass::set_cull_mode(CullMode cull_mode) +{ + m_cull_mode = cull_mode; + + if (m_active) + { + if (cull_mode == CullMode::Disabled) + { + CHK(glDisable(GL_CULL_FACE)); + } + else + { + CHK(glEnable(GL_CULL_FACE)); + if (cull_mode == CullMode::Front) + CHK(glCullFace(GL_FRONT)); + else if (cull_mode == CullMode::Back) + CHK(glCullFace(GL_BACK)); + else + throw std::runtime_error("Shader::set_cull_mode(): invalid cull mode!"); + } + } +} + +#endif // defined(HELLOIMGUI_HAS_OPENGL) diff --git a/src/renderpass_metal.mm b/src/renderpass_metal.mm new file mode 100644 index 0000000..6ec12bb --- /dev/null +++ b/src/renderpass_metal.mm @@ -0,0 +1,221 @@ +// The Metal version is still an untested work-in-progress +#if defined(HELLOIMGUI_HAS_METAL) + +#import +#import + +#include "hello_imgui/internal/backend_impls/rendering_metal.h" +#include "renderpass.h" +#include "shader.h" + +RenderPass::RenderPass(bool write_depth, bool clear) : + m_clear(clear), m_clear_color(0, 0, 0, 0), m_clear_depth(1.f), m_viewport_offset(0), m_viewport_size(0), + m_framebuffer_size(0), m_depth_test(write_depth ? DepthTest::Less : DepthTest::Always), m_depth_write(write_depth), + m_cull_mode(CullMode::Back), m_active(false), m_command_buffer(nullptr), m_command_encoder(nullptr) +{ + m_pass_descriptor = [MTLRenderPassDescriptor new]; + + set_clear_color(m_clear_color); + set_clear_depth(m_clear_depth); +} + +RenderPass::~RenderPass() +{ + MTLRenderPassDescriptor *pass_descriptor = (MTLRenderPassDescriptor *)m_pass_descriptor; + [pass_descriptor release]; +} + +void RenderPass::begin() +{ +#if !defined(NDEBUG) + if (m_active) + throw std::runtime_error("RenderPass::begin(): render pass is already active!"); +#endif + + auto &gMetalGlobals = HelloImGui::GetMetalGlobals(); + + id command_buffer = [gMetalGlobals.mtlCommandQueue commandBuffer]; + + MTLRenderPassDescriptor *pass_descriptor = (__bridge MTLRenderPassDescriptor *)m_pass_descriptor; + + bool clear_manual = m_clear && (m_viewport_offset != int2(0, 0) || m_viewport_size != m_framebuffer_size); + + pass_descriptor.colorAttachments[0].texture = gMetalGlobals.caMetalDrawable.texture; + pass_descriptor.colorAttachments[0].loadAction = m_clear && !clear_manual ? MTLLoadActionClear : MTLLoadActionLoad; + pass_descriptor.colorAttachments[0].storeAction = MTLStoreActionStore; + + id command_encoder = [command_buffer renderCommandEncoderWithDescriptor:pass_descriptor]; + + [command_encoder setFrontFacingWinding:MTLWindingCounterClockwise]; + + m_command_buffer = (__bridge_retained void *)command_buffer; + m_command_encoder = (__bridge_retained void *)command_encoder; + m_active = true; + + set_viewport(m_viewport_offset, m_viewport_size); + + // if (clear_manual) + // { + // MTLDepthStencilDescriptor *depth_desc = [MTLDepthStencilDescriptor new]; + // depth_desc.depthCompareFunction = MTLCompareFunctionAlways; + // depth_desc.depthWriteEnabled = m_targets[0].get() != nullptr; + // id device = (__bridge id)metal_device(); + // id depth_state = [device newDepthStencilStateWithDescriptor:depth_desc]; + // [command_encoder setDepthStencilState:depth_state]; + + // if (!m_clear_shader) + // { + // m_clear_shader = new Shader(this, + + // "clear_shader", + + // /* Vertex shader */ + // R"(using namespace metal; + + // struct VertexOut { + // float4 position [[position]]; + // }; + + // vertex VertexOut vertex_main(const device float2 *position, + // constant float &clear_depth, + // uint id [[vertex_id]]) { + // VertexOut vert; + // vert.position = float4(position[id], clear_depth, 1.f); + // return vert; + // })", + + // /* Fragment shader */ + // R"(using namespace metal; + + // struct VertexOut { + // float4 position [[position]]; + // }; + + // fragment float4 fragment_main(VertexOut vert [[stage_in]], + // constant float4 &clear_color) { + // return clear_color; + // })"); + + // const float positions[] = {-1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, 1.f, -1.f, 1.f}; + + // m_clear_shader->set_buffer("position", VariableType::Float32, {6, 2}, positions); + // } + + // m_clear_shader->set_uniform("clear_color", m_clear_color.at(0)); + // m_clear_shader->set_uniform("clear_depth", m_clear_depth); + // m_clear_shader->begin(); + // m_clear_shader->draw_array(Shader::PrimitiveType::Triangle, 0, 6, false); + // m_clear_shader->end(); + // } + + set_depth_test(m_depth_test, m_depth_write); + set_cull_mode(m_cull_mode); +} + +void RenderPass::end() +{ +#if !defined(NDEBUG) + if (!m_active) + throw std::runtime_error("RenderPass::end(): render pass is not active!"); +#endif + id command_buffer = (__bridge_transfer id)m_command_buffer; + id command_encoder = (__bridge_transfer id)m_command_encoder; + [command_encoder endEncoding]; + [command_buffer commit]; + m_command_encoder = nullptr; + m_command_buffer = nullptr; + m_active = false; +} + +void RenderPass::resize(const int2 &size) +{ + m_framebuffer_size = size; + m_viewport_offset = int2(0, 0); + m_viewport_size = size; +} + +void RenderPass::set_clear_color(const float4 &color) +{ + m_clear_color = color; + + MTLRenderPassDescriptor *pass_descriptor = (__bridge MTLRenderPassDescriptor *)m_pass_descriptor; + + pass_descriptor.colorAttachments[0].clearColor = MTLClearColorMake(color.x, color.y, color.z, color.w); +} + +void RenderPass::set_clear_depth(float depth) +{ + m_clear_depth = depth; + + MTLRenderPassDescriptor *pass_descriptor = (__bridge MTLRenderPassDescriptor *)m_pass_descriptor; + pass_descriptor.depthAttachment.clearDepth = depth; +} + +void RenderPass::set_viewport(const int2 &offset, const int2 &size) +{ + m_viewport_offset = offset; + m_viewport_size = size; + if (m_active) + { + id command_encoder = (__bridge id)m_command_encoder; + [command_encoder + setViewport:(MTLViewport){(double)offset.x, (double)offset.y, (double)size.x, (double)size.y, 0.0, 1.0}]; + int2 scissor_size = max(min(offset + size, m_framebuffer_size) - offset, int2(0)); + int2 scissor_offset = max(min(offset, m_framebuffer_size), int2(0)); + [command_encoder setScissorRect:(MTLScissorRect){(NSUInteger)scissor_offset.x, (NSUInteger)scissor_offset.y, + (NSUInteger)scissor_size.x, (NSUInteger)scissor_size.y}]; + } +} + +void RenderPass::set_depth_test(DepthTest depth_test, bool depth_write) +{ + m_depth_test = depth_test; + m_depth_write = depth_write; + if (m_active) + { + MTLDepthStencilDescriptor *depth_desc = [MTLDepthStencilDescriptor new]; + + MTLCompareFunction func; + switch (depth_test) + { + case DepthTest::Never: func = MTLCompareFunctionNever; break; + case DepthTest::Less: func = MTLCompareFunctionLess; break; + case DepthTest::Equal: func = MTLCompareFunctionEqual; break; + case DepthTest::LessEqual: func = MTLCompareFunctionLessEqual; break; + case DepthTest::Greater: func = MTLCompareFunctionGreater; break; + case DepthTest::NotEqual: func = MTLCompareFunctionNotEqual; break; + case DepthTest::GreaterEqual: func = MTLCompareFunctionGreater; break; + case DepthTest::Always: func = MTLCompareFunctionAlways; break; + default: throw std::runtime_error("Shader::set_depth_test(): invalid depth test mode!"); + } + depth_desc.depthCompareFunction = func; + depth_desc.depthWriteEnabled = depth_write; + + auto &gMetalGlobals = HelloImGui::GetMetalGlobals(); + + id depth_state = + [gMetalGlobals.caMetalLayer.device newDepthStencilStateWithDescriptor:depth_desc]; + id command_encoder = (__bridge id)m_command_encoder; + [command_encoder setDepthStencilState:depth_state]; + } +} + +void RenderPass::set_cull_mode(CullMode cull_mode) +{ + m_cull_mode = cull_mode; + if (m_active) + { + MTLCullMode cull_mode_mtl; + switch (cull_mode) + { + case CullMode::Front: cull_mode_mtl = MTLCullModeFront; break; + case CullMode::Back: cull_mode_mtl = MTLCullModeBack; break; + case CullMode::Disabled: cull_mode_mtl = MTLCullModeNone; break; + default: throw std::runtime_error("Shader::set_cull_mode(): invalid cull mode!"); + } + id command_encoder = (__bridge id)m_command_encoder; + [command_encoder setCullMode:cull_mode_mtl]; + } +} + +#endif // defined(HELLOIMGUI_HAS_METAL) \ No newline at end of file diff --git a/src/shader_gl.cpp b/src/shader_gl.cpp index dacbf77..62f148d 100644 --- a/src/shader_gl.cpp +++ b/src/shader_gl.cpp @@ -2,6 +2,7 @@ #include "hello_imgui/hello_imgui.h" #include "hello_imgui/hello_imgui_include_opengl.h" // cross-platform way to include OpenGL headers +#include "opengl_check.h" #include "shader.h" #if !defined(GL_HALF_FLOAT) @@ -12,31 +13,6 @@ using std::string; -bool check_glerror(const char *cmd) -{ - GLenum err = glGetError(); - const char *msg = nullptr; - - switch (err) - { - case GL_NO_ERROR: return false; - case GL_INVALID_ENUM: msg = "invalid enumeration"; break; - case GL_INVALID_VALUE: msg = "invalid value"; break; - case GL_INVALID_OPERATION: msg = "invalid operation"; break; - case GL_INVALID_FRAMEBUFFER_OPERATION: msg = "invalid framebuffer operation"; break; - case GL_OUT_OF_MEMORY: msg = "out of memory"; break; -#ifndef __EMSCRIPTEN__ - case GL_STACK_UNDERFLOW: msg = "stack underflow"; break; - case GL_STACK_OVERFLOW: msg = "stack overflow"; break; -#endif - default: msg = "unknown error"; break; - } - - fmt::print(stderr, "OpenGL error ({}) during operation \"{}\"!\n", msg, cmd); - HelloImGui::Log(HelloImGui::LogLevel::Error, "OpenGL error (%s) during operation \"%s\"!\n", msg, cmd); - return true; -} - static GLuint compile_gl_shader(GLenum type, const std::string &name, const std::string &shader_string) { if (shader_string.empty()) @@ -82,10 +58,10 @@ static GLuint compile_gl_shader(GLenum type, const std::string &name, const std: return id; } -Shader::Shader(const std::string &name, const std::string &vs_filename, const std::string &fs_filename, - BlendMode blend_mode) : - m_name(name), - m_blend_mode(blend_mode), m_shader_handle(0) +Shader::Shader(RenderPass *render_pass, const std::string &name, const std::string &vs_filename, + const std::string &fs_filename, BlendMode blend_mode) : + m_render_pass(render_pass), + m_name(name), m_blend_mode(blend_mode), m_shader_handle(0) { string vertex_shader, fragment_shader; { @@ -650,4 +626,4 @@ void Shader::draw_array(PrimitiveType primitive_type, size_t offset, size_t coun } } -#endif +#endif // defined(HELLOIMGUI_HAS_OPENGL) diff --git a/src/shader_metal.mm b/src/shader_metal.mm new file mode 100644 index 0000000..0d4b45e --- /dev/null +++ b/src/shader_metal.mm @@ -0,0 +1,409 @@ +// The Metal version is still an untested work-in-progress +#if defined(HELLOIMGUI_HAS_METAL) + +#include "renderpass.h" +#include "shader.h" +#include + +#include "hello_imgui/hello_imgui.h" +#include "hello_imgui/internal/backend_impls/rendering_metal.h" + +#import +#import + +#define METAL_BUFFER_THRESHOLD 64 + +#include + +using std::string; + +id compile_metal_shader(id device, const std::string &name, const std::string &type_str, + const std::string &src) +{ + if (src.empty()) + return nil; + + id library = nil; + NSError *error = nil; + std::string activity; + if (src.size() > 4 && strncmp(src.data(), "MTLB", 4) == 0) + { + dispatch_data_t data = dispatch_data_create(src.data(), src.size(), NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT); + library = [device newLibraryWithData:data error:&error]; + activity = "load"; + } + else + { + NSString *str = [NSString stringWithUTF8String:src.c_str()]; + MTLCompileOptions *opts = [MTLCompileOptions new]; + library = [device newLibraryWithSource:str options:opts error:&error]; + activity = "compile"; + } + if (error) + { + const char *error_shader = [[error description] UTF8String]; + throw std::runtime_error(std::string("compile_metal_shader(): unable to ") + activity + " " + type_str + + " shader \"" + name + "\":\n\n" + error_shader); + } + + NSArray *function_names = [library functionNames]; + if ([function_names count] != 1) + throw std::runtime_error("compile_metal_shader(name=\"" + name + "\"): library must contain exactly 1 shader!"); + NSString *function_name = [function_names objectAtIndex:0]; + + id function = [library newFunctionWithName:function_name]; + if (!function) + throw std::runtime_error("compile_metal_shader(name=\"" + name + "\"): function not found!"); + + return function; +} + +Shader::Shader(RenderPass *render_pass, const std::string &name, const std::string &vs_filename, + const std::string &fs_filename, BlendMode blend_mode) : + m_render_pass(render_pass), + m_name(name), m_blend_mode(blend_mode), m_pipeline_state(nullptr) +{ + auto &gMetalGlobals = HelloImGui::GetMetalGlobals(); + id device = gMetalGlobals.caMetalLayer.device; + + string vertex_shader, fragment_shader; + { + auto load_shader_file = [](const string &filename) + { + auto shader_txt = HelloImGui::LoadAssetFileData(filename.c_str()); + if (shader_txt.data == nullptr) + throw std::runtime_error(fmt::format("Cannot load point shader from file \"{}\"", filename)); + + return shader_txt; + }; + auto vs = load_shader_file(vs_filename); + auto fs = load_shader_file(fs_filename); + + vertex_shader = string((char *)vs.data, vs.dataSize); + fragment_shader = string((char *)fs.data, fs.dataSize); + + HelloImGui::FreeAssetFileData(&vs); + HelloImGui::FreeAssetFileData(&fs); + } + + id vertex_func = compile_metal_shader(device, name, "vertex", vertex_shader), + fragment_func = compile_metal_shader(device, name, "fragment", fragment_shader); + + MTLRenderPipelineDescriptor *pipeline_desc = [MTLRenderPipelineDescriptor new]; + pipeline_desc.vertexFunction = vertex_func; + pipeline_desc.fragmentFunction = fragment_func; + + MTLRenderPipelineColorAttachmentDescriptor *att = pipeline_desc.colorAttachments[0]; + att.pixelFormat = gMetalGlobals.caMetalLayer.pixelFormat; + + if (blend_mode == BlendMode::AlphaBlend) + { + att.blendingEnabled = YES; + att.rgbBlendOperation = MTLBlendOperationAdd; + att.alphaBlendOperation = MTLBlendOperationAdd; + att.sourceRGBBlendFactor = MTLBlendFactorSourceAlpha; + att.sourceAlphaBlendFactor = MTLBlendFactorSourceAlpha; + att.destinationRGBBlendFactor = MTLBlendFactorOneMinusSourceAlpha; + att.destinationAlphaBlendFactor = MTLBlendFactorOneMinusSourceAlpha; + } + + // pipeline_desc.sampleCount = 1; + + NSError *error = nil; + MTLRenderPipelineReflection *reflection = nil; + id pipeline_state = + [device newRenderPipelineStateWithDescriptor:pipeline_desc + options:MTLPipelineOptionArgumentInfo + reflection:&reflection + error:&error]; + if (error) + { + const char *error_pipeline = [[error description] UTF8String]; + throw std::runtime_error("compile_metal_pipeline(): unable to create render pipeline state!\n\n" + + std::string(error_pipeline)); + } + + m_pipeline_state = (__bridge_retained void *)pipeline_state; + + for (MTLArgument *arg in [reflection vertexArguments]) + { + std::string name = [arg.name UTF8String]; + if (m_buffers.find(name) != m_buffers.end()) + throw std::runtime_error("Shader::Shader(): \"" + name + "\": duplicate argument name in shader code!"); + else if (name == "indices") + throw std::runtime_error("Shader::Shader(): argument name 'indices' is reserved!"); + + Buffer &buf = m_buffers[name]; + buf.index = arg.index; + if (arg.type == MTLArgumentTypeBuffer) + buf.type = VertexBuffer; + else if (arg.type == MTLArgumentTypeTexture) + buf.type = VertexTexture; + else if (arg.type == MTLArgumentTypeSampler) + buf.type = VertexSampler; + else + throw std::runtime_error("Shader::Shader(): \"" + name + "\": unsupported argument type!"); + } + + for (MTLArgument *arg in [reflection fragmentArguments]) + { + std::string name = [arg.name UTF8String]; + if (m_buffers.find(name) != m_buffers.end()) + throw std::runtime_error("Shader::Shader(): \"" + name + "\": duplicate argument name in shader code!"); + else if (name == "indices") + throw std::runtime_error("Shader::Shader(): argument name 'indices' is reserved!"); + + Buffer &buf = m_buffers[name]; + buf.index = arg.index; + if (arg.type == MTLArgumentTypeBuffer) + buf.type = FragmentBuffer; + else if (arg.type == MTLArgumentTypeTexture) + buf.type = FragmentTexture; + else if (arg.type == MTLArgumentTypeSampler) + buf.type = FragmentSampler; + else + throw std::runtime_error("Shader::Shader(): \"" + name + "\": unsupported argument type!"); + } + + Buffer &buf = m_buffers["indices"]; + buf.index = -1; + buf.type = IndexBuffer; +} + +Shader::~Shader() +{ + for (const auto &[key, buf] : m_buffers) + { + if (!buf.buffer) + continue; + if (buf.type == VertexBuffer || buf.type == FragmentBuffer || buf.type == IndexBuffer) + { + if (buf.size <= METAL_BUFFER_THRESHOLD) + delete[] (uint8_t *)buf.buffer; + else + (void)(__bridge_transfer id)buf.buffer; + } + else if (buf.type == VertexTexture || buf.type == FragmentTexture) + { + (void)(__bridge_transfer id)buf.buffer; + } + else if (buf.type == VertexSampler || buf.type == FragmentSampler) + { + (void)(__bridge_transfer id)buf.buffer; + } + else + { + std::cerr << "Shader::~Shader(): unknown buffer type!" << std::endl; + } + } + (void)(__bridge_transfer id)m_pipeline_state; +} + +void Shader::set_buffer(const std::string &name, VariableType dtype, size_t ndim, const size_t *shape, const void *data) +{ + auto &gMetalGlobals = HelloImGui::GetMetalGlobals(); + + auto it = m_buffers.find(name); + if (it == m_buffers.end()) + throw std::runtime_error("Shader::set_buffer(): could not find argument named \"" + name + "\""); + Buffer &buf = m_buffers[name]; + if (!(buf.type == VertexBuffer || buf.type == FragmentBuffer || buf.type == IndexBuffer)) + throw std::runtime_error("Shader::set_buffer(): argument named \"" + name + "\" is not a buffer!"); + + for (size_t i = 0; i < 3; ++i) + buf.shape[i] = i < ndim ? shape[i] : 1; + + size_t size = type_size(dtype) * buf.shape[0] * buf.shape[1] * buf.shape[2]; + if (buf.buffer && buf.size != size) + { + if (buf.size <= METAL_BUFFER_THRESHOLD) + delete[] (uint8_t *)buf.buffer; + else + (void)(__bridge_transfer id)buf.buffer; + buf.buffer = nullptr; + } + + if (size <= METAL_BUFFER_THRESHOLD && name != "indices") + { + if (!buf.buffer) + buf.buffer = new uint8_t[size]; + memcpy(buf.buffer, data, size); + } + else + { + /* Procedure recommended by Apple: create a temporary shared buffer and + blit into a private GPU-only buffer */ + id device = gMetalGlobals.caMetalLayer.device; + id mtl_buffer; + + if (buf.buffer) + mtl_buffer = (__bridge_transfer id)buf.buffer; + else + mtl_buffer = [device newBufferWithLength:size options:MTLResourceStorageModePrivate]; + + id temp_buffer = [device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared]; + + id command_buffer = [gMetalGlobals.mtlCommandQueue commandBuffer]; + id blit_encoder = [command_buffer blitCommandEncoder]; + + [blit_encoder copyFromBuffer:temp_buffer sourceOffset:0 toBuffer:mtl_buffer destinationOffset:0 size:size]; + + [blit_encoder endEncoding]; + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + + buf.buffer = (__bridge_retained void *)mtl_buffer; + } + + buf.dtype = dtype; + buf.ndim = ndim; + buf.size = size; +} + +// void Shader::set_texture(const std::string &name, Texture *texture) +// { +// auto it = m_buffers.find(name); +// if (it == m_buffers.end()) +// throw std::runtime_error("Shader::set_texture(): could not find argument named \"" + name + "\""); +// Buffer &buf = m_buffers[name]; +// if (!(buf.type == VertexTexture || buf.type == FragmentTexture)) +// throw std::runtime_error("Shader::set_texture(): argument named \"" + name + "\" is not a texture!"); + +// if (buf.buffer) +// { +// (void)(__bridge_transfer id)buf.buffer; +// buf.buffer = nullptr; +// } + +// buf.buffer = (__bridge_retained void *)((__bridge id)texture->texture_handle()); + +// std::string sampler_name; +// if (name.length() > 8 && name.compare(name.length() - 8, 8, "_texture") == 0) +// sampler_name = name.substr(0, name.length() - 8) + "_sampler"; +// else +// sampler_name = name + "_sampler"; + +// if (m_buffers.find(sampler_name) != m_buffers.end()) +// { +// /* Also set the sampler state */ +// Buffer &buf2 = m_buffers[sampler_name]; + +// if (buf2.buffer) +// { +// (void)(__bridge_transfer id)buf2.buffer; +// buf2.buffer = nullptr; +// } + +// buf2.buffer = (__bridge_retained void *)((__bridge id)texture->sampler_state_handle()); +// } +// } + +void Shader::begin() +{ + id pipeline_state = (__bridge id)m_pipeline_state; + id command_enc = (__bridge id)m_render_pass->command_encoder(); + + [command_enc setRenderPipelineState:pipeline_state]; + + for (const auto &[key, buf] : m_buffers) + { + bool indices = buf.type == IndexBuffer; + if (!buf.buffer) + { + if (!indices) + fprintf(stderr, + "Shader::begin(): shader \"%s\" has an unbound " + "argument \"%s\"!\n", + m_name.c_str(), key.c_str()); + continue; + } + + switch (buf.type) + { + case VertexTexture: + { + id texture = (__bridge id)buf.buffer; + [command_enc setVertexTexture:texture atIndex:buf.index]; + } + break; + + case FragmentTexture: + { + id texture = (__bridge id)buf.buffer; + [command_enc setFragmentTexture:texture atIndex:buf.index]; + } + break; + + case VertexSampler: + { + id state = (__bridge id)buf.buffer; + [command_enc setVertexSamplerState:state atIndex:buf.index]; + } + break; + + case FragmentSampler: + { + id state = (__bridge id)buf.buffer; + [command_enc setFragmentSamplerState:state atIndex:buf.index]; + } + break; + + default: + if (buf.size <= METAL_BUFFER_THRESHOLD && !indices) + { + if (buf.type == VertexBuffer) + [command_enc setVertexBytes:buf.buffer length:buf.size atIndex:buf.index]; + else if (buf.type == FragmentBuffer) + [command_enc setFragmentBytes:buf.buffer length:buf.size atIndex:buf.index]; + else + throw std::runtime_error("Shader::begin(): unexpected buffer type!"); + } + else + { + id buffer = (__bridge id)buf.buffer; + if (buf.type == VertexBuffer) + [command_enc setVertexBuffer:buffer offset:0 atIndex:buf.index]; + else if (buf.type == FragmentBuffer) + [command_enc setFragmentBuffer:buffer offset:0 atIndex:buf.index]; + } + break; + } + } +} + +void Shader::end() +{ + /* No-op */ +} + +void Shader::draw_array(PrimitiveType primitive_type, size_t offset, size_t count, bool indexed, size_t instances) +{ + MTLPrimitiveType primitive_type_mtl; + switch (primitive_type) + { + case PrimitiveType::Point: primitive_type_mtl = MTLPrimitiveTypePoint; break; + case PrimitiveType::Line: primitive_type_mtl = MTLPrimitiveTypeLine; break; + case PrimitiveType::LineStrip: primitive_type_mtl = MTLPrimitiveTypeLineStrip; break; + case PrimitiveType::Triangle: primitive_type_mtl = MTLPrimitiveTypeTriangle; break; + case PrimitiveType::TriangleStrip: primitive_type_mtl = MTLPrimitiveTypeTriangleStrip; break; + default: throw std::runtime_error("Shader::draw_array(): invalid primitive type!"); + } + + id command_enc = (__bridge id)m_render_pass->command_encoder(); + + if (!indexed) + { + [command_enc drawPrimitives:primitive_type_mtl vertexStart:offset vertexCount:count]; + } + else + { + id index_buffer = (__bridge id)m_buffers["indices"].buffer; + [command_enc drawIndexedPrimitives:primitive_type_mtl + indexCount:count + indexType:MTLIndexTypeUInt32 + indexBuffer:index_buffer + indexBufferOffset:offset * 4]; + } +} + +#endif // defined(HELLOIMGUI_HAS_METAL)