From 34cb847df0487e4bba278e50e06b801f326d1a4c Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 10:01:13 -0800 Subject: [PATCH 1/7] Refactor bundle to handle error during async call. --- opto/trace/bundle.py | 239 +++++++++++++++++++++---------- tests/unit_tests/test_asyncio.py | 52 +++++++ 2 files changed, 216 insertions(+), 75 deletions(-) diff --git a/opto/trace/bundle.py b/opto/trace/bundle.py index cad56496..43902cb7 100644 --- a/opto/trace/bundle.py +++ b/opto/trace/bundle.py @@ -67,6 +67,7 @@ class FunModule(Module): description (str): a description of the operator; see the MessageNode for syntax. _process_inputs (bool): if True, the input is extracted from the container of nodes; if False, the inputs are passed directly to the underlying function. trainable (bool): if True, the block of code is treated as a variable in the optimization + traceable_code (bool): if True, the operator's code is traceable by Trace catch_execution_error (bool): if True, the operator catches the exception raised during the execution of the operator and return ExecutionError. allow_external_dependencies (bool): if True, the operator allows external dependencies to be used in the operator. Namely, not all nodes used to create the output are in the inputs. In this case, the extra dependencies are stored in the info dictionary with key 'extra_dependencies'. overwrite_python_recursion (bool): if True, the operator allows the python recursion behavior of calling the decorated function to be overwritten. When true, applying bundle on a recursive function, would be the same as calling the function directly. When False, the Python's oriignal recursion behavior of decorated functions is preserved. @@ -149,7 +150,9 @@ def trainable(self): @property def fun(self, *args, **kwargs): - # This is called within trace_nodes context manager. + """ Return a callable function. Return the decorated function if the parameter is None. Otherwise, return the function defined by the parameter. When exception happens during the defining the function with the parameter, raise a trace.ExecutionError. """ + + # This function should be later called within trace_nodes context manager. if self.parameter is None: return self._fun else: @@ -198,17 +201,21 @@ def fun(self, *args, **kwargs): def name(self): return get_op_name(self.description) - def forward(self, *args, **kwargs): - """ - All nodes used in the operator fun are added to used_nodes during - the execution. If the output is not a Node, we wrap it as a - MessageNode, whose inputs are nodes in used_nodes. - """ - + def _wrap_inputs(self, fun, args, kwargs): + """ Wrap the inputs to a function as nodes when they're not. - fun = self.fun # define the function only once - self.info['fun'] = fun + Args: + fun (callable): the function to be wrapped. + args (list): the positional arguments of the function. + kwargs (dict): the keyword arguments of the function. + Returns: + inputs (dict): the inputs dict to construct the MessageNode (constructed by args and kwargs). + args (list): the wrapped positional arguments. + kwargs (dict): the wrapped keyword arguments. + _args (list): the original positional arguments (including the default values). + _kwargs (dict): the original keyword arguments (including the default values). + """ ## Wrap the inputs as nodes # add default into kwargs @@ -224,7 +231,6 @@ def forward(self, *args, **kwargs): kwargs[k] = v # convert args and kwargs to nodes, except for FunModule _args, _kwargs = args, kwargs # back up - args = [node(a, name=fullargspec.args[i] if not isinstance(a, Node) else None) if not isinstance(a, FunModule) else a for i, a in enumerate(args)] kwargs = {k: node(v, name=k if not isinstance(v, Node) else None) if not isinstance(v, FunModule) else v for k, v in kwargs.items()} @@ -234,7 +240,7 @@ def forward(self, *args, **kwargs): _, varargs, varkw, _, _, _, _ = inspect.getfullargspec(fun) - # bind the node version of args and kwargs + # bind the node version of args and kwargs ba = inspect.signature(fun).bind(*args, **kwargs) spec = ba.arguments @@ -253,7 +259,10 @@ def extract_param(n): inputs[k] = extract_param(v) assert all([isinstance(n, Node) for n in inputs.values()]), "All values in inputs must be nodes." + return inputs, args, kwargs, _args, _kwargs + def _get_tracer(self): + """ Get a tracer to overwrite the python recursion behavior of calling the decorated function. """ # Define a tracer to deal with recursive function calls _bundled_func = None @@ -282,67 +291,111 @@ def tracer(frame, event, arg = None): if frame.f_code.co_name in frame.f_globals: frame.f_globals[frame.f_code.co_name] = _bundled_func return tracer + return tracer + + def _construct_error_comment(self, e): + """ Construct the error comment on the source code and traceback. """ + self.info['traceback'] = traceback.format_exc() # This is saved for user debugging + # Construct message to optimizer + error_class = e.__class__.__name__ + detail = e.args[0] + cl, exc, tb = sys.exc_info() + assert tb is not None # we're in the except block, so tb should not be None + n_fun_calls = len(traceback.extract_tb(tb)) + # Step through the traceback stack + comments = [] + base_message = f'({error_class}) {detail}.' + for i, (f, ln) in enumerate(traceback.walk_tb(tb)): + if i>0: # ignore the first one, since that is the try statement above + error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.' + + if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here. + comment = self.generate_comment(self.parameter._data, error_message, ln, 1) + comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1) + else: + try: + f_source, f_source_ln = self.get_source(f, bug_mode=True) + except OSError: # OSError: could not get source code + # we reach the compiled C level, so the previous level is actually the bottom + comments[-1] = comment_backup # replace the previous comment + break # exit the loop + comment = self.generate_comment(f_source, error_message, ln, f_source_ln) + comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln) + comments.append(comment) + commented_code = '\n\n'.join(comments) + self.info['error_comment'] = commented_code + f"\n{base_message}" + output = e + return output + + def sync_call_fun(self, fun, *_args, **_kwargs): + """ Call the operator fun and return the output. Catch the exception if catch_execution_error is True. """ + oldtracer = sys.gettrace() + if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior + # Running a tracer would slow down the execution, so we only do this when necessary. + sys.settrace(self._get_tracer()) + + if self.catch_execution_error: + try: + output = fun(*_args, **_kwargs) + except Exception as e: + output = self._construct_error_comment(e) + else: + output = fun(*_args, **_kwargs) + sys.settrace(oldtracer) + return output - ## Execute self.fun - with trace_nodes() as used_nodes: - # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + async def async_call_fun(self, fun, *_args, **_kwargs): + oldtracer = sys.gettrace() + if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior + # Running a tracer would slow down the execution, so we only do this when necessary. + sys.settrace(self._get_tracer()) + + if self.catch_execution_error: + try: + output = await fun(*_args, **_kwargs) + except Exception as e: + output = self._construct_error_comment(e) + else: + output = await fun(*_args, **_kwargs) - # args, kwargs are nodes - # _args, _kwargs are the original inputs (_kwargs inlcudes the defaults) + sys.settrace(oldtracer) + return output - # Construct the inputs to call self.fun - if self._process_inputs: - if self.traceable_code: - _args, _kwargs = detach_inputs(args), detach_inputs(kwargs) - else: - _args, _kwargs = to_data(args), to_data(kwargs) - # else the inputs are passed directly to the function - # so we don't change _args and _kwargs - - oldtracer = sys.gettrace() - if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior - sys.settrace(tracer) - # add an except here - if self.catch_execution_error: - try: - output = fun(*_args, **_kwargs) - except Exception as e: - # Construct the error comment on the source code and traceback - self.info['traceback'] = traceback.format_exc() # This is saved for user debugging - # Construct message to optimizer - error_class = e.__class__.__name__ - detail = e.args[0] - cl, exc, tb = sys.exc_info() - n_fun_calls = len(traceback.extract_tb(tb)) - # Step through the traceback stack - comments = [] - base_message = f'({error_class}) {detail}.' - for i, (f, ln) in enumerate(traceback.walk_tb(tb)): - if i>0: # ignore the first one, since that is the try statement above - error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.' - - if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here. - comment = self.generate_comment(self.parameter._data, error_message, ln, 1) - comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1) - else: - try: - f_source, f_source_ln = self.get_source(f, bug_mode=True) - except OSError: # OSError: could not get source code - # we reach the compiled C level, so the previous level is actually the bottom - comments[-1] = comment_backup # replace the previous comment - break # exit the loop - comment = self.generate_comment(f_source, error_message, ln, f_source_ln) - comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln) - comments.append(comment) - commented_code = '\n\n'.join(comments) - self.info['error_comment'] = commented_code + f"\n{base_message}" - output = e - else: - output = fun(*_args, **_kwargs) - sys.settrace(oldtracer) + def preprocess_inputs(self, args, kwargs, _args, _kwargs): + # NOTE This function must be put inside the used_nodes context manager. + """ Preprocess the inputs for the operator fun. - # logging inputs and output of the function call + Args: + _args (list): the original positional arguments. This includes the default values. + _kwargs (dict): the original keyword arguments. This includes the default values. + args (list): the wrapped positional arguments. + kwargs (dict): the wrapped keyword arguments. + """ + # Construct the inputs to call self.fun + if self._process_inputs: # This is for handling hierarchical graph + if self.traceable_code: + _args, _kwargs = detach_inputs(args), detach_inputs(kwargs) + else: # NOTE Extract data from the nodes and pass them to the function; This line must be put inside the used_nodes context manager. + _args, _kwargs = to_data(args), to_data(kwargs) # read node.data; this ensures the inputs are treated as used nodes + # else the inputs are passed directly to the function + # so we don't change _args and _kwargs + return _args, _kwargs # this will be passed as the input to the function + + def postprocess_output(self, output, fun, _args, _kwargs, used_nodes, inputs): + """ + Wrap the output as a MessageNode. Log the inputs and output of the function call. + + Args: + output (Any): the output of the operator fun. + fun (callable): the operator fun. + _args (list): the original positional arguments. This includes the default values. + _kwargs (dict): the original keyword arguments. This includes the default values. + used_nodes (List[Node]): the nodes used in the operator fun. + inputs (Dict[str, Node]): the inputs of the operator fun. + """ + + # Log inputs and output of the function call self.info["output"] = output self.info['inputs']["args"] = _args self.info['inputs']["kwargs"] = _kwargs @@ -361,15 +414,51 @@ def tracer(frame, event, arg = None): inputs = {} # We don't need to keep track of the inputs if we are not tracing. # Wrap the output as a MessageNode or an ExceptionNode nodes = self.wrap(output, inputs, external_dependencies) + return nodes - # If the output is a corountine, we return a coroutine. - if nodes._data is not None and inspect.iscoroutine(nodes._data): - async def _run_coro(): - nodes._data = await nodes._data - return nodes - return _run_coro() + def forward(self, *args, **kwargs): + fun = self.fun # Define the function (only once) + self.info['fun'] = fun + if inspect.iscoroutinefunction(fun): + return self.async_forward(fun, *args, **kwargs) # Return a coroutine that returns a MessageNode else: - return nodes + return self.sync_forward(fun, *args, **kwargs) # Return a MessageNode + + def sync_forward(self, fun, *args, **kwargs): + """ + Call the operator fun and return a MessageNode. All nodes used in + the operator fun are added to used_nodes during the execution. If + the output is not a Node, we wrap it as a MessageNode, whose inputs + are nodes in used_nodes. Sync version. + """ + # Wrap the inputs as nodes + inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs) + ## Execute fun + with trace_nodes() as used_nodes: + # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + _args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs) + output = self.sync_call_fun(fun, *_args, **_kwargs) + # Wrap the output as a MessageNode or an ExceptionNode + nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs) + return nodes + + async def async_forward(self, fun, *args, **kwargs): + """ + Call the operator fun and return a MessageNode. All nodes used in + the operator fun are added to used_nodes during the execution. If + the output is not a Node, we wrap it as a MessageNode, whose inputs + are nodes in used_nodes. Async version. + """ + # Wrap the inputs as nodes + inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs) + ## Execute fun + with trace_nodes() as used_nodes: + # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + _args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs) + output = await self.async_call_fun(fun, *_args, **_kwargs) # use await to call the async function + # Wrap the output as a MessageNode or an ExceptionNode + nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs) + return nodes def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external_dependencies: List[Node]): """Wrap the output as a MessageNode of inputs as the parents.""" diff --git a/tests/unit_tests/test_asyncio.py b/tests/unit_tests/test_asyncio.py index 3d86c2ce..041b2a43 100644 --- a/tests/unit_tests/test_asyncio.py +++ b/tests/unit_tests/test_asyncio.py @@ -9,6 +9,7 @@ async def basic(a=0): return 'basic' async def main(): + # single task a = trace.node('a') st = time.time() x = await basic(a) @@ -25,6 +26,7 @@ async def main(): async def main2(): + # multiple tasks a = trace.node('a') st = time.time() x, y, z = await asyncio.gather(basic(a), basic(a), basic(a)) # run in parallel @@ -46,3 +48,53 @@ async def main2(): asyncio.run(main2()) + + +@trace.bundle() +async def error(a=0): + raise ValueError('error') + +async def main3(): + # error handling + a = trace.node('a') + st = time.time() + try: + x = await error(a) + except trace.ExecutionError as e: + print(e) + x = e + ed = time.time() + print("Time taken: ", ed - st) + print(type(x), 'developer message:', x) + assert isinstance(x, trace.ExecutionError) + x = x.exception_node + print(type(x), 'optimizer message:', x.data) + assert isinstance(x, trace.nodes.MessageNode) + assert a in x.parents + assert len(x.parents) == 1 + +asyncio.run(main3()) + +async def main4(): + # multiple error handling + a = trace.node('a') + b = trace.node('b') + c = trace.node('c') + st = time.time() + try: + x, y, z = await asyncio.gather(error(a), error(b), error(c)) # run in parallel + except trace.ExecutionError as e: + # print(e) + x = e # This will catch the first error + print(e.exception_node.parents) + ed = time.time() + print("Time taken: ", ed - st) + print(type(x), 'developer message:', x) + assert isinstance(x, trace.ExecutionError) + x = x.exception_node + print(type(x), 'optimizer message:', x.data) + assert isinstance(x, trace.nodes.MessageNode) + assert a in x.parents + assert len(x.parents) == 1 + +asyncio.run(main4()) \ No newline at end of file From 611ac7941745ddafbceaf0079c83b99aeac2e90a Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:08:13 -0800 Subject: [PATCH 2/7] Update version number. --- opto/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opto/version.py b/opto/version.py index b9bf8369..acf3be3e 100644 --- a/opto/version.py +++ b/opto/version.py @@ -1 +1 @@ -__version__ = "0.1.2.2" \ No newline at end of file +__version__ = "0.1.3" \ No newline at end of file From 3dcd504a33db94d5f9f89707f0ff8bd8fe26fb33 Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:35:10 -0800 Subject: [PATCH 3/7] Add missing __init__.py --- opto/utils/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 opto/utils/__init__.py diff --git a/opto/utils/__init__.py b/opto/utils/__init__.py new file mode 100644 index 00000000..e69de29b From fbb3d40f46ba2699efed315763efd5fe5c484a1a Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:39:03 -0800 Subject: [PATCH 4/7] Update test_llm to omit the test when oai config is missing. --- tests/unit_tests/test_llm.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/unit_tests/test_llm.py b/tests/unit_tests/test_llm.py index 37c0fe0b..3e4843de 100644 --- a/tests/unit_tests/test_llm.py +++ b/tests/unit_tests/test_llm.py @@ -1,21 +1,25 @@ from opto.utils.llm import AutoGenLLM from opto.optimizers.utils import print_color -llm = AutoGenLLM() -system_prompt = 'You are a helpful assistant.' -user_prompt = "Hello world." +try: + llm = AutoGenLLM() + system_prompt = 'You are a helpful assistant.' + user_prompt = "Hello world." -messages = [{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}] + messages = [{"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}] -output = llm(messages=messages) -# Alternatively, you can use the following code: -# output = llm.create(messages=messages) + output = llm(messages=messages) + # Alternatively, you can use the following code: + # output = llm.create(messages=messages) -response = output.choices[0].message.content + response = output.choices[0].message.content -print_color(f'System: {system_prompt}', 'red') -print_color(f'User: {user_prompt}', 'blue') -print_color(f'LLM: {response}', 'green') + print_color(f'System: {system_prompt}', 'red') + print_color(f'User: {user_prompt}', 'blue') + print_color(f'LLM: {response}', 'green') +except FileNotFoundError as e: + print_color(f'Error: {e}', 'red') + print_colorf('Omit the test.', 'yellow') \ No newline at end of file From 418a6cdc030909460dfc0a610f2ae580d4fabc30 Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:40:59 -0800 Subject: [PATCH 5/7] Fix a typo. --- tests/unit_tests/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_llm.py b/tests/unit_tests/test_llm.py index 3e4843de..31f33ee3 100644 --- a/tests/unit_tests/test_llm.py +++ b/tests/unit_tests/test_llm.py @@ -22,4 +22,4 @@ print_color(f'LLM: {response}', 'green') except FileNotFoundError as e: print_color(f'Error: {e}', 'red') - print_colorf('Omit the test.', 'yellow') \ No newline at end of file + print_color('Omit the test.', 'yellow') \ No newline at end of file From a0382d8bdc35592d80c5b5e91da95b57105a9f8a Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:43:12 -0800 Subject: [PATCH 6/7] Update test_copy to omit testing when oai config is not found. --- tests/unit_tests/test_copy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_copy.py b/tests/unit_tests/test_copy.py index cadcbf74..e3453498 100644 --- a/tests/unit_tests/test_copy.py +++ b/tests/unit_tests/test_copy.py @@ -21,5 +21,9 @@ def fun(x): optimizer = OptoPrime([x]) optimizer2 = copy.deepcopy(optimizer) -llm = AutoGenLLM() -copy.deepcopy(llm) \ No newline at end of file +try: + llm = AutoGenLLM() + copy.deepcopy(llm) +except FileNotFoundError as e: + print(f'Error: {e}') + print('Omit the test.') \ No newline at end of file From 58103bbf14d453739327c5bf412aa0993d913ea3 Mon Sep 17 00:00:00 2001 From: chinganc Date: Tue, 19 Nov 2024 11:45:39 -0800 Subject: [PATCH 7/7] Update test_copy to omit testing when oai config is not found. --- tests/unit_tests/test_copy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_copy.py b/tests/unit_tests/test_copy.py index e3453498..d9b7dcc2 100644 --- a/tests/unit_tests/test_copy.py +++ b/tests/unit_tests/test_copy.py @@ -18,10 +18,12 @@ def fun(x): x = trace.node('x', trainable=True) copy.deepcopy(x) -optimizer = OptoPrime([x]) -optimizer2 = copy.deepcopy(optimizer) + try: + optimizer = OptoPrime([x]) + optimizer2 = copy.deepcopy(optimizer) + llm = AutoGenLLM() copy.deepcopy(llm) except FileNotFoundError as e: