Skip to content

Commit

Permalink
Reuse shader wrappers and shader data (ManimCommunity#2062)
Browse files Browse the repository at this point in the history
* reuse shader wrappers and shader data arrays

* Update uniforms

Co-authored-by: Laith Bahodi <[email protected]>
Co-authored-by: Darylgolden <[email protected]>
  • Loading branch information
3 people authored Mar 18, 2022
1 parent 91a1f55 commit 10a5f40
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions manim/mobject/opengl/opengl_vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from manim import config
from manim.constants import *
from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint
from manim.renderer.shader_wrapper import ShaderWrapper
from manim.utils.bezier import (
bezier,
get_quadratic_approximation_of_cubic,
Expand Down Expand Up @@ -132,6 +133,12 @@ def __init__(
if stroke_color:
self.stroke_color = Color(stroke_color)

self.fill_data = None
self.stroke_data = None
self.fill_shader_wrapper = None
self.stroke_shader_wrapper = None
self.init_shader_data()

def get_group_class(self):
return OpenGLVGroup

Expand Down Expand Up @@ -1489,8 +1496,6 @@ def flip(self, *args, **kwargs):

# For shaders
def init_shader_data(self):
from ...renderer.shader_wrapper import ShaderWrapper

self.fill_data = np.zeros(0, dtype=self.fill_dtype)
self.stroke_data = np.zeros(0, dtype=self.stroke_dtype)
self.fill_shader_wrapper = ShaderWrapper(
Expand All @@ -1511,30 +1516,23 @@ def refresh_shader_wrapper_id(self):
return self

def get_fill_shader_wrapper(self):
from ...renderer.shader_wrapper import ShaderWrapper
self.update_fill_shader_wrapper()
return self.fill_shader_wrapper

return ShaderWrapper(
vert_data=self.get_fill_shader_data(),
vert_indices=self.get_triangulation(),
shader_folder=self.fill_shader_folder,
render_primitive=moderngl.TRIANGLES,
uniforms=self.get_fill_uniforms(),
depth_test=self.depth_test,
)
def update_fill_shader_wrapper(self):
self.fill_shader_wrapper.vert_data = self.get_fill_shader_data()
self.fill_shader_wrapper.vert_indices = self.get_triangulation()
self.fill_shader_wrapper.uniforms = self.get_fill_uniforms()

def get_stroke_shader_wrapper(self):
from ...renderer.shader_wrapper import ShaderWrapper
self.update_stroke_shader_wrapper()
return self.stroke_shader_wrapper

return ShaderWrapper(
vert_data=self.get_stroke_shader_data(),
shader_folder=self.stroke_shader_folder,
render_primitive=moderngl.TRIANGLES,
uniforms=self.get_stroke_uniforms(),
depth_test=self.depth_test,
)
def update_stroke_shader_wrapper(self):
self.stroke_shader_wrapper.vert_data = self.get_stroke_shader_data()
self.stroke_shader_wrapper.uniforms = self.get_stroke_uniforms()

def get_shader_wrapper_list(self):

# Build up data lists
fill_shader_wrappers = []
stroke_shader_wrappers = []
Expand Down Expand Up @@ -1580,31 +1578,34 @@ def get_fill_uniforms(self):

def get_stroke_shader_data(self):
points = self.points
stroke_data = np.zeros(len(points), dtype=OpenGLVMobject.stroke_dtype)
if len(self.stroke_data) != len(points):
self.stroke_data = np.zeros(len(points), dtype=OpenGLVMobject.stroke_dtype)

nppc = self.n_points_per_curve
stroke_data["point"] = points
stroke_data["prev_point"][:nppc] = points[-nppc:]
stroke_data["prev_point"][nppc:] = points[:-nppc]
stroke_data["next_point"][:-nppc] = points[nppc:]
stroke_data["next_point"][-nppc:] = points[:nppc]
if "points" not in self.locked_data_keys:
nppc = self.n_points_per_curve
self.stroke_data["point"] = points
self.stroke_data["prev_point"][:nppc] = points[-nppc:]
self.stroke_data["prev_point"][nppc:] = points[:-nppc]
self.stroke_data["next_point"][:-nppc] = points[nppc:]
self.stroke_data["next_point"][-nppc:] = points[:nppc]

self.read_data_to_shader(stroke_data, "color", "stroke_rgba")
self.read_data_to_shader(stroke_data, "stroke_width", "stroke_width")
self.read_data_to_shader(stroke_data, "unit_normal", "unit_normal")
self.read_data_to_shader(self.stroke_data, "color", "stroke_rgba")
self.read_data_to_shader(self.stroke_data, "stroke_width", "stroke_width")
self.read_data_to_shader(self.stroke_data, "unit_normal", "unit_normal")

return stroke_data
return self.stroke_data

def get_fill_shader_data(self):
points = self.points
fill_data = np.zeros(len(points), dtype=OpenGLVMobject.fill_dtype)
fill_data["vert_index"][:, 0] = range(len(points))
if len(self.fill_data) != len(points):
self.fill_data = np.zeros(len(points), dtype=OpenGLVMobject.fill_dtype)
self.fill_data["vert_index"][:, 0] = range(len(points))

self.read_data_to_shader(fill_data, "point", "points")
self.read_data_to_shader(fill_data, "color", "fill_rgba")
self.read_data_to_shader(fill_data, "unit_normal", "unit_normal")
self.read_data_to_shader(self.fill_data, "point", "points")
self.read_data_to_shader(self.fill_data, "color", "fill_rgba")
self.read_data_to_shader(self.fill_data, "unit_normal", "unit_normal")

return fill_data
return self.fill_data

def refresh_shader_data(self):
self.get_fill_shader_data()
Expand Down

0 comments on commit 10a5f40

Please sign in to comment.