From 10a5f40aa5c51bd6822742b707acd2390ce2cc23 Mon Sep 17 00:00:00 2001 From: Ryan McCauley <32387857+ryanmccauley211@users.noreply.github.com> Date: Fri, 18 Mar 2022 06:04:46 +0000 Subject: [PATCH] Reuse shader wrappers and shader data (#2062) * reuse shader wrappers and shader data arrays * Update uniforms Co-authored-by: Laith Bahodi <70682032+hydrobeam@users.noreply.github.com> Co-authored-by: Darylgolden --- .../opengl/opengl_vectorized_mobject.py | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index 47ea18b82e..46dffe11b2 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -12,6 +12,7 @@ from manim import config from manim.constants import * from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint +from manim.renderer.shader_wrapper import ShaderWrapper from manim.utils.bezier import ( bezier, get_quadratic_approximation_of_cubic, @@ -132,6 +133,12 @@ def __init__( if stroke_color: self.stroke_color = Color(stroke_color) + self.fill_data = None + self.stroke_data = None + self.fill_shader_wrapper = None + self.stroke_shader_wrapper = None + self.init_shader_data() + def get_group_class(self): return OpenGLVGroup @@ -1489,8 +1496,6 @@ def flip(self, *args, **kwargs): # For shaders def init_shader_data(self): - from ...renderer.shader_wrapper import ShaderWrapper - self.fill_data = np.zeros(0, dtype=self.fill_dtype) self.stroke_data = np.zeros(0, dtype=self.stroke_dtype) self.fill_shader_wrapper = ShaderWrapper( @@ -1511,30 +1516,23 @@ def refresh_shader_wrapper_id(self): return self def get_fill_shader_wrapper(self): - from ...renderer.shader_wrapper import ShaderWrapper + self.update_fill_shader_wrapper() + return self.fill_shader_wrapper - return ShaderWrapper( - vert_data=self.get_fill_shader_data(), - vert_indices=self.get_triangulation(), - shader_folder=self.fill_shader_folder, - render_primitive=moderngl.TRIANGLES, - uniforms=self.get_fill_uniforms(), - depth_test=self.depth_test, - ) + def update_fill_shader_wrapper(self): + self.fill_shader_wrapper.vert_data = self.get_fill_shader_data() + self.fill_shader_wrapper.vert_indices = self.get_triangulation() + self.fill_shader_wrapper.uniforms = self.get_fill_uniforms() def get_stroke_shader_wrapper(self): - from ...renderer.shader_wrapper import ShaderWrapper + self.update_stroke_shader_wrapper() + return self.stroke_shader_wrapper - return ShaderWrapper( - vert_data=self.get_stroke_shader_data(), - shader_folder=self.stroke_shader_folder, - render_primitive=moderngl.TRIANGLES, - uniforms=self.get_stroke_uniforms(), - depth_test=self.depth_test, - ) + def update_stroke_shader_wrapper(self): + self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data() + self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms() def get_shader_wrapper_list(self): - # Build up data lists fill_shader_wrappers = [] stroke_shader_wrappers = [] @@ -1580,31 +1578,34 @@ def get_fill_uniforms(self): def get_stroke_shader_data(self): points = self.points - stroke_data = np.zeros(len(points), dtype=OpenGLVMobject.stroke_dtype) + if len(self.stroke_data) != len(points): + self.stroke_data = np.zeros(len(points), dtype=OpenGLVMobject.stroke_dtype) - nppc = self.n_points_per_curve - stroke_data["point"] = points - stroke_data["prev_point"][:nppc] = points[-nppc:] - stroke_data["prev_point"][nppc:] = points[:-nppc] - stroke_data["next_point"][:-nppc] = points[nppc:] - stroke_data["next_point"][-nppc:] = points[:nppc] + if "points" not in self.locked_data_keys: + nppc = self.n_points_per_curve + self.stroke_data["point"] = points + self.stroke_data["prev_point"][:nppc] = points[-nppc:] + self.stroke_data["prev_point"][nppc:] = points[:-nppc] + self.stroke_data["next_point"][:-nppc] = points[nppc:] + self.stroke_data["next_point"][-nppc:] = points[:nppc] - self.read_data_to_shader(stroke_data, "color", "stroke_rgba") - self.read_data_to_shader(stroke_data, "stroke_width", "stroke_width") - self.read_data_to_shader(stroke_data, "unit_normal", "unit_normal") + self.read_data_to_shader(self.stroke_data, "color", "stroke_rgba") + self.read_data_to_shader(self.stroke_data, "stroke_width", "stroke_width") + self.read_data_to_shader(self.stroke_data, "unit_normal", "unit_normal") - return stroke_data + return self.stroke_data def get_fill_shader_data(self): points = self.points - fill_data = np.zeros(len(points), dtype=OpenGLVMobject.fill_dtype) - fill_data["vert_index"][:, 0] = range(len(points)) + if len(self.fill_data) != len(points): + self.fill_data = np.zeros(len(points), dtype=OpenGLVMobject.fill_dtype) + self.fill_data["vert_index"][:, 0] = range(len(points)) - self.read_data_to_shader(fill_data, "point", "points") - self.read_data_to_shader(fill_data, "color", "fill_rgba") - self.read_data_to_shader(fill_data, "unit_normal", "unit_normal") + self.read_data_to_shader(self.fill_data, "point", "points") + self.read_data_to_shader(self.fill_data, "color", "fill_rgba") + self.read_data_to_shader(self.fill_data, "unit_normal", "unit_normal") - return fill_data + return self.fill_data def refresh_shader_data(self): self.get_fill_shader_data()