diff --git a/docs/_toc.yml b/docs/_toc.yml index 67144cd0..5b9ea4c2 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -12,6 +12,11 @@ parts: - file: quickstart/quick_start_2 - file: quickstart/virtualhome + - caption: FAQ + numbered: false + chapters: + - file: faq/faq + - caption: πŸ“šTutorials chapters: - file: tutorials/basic_tutorial diff --git a/docs/faq/faq.md b/docs/faq/faq.md new file mode 100644 index 00000000..07a22078 --- /dev/null +++ b/docs/faq/faq.md @@ -0,0 +1,47 @@ +# FAQ + +### Difference to Libraries like TextGrad + +TextGrad is both a library and an optimizer algorithm. Currently, we support three optimizers: + +- OPRO: [Large Language Models as Optimizers](https://arxiv.org/abs/2309.03409) +- TextGrad: [TextGrad: Automatic "Differentiation" via Text](https://arxiv.org/abs/2406.07496) +- OptoPrime: [Our proposed algorithm](https://arxiv.org/abs/2406.16218) -- using the entire computational graph to perform parameter update. It is 2-3x + faster than TextGrad. + +Using our framework, you can seamlessly switch between different optimizers: + +```python +optimizer1 = OptoPrime(strange_sort_list.parameters()) +optimizer2 = OPRO(strange_sort_list.parameters()) +optimizer3 = TextGrad(strange_sort_list.parameters()) +``` + +Here is a summary of the optimizers: + +| | Computation Graph | Code as Functions | Library Support | Supported Optimizers | Speed | Large Graph | +|-----------------------------------|-------------------|-------------------|------------------|---------------------------|-------------|-------------| +| OPRO | ❌ | ❌ | ❌ | OPRO | ⚑️ | βœ… | +| TextGrad | βœ… | ❌ | βœ… | TextGrad | 🐌 | βœ… | +| Trace | βœ… | βœ… | βœ… | OPRO, OptoPrime, TextGrad | ⚑ | βœ… | + +The table evaluates the frameworks in the following aspects: + +- Computation Graph: Whether the optimizer leverages the computation graph of the workflow. +- Code as Functions: Whether the framework allows users to write actual executable Python functions and not require + users to wrap them in strings. +- Library Support: Whether the framework has a library to support the optimizer. +- Speed: TextGrad is about 2-3x slower than OptoPrime (Trace). OPRO has no concept of computational graph, therefore is very fast. +- Large Graph: OptoPrime (Trace) represents the entire computation graph in context, therefore, might have issue with graphs that have more than hundreds of operations. TextGrad does not have the context-length issue, however, might be very slow on large graphs. + +We provide a comparison to validate our implementation of TextGrad in Trace: + +

+ drawing +

+ +To produce this table, we ran the TextGrad pip-installed repo on 2024-10-30, and we also include the numbers reported in the TextGrad paper. +The LLM APIs are called around the same time to ensure a fair comparison. TextGrad paper's result was reported in 2024-06. + +### Difference to Libraries like AutoGen, AG2, OpenAI Swarm, Llama Stack + diff --git a/opto/optimizers/optoprime.py b/opto/optimizers/optoprime.py index 351e590b..faa73695 100644 --- a/opto/optimizers/optoprime.py +++ b/opto/optimizers/optoprime.py @@ -1,18 +1,18 @@ + from typing import Any, List, Dict, Union, Tuple from dataclasses import dataclass, asdict -from opto.trace.nodes import ParameterNode, Node, MessageNode -from opto.optimizers.optimizer import Optimizer - -from opto.trace.propagators import TraceGraph, GraphPropagator from textwrap import dedent, indent -from opto.trace.propagators.propagators import Propagator -from opto.optimizers.buffers import FIFOBuffer import autogen import warnings import json - import re import copy +from opto.trace.nodes import ParameterNode, Node, MessageNode +from opto.trace.propagators import TraceGraph, GraphPropagator +from opto.trace.propagators.propagators import Propagator +from opto.optimizers.optimizer import Optimizer +from opto.optimizers.buffers import FIFOBuffer +from opto.utils.llm import AutoGenLLM def get_fun_name(node: MessageNode): @@ -252,7 +252,7 @@ class OptoPrime(Optimizer): def __init__( self, parameters: List[ParameterNode], - config_list: List = None, # autogen config_dict + LLM: AutoGenLLM = None, *args, propagator: Propagator = None, objective: Union[None, str] = None, @@ -267,11 +267,7 @@ def __init__( ): super().__init__(parameters, *args, propagator=propagator, **kwargs) self.ignore_extraction_error = ignore_extraction_error - if config_list is None: - config_list = autogen.config_list_from_json("OAI_CONFIG_LIST") - if filter_dict is not None: - config_list = autogen.filter_config_list(config_list, filter_dict) - self.llm = autogen.OpenAIWrapper(config_list=config_list) + self.llm = LLM or AutoGenLLM() self.objective = objective or self.default_objective self.example_problem = ProblemInstance.problem_template.format( instruction=self.default_objective, @@ -510,13 +506,13 @@ def call_llm( messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] try: # Try tp force it to be a json object - response = self.llm.create( + response = self.llm( messages=messages, response_format={"type": "json_object"}, max_tokens=max_tokens, ) except Exception: - response = self.llm.create(messages=messages, max_tokens=max_tokens) + response = self.llm(messages=messages, max_tokens=max_tokens) response = response.choices[0].message.content if verbose: diff --git a/opto/optimizers/utils.py b/opto/optimizers/utils.py new file mode 100644 index 00000000..c81295f4 --- /dev/null +++ b/opto/optimizers/utils.py @@ -0,0 +1,13 @@ +def print_color(message, color=None, logger=None): + colors = { + 'red': '\033[91m', + 'green': '\033[92m', + 'yellow': '\033[93m', + 'blue': '\033[94m', + 'magenta': '\033[95m', + 'cyan': '\033[96m' + } + print(f"{colors.get(color, '')}{message}\033[0m") # Default to no color if invalid color is provided + + if logger is not None: + logger.log(message) \ No newline at end of file diff --git a/opto/trace/bundle.py b/opto/trace/bundle.py index 076e3597..8c1115fa 100644 --- a/opto/trace/bundle.py +++ b/opto/trace/bundle.py @@ -5,6 +5,8 @@ import re import sys import traceback +import asyncio + from typing import List, Dict, Callable, Union, Any from opto.trace.broadcast import recursive_conversion @@ -24,9 +26,22 @@ def bundle( allow_external_dependencies=False, overwrite_python_recursion=False, ): - """ - Wrap a function as a FunModule, which returns node objects. - The input signature to the wrapped function stays the same. bundle can be used with other decorators so long as they are not named 'bundle'. + """Wrap a function as a FunModule which returns node objects. + + The input signature to the wrapped function stays the same. bundle can be used with other decorators + so long as they are not named 'bundle'. + + Args: + description (str, optional): Description of the operator. Defaults to None. + traceable_code (bool, optional): Whether the operator's code is traceable by Trace. Defaults to False. + _process_inputs (bool, optional): Whether to extract input from container of nodes. Defaults to True. + trainable (bool, optional): Whether block of code is treated as variable in optimization. Defaults to False. + catch_execution_error (bool, optional): Whether to catch exceptions during operator execution. Defaults to True. + allow_external_dependencies (bool, optional): Whether to allow external dependencies. Defaults to False. + overwrite_python_recursion (bool, optional): Whether to overwrite Python recursion behavior. Defaults to False. + + Returns: + FunModule: The wrapped function that returns node objects. """ prev_f_locals = inspect.stack()[1].frame.f_locals def decorator(fun): @@ -65,6 +80,7 @@ class FunModule(Module): description (str): a description of the operator; see the MessageNode for syntax. _process_inputs (bool): if True, the input is extracted from the container of nodes; if False, the inputs are passed directly to the underlying function. trainable (bool): if True, the block of code is treated as a variable in the optimization + traceable_code (bool): if True, the operator's code is traceable by Trace catch_execution_error (bool): if True, the operator catches the exception raised during the execution of the operator and return ExecutionError. allow_external_dependencies (bool): if True, the operator allows external dependencies to be used in the operator. Namely, not all nodes used to create the output are in the inputs. In this case, the extra dependencies are stored in the info dictionary with key 'extra_dependencies'. overwrite_python_recursion (bool): if True, the operator allows the python recursion behavior of calling the decorated function to be overwritten. When true, applying bundle on a recursive function, would be the same as calling the function directly. When False, the Python's oriignal recursion behavior of decorated functions is preserved. @@ -147,7 +163,9 @@ def trainable(self): @property def fun(self, *args, **kwargs): - # This is called within trace_nodes context manager. + """ Return a callable function. Return the decorated function if the parameter is None. Otherwise, return the function defined by the parameter. When exception happens during the defining the function with the parameter, raise a trace.ExecutionError. """ + + # This function should be later called within trace_nodes context manager. if self.parameter is None: return self._fun else: @@ -196,17 +214,21 @@ def fun(self, *args, **kwargs): def name(self): return get_op_name(self.description) - def forward(self, *args, **kwargs): - """ - All nodes used in the operator fun are added to used_nodes during - the execution. If the output is not a Node, we wrap it as a - MessageNode, whose inputs are nodes in used_nodes. - """ - + def _wrap_inputs(self, fun, args, kwargs): + """ Wrap the inputs to a function as nodes when they're not. - fun = self.fun # define the function only once - self.info['fun'] = fun + Args: + fun (callable): the function to be wrapped. + args (list): the positional arguments of the function. + kwargs (dict): the keyword arguments of the function. + Returns: + inputs (dict): the inputs dict to construct the MessageNode (constructed by args and kwargs). + args (list): the wrapped positional arguments. + kwargs (dict): the wrapped keyword arguments. + _args (list): the original positional arguments (including the default values). + _kwargs (dict): the original keyword arguments (including the default values). + """ ## Wrap the inputs as nodes # add default into kwargs @@ -232,7 +254,7 @@ def forward(self, *args, **kwargs): _, varargs, varkw, _, _, _, _ = inspect.getfullargspec(fun) - # bind the node version of args and kwargs + # bind the node version of args and kwargs ba = inspect.signature(fun).bind(*args, **kwargs) spec = ba.arguments @@ -251,7 +273,10 @@ def extract_param(n): inputs[k] = extract_param(v) assert all([isinstance(n, Node) for n in inputs.values()]), "All values in inputs must be nodes." + return inputs, args, kwargs, _args, _kwargs + def _get_tracer(self): + """ Get a tracer to overwrite the python recursion behavior of calling the decorated function. """ # Define a tracer to deal with recursive function calls _bundled_func = None @@ -280,67 +305,111 @@ def tracer(frame, event, arg = None): if frame.f_code.co_name in frame.f_globals: frame.f_globals[frame.f_code.co_name] = _bundled_func return tracer + return tracer + + def _construct_error_comment(self, e): + """ Construct the error comment on the source code and traceback. """ + self.info['traceback'] = traceback.format_exc() # This is saved for user debugging + # Construct message to optimizer + error_class = e.__class__.__name__ + detail = e.args[0] + cl, exc, tb = sys.exc_info() + assert tb is not None # we're in the except block, so tb should not be None + n_fun_calls = len(traceback.extract_tb(tb)) + # Step through the traceback stack + comments = [] + base_message = f'({error_class}) {detail}.' + for i, (f, ln) in enumerate(traceback.walk_tb(tb)): + if i>0: # ignore the first one, since that is the try statement above + error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.' + + if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here. + comment = self.generate_comment(self.parameter._data, error_message, ln, 1) + comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1) + else: + try: + f_source, f_source_ln = self.get_source(f, bug_mode=True) + except OSError: # OSError: could not get source code + # we reach the compiled C level, so the previous level is actually the bottom + comments[-1] = comment_backup # replace the previous comment + break # exit the loop + comment = self.generate_comment(f_source, error_message, ln, f_source_ln) + comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln) + comments.append(comment) + commented_code = '\n\n'.join(comments) + self.info['error_comment'] = commented_code + f"\n{base_message}" + output = e + return output + + def sync_call_fun(self, fun, *_args, **_kwargs): + """ Call the operator fun and return the output. Catch the exception if catch_execution_error is True. """ + oldtracer = sys.gettrace() + if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior + # Running a tracer would slow down the execution, so we only do this when necessary. + sys.settrace(self._get_tracer()) + + if self.catch_execution_error: + try: + output = fun(*_args, **_kwargs) + except Exception as e: + output = self._construct_error_comment(e) + else: + output = fun(*_args, **_kwargs) + sys.settrace(oldtracer) + return output - ## Execute self.fun - with trace_nodes() as used_nodes: - # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + async def async_call_fun(self, fun, *_args, **_kwargs): + oldtracer = sys.gettrace() + if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior + # Running a tracer would slow down the execution, so we only do this when necessary. + sys.settrace(self._get_tracer()) - # args, kwargs are nodes - # _args, _kwargs are the original inputs (_kwargs inlcudes the defaults) + if self.catch_execution_error: + try: + output = await fun(*_args, **_kwargs) + except Exception as e: + output = self._construct_error_comment(e) + else: + output = await fun(*_args, **_kwargs) - # Construct the inputs to call self.fun - if self._process_inputs: - if self.traceable_code: - _args, _kwargs = detach_inputs(args), detach_inputs(kwargs) - else: - _args, _kwargs = to_data(args), to_data(kwargs) - # else the inputs are passed directly to the function - # so we don't change _args and _kwargs - - oldtracer = sys.gettrace() - if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior - sys.settrace(tracer) - # add an except here - if self.catch_execution_error: - try: - output = fun(*_args, **_kwargs) - except Exception as e: - # Construct the error comment on the source code and traceback - self.info['traceback'] = traceback.format_exc() # This is saved for user debugging - # Construct message to optimizer - error_class = e.__class__.__name__ - detail = e.args[0] - cl, exc, tb = sys.exc_info() - n_fun_calls = len(traceback.extract_tb(tb)) - # Step through the traceback stack - comments = [] - base_message = f'({error_class}) {detail}.' - for i, (f, ln) in enumerate(traceback.walk_tb(tb)): - if i>0: # ignore the first one, since that is the try statement above - error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.' - - if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here. - comment = self.generate_comment(self.parameter._data, error_message, ln, 1) - comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1) - else: - try: - f_source, f_source_ln = self.get_source(f, bug_mode=True) - except OSError: # OSError: could not get source code - # we reach the compiled C level, so the previous level is actually the bottom - comments[-1] = comment_backup # replace the previous comment - break # exit the loop - comment = self.generate_comment(f_source, error_message, ln, f_source_ln) - comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln) - comments.append(comment) - commented_code = '\n\n'.join(comments) - self.info['error_comment'] = commented_code + f"\n{base_message}" - output = e - else: - output = fun(*_args, **_kwargs) - sys.settrace(oldtracer) + sys.settrace(oldtracer) + return output + + def preprocess_inputs(self, args, kwargs, _args, _kwargs): + # NOTE This function must be put inside the used_nodes context manager. + """ Preprocess the inputs for the operator fun. - # logging inputs and output of the function call + Args: + _args (list): the original positional arguments. This includes the default values. + _kwargs (dict): the original keyword arguments. This includes the default values. + args (list): the wrapped positional arguments. + kwargs (dict): the wrapped keyword arguments. + """ + # Construct the inputs to call self.fun + if self._process_inputs: # This is for handling hierarchical graph + if self.traceable_code: + _args, _kwargs = detach_inputs(args), detach_inputs(kwargs) + else: # NOTE Extract data from the nodes and pass them to the function; This line must be put inside the used_nodes context manager. + _args, _kwargs = to_data(args), to_data(kwargs) # read node.data; this ensures the inputs are treated as used nodes + # else the inputs are passed directly to the function + # so we don't change _args and _kwargs + return _args, _kwargs # this will be passed as the input to the function + + def postprocess_output(self, output, fun, _args, _kwargs, used_nodes, inputs): + """ + Wrap the output as a MessageNode. Log the inputs and output of the function call. + + Args: + output (Any): the output of the operator fun. + fun (callable): the operator fun. + _args (list): the original positional arguments. This includes the default values. + _kwargs (dict): the original keyword arguments. This includes the default values. + used_nodes (List[Node]): the nodes used in the operator fun. + inputs (Dict[str, Node]): the inputs of the operator fun. + """ + + # Log inputs and output of the function call self.info["output"] = output self.info['inputs']["args"] = _args self.info['inputs']["kwargs"] = _kwargs @@ -361,6 +430,50 @@ def tracer(frame, event, arg = None): nodes = self.wrap(output, inputs, external_dependencies) return nodes + def forward(self, *args, **kwargs): + fun = self.fun # Define the function (only once) + self.info['fun'] = fun + if inspect.iscoroutinefunction(fun): + return self.async_forward(fun, *args, **kwargs) # Return a coroutine that returns a MessageNode + else: + return self.sync_forward(fun, *args, **kwargs) # Return a MessageNode + + def sync_forward(self, fun, *args, **kwargs): + """ + Call the operator fun and return a MessageNode. All nodes used in + the operator fun are added to used_nodes during the execution. If + the output is not a Node, we wrap it as a MessageNode, whose inputs + are nodes in used_nodes. Sync version. + """ + # Wrap the inputs as nodes + inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs) + ## Execute fun + with trace_nodes() as used_nodes: + # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + _args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs) + output = self.sync_call_fun(fun, *_args, **_kwargs) + # Wrap the output as a MessageNode or an ExceptionNode + nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs) + return nodes + + async def async_forward(self, fun, *args, **kwargs): + """ + Call the operator fun and return a MessageNode. All nodes used in + the operator fun are added to used_nodes during the execution. If + the output is not a Node, we wrap it as a MessageNode, whose inputs + are nodes in used_nodes. Async version. + """ + # Wrap the inputs as nodes + inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs) + ## Execute fun + with trace_nodes() as used_nodes: + # After exit, used_nodes contains the nodes whose data attribute is read in the operator fun. + _args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs) + output = await self.async_call_fun(fun, *_args, **_kwargs) # use await to call the async function + # Wrap the output as a MessageNode or an ExceptionNode + nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs) + return nodes + def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external_dependencies: List[Node]): """Wrap the output as a MessageNode of inputs as the parents.""" # Some nodes are used in the operator fun, we need to wrap the output as a MessageNode. @@ -409,27 +522,22 @@ def generate_comment(self, code: str, comment: str, comment_line_number: int, ba def get_source(self, obj: Any, bug_mode=False): """ Get the source code of the function and its line number, excluding the @bundle decorator line. - - Allowable two types of usages: - - Decorator style: - - @blah - ... - @bundle # or @ ....bundle() - ... - def fun(...): # ... - .... - - - or inline usage - - bundle()(fun) # or ....bundle()(fun) - bug_mode=True means We are in the forward() function, but there is an error during execution. The error can be caused by a lambda function which does not have `def` in the source code. We turn off the error raising in the end of this function. + + Allowable two types of usages: + + Examples: + + >>> @blah + >>> ... + >>> @bundle # or @ ....bundle() + >>> def fun(...): # ... + >>> .... + or inline usage + >>> bundle()(fun) # or ....bundle()(fun) """ source = inspect.getsource(obj) # the source includes @bundle, or @trace.bundle, etc. we will remove those parts. line_number = int(inspect.getsourcelines(obj)[1]) # line number of obj diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index b37749a8..4e522936 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -7,39 +7,27 @@ import heapq - def node(data, name=None, trainable=False, description=None, constraint=None): - """Create a Node object from data. If the data is already a Node, it will be returned as is. - This function is provided for the convenience of the user and should be used instead of directly invoking the Node class. - - Parameters - ---------- - data: The data to create the Node from. - - name: (optional) The name of the Node. - - trainable: (optional) A boolean indicating whether the Node is trainable or not. Default is False. - - description: (optional) A string describing the data. - - constraint: (optional) A string describing any constraint that the data should obey. - - Code Description - ---------- - The node function allows users to create Node objects from data. - The function first checks if the trainable parameter is True. - If it is, it checks if the data is already a Node. - If it is, it extracts the underlying data and updates the name if a new name is provided. - It then creates a ParameterNode object with the extracted data, name, trainable set to True, and the provided constraint. - If the message is not already a Node, it creates a new ParameterNode object with the message as the data, - the provided name, trainable set to True, and the provided constraint. - - If the trainable parameter is False, the function checks if the message is already a Node. - If it is, it checks if a name is provided. - If a name is provided, it issues a warning that the name is ignored because the message is already a Node. - It then returns the message as is. - If the message is not already a Node, it creates a new Node object with the message as the data, - the provided name, and the provided constraint. + """Create a Node object from data. + + Args: + data: The data to create the Node from. + name (str, optional): The name of the Node. + trainable (bool, optional): Whether the Node is trainable. Defaults to False. + description (str, optional): A string describing the data. + constraint (str, optional): A string describing any constraint that the data should obey. + + Returns: + Node: A Node object containing the data. + + Notes: + If trainable=True: + - If data is already a Node, extracts underlying data and updates name + - Creates ParameterNode with extracted data, name, trainable=True and constraint + + If trainable=False: + - If data is already a Node, returns it (with warning if name provided) + - Otherwise creates new Node with data, name and constraint """ assert type(description) is str or description is None @@ -63,46 +51,38 @@ def node(data, name=None, trainable=False, description=None, constraint=None): class Graph: """Graph is a registry of all the nodes, forming a Directed Acyclic Graph (DAG). - - Attributes - ---------- - TRACE: A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True. - - _nodes: An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name. - - Code Description - ---------- - The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG). - It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes. - - Note - ---------- - The `register` method assumes that elements in `_nodes` are never removed, - which is important for maintaining the integrity of node names. + + Attributes: + TRACE (bool): A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True. + _nodes (defaultdict): An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name. + + Notes: + The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG). + It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes. + The `register` method assumes that elements in `_nodes` are never removed, + which is important for maintaining the integrity of node names. """ TRACE = True # When True, we trace the graph when creating MessageNode. When False, we don't trace the graph. def __init__(self): - """Initialize the Graph object, setting up the `_nodes` attribute as a defaultdict of lists to store nodes by their names. + """Initialize the Graph object. + + The initialization sets up the `_nodes` attribute as a defaultdict of lists to store nodes by their names. """ self._nodes = defaultdict(list) # a lookup table to find nodes by name def clear(self): - """Remove all nodes from the graph by deleting each node and reinitializing the `_nodes` attribute. + """Remove all nodes from the graph. - Code Description - ---------- - The clear function iterates over the current nodes stored in the _nodes attribute and deletes each node. - After all nodes have been deleted, it reinitializes the _nodes attribute to an empty defaultdict of lists. + The clear function iterates over the current nodes stored in the _nodes attribute and deletes each node. + After all nodes have been deleted, it reinitializes the _nodes attribute to an empty defaultdict of lists. This ensures that the graph is completely cleared and ready to be repopulated with new nodes if necessary. - The function is called in unit tests to reset the state of the graph between test cases, - ensuring that each test runs with a clean slate and is not affected by the state left by previous tests. - - Note - ---------- - After calling clear, any references to the previously stored nodes will become invalid. + Notes: + After calling clear, any references to the previously stored nodes will become invalid. + The function is called in unit tests to reset the state of the graph between test cases, + ensuring that each test runs with a clean slate and is not affected by the state left by previous tests. """ for node in self._nodes.values(): del node @@ -110,23 +90,17 @@ def clear(self): # self._levels = defaultdict(list) def register(self, node): - """Add a node to the graph, ensuring that the node is an instance of the Node class and that its name follows the expected format (containing a colon). - This method also handles name scoping and assigns a unique name to the node based on its position in the DAG. - - Parameters - ---------- - node: The node object to be registered in the graph. + """Add a node to the graph. - Code Description - ---------- - After checking that the input is a `Node` and its name has the right format, the function splits the name of the node into the `name` variable and the identifier. - The function then checks if there are any name scopes defined in the `NAME_SCOPES` list. If the length of the list is greater than 0, the name is prefixed with the last scope in the list followed by a "/". This allows for scoping of node names. - Finally, the function adds the node to the `_nodes` dictionary using the modified name as the key. The `_name` attribute of the node is set to the modified name followed by the index of the node in the list of nodes with the same name. + Args: + node: The node object to be registered in the graph. - Note - ---------- - The `register` function should only be called after the node has been properly initialized and its name has been set. - The function assumes that elements in the `_nodes` dictionary never get removed. + Notes: + The register function should only be called after the node has been properly initialized and its name has been set. + The function assumes that elements in the `_nodes` dictionary never get removed. + After checking that the input is a `Node` and its name has the right format, the function splits the name of the node into the `name` variable and the identifier. + The function then checks if there are any name scopes defined in the `NAME_SCOPES` list. If the length of the list is greater than 0, the name is prefixed with the last scope in the list followed by a "/". This allows for scoping of node names. + Finally, the function adds the node to the `_nodes` dictionary using the modified name as the key. The `_name` attribute of the node is set to the modified name followed by the index of the node in the list of nodes with the same name. """ assert isinstance(node, Node) assert len(node.name.split(":")) == 2 @@ -135,45 +109,50 @@ def register(self, node): name = NAME_SCOPES[-1] + "/" + name self._nodes[name].append(node) node._name = ( - name + ":" + str(len(self._nodes[name]) - 1) + name + ":" + str(len(self._nodes[name]) - 1) ) # NOTE assume elements in self._nodes never get removed. # self._levels[node._level].append(node) def get(self, name): - """The `get` method retrieves a node from the graph by its name, which includes an identifier. - - Parameters - ---------- - name: A string in the format "name:id", where "name" is the name of the node and "id" is the identifier of the node. - - Code Description - ---------- - The get function is designed to extract and return a specific node from the graph. The input parameter 'name' is expected to be a string formatted as "name:id". - The function first splits this string into two parts: 'name' and 'id', using the colon (":") as the delimiter. - The 'name' part represents the name of the node, and the 'id' part represents the identifier of the node, which is then converted to an integer. - The function then accesses the '_nodes' dictionary attribute of the graph object, using the 'name' as the key to retrieve the list of nodes associated with that name. - Finally, it returns the node at the position specified by the integer 'id' within that list. - - Note - ---------- - Ensure that the 'name' parameter is correctly formatted as "name:id" before calling this function. - The function assumes that the '_nodes' attribute is a dictionary where each key is a node name and the corresponding value is a list of nodes. - The 'id' should be a valid index within the list of nodes for the given 'name'. + """Retrieve a node from the graph by its name. + + Args: + name (str): A string in the format "name:id", where "name" is the name of the node and "id" is the identifier of the node. + + Returns: + Node: The requested node from the graph. + + Notes: + Ensure that the 'name' parameter is correctly formatted as "name:id" before calling this function. + The function assumes that the '_nodes' attribute is a dictionary where each key is a node name and the corresponding value is a list of nodes. + The 'id' should be a valid index within the list of nodes for the given 'name'. """ name, id = name.split(":") return self._nodes[name][int(id)] @property def roots(self): - """The `roots` property returns a list of all root nodes in the graph. A root node is identified by its `is_root` attribute.""" + """Get all root nodes in the graph. + + Returns: + list: A list of all root nodes in the graph. A root node is identified by its `is_root` attribute. + """ return [v for vv in self._nodes.values() for v in vv if v.is_root] def __str__(self): - """The `__str__` method provides a string representation of the `_nodes` attribute, useful for debugging and logging.""" + """Get string representation of the graph. + + Returns: + str: String representation of the `_nodes` attribute, useful for debugging and logging. + """ return str(self._nodes) def __len__(self): - """The `__len__` method returns the total number of nodes in the graph by summing the lengths of all lists in the `_nodes` dictionary.""" + """Get total number of nodes in the graph. + + Returns: + int: The total number of nodes in the graph by summing the lengths of all lists in the `_nodes` dictionary. + """ # This is the number of nodes in the graph return sum([len(v) for v in self._nodes.values()]) @@ -185,82 +164,75 @@ def __len__(self): T = TypeVar("T") """Graph is a registry of all the nodes, forming a Directed Acyclic Graph (DAG). - - Attributes - ---------- - TRACE: A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True. - - _nodes: An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name. - - Code Description - ---------- - The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG). - It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes. - - Note - ---------- - The `register` method assumes that elements in `_nodes` are never removed, - which is important for maintaining the integrity of node names. - """ + + Attributes: + TRACE (bool): A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True. + _nodes (defaultdict): An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name. + + Notes: + The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG). + It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes. + The `register` method assumes that elements in `_nodes` are never removed, + which is important for maintaining the integrity of node names. +""" + class AbstractNode(Generic[T]): """AbstractNode represents an abstract data node in a directed graph. - - Attributes - ---------- - `data`: The data stored in the node. - `parents`: The list of parent nodes. - `children`: The list of child nodes. - `name`: The name of the node. - `py_name`: The name of the node without the ":" character. - `id`: The ID of the node. - `level`: The level of the node in the graph. - `is_root`: A boolean indicating whether the node is a root node. - `is_leaf`: A boolean indicating whether the node is a leaf node. - - Code Description - ---------- - The `AbstractNode` class is meant to be subclassed and extended to create specific types of nodes. - The node can have multiple parents and children, forming a directed graph structure. - The node has a name, which is used to identify it within the graph. - The `py_name` attribute is the same as the name attribute, but with the ":" character removed. - - The node can be initialized with a value, an optional name, and an optional trainable flag. - If the value is an instance of the `Node` class, the node will be initialized as a reference to that node, otherwise, the value will be stored directly in the node. - The default name is generated based on the type of the value and a version number which serves as the identifier, separated by ":". - - The `AbstractNode` class provides several properties to access its attributes. The `data` property allows access to the stored data. - If the node is being traced within a context, the `data` property adds the node to the list of nodes used in that context. - The `parents` property returns a list of parent nodes, and the `children` property returns a list of child nodes. - The `name` property returns the name of the node, and the `py_name` property returns the name without the ":" character. - The `id` property returns the version number/identifier extracted from the name. - The `level` property returns the level of the node in the DAG. - The `is_root` property returns True if the node has no parents, and the `is_leaf` property returns True if the node has no children. - - The `AbstractNode` class also provides internal methods to add parents and children to the node. - The `_add_child` method adds a child node to the node's list of children. - The `_add_parent` method adds a parent node to the node's list of parents and updates the level of the node based on the parent's level. - - The `AbstractNode` class overrides the `__str__` method to provide a string representation of the node. The representation includes the name, the type of the data, and the data itself. - The `AbstractNode` class implements the `__deepcopy__` method to create a deep copy of the node. This allows the node to be detached from the original graph. - The `AbstractNode` class provides comparison methods `lt` and `gt` to compare the levels of two nodes in the DAG. + + Attributes: + data: The data stored in the node. + parents: The list of parent nodes. + children: The list of child nodes. + name: The name of the node. + py_name: The name of the node without the ":" character. + id: The ID of the node. + level: The level of the node in the graph. + is_root: A boolean indicating whether the node is a root node. + is_leaf: A boolean indicating whether the node is a leaf node. + + Notes: + The `AbstractNode` class is meant to be subclassed and extended to create specific types of nodes. + The node can have multiple parents and children, forming a directed graph structure. + The node has a name, which is used to identify it within the graph. + The `py_name` attribute is the same as the name attribute, but with the ":" character removed. + + The node can be initialized with a value, an optional name, and an optional trainable flag. + If the value is an instance of the `Node` class, the node will be initialized as a reference to that node, otherwise, the value will be stored directly in the node. + The default name is generated based on the type of the value and a version number which serves as the identifier, separated by ":". + + The `AbstractNode` class provides several properties to access its attributes. The `data` property allows access to the stored data. + If the node is being traced within a context, the `data` property adds the node to the list of nodes used in that context. + The `parents` property returns a list of parent nodes, and the `children` property returns a list of child nodes. + The `name` property returns the name of the node, and the `py_name` property returns the name without the ":" character. + The `id` property returns the version number/identifier extracted from the name. + The `level` property returns the level of the node in the DAG. + The `is_root` property returns True if the node has no parents, and the `is_leaf` property returns True if the node has no children. + + The `AbstractNode` class also provides internal methods to add parents and children to the node. + The `_add_child` method adds a child node to the node's list of children. + The `_add_parent` method adds a parent node to the node's list of parents and updates the level of the node based on the parent's level. + + The `AbstractNode` class overrides the `__str__` method to provide a string representation of the node. The representation includes the name, the type of the data, and the data itself. + The `AbstractNode` class implements the `__deepcopy__` method to create a deep copy of the node. This allows the node to be detached from the original graph. + The `AbstractNode` class provides comparison methods `lt` and `gt` to compare the levels of two nodes in the DAG. """ def __init__(self, value, *, name=None, trainable=False) -> None: """Initialize an instance of the AbstractNode class. - Parameters - ---------- - value: The value to be assigned to the node. - name: The name of the node (optional). - trainable: A boolean indicating whether the node is trainable or not (optional). + Args: + value: The value to be assigned to the node. + name (str, optional): The name of the node. Defaults to None. + trainable (bool, optional): Whether the node is trainable or not. Defaults to False. - Code Description - ---------- - During initialization, this function generates a default name for the node based on the type of the `value` parameter. If the `name` parameter is provided, it is appended to the default name. The format of the name is "type:version", where the version is set to 0 if no name is provided. - If the `value` parameter is an instance of the Node class, the `_data` attribute of the current node is set to the `_data` attribute of the `value` parameter, and the `_name` attribute is set to the `_name` attribute of the `value` parameter if no name is provided. - Otherwise, the `_data` attribute is set to the `value` parameter itself, and the `_name` attribute is set to the default name. - Finally, the function calls the `register` function of the GRAPH object to register the current node in the graph. + Notes: + During initialization, this function generates a default name for the node based on the type of the `value` parameter. + If the `name` parameter is provided, it is appended to the default name. The format of the name is "type:version", where the version is set to 0 if no name is provided. + If the `value` parameter is an instance of the Node class, the `_data` attribute of the current node is set to the `_data` attribute of the `value` parameter, + and the `_name` attribute is set to the `_name` attribute of the `value` parameter if no name is provided. + Otherwise, the `_data` attribute is set to the `value` parameter itself, and the `_name` attribute is set to the default name. + Finally, the function calls the `register` function of the GRAPH object to register the current node in the graph. """ self._parents = [] self._children = [] @@ -276,11 +248,15 @@ def __init__(self, value, *, name=None, trainable=False) -> None: @property def data(self): - """Retrieve the internal data of a node, potentially adding the node to a list of used nodes if certain conditions are met. - - Note - ---------- - This function assumes that the "_data" attribute exists within the node object. If this attribute is not present, an AttributeError will be raised. + """Retrieve the internal data of a node. + + Returns: + Any: The internal data stored in the node. + + Notes: + If within a trace_nodes context and GRAPH.TRACE is True, adds the node to USED_NODES. + This function assumes that the "_data" attribute exists within the node object. + If this attribute is not present, an AttributeError will be raised. """ if len(USED_NODES) > 0 and GRAPH.TRACE: # We're within trace_nodes context. USED_NODES[-1].add(self) @@ -288,58 +264,112 @@ def data(self): @property def parents(self): - """Access the parents of a node. It is an essential part of the graph structure and is used in various operations such as graph traversal and feedback propagation.""" + """Get the parents of a node. + + Returns: + list: The list of parent nodes. + + Notes: + This property is an essential part of the graph structure and is used in various operations + such as graph traversal and feedback propagation. + """ return self._parents @property def children(self): - """Access the children of a node. This property is essential for accessing the hierarchical structure of nodes, allowing traversal and manipulation of the DAG.""" + """Get the children of a node. + + Returns: + list: The list of child nodes. + + Notes: + This property is essential for accessing the hierarchical structure of nodes, + allowing traversal and manipulation of the DAG. + """ return self._children @property def name(self): - """This property is set when the node is registered in the graph. It is a combination of the node's name and its index in the list of nodes with the same name. The index is incremented each time a new node with the same name is registered. This assumes that elements in the `_nodes` dictionary of the graph never get removed.""" + """Get the name of the node. + + Returns: + str: The name of the node. + + Notes: + This property is set when the node is registered in the graph. + It is a combination of the node's name and its index in the list of nodes with the same name. + The index is incremented each time a new node with the same name is registered. + This assumes that elements in the `_nodes` dictionary of the graph never get removed. + """ return self._name @property def py_name(self): + """Get the Python-friendly name of the node. + + Returns: + str: The name of the node with ":" characters removed. + """ return self.name.replace(":", "") @property def id(self): - """The `name` property is a string formatted as "name:identifier". This property splits that string using the colon (":") delimiter and returns the second part, which corresponds to the identifier. - This identifier is typically a unique part of the node's name, distinguishing it from other nodes with the same base name. - Ensure that the `name` attribute contains a colon (":") to avoid index errors during the split operation. + """Get the identifier part of the node's name. + + Returns: + str: The identifier portion of the node's name (part after the colon). + + Notes: + The `name` property is a string formatted as "name:identifier". + This property splits that string using the colon (":") delimiter and returns the second part, + which corresponds to the identifier. + Ensure that the `name` attribute contains a colon (":") to avoid index errors during the split operation. """ return self.name.split(":")[1] @property def level(self): - """The level of a node in the graph. The level is determined by the maximum level of its parents plus one. The level of a root node is 0.""" + """Get the level of the node in the graph. + + Returns: + int: The level of the node. + + Notes: + The level is determined by the maximum level of its parents plus one. + The level of a root node is 0. + """ return self._level @property def is_root(self): - """A boolean indicating whether the node is a root node in a graph structure. A root node has no parents.""" + """Check if the node is a root node. + + Returns: + bool: True if the node has no parents, False otherwise. + """ return len(self.parents) == 0 @property def is_leaf(self): - """A boolean indicating whether the node is a leaf node in a graph structure. A leaf node has no children.""" + """Check if the node is a leaf node. + + Returns: + bool: True if the node has no children, False otherwise. + """ return len(self.children) == 0 def _add_child(self, child): """Add a child node to the current node. - Parameters - ---------- - child: The child node to be added. + Args: + child: The child node to be added. - Code Description - ---------- - 1. The `_add_child` function first checks if the child node is not the same as the current node itself. If it is, it raises an assertion error (no self-loops allowed in the DAG). - 2. It then checks if the child node is an instance of the `Node` class. If it is not, it raises a different assertion error. - 3. Finally, it calls the `_add_parent` function of the child node, passing the current node as the parent. + Notes: + The function first checks if the child node is not the same as the current node itself. + If it is, it raises an assertion error (no self-loops allowed in the DAG). + It then checks if the child node is an instance of the `Node` class. + If it is not, it raises a different assertion error. + Finally, it calls the `_add_parent` function of the child node, passing the current node as the parent. """ assert child is not self, "Cannot add self as a child." assert isinstance(child, Node), f"{child} is not a Node." @@ -347,18 +377,23 @@ def _add_child(self, child): def _add_parent(self, parent): """Add a parent node to the current node. - - Parameters - ---------- - parent: The parent node to be added. - - Code Description - ---------- - 1. The `_add_parent` function asserts that the parent node is not the same as the current node itself. This check prevents self-loops in the DAG. - 2. It then asserts that the parent node is an instance of the `Node` class. This check ensures that only valid nodes can be added as parents. - 3. If both checks pass, the function proceeds to add the current node as a child to the parent node by appending it to the parent's `_children` attribute. Similarly, it adds the parent node to the current node's `_parents` attribute. - 4. Finally, the function calls the _update_level method to update the level attribute of the current node. It passes the maximum value between the current node's _level attribute and the parent node's _level attribute plus one as the new level value. - This ensures that the hierarchical structure of the nodes is maintained correctly, with child nodes always having a level greater than or equal to their parent nodes. + + Args: + parent: The parent node to be added. + + Notes: + The function asserts that the parent node is not the same as the current node itself. + This check prevents self-loops in the DAG. + It then asserts that the parent node is an instance of the `Node` class. + This check ensures that only valid nodes can be added as parents. + If both checks pass, the function proceeds to add the current node as a child to the parent node + by appending it to the parent's `_children` attribute. Similarly, it adds the parent node to + the current node's `_parents` attribute. + Finally, the function calls the _update_level method to update the level attribute of the current node. + It passes the maximum value between the current node's _level attribute and the parent node's _level + attribute plus one as the new level value. + This ensures that the hierarchical structure of the nodes is maintained correctly, + with child nodes always having a level greater than or equal to their parent nodes. """ assert parent is not self, "Cannot add self as a parent." assert isinstance(parent, Node), f"{parent} is {type(parent)}, which is not a Node." @@ -367,15 +402,14 @@ def _add_parent(self, parent): self._update_level(max(self._level, parent._level + 1)) # Update the level, because the parent is added def _update_level(self, new_level): - """Update the level attribute of the current node to a new specified level. - - Parameters - ---------- - new_level: The new level to which the node's level attribute should be updated. Must be an integer. + """Update the level attribute of the current node. - Note - ---------- - The function does not perform any validation or checks on the new_level parameter; it directly assigns it to the _level attribute. + Args: + new_level (int): The new level to which the node's level attribute should be updated. + + Notes: + The function does not perform any validation or checks on the new_level parameter; + it directly assigns it to the _level attribute. """ # GRAPH._levels[self._level].remove(self) # this uses the == operator which compares values. We need to compare references. self._level = new_level @@ -383,39 +417,44 @@ def _update_level(self, new_level): # assert all([len(GRAPH._levels[i]) > 0 for i in range(len(GRAPH._levels))]), "Some levels are empty." def __str__(self) -> str: - """Return a string representation of the node, including its name, data type, and data value. + """Get string representation of the node. + + Returns: + str: A string containing the node's name, data type, and data value. - Code Description - ---------- - The `__str__` method constructs a string representation of the node by concatenating the node's name, the data type of the node's data, and the actual data stored in the node. - Doing str(node) allows us to look up the node in the feedback dictionary maintained by Trace during the backward pass easily. - - Note - ---------- - Ensure that the node has been properly initialized and registered before calling this method to avoid any unexpected behavior. + Notes: + The `__str__` method constructs a string representation of the node by concatenating + the node's name, the data type of the node's data, and the actual data stored in the node. + Doing str(node) allows us to look up the node in the feedback dictionary maintained by + Trace during the backward pass easily. + Ensure that the node has been properly initialized and registered before calling this + method to avoid any unexpected behavior. """ # str(node) allows us to look up in the feedback dictionary easily return f"Node: ({self.name}, dtype={type(self._data)}, data={self._data})" def __deepcopy__(self, memo): - """Create a deep copy of the node, which is detached from the original graph. - - Parameters - ---------- - memo: A dictionary used to keep track of objects that have already been copied to avoid infinite recursion during the deep copy process. - - Code Description - ---------- - The new instance will be a completely independent copy of the original, with no shared references to mutable objects. - 1. The function starts by obtaining the class of the current instance (`cls = self.__class__`). - 2. It then creates a new, uninitialized instance of this class (`result = cls.__new__(cls)`). - 3. The `memo` dictionary is updated to associate the original instance's ID with the new instance (`memo[id(self)] = result`). - This helps in tracking already copied objects to prevent infinite loops. - 4. The function iterates over all the attributes of the original instance (`for k, v in self.__dict__.items():`). - 5. For attributes named `_parents` or `_children`, it sets these attributes in the new instance to empty lists (`setattr(result, k, [])`). - This ensures that the new instance starts with no parent or child nodes. - 6. For all other attributes, it performs a deep copy of the attribute's value and assigns it to the new instance (`setattr(result, k, copy.deepcopy(v, memo))`). - 7. Finally, the new instance is returned (`return result`). + """Create a deep copy of the node. + + Args: + memo (dict): A dictionary used to keep track of objects that have already been copied. + + Returns: + AbstractNode: A new instance that is a deep copy of the current node. + + Notes: + The new instance will be a completely independent copy of the original, + with no shared references to mutable objects. + 1. The function starts by obtaining the class of the current instance (`cls = self.__class__`). + 2. It then creates a new, uninitialized instance of this class (`result = cls.__new__(cls)`). + 3. The `memo` dictionary is updated to associate the original instance's ID with the new instance. + This helps in tracking already copied objects to prevent infinite loops. + 4. The function iterates over all the attributes of the original instance. + 5. For attributes named `_parents` or `_children`, it sets these attributes in the new instance + to empty lists. This ensures that the new instance starts with no parent or child nodes. + 6. For all other attributes, it performs a deep copy of the attribute's value and assigns it + to the new instance. + 7. Finally, the new instance is returned. """ cls = self.__class__ result = cls.__new__(cls) @@ -430,30 +469,34 @@ def __deepcopy__(self, memo): return result def lt(self, other): - """Less than comparison based on the level attribute of the nodes. - - Parameters - ---------- - other: The other node to compare against. + """Compare if this node's level is less than another node's level. + + Args: + other: The other node to compare against. - Note - ---------- - This method is used to compare the levels of two nodes in the DAG. - Therefore it checks if the negated level of the current node (`-self._level`) is less than the negated level of the other node (`-other._level`) + Returns: + bool: True if this node's level is less than the other node's level. + + Notes: + This method is used to compare the levels of two nodes in the DAG. + Therefore it checks if the negated level of the current node (`-self._level`) + is less than the negated level of the other node (`-other._level`) """ return -self._level < -other._level def gt(self, other): - """Greater than comparison based on the level attribute of the nodes. - - Parameters - ---------- - other: The other node to compare against. - - Note - ---------- - This method is used to compare the levels of two nodes in the DAG. - Therefore it checks if the negated level of the current node (`-self._level`) is greater than the negated level of the other node (`-other._level`) + """Compare if this node's level is greater than another node's level. + + Args: + other: The other node to compare against. + + Returns: + bool: True if this node's level is greater than the other node's level. + + Notes: + This method is used to compare the levels of two nodes in the DAG. + Therefore it checks if the negated level of the current node (`-self._level`) + is greater than the negated level of the other node (`-other._level`) """ return -self._level > -other._level @@ -464,15 +507,21 @@ def gt(self, other): def get_op_name(description): """Extract the operator type from the description. - - Parameters - ---------- - description: A string containing the description of the node. - - Code Description - ---------- - The `get_op_name` function takes a description as input and uses regular expression to search for the operator type enclosed in square brackets at the beginning of the description. - If a match is found, the operator type is extracted and returned. Otherwise, a `ValueError` is raised with a specific error message. + + Args: + description (str): A string containing the description of the node. + + Returns: + str: The extracted operator type. + + Raises: + ValueError: If the description does not contain an operator type in square brackets. + + Notes: + The `get_op_name` function takes a description as input and uses regular expression + to search for the operator type enclosed in square brackets at the beginning of the description. + If a match is found, the operator type is extracted and returned. + Otherwise, a `ValueError` is raised with a specific error message. """ assert type(description) is str, f"Description must be a string, but it is {type(description)}: {description}." match = re.search(r"^\[([^\[\]]+)\]", description) @@ -482,22 +531,21 @@ def get_op_name(description): else: raise ValueError(f"The description '{description}' must contain the operator type in square brackets.") + class NodeVizStyleGuide: - """A class to provide a standardized way to visualize nodes in a graph, particularly for use with graph visualization tools like Graphviz. - - Attributes - ---------- - style: A string that defines the style of the visualization. Default is 'default'. - print_limit: An integer that sets the maximum number of characters to print for node descriptions and content. Default is 100. + """A class to provide a standardized way to visualize nodes in a graph. + + Attributes: + style (str): Defines the style of the visualization. Default is 'default'. + print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100. """ def __init__(self, style='default', print_limit=100): - """Initialize the NodeVizStyleGuide with a specified style and print limit. - - Parameters - ---------- - style: A string defining the style of the visualization. Default is 'default'. - print_limit: An integer setting the maximum number of characters to print for node descriptions and content. Default is 100. + """Initialize the NodeVizStyleGuide. + + Args: + style (str, optional): The style of visualization to use. Defaults to 'default'. + print_limit (int, optional): Maximum characters to print for descriptions and content. Defaults to 100. """ self.style = style self.print_limit = print_limit @@ -505,17 +553,19 @@ def __init__(self, style='default', print_limit=100): def get_attrs(self, x): """Get the attributes for a node based on the style guide. - Parameters - ---------- - x: The node for which attributes are to be generated. + Args: + x: The node for which attributes are to be generated. + + Returns: + dict: Dictionary of visualization attributes for the node. - Code Description - ---------- - The `get_attrs` method takes a node `x` as input and returns a dictionary of attributes for the node. - The attributes include the label, shape, fill color, and style of the node, which are determined based on the node's properties and the style guide. - The method calls other helper methods to construct the label, determine the node shape, assign a color, and set the style. + Notes: + The attributes include the label, shape, fill color, and style of the node, + which are determined based on the node's properties and the style guide. + The method calls other helper methods to construct the label, determine the node shape, + assign a color, and set the style. """ - attrs= { + attrs = { 'label': self.get_label(x), 'shape': self.get_node_shape(x), 'fillcolor': self.get_color(x), @@ -524,17 +574,19 @@ def get_attrs(self, x): return attrs def get_label(self, x): - """Construct a label for a node based on its name, description, and content. + """Construct a label for a node. - Parameters - ---------- - x: The node for which the label is to be constructed. + Args: + x: The node for which the label is to be constructed. - Note - ---------- - Using a colon in the name can cause problems in graph visualization tools like Graphviz. - To avoid issues, the label is constructed by combining the node's Python name, truncated description, and content. - If the description or content exceeds the print limit, it is truncated and appended with an ellipsis. + Returns: + str: The constructed label string. + + Notes: + Using a colon in the name can cause problems in graph visualization tools like Graphviz. + To avoid issues, the label is constructed by combining the node's Python name, + truncated description, and content. + If the description or content exceeds the print limit, it is truncated and appended with an ellipsis. """ # using colon in the name causes problems in graphviz description = x.description @@ -552,15 +604,18 @@ def get_label(self, x): return text + content def get_node_shape(self, x): - """Determine the shape of a node based on its type. + """Determine the shape of a node. + + Args: + x: The node for which the shape is to be determined. - Parameters - ---------- - x: The node for which the shape is to be determined. + Returns: + str: The shape to use for the node. - Note - ---------- - The shape of a node is determined based on its type. ParameterNode types are represented as 'box', while other types are represented as 'ellipse'. + Notes: + The shape of a node is determined based on its type. + ParameterNode types are represented as 'box', + while other types are represented as 'ellipse'. """ if type(x) == ParameterNode: return 'box' @@ -568,15 +623,18 @@ def get_node_shape(self, x): return "ellipse" def get_color(self, x): - """Assign a color to a node based on its type. + """Assign a color to a node. + + Args: + x: The node for which the color is to be assigned. - Parameters - ---------- - x: The node for which the color is to be assigned. + Returns: + str: The color to use for the node. - Note - ---------- - The color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'lightgray'. + Notes: + The color of a node is determined based on its type. + ExceptionNode types are colored 'firebrick1', + and ParameterNode types are colored 'lightgray'. """ if type(x) == ExceptionNode: return 'firebrick1' @@ -586,52 +644,55 @@ def get_color(self, x): return "" def get_style(self, x): - """Set the style of a node based on its properties. + """Set the style of a node. - Parameters - ---------- - x: The node for which the style is to be set. + Args: + x: The node for which the style is to be set. - Note - ---------- - The style of a node is set to 'filled,solid' if the node is trainable; otherwise, it returns an empty string. + Returns: + str: The style string for the node. + + Notes: + The style of a node is set to 'filled,solid' if the node is trainable; + otherwise, it returns an empty string. """ return 'filled,solid' if x.trainable else "" + class NodeVizStyleGuideColorful(NodeVizStyleGuide): """A class to provide a colorful style guide for visualizing nodes in a graph. - Attributes - ---------- - style: A string defining the style of the visualization. Default is 'default'. - print_limit: An integer setting the maximum number of characters to print for node descriptions and content. Default is 100. + Attributes: + style (str): Defines the style of the visualization. Default is 'default'. + print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100. """ def __init__(self, style='default', print_limit=100): - """Initialize the NodeVizStyleGuideColorful with a specified style and print limit. + """Initialize the NodeVizStyleGuideColorful. - Parameters - ---------- - style: A string defining the style of the visualization. Default is 'default'. - print_limit: An integer setting the maximum number of characters to print for node descriptions and content. Default is 100. + Args: + style (str, optional): The style of visualization to use. Defaults to 'default'. + print_limit (int, optional): Maximum characters to print for descriptions and content. Defaults to 100. """ self.style = style self.print_limit = print_limit def get_attrs(self, x): """Get the attributes for a node based on the colorful style guide. - - Parameters - ---------- - x: The node for which attributes are to be generated. - - Code Description - ---------- - The `get_attrs` method takes a node `x` as input and returns a dictionary of attributes for the node. - The attributes include the label, shape, fill color, style, border color, and border width of the node, which are determined based on the node's properties and the style guide. - The method calls other helper methods to construct the label, determine the node shape, assign a color, and set the style. - """ - attrs= { + + Args: + x: The node for which attributes are to be generated. + + Returns: + dict: Dictionary of visualization attributes for the node. + + Notes: + The attributes include the label, shape, fill color, style, border color, + and border width of the node, which are determined based on the node's properties + and the style guide. The method calls other helper methods to construct the label, + determine the node shape, assign a color, and set the style. + """ + attrs = { 'label': self.get_label(x), 'shape': self.get_node_shape(x), 'fillcolor': self.get_color(x), @@ -642,15 +703,18 @@ def get_attrs(self, x): return attrs def get_border_color(self, x): - """Assign a border color to a node based on its type. + """Assign a border color to a node. + + Args: + x: The node for which the border color is to be assigned. - Parameters - ---------- - x: The node for which the border color is to be assigned. + Returns: + str: The border color to use for the node. - Note - ---------- - The border color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'black'. + Notes: + The border color of a node is determined based on its type. + ExceptionNode types are colored 'firebrick1', + and ParameterNode types are colored 'black'. """ if type(x) == ExceptionNode: return 'black' @@ -658,17 +722,20 @@ def get_border_color(self, x): return '#FF7E79' return "#5C9BD5" - + def get_color(self, x): - """Assign a fill color to a node based on its type. + """Assign a fill color to a node. + + Args: + x: The node for which the fill color is to be assigned. - Parameters - ---------- - x: The node for which the fill color is to be assigned. + Returns: + str: The fill color to use for the node. - Note - ---------- - The fill color of a node is determined based on its type. ExceptionNode types are colored 'firebrick1', and ParameterNode types are colored 'lightgray'. + Notes: + The fill color of a node is determined based on its type. + ExceptionNode types are colored 'firebrick1', + and ParameterNode types are colored 'lightgray'. """ if type(x) == ExceptionNode: return 'firebrick1' @@ -676,60 +743,69 @@ def get_color(self, x): return '#FFE5E5' return "#DEEBF6" - + def get_style(self, x): - """Set the style of a node always as if it is trainable.""" + """Set the style of a node always as if it is trainable. + + Args: + x: The node for which the style is to be set. + + Returns: + str: The style string 'filled,solid'. + """ return 'filled,solid' class Node(AbstractNode[T]): """A data node in a directed graph, this is a basic data structure of Trace. - - Attributes - ---------- - trainable: A boolean indicating whether the node is trainable or not. - _feedback: A dictionary of feedback from children nodes. - _description: A string describing the node. - _constraint: A string describing all constraints that the data in the node should satisfy. - _backwarded: A boolean indicating whether the backward method has been called. - _info: A dictionary containing additional information about the node. - _dependencies: A dictionary of dependencies on parameters and expandable nodes. - - Code Description - ---------- - The `Node` class extends the `AbstractNode` class to represent a data node in a directed graph. - It includes additional attributes and methods to handle feedback, constraints, and dependencies. - The node can be marked as trainable, and it can store feedback from children nodes. - The node has a description and additional information associated with it. - The node can also track dependencies on parameters and expandable nodes, which are nodes that depend on parameters not visible in the current graph level. - - Note - ---------- - The `Node` class is meant to be subclassed and extended to create specific types of nodes. - The feedback mechanism is analogous to gradients in machine learning and is used to propagate information back through the graph. - The feedback mechanism is designed to support non-commutative aggregation, so feedback should be handled carefully to maintain the correct order of operations. + + Args: + value (Any): The value to be assigned to the node. + name (str, optional): The name of the node. + trainable (bool, optional): Whether the node is trainable or not. Defaults to False. + description (str, optional): String describing the node. Defaults to "[Node] This is a node in a computational graph." + constraint (Union[None, str], optional): String describing constraints that the data should satisfy. Defaults to None. + info (Union[None, Dict], optional): Dictionary containing additional information about the node. Defaults to None. + + Attributes: + trainable (bool): Whether the node is trainable or not. + _feedback (dict): Dictionary of feedback from children nodes. + _description (str): String describing the node. + _constraint (str): String describing all constraints that the data should satisfy. + _backwarded (bool): Whether the backward method has been called. + _info (dict): Dictionary containing additional information about the node. + _dependencies (dict): Dictionary of dependencies on parameters and expandable nodes. + + Notes: + The Node class extends AbstractNode to represent a data node in a directed graph. + It includes attributes and methods to handle feedback, constraints, and dependencies. + The node can be marked as trainable and store feedback from children nodes. + The feedback mechanism is analogous to gradients in machine learning and propagates + information back through the graph. The feedback mechanism supports non-commutative + aggregation, so feedback should be handled carefully to maintain correct operation order. + The node can track dependencies on parameters and expandable nodes (nodes that depend + on parameters not visible in the current graph level). """ def __init__( - self, - value: Any, - *, - name: str = None, - trainable: bool = False, - description: str = "[Node] This is a node in a computational graph.", - constraint: Union[None, str] = None, - info: Union[None, Dict] = None, + self, + value: Any, + *, + name: str = None, + trainable: bool = False, + description: str = "[Node] This is a node in a computational graph.", + constraint: Union[None, str] = None, + info: Union[None, Dict] = None, ) -> None: """Initialize an instance of the Node class. - Parameters - ---------- - value: The value to be assigned to the node. - name: The name of the node (optional). - trainable: A boolean indicating whether the node is trainable or not (optional). - description: A string describing the node (optional). - constraint: A string describing constraints on the node (optional). - info: A dictionary containing additional information about the node (optional). + Args: + value: The value to be assigned to the node. + name: The name of the node (optional). + trainable: A boolean indicating whether the node is trainable or not (optional). + description: A string describing the node (optional). + constraint: A string describing constraints on the node (optional). + info: A dictionary containing additional information about the node (optional). """ if description == "" or description is None: @@ -751,12 +827,15 @@ def __init__( self._constraint = constraint # A constraint on the node self._backwarded = False # True if backward has been called self._info = info # Additional information about the node - self._dependencies = {'parameter': set(), 'expandable': set()} # A dictionary of dependencies on parameters and expandable nodes; expandable nodes are those who depened on parameters not visible in the current graph level. + self._dependencies = {'parameter': set(), + 'expandable': set()} # A dictionary of dependencies on parameters and expandable nodes; expandable nodes are those who depened on parameters not visible in the current graph level. def zero_feedback(self): # set feedback to zero """Zero out the feedback of the node. - zero_feedback should be used judiciously within the feedback propagation process to avoid unintended loss of feedback data. - It is specifically designed to be used after feedback has been successfully propagated to parent nodes. + + Notes: + zero_feedback should be used judiciously within the feedback propagation process to avoid unintended loss of feedback data. + It is specifically designed to be used after feedback has been successfully propagated to parent nodes. """ self._feedback = defaultdict(list) @@ -782,103 +861,99 @@ def type(self): @property def parameter_dependencies(self): - """ The depended parameters. + """The depended parameters. - Note - ---------- - Ensure that the '_dependencies' attribute is properly initialized and contains a 'parameter' key with a corresponding value before calling the parameter_dependencies function to avoid potential KeyError exceptions. + Notes: + Ensure that the '_dependencies' attribute is properly initialized and contains a 'parameter' key + with a corresponding value before calling the parameter_dependencies function to avoid potential + KeyError exceptions. """ return self._dependencies['parameter'] @property def expandable_dependencies(self): - """ The depended expandable nodes, where expandable nodes are those who depend on parameters not visible in the current graph level. - - Note - ---------- - Ensure that the '_dependencies' attribute is properly initialized and contains an 'expandable' key with a corresponding value before calling the expandable_dependencies function to avoid potential KeyError exceptions + """The depended expandable nodes. + + Notes: + Expandable nodes are those who depend on parameters not visible in the current graph level. + Ensure that the '_dependencies' attribute is properly initialized and contains an 'expandable' key + with a corresponding value before calling the expandable_dependencies function to avoid potential + KeyError exceptions. """ return self._dependencies['expandable'] def _add_feedback(self, child, feedback): """Add feedback from a child. - - Parameters - ---------- - child: The child node from which feedback is received. - feedback: The feedback received from the child node. + + Args: + child: The child node from which feedback is received. + feedback: The feedback received from the child node. """ self._feedback[child].append(feedback) # This is not traced def _set(self, value: Any): """Set the value of the node. If value is Node, it will be unwrapped. - - Parameters - ---------- - value: The value to be assigned to the node. - - Note - ---------- - The `_set` method sets the `_data` attribute of the node to the provided `value`. - If the `value` is an instance of the `Node` class, the `_data` attribute of the current node is set to the `_data` attribute of the `value` parameter. - Otherwise, the `_data` attribute is set to the `value` parameter itself. - When `_data` is set using `_set`, that usage is not traced. + + Args: + value: The value to be assigned to the node. + + Notes: + The `_set` method sets the `_data` attribute of the node to the provided `value`. + If the `value` is an instance of the `Node` class, the `_data` attribute of the current node + is set to the `_data` attribute of the `value` parameter. + Otherwise, the `_data` attribute is set to the `value` parameter itself. + When `_data` is set using `_set`, that usage is not traced. """ if isinstance(value, Node): value = value.data self._data = value def _itemize(self): # for priority queue - """Return a tuple containing the node's level; useful for maintaining priority queues of nodes in a DAG.""" + """Return a tuple containing the node's level. + + Returns: + tuple: A tuple containing (-level, id, self) used for priority queue ordering. + """ return (-self.level, id(self), self) def backward( - self, - feedback: Any = "", - propagator=None, - retain_graph=False, - visualize=False, - simple_visualization=True, - reverse_plot=False, - print_limit=100, + self, + feedback: Any = "", + propagator=None, + retain_graph=False, + visualize=False, + simple_visualization=True, + reverse_plot=False, + print_limit=100, ): - """Performs a backward pass in a computational graph. - This function propagates feedback from the current node to its parents, updates the graph visualization if required, and returns the resulting graph. - - Parameters - ---------- - feedback: The feedback given to the current node. - propagator: A function that takes in a node and a feedback, and returns a dict of {parent: parent_feedback}. If not provided, a default `GraphPropagator` object is used. - retain_graph: If True, the graph will be retained after backward pass. - visualize: If True, the graph will be visualized using graphviz. - simple_visualization: If True, identity operators will be skipped in the visualization; otherwise, they will be included. - reverse_plot: if True, plot the graph in reverse order (from child to parent). - print_limit: The maximum number of characters to print for node descriptions and content. - - Code Description - ---------- - The function checks if the current node has already been backwarded. If it has, an `AttributeError` is raised. - Otherwise, the function adds the feedback to the current node by calling the `_add_feedback` method of the node object. - The feedback is initialized with a special "FEEDBACK_ORACLE" node and the propagated feedback from the `propagator` object. - - If the current node has no parents, indicating that it is a root node, the function checks if visualization is enabled. - If it is, the current node is added to the `digraph` object with the appropriate style attributes. Finally, the function returns the `digraph` object. - - If the current node has parents, indicating that it is not a root node, the function initializes a priority queue. - The priority queue is used to process the nodes in the correct order during the backward pass. - The function enters a loop that continues until the `queue` is empty. - In each iteration, a node is popped from the `queue` and processed. - The node is checked to ensure it has parents and is an instance of the `MessageNode` class. If not, an `AttributeError` is raised. - - The function propagates information from the current node to its parents by calling the `propagator` object with the current node as the argument. - The `propagator` object computes the propagated feedback based on the child node's description, data, and feedback. - The propagated feedback is then added to the parents of the current node by calling the `_add_feedback` method of each parent node. - - After processing the parents of the current node, the `_backwarded` attribute of the current node is updated to indicate that it has been backwarded. - This attribute is set to `True` unless the `retain_graph` parameter is set to `True`. - - The loop continues until the `queue` is empty, indicating that all the nodes have been processed. Finally, the function returns the `digraph` object. + """Performs a backward pass in a computational graph. + + This function propagates feedback from the current node to its parents, updates the graph + visualization if required, and returns the resulting graph. + + Args: + feedback: The feedback given to the current node. + propagator: A function that takes in a node and a feedback, and returns a dict of {parent: parent_feedback}. + If not provided, a default `GraphPropagator` object is used. + retain_graph: If True, the graph will be retained after backward pass. + visualize: If True, the graph will be visualized using graphviz. + simple_visualization: If True, identity operators will be skipped in the visualization. + reverse_plot: if True, plot the graph in reverse order (from child to parent). + print_limit: The maximum number of characters to print for node descriptions and content. + + Returns: + digraph: The visualization graph object if visualize=True, None otherwise. + + Raises: + AttributeError: If the node has already been backwarded. + + Notes: + The function checks if the current node has already been backwarded. If it has, an AttributeError is raised. + For root nodes (no parents), only visualization is performed if enabled. + For non-root nodes, feedback is propagated through the graph using a priority queue to ensure correct ordering. + The propagator computes feedback for parent nodes based on the current node's description, data and feedback. + Visualization is handled using graphviz if enabled, with options to simplify the graph by skipping identity operators. """ if propagator is None: from opto.trace.propagators.graph_propagator import GraphPropagator # this avoids circular import @@ -957,22 +1032,32 @@ def backward( return digraph def clone(self): - """Create and return a duplicate of the current Node object.""" + """Create and return a duplicate of the current Node object. + + Returns: + Node: A clone of the current node. + """ import opto.trace.operators as ops return ops.clone(self) def detach(self): - """Create and return a deep copy of the current instance of the Node class.""" + """Create and return a deep copy of the current instance of the Node class. + + Returns: + Node: A deep copy of the current node. + """ return copy.deepcopy(self) # Get attribute and call operators def getattr(self, key): """Get the attribute of the node with the specified key. - - Parameters - ---------- - key: The key of the attribute to get. + + Args: + key: The key of the attribute to get. + + Returns: + Node: A node containing the requested attribute. """ import opto.trace.operators as ops @@ -981,11 +1066,13 @@ def getattr(self, key): def call(self, fun: str, *args, **kwargs): """Call the function with the specified arguments and keyword arguments. - Parameters - ---------- - fun: The function to call. - args: The arguments to pass to the function. - kwargs: The keyword arguments to pass to the function. + Args: + fun: The function to call. + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + + Returns: + Node: The result of the function call wrapped in a node. """ args = (node(arg) for arg in args) # convert args to nodes kwargs = {k: node(v) for k, v in kwargs.items()} @@ -994,14 +1081,16 @@ def call(self, fun: str, *args, **kwargs): def __call__(self, *args, **kwargs): """Call the function with the specified arguments and keyword arguments. - Parameters - ---------- - args: The arguments to pass to the function. - kwargs: The keyword arguments to pass to the function. + Args: + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + + Returns: + Node: The result of the function call wrapped in a node. - Note - ---------- - By using the `__call__` method, the Node object can be used as if it were a regular callable function, providing a seamless interface for function invocation. + Notes: + By using the __call__ method, the Node object can be used as if it were a regular + callable function, providing a seamless interface for function invocation. """ import opto.trace.operators as ops @@ -1012,10 +1101,12 @@ def __call__(self, *args, **kwargs): # container magic methods def len(self): """Return the length of the node. - - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + + Returns: + Node: A node containing the length value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1024,13 +1115,14 @@ def len(self): def __getitem__(self, key): """Get the item at the specified key. - Parameters - ---------- - key: The key of the item to get. + Args: + key: The key of the item to get. + + Returns: + Node: A node containing the requested item. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1039,13 +1131,14 @@ def __getitem__(self, key): def __contains__(self, item): """Check if the item is contained in the node. - Parameters - ---------- - item: The item to check for containment. + Args: + item: The item to check for containment. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the boolean result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1054,10 +1147,12 @@ def __contains__(self, item): # Unary operators and functions def __pos__(self): """Return the positive value of the node. - - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + + Returns: + Node: A node containing the positive value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1066,9 +1161,11 @@ def __pos__(self): def __neg__(self): """Return the negative value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the negative value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1077,9 +1174,11 @@ def __neg__(self): def __abs__(self): """Return the absolute value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the absolute value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1088,9 +1187,11 @@ def __abs__(self): def __invert__(self): """Return the inverted value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the inverted value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1099,13 +1200,14 @@ def __invert__(self): def __round__(self, n=None): """Return the rounded value of the node. - Parameters - ---------- - n: The number of decimal places to round to (optional). + Args: + n: The number of decimal places to round to (optional). + + Returns: + Node: A node containing the rounded value. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1114,9 +1216,11 @@ def __round__(self, n=None): def __floor__(self): """Return the floor value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the floor value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1125,9 +1229,11 @@ def __floor__(self): def __ceil__(self): """Return the ceiling value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the ceiling value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1136,9 +1242,11 @@ def __ceil__(self): def __trunc__(self): """Return the truncated value of the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the truncated value. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1148,13 +1256,14 @@ def __trunc__(self): def __add__(self, other): """Return the sum of the node and another value. - Parameters - ---------- - other: The value to add to the node. + Args: + other: The value to add to the node. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the sum. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1166,26 +1275,28 @@ def __add__(self, other): def __radd__(self, other): """Return the sum of another value and the node. - Parameters - ---------- - other: The value to add to the node. + Args: + other: The value to add to the node. + + Returns: + Node: A node containing the sum. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) + self def __sub__(self, other): """Return the difference between the node and another value. - Parameters - ---------- - other: The value to subtract from the node. + Args: + other: The value to subtract from the node. + + Returns: + Node: A node containing the difference. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1194,26 +1305,28 @@ def __sub__(self, other): def __rsub__(self, other): """Return the difference between another value and the node. - Parameters - ---------- - other: The value to subtract the node from. + Args: + other: The value to subtract the node from. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the difference. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) - self def __mul__(self, other): """Return the product of the node and another value. - Parameters - ---------- - other: The value to multiply the node by. + Args: + other: The value to multiply the node by. + + Returns: + Node: A node containing the product. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1222,26 +1335,28 @@ def __mul__(self, other): def __rmul__(self, other): """Return the product of another value and the node. - Parameters - ---------- - other: The value to multiply the node by. + Args: + other: The value to multiply the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the product. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return self * node(other) def __floordiv__(self, other): """Return the floor division of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide the node by. + + Returns: + Node: A node containing the floor division result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1250,26 +1365,28 @@ def __floordiv__(self, other): def __rfloordiv__(self, other): """Return the floor division of another value by the node. - Parameters - ---------- - other: The value to divide by the node. + Args: + other: The value to divide by the node. + + Returns: + Node: A node containing the floor division result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) // self def __truediv__(self, other): """Return the true division of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the division result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1278,26 +1395,28 @@ def __truediv__(self, other): def __rtruediv__(self, other): """Return the true division of another value by the node. - Parameters - ---------- - other: The value to divide by the node. + Args: + other: The value to divide by the node. + + Returns: + Node: A node containing the division result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) / self def __div__(self, other): """Return the division of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the division result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1305,27 +1424,29 @@ def __div__(self, other): def __rdiv__(self, other): """Return the division of another value by the node. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The value to divide by the node. + + Returns: + Node: A node containing the division result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) / self def __mod__(self, other): """Return the modulo of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the modulo result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1334,26 +1455,28 @@ def __mod__(self, other): def __rmod__(self, other): """Return the modulo of another value by the node. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide by the node. + + Returns: + Node: A node containing the modulo result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ - return node(other) % self + return node(other) % self def __divmod__(self, other): """Return the division and modulo of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the division and modulo results. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1361,27 +1484,29 @@ def __divmod__(self, other): def __rdivmod__(self, other): """Return the division and modulo of another value by the node. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The value to divide by the node. + + Returns: + Node: A node containing the division and modulo results. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return divmod(node(other), self) def __pow__(self, other): """Return the power of the node raised to another value. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The exponent value. + + Returns: + Node: A node containing the power result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1389,27 +1514,29 @@ def __pow__(self, other): def __rpow__(self, other): """Return the power of another value raised to the node. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The base value. + + Returns: + Node: A node containing the power result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) ** self def __lshift__(self, other): """Return the left shift of the node by another value. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The number of positions to shift. + + Returns: + Node: A node containing the left shift result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1418,26 +1545,28 @@ def __lshift__(self, other): def __rlshift__(self, other): """Return the left shift of another value by the node. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to shift. + + Returns: + Node: A node containing the left shift result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) << self def __rshift__(self, other): """Return the right shift of the node by another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The number of positions to shift. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the right shift result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1446,26 +1575,28 @@ def __rshift__(self, other): def __rrshift__(self, other): """Return the right shift of another value by the node. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to shift. + + Returns: + Node: A node containing the right shift result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) >> self def __and__(self, other): """Return the bitwise AND of the node and another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to AND with. + + Returns: + Node: A node containing the AND result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1473,27 +1604,29 @@ def __and__(self, other): def __rand__(self, other): """Return the bitwise AND of another value and the node. - - Parameters - ---------- - other: The value to divide the node by. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Args: + other: The value to AND with. + + Returns: + Node: A node containing the AND result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) & self def __or__(self, other): """Return the bitwise OR of the node and another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to OR with. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the OR result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1502,26 +1635,28 @@ def __or__(self, other): def __ror__(self, other): """Return the bitwise OR of another value and the node. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to OR with. + + Returns: + Node: A node containing the OR result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) | self def __xor__(self, other): """Return the bitwise XOR of the node and another value. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to XOR with. + + Returns: + Node: A node containing the XOR result. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ import opto.trace.operators as ops @@ -1530,29 +1665,26 @@ def __xor__(self, other): def __rxor__(self, other): """Return the bitwise XOR of another value and the node. - Parameters - ---------- - other: The value to divide the node by. + Args: + other: The value to XOR with. - Note - ---------- - We overload magic methods that return a value. This method returns a MessageNode. + Returns: + Node: A node containing the XOR result. + + Notes: + We overload magic methods that return a value. This method returns a MessageNode. """ return node(other) ^ self def __iter__(self): """Return an iterator for the node. - Code Description - ---------- - The __iter__ method is designed to make the Node object iterable. It does this by determining the appropriate iterable class to use based on the type of the Node object's data attribute. - It handles various types of collections such as lists, tuples, sets, and dictionaries, and returns an iterable object accordingly. - This ensures that the Node object can be iterated over seamlessly, regardless of the type of its data attribute. + Returns: + iterator: An iterator over the node's data. - Note - ---------- - The Node object must have a data attribute that is a list, tuple, set, or dictionary. - The iterate function called by __iter__ handles the conversion of sets to lists and wraps items in lists or dictionaries with node objects. + Notes: + The Node object must have a data attribute that is a list, tuple, set, or dictionary. + The iterate function called by __iter__ handles the conversion of sets to lists and wraps items in lists or dictionaries with node objects. """ import opto.trace.iterators as it @@ -1561,11 +1693,13 @@ def __iter__(self): def __len__(self): """Return the length of the node. - Note - ---------- - __len__ restricts return type to be integer - Therefore, this method only returns integer - If a Node/MessageNode is desired to be returned, call node.len() instead + Returns: + int: The length of the node's data. + + Notes: + __len__ restricts return type to be integer. + Therefore, this method only returns integer. + If a Node/MessageNode is desired to be returned, call node.len() instead. """ return len(self._data) @@ -1576,14 +1710,15 @@ def __len__(self): def __lt__(self, other): """Check if the node is less than another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Returns: + Node: A node containing the comparison result. + + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops @@ -1595,14 +1730,15 @@ def __lt__(self, other): def __le__(self, other): """Check if the node is less than or equal to another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. + + Returns: + Node: A node containing the comparison result. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops @@ -1614,14 +1750,15 @@ def __le__(self, other): def __gt__(self, other): """Check if the node is greater than another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. + + Returns: + Node: A node containing the comparison result. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops @@ -1633,14 +1770,15 @@ def __gt__(self, other): def __ge__(self, other): """Check if the node is greater than or equal to another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Returns: + Node: A node containing the comparison result. + + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops @@ -1654,13 +1792,14 @@ def __ge__(self, other): def __eq__(self, other): """Check if the node is equal to another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. + + Returns: + bool: True if the values are equal, False otherwise. - Note - ---------- - __eq__ restricts return type to be bool; otherwise, it will create issues (for example, the "in" operator will not work). + Notes: + __eq__ restricts return type to be bool; otherwise, it will create issues (for example, the "in" operator will not work). """ # import opto.trace.operators as ops # return ops.eq(self, node(other)) @@ -1671,14 +1810,15 @@ def __eq__(self, other): def eq(self, other): """Check if the node is equal to another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Returns: + Node: A node containing the comparison result. + + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops return ops.eq(self, node(other)) @@ -1686,28 +1826,35 @@ def eq(self, other): def neq(self, other): """Check if the node is not equal to another value. - Parameters - ---------- - other: The value to compare the node to. + Args: + other: The value to compare the node to. + + Returns: + Node: A node containing the comparison result. - Note - ---------- - If a logic operator is used in an if-statement, it will return a boolean value. - Otherwise, it will return a MessageNode. + Notes: + If a logic operator is used in an if-statement, it will return a boolean value. + Otherwise, it will return a MessageNode. """ import opto.trace.operators as ops return ops.neq(self, node(other)) def __hash__(self): - """Return the hash value of the node.""" + """Return the hash value of the node. + + Returns: + int: The hash value of the node. + """ return super().__hash__() def __bool__(self): """Return the boolean value of the node. - Note - ---------- - The access to the `_data` attribute happening in this method is not traced. + Returns: + bool: The boolean value of the node's data. + + Notes: + The access to the `_data` attribute happening in this method is not traced. """ # not tracing this conversion return bool(self._data) @@ -1824,14 +1971,14 @@ def append(self, *args, **kwargs): class ParameterNode(Node[T]): # This is a shorthand of a trainable Node. def __init__( - self, - value, - *, - name=None, - trainable=True, - description="[ParameterNode] This is a ParameterNode in a computational graph.", - constraint=None, - info=None, + self, + value, + *, + name=None, + trainable=True, + description="[ParameterNode] This is a ParameterNode in a computational graph.", + constraint=None, + info=None, ) -> None: if description is None or description == "": description = "[ParameterNode] This is a ParameterNode in a computational graph." @@ -1851,28 +1998,41 @@ def __str__(self) -> str: class MessageNode(Node[T]): - """Output of an operator. - - description: a string to describe the operator it begins with - [operator_name] and then describes the operator. When referring to - inputs use the keys in args (if args is a dict), or the names of the - nodes in args (if args is a list). Here're some examples: - - MessageNode(node_a, inputs=[node_a], description="[identity] This is an identity operator.") - MessageNode(copy_node_a, inputs=[node_a], description="[copy] This is a copy operator.") - MesssageNode(1, inputs={'a':node_a, 'b':node_b}, description="[Add] This is an add operator of a and b.") + """A node representing the output of an operator. + + The description string should begin with [operator_name] followed by details about the operator. + When referring to inputs in the description, use either: + - The keys in args (if args is a dict) + - The names of the nodes in args (if args is a list) + + Examples: + >>> MessageNode(node_a, inputs=[node_a], + >>> description="[identity] This is an identity operator.") + >>> MessageNode(copy_node_a, inputs=[node_a], + >>> description="[copy] This is a copy operator.") + >>> MessageNode(1, inputs={'a':node_a, 'b':node_b}, + >>> description="[Add] This is an add operator of a and b.") + + Attributes: + value: The output value of the operator + inputs (Union[List[Node], Dict[str, Node]]): Input nodes to the operator + description (str): Description string starting with [operator_name] + constraint: Optional constraints on the output + name (str, optional): Name of the node + info (optional): Additional operator information """ + # TODO document what needs to go into info def __init__( - self, - value, - *, - inputs: Union[List[Node], Dict[str, Node]], # extra - description: str, - constraint=None, - name=None, - info=None, + self, + value, + *, + inputs: Union[List[Node], Dict[str, Node]], # extra + description: str, + constraint=None, + name=None, + info=None, ) -> None: super().__init__(value, name=name, description=description, constraint=constraint, info=info) @@ -1892,10 +2052,9 @@ def __init__( self._add_parent(v) self._add_dependencies(v) # Initializes the dependencies on parameter and expandable nodes - if len(self.hidden_dependencies)>0: + if len(self.hidden_dependencies) > 0: self._dependencies['expandable'].add(self) - @property def inputs(self): return copy.copy(self._inputs) @@ -1922,9 +2081,10 @@ def hidden_dependencies(self): # this needs to be recursive output = self.info['output'] if isinstance(self.info, dict) and \ - isinstance(output, Node) and all(isinstance(i, Node) for i in inputs): # traceable code + isinstance(output, Node) and all(isinstance(i, Node) for i in inputs): # traceable code # The inner function is traceable. - diff = diff | (output.parameter_dependencies - self.parameter_dependencies) # add extra parameters explicitly used in the inner function + diff = diff | ( + output.parameter_dependencies - self.parameter_dependencies) # add extra parameters explicitly used in the inner function extra_expandable = output.expandable_dependencies - self.expandable_dependencies for n in extra_expandable: # add extra hidden dependencies diff = diff | n.hidden_dependencies @@ -1941,14 +2101,14 @@ class ExceptionNode(MessageNode[T]): """Node containing the exception message.""" def __init__( - self, - value: Exception, - *, - inputs: Union[List[Node], Dict[str, Node]], - description: str = "[ExceptionNode] This is node containing the error of execution.", - constraint=None, - name=None, - info=None, + self, + value: Exception, + *, + inputs: Union[List[Node], Dict[str, Node]], + description: str = "[ExceptionNode] This is node containing the error of execution.", + constraint=None, + name=None, + info=None, ) -> None: e = value error_type = re.search(r"", str(type(e))).group(1) @@ -1960,7 +2120,7 @@ def create_feedback(self, style='simple'): assert style in ('simple', 'full') feedback = self._data if style in ('line', 'full'): - if type(self.info)==dict and self.info.get('error_comment') is not None: + if type(self.info) == dict and self.info.get('error_comment') is not None: feedback = self.info['error_comment'] return feedback diff --git a/opto/utils/__init__.py b/opto/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opto/utils/llm.py b/opto/utils/llm.py new file mode 100644 index 00000000..e181f6be --- /dev/null +++ b/opto/utils/llm.py @@ -0,0 +1,103 @@ +from typing import List, Tuple, Dict, Any, Callable, Union +import time +import autogen # We import autogen here to avoid the need of installing autogen + +class AbstractModel: + """ + A minimal abstraction of a model api that refreshes the model every + reset_freq seconds (this is useful for long-running models that may require + refreshing certificates or memory management). + """ + def __init__(self, factory: Callable, reset_freq: Union[int, None] = None) -> None: + """ + Args: + factory: A function that takes no arguments and returns a model that is callable. + reset_freq: The number of seconds after which the model should be + refreshed. If None, the model is never refreshed. + """ + self.factory = factory + self._model = self.factory() + self.reset_freq = reset_freq + self._init_time = time.time() + + @property + def model(self): + # Overwrite this when subclassing + return self._model + + # This is the main API + def __call__(self, *args, **kwargs) -> Any: + """ The call function handles refreshing the model if needed. """ + if self.reset_freq is not None and time.time() - self._init_time > self.reset_freq: + self._model = self.factory() + self._init_time = time.time() + return self.model(*args, **kwargs) + + def __getstate__(self): + state = self.__dict__.copy() + state['_model'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._model = self.factory() + +class AutoGenLLM(AbstractModel): + """ This is the main class Trace uses to interact with the model. It is a wrapper around autogen's OpenAIWrapper. For using models not supported by autogen, subclass AutoGenLLM and override the `_factory` and `create` method. """ + + def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_freq: Union[int, None] = None) -> None: + if config_list is None: + config_list = autogen.config_list_from_json("OAI_CONFIG_LIST") + if filter_dict is not None: + config_list = autogen.filter_config_list(config_list, filter_dict) + + factory = lambda *args, **kwargs : self._factory(config_list) + super().__init__(factory, reset_freq) + + @classmethod + def _factory(cls, config_list): + return autogen.OpenAIWrapper(config_list=config_list) + + @property + def model(self): + return lambda *args, **kwargs : self.create(*args, **kwargs) + + # This is main API. We use the API of autogen's OpenAIWrapper + def create(self, **config: Any) -> autogen.ModelClient.ModelClientResponseProtocol: + """Make a completion for a given config using available clients. + Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. + The config in each client will be overridden by the config. + + Args: + - context (Dict | None): The context to instantiate the prompt or messages. Default to None. + It needs to contain keys that are used by the prompt template or the filter function. + E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. + The actual prompt will be: + "Complete the following sentence: Today I feel". + More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating). + - cache (AbstractCache | None): A Cache object to use for response cache. Default to None. + Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided, + then the cache_seed argument is ignored. If this argument is not provided or None, + then the cache_seed argument is used. + - agent (AbstractAgent | None): The object responsible for creating a completion if an agent. + - (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41. + An integer cache_seed is useful when implementing "controlled randomness" for the completion. + None for no caching. + Note: this is a legacy argument. It is only used when the cache argument is not provided. + - filter_func (Callable | None): A function that takes in the context and the response + and returns a boolean to indicate whether the response is valid. E.g., + + ```python + def yes_or_no_filter(context, response): + return context.get("yes_or_no_choice", False) is False or any( + text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response) + ) + ``` + + - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. + - api_version (str | None): The api version. Default to None. E.g., "2024-02-01". + Raises: + - RuntimeError: If all declared custom model clients are not registered + - APIError: If any model client create call raises an APIError + """ + return self._model.create(**config) diff --git a/opto/version.py b/opto/version.py index b9bf8369..acf3be3e 100644 --- a/opto/version.py +++ b/opto/version.py @@ -1 +1 @@ -__version__ = "0.1.2.2" \ No newline at end of file +__version__ = "0.1.3" \ No newline at end of file diff --git a/tests/unit_tests/test_asyncio.py b/tests/unit_tests/test_asyncio.py new file mode 100644 index 00000000..041b2a43 --- /dev/null +++ b/tests/unit_tests/test_asyncio.py @@ -0,0 +1,100 @@ + +import asyncio +import time +from opto import trace + +@trace.bundle() +async def basic(a=0): + await asyncio.sleep(1) + return 'basic' + +async def main(): + # single task + a = trace.node('a') + st = time.time() + x = await basic(a) + ed = time.time() + print("Time taken: ", ed - st) + print(type(x), x) + assert type(x) == trace.nodes.MessageNode + assert x == 'basic' + assert a in x.parents + assert len(x.parents) == 1 + + +asyncio.run(main()) + + +async def main2(): + # multiple tasks + a = trace.node('a') + st = time.time() + x, y, z = await asyncio.gather(basic(a), basic(a), basic(a)) # run in parallel + ed = time.time() + print("Time taken: ", ed - st) + + assert type(x) == trace.nodes.MessageNode + assert x == 'basic' + assert a in x.parents + assert len(x.parents) == 1 + assert type(y) == trace.nodes.MessageNode + assert y == 'basic' + assert a in y.parents + assert len(y.parents) == 1 + assert type(z) == trace.nodes.MessageNode + assert z == 'basic' + assert a in z.parents + assert len(z.parents) == 1 + + +asyncio.run(main2()) + + +@trace.bundle() +async def error(a=0): + raise ValueError('error') + +async def main3(): + # error handling + a = trace.node('a') + st = time.time() + try: + x = await error(a) + except trace.ExecutionError as e: + print(e) + x = e + ed = time.time() + print("Time taken: ", ed - st) + print(type(x), 'developer message:', x) + assert isinstance(x, trace.ExecutionError) + x = x.exception_node + print(type(x), 'optimizer message:', x.data) + assert isinstance(x, trace.nodes.MessageNode) + assert a in x.parents + assert len(x.parents) == 1 + +asyncio.run(main3()) + +async def main4(): + # multiple error handling + a = trace.node('a') + b = trace.node('b') + c = trace.node('c') + st = time.time() + try: + x, y, z = await asyncio.gather(error(a), error(b), error(c)) # run in parallel + except trace.ExecutionError as e: + # print(e) + x = e # This will catch the first error + print(e.exception_node.parents) + ed = time.time() + print("Time taken: ", ed - st) + print(type(x), 'developer message:', x) + assert isinstance(x, trace.ExecutionError) + x = x.exception_node + print(type(x), 'optimizer message:', x.data) + assert isinstance(x, trace.nodes.MessageNode) + assert a in x.parents + assert len(x.parents) == 1 + +asyncio.run(main4()) \ No newline at end of file diff --git a/tests/unit_tests/test_copy.py b/tests/unit_tests/test_copy.py new file mode 100644 index 00000000..d9b7dcc2 --- /dev/null +++ b/tests/unit_tests/test_copy.py @@ -0,0 +1,31 @@ +from opto import trace +from opto.optimizers import OptoPrime +from opto.utils.llm import AutoGenLLM +import copy + + +x = trace.node('x') +copy.deepcopy(x) + + + +@trace.bundle(trainable=True) +def fun(x): + pass + +copy.deepcopy(fun.parameter) + + +x = trace.node('x', trainable=True) +copy.deepcopy(x) + + +try: + optimizer = OptoPrime([x]) + optimizer2 = copy.deepcopy(optimizer) + + llm = AutoGenLLM() + copy.deepcopy(llm) +except FileNotFoundError as e: + print(f'Error: {e}') + print('Omit the test.') \ No newline at end of file diff --git a/tests/unit_tests/test_llm.py b/tests/unit_tests/test_llm.py new file mode 100644 index 00000000..31f33ee3 --- /dev/null +++ b/tests/unit_tests/test_llm.py @@ -0,0 +1,25 @@ +from opto.utils.llm import AutoGenLLM +from opto.optimizers.utils import print_color + +try: + llm = AutoGenLLM() + system_prompt = 'You are a helpful assistant.' + user_prompt = "Hello world." + + + messages = [{"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}] + + output = llm(messages=messages) + # Alternatively, you can use the following code: + # output = llm.create(messages=messages) + + response = output.choices[0].message.content + + + print_color(f'System: {system_prompt}', 'red') + print_color(f'User: {user_prompt}', 'blue') + print_color(f'LLM: {response}', 'green') +except FileNotFoundError as e: + print_color(f'Error: {e}', 'red') + print_color('Omit the test.', 'yellow') \ No newline at end of file