diff --git a/src/autoqasm/api.py b/src/autoqasm/api.py index d4a567c..ce5cb33 100644 --- a/src/autoqasm/api.py +++ b/src/autoqasm/api.py @@ -339,13 +339,17 @@ def _convert_subroutine( for i, param in enumerate(inspect.signature(f).parameters.values()) if param.annotation == aq_types.QubitIdentifierType } + + # Map args and kwargs to function signature + bound_args = inspect.signature(oqpy_sub).bind(*[oqpy_program, *args], **kwargs) + args = [ (aq_instructions.qubits._qubit(arg) if i in quantum_indices else arg) - for i, arg in enumerate(args) + for i, arg in enumerate(bound_args.args[1:]) ] # Process the program - subroutine_function_call = oqpy_sub(oqpy_program, *args, **kwargs) + subroutine_function_call = oqpy_sub(oqpy_program, *args) program_conversion_context.register_args(args) # Mark that we are finished processing this function @@ -357,8 +361,13 @@ def _convert_subroutine( _wrap_for_oqpy_subroutine(_dummy_function(f), options) ) + # Map args and kwargs to function signature + bound_args = inspect.signature(oqpy_sub).bind(*((oqpy_program, *args)), **kwargs) + + args = bound_args.args[1:] + # Process the program - subroutine_function_call = oqpy_sub(oqpy_program, *args, **kwargs) + subroutine_function_call = oqpy_sub(oqpy_program, *args) # Add the subroutine invocation to the program ret_type = subroutine_function_call.subroutine_decl.return_type diff --git a/test/unit_tests/autoqasm/test_api.py b/test/unit_tests/autoqasm/test_api.py index aa932d2..ee7ec1b 100644 --- a/test/unit_tests/autoqasm/test_api.py +++ b/test/unit_tests/autoqasm/test_api.py @@ -1240,3 +1240,29 @@ def main(): h __qubits__[2]; h __qubits__[3];""" assert main.build().to_ir() == expected_ir + + +def test_subroutine_call_with_kwargs(): + """Test that subroutine call works with keyword arguments""" + + @aq.subroutine + def test(a: int, b: int) -> None: + aq.instructions.h(a) + aq.instructions.h(b) + + @aq.main(num_qubits=2) + def main(): + test(0, b=1) # Test with one keyword argument + test(a=2, b=3) # Test with keyword argument + test(b=5, a=4) # Test with keyword argument in any order + + expected = """OPENQASM 3.0; +def test(int[32] a, int[32] b) { + h __qubits__[a]; + h __qubits__[b]; +} +qubit[2] __qubits__; +test(0, 1); +test(2, 3); +test(4, 5);""" + assert main.build().to_ir() == expected