diff --git a/cadquery/cqgi.py b/cadquery/cqgi.py index 87f2ce917..10b9ff188 100644 --- a/cadquery/cqgi.py +++ b/cadquery/cqgi.py @@ -2,6 +2,7 @@ The CadQuery Gateway Interface. Provides classes and tools for executing CadQuery scripts """ +import sys import ast import traceback import time @@ -65,7 +66,9 @@ def _find_vars(self): assignment_finder = ConstantAssignmentFinder(self.metadata) for node in self.ast_tree.body: - if isinstance(node, ast.Assign): + if isinstance(node, ast.AnnAssign): + assignment_finder.visit_AnnAssign(node) + elif isinstance(node, ast.Assign): assignment_finder.visit_Assign(node) def _find_descriptions(self): @@ -564,6 +567,31 @@ def handle_assignment(self, var_name, value_node): print("Unable to handle assignment for variable '%s'" % var_name) pass + def handle_ann_assignment(self, var_name, annotation_id, value_node): + try: + if annotation_id == "int" or annotation_id == "float": + self.cqModel.add_script_parameter( + InputParameter.create( + value_node, var_name, NumberParameterType, value_node.n + ) + ) + elif annotation_id == "str": + self.cqModel.add_script_parameter( + InputParameter.create( + value_node, var_name, StringParameterType, value_node.s + ) + ) + elif annotation_id == "bool": + self.cqModel.add_script_parameter( + InputParameter.create( + value_node, var_name, BooleanParameterType, value_node.s + ) + ) + + except: + print("Unable to handle annotated assignment for variable '%s'" % var_name) + pass + def visit_Assign(self, node): try: @@ -595,3 +623,21 @@ def visit_Assign(self, node): print("Unable to handle assignment for node '%s'" % ast.dump(left_side)) return node + + def visit_AnnAssign(self, node): + left_side = node.target + + # do not handle Attribute or Subscript + if not isinstance(left_side, ast.Name): + return + + annTypes = ["int", "float", "str", "bool"] + + if ( + hasattr(node, "annotation") + and isinstance(node.annotation, ast.Name) + and node.annotation.id in annTypes + ): + self.handle_ann_assignment(left_side.id, node.annotation.id, node.value) + + return node diff --git a/partcad.yaml b/partcad.yaml index 87d9b95b1..5f05a841c 100644 --- a/partcad.yaml +++ b/partcad.yaml @@ -1,3 +1,6 @@ +# This is a PartCAD package. +# See https://partcad.org/ and https://github.com/openvmp/partcad for more information. + name: /pub/examples/script/cadquery desc: CadQuery examples url: https://github.com/CadQuery/cadquery diff --git a/tests/test_cqgi.py b/tests/test_cqgi.py index 33f371ad5..c3ec3006d 100644 --- a/tests/test_cqgi.py +++ b/tests/test_cqgi.py @@ -15,7 +15,8 @@ TESTSCRIPT = textwrap.dedent( """ height=2.0 - width=3.0 + width:float=3.0 + transparent=False (a,b) = (1.0,1.0) o = (2, 2, 0) foo="bar" @@ -29,9 +30,10 @@ """ height=2.0 width=3.0 + transparent:bool=False (a,b) = (1.0,1.0) o = (2, 2, 0) - foo="bar" + foo:str="bar" debug(foo, { "color": 'yellow' } ) result = "%s|%s|%s|%s|%s" % ( str(height) , str(width) , foo , str(a) , str(o) ) show_object(result) @@ -45,7 +47,8 @@ def test_parser(self): model = cqgi.CQModel(TESTSCRIPT) metadata = model.metadata self.assertEqual( - set(metadata.parameters.keys()), {"height", "width", "a", "b", "foo", "o"} + set(metadata.parameters.keys()), + {"height", "width", "transparent", "a", "b", "foo", "o"}, ) def test_build_with_debug(self): @@ -135,7 +138,7 @@ def test_that_two_results_are_returned(self): """ h = 1 show_object(h) - h = 2 + h: int = 2 show_object(h) """ ) @@ -166,6 +169,16 @@ def test_that_assigning_string_to_number_fails(self): result = cqgi.parse(script).build({"h": "a string"}) self.assertTrue(isinstance(result.exception, cqgi.InvalidParameterError)) + def test_that_assigning_string_to_annotated_list_fails(self): + script = textwrap.dedent( + """ + h: list[float] = [20.0] + show_object(h) + """ + ) + result = cqgi.parse(script).build({"h": "a string"}) + self.assertTrue(isinstance(result.exception, cqgi.InvalidParameterError)) + def test_that_assigning_unknown_var_fails(self): script = textwrap.dedent( """ @@ -222,7 +235,10 @@ def test_that_only_top_level_vars_are_detected(self): def do_stuff(): x = 1 - y = 2 + y: int = 2 + class Foo: + z = 3 + zz: int = 4 show_object( "result" ) """