Skip to content

Commit

Permalink
Merge branch 'experimental' of https://github.com/microsoft/Trace int…
Browse files Browse the repository at this point in the history
…o experimental
  • Loading branch information
allenanie committed Nov 20, 2024
2 parents b7174f6 + 58103bb commit 4347bd5
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 91 deletions.
238 changes: 164 additions & 74 deletions opto/trace/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class FunModule(Module):
description (str): a description of the operator; see the MessageNode for syntax.
_process_inputs (bool): if True, the input is extracted from the container of nodes; if False, the inputs are passed directly to the underlying function.
trainable (bool): if True, the block of code is treated as a variable in the optimization
traceable_code (bool): if True, the operator's code is traceable by Trace
catch_execution_error (bool): if True, the operator catches the exception raised during the execution of the operator and return ExecutionError.
allow_external_dependencies (bool): if True, the operator allows external dependencies to be used in the operator. Namely, not all nodes used to create the output are in the inputs. In this case, the extra dependencies are stored in the info dictionary with key 'extra_dependencies'.
overwrite_python_recursion (bool): if True, the operator allows the python recursion behavior of calling the decorated function to be overwritten. When true, applying bundle on a recursive function, would be the same as calling the function directly. When False, the Python's oriignal recursion behavior of decorated functions is preserved.
Expand Down Expand Up @@ -149,7 +150,9 @@ def trainable(self):

@property
def fun(self, *args, **kwargs):
# This is called within trace_nodes context manager.
""" Return a callable function. Return the decorated function if the parameter is None. Otherwise, return the function defined by the parameter. When exception happens during the defining the function with the parameter, raise a trace.ExecutionError. """

# This function should be later called within trace_nodes context manager.
if self.parameter is None:
return self._fun
else:
Expand Down Expand Up @@ -198,17 +201,21 @@ def fun(self, *args, **kwargs):
def name(self):
return get_op_name(self.description)

def forward(self, *args, **kwargs):
"""
All nodes used in the operator fun are added to used_nodes during
the execution. If the output is not a Node, we wrap it as a
MessageNode, whose inputs are nodes in used_nodes.
"""

def _wrap_inputs(self, fun, args, kwargs):
""" Wrap the inputs to a function as nodes when they're not.
fun = self.fun # define the function only once
self.info['fun'] = fun
Args:
fun (callable): the function to be wrapped.
args (list): the positional arguments of the function.
kwargs (dict): the keyword arguments of the function.
Returns:
inputs (dict): the inputs dict to construct the MessageNode (constructed by args and kwargs).
args (list): the wrapped positional arguments.
kwargs (dict): the wrapped keyword arguments.
_args (list): the original positional arguments (including the default values).
_kwargs (dict): the original keyword arguments (including the default values).
"""
## Wrap the inputs as nodes

# add default into kwargs
Expand All @@ -234,7 +241,7 @@ def forward(self, *args, **kwargs):
_, varargs, varkw, _, _, _, _ = inspect.getfullargspec(fun)


# bind the node version of args and kwargs
# bind the node version of args and kwargs
ba = inspect.signature(fun).bind(*args, **kwargs)
spec = ba.arguments

Expand All @@ -253,7 +260,10 @@ def extract_param(n):
inputs[k] = extract_param(v)
assert all([isinstance(n, Node) for n in inputs.values()]), "All values in inputs must be nodes."

return inputs, args, kwargs, _args, _kwargs

def _get_tracer(self):
""" Get a tracer to overwrite the python recursion behavior of calling the decorated function. """

# Define a tracer to deal with recursive function calls
_bundled_func = None
Expand Down Expand Up @@ -282,67 +292,111 @@ def tracer(frame, event, arg = None):
if frame.f_code.co_name in frame.f_globals:
frame.f_globals[frame.f_code.co_name] = _bundled_func
return tracer
return tracer

def _construct_error_comment(self, e):
""" Construct the error comment on the source code and traceback. """
self.info['traceback'] = traceback.format_exc() # This is saved for user debugging
# Construct message to optimizer
error_class = e.__class__.__name__
detail = e.args[0]
cl, exc, tb = sys.exc_info()
assert tb is not None # we're in the except block, so tb should not be None
n_fun_calls = len(traceback.extract_tb(tb))
# Step through the traceback stack
comments = []
base_message = f'({error_class}) {detail}.'
for i, (f, ln) in enumerate(traceback.walk_tb(tb)):
if i>0: # ignore the first one, since that is the try statement above
error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.'

if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here.
comment = self.generate_comment(self.parameter._data, error_message, ln, 1)
comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1)
else:
try:
f_source, f_source_ln = self.get_source(f, bug_mode=True)
except OSError: # OSError: could not get source code
# we reach the compiled C level, so the previous level is actually the bottom
comments[-1] = comment_backup # replace the previous comment
break # exit the loop
comment = self.generate_comment(f_source, error_message, ln, f_source_ln)
comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln)
comments.append(comment)
commented_code = '\n\n'.join(comments)
self.info['error_comment'] = commented_code + f"\n{base_message}"
output = e
return output

def sync_call_fun(self, fun, *_args, **_kwargs):
""" Call the operator fun and return the output. Catch the exception if catch_execution_error is True. """
oldtracer = sys.gettrace()
if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior
# Running a tracer would slow down the execution, so we only do this when necessary.
sys.settrace(self._get_tracer())

if self.catch_execution_error:
try:
output = fun(*_args, **_kwargs)
except Exception as e:
output = self._construct_error_comment(e)
else:
output = fun(*_args, **_kwargs)

sys.settrace(oldtracer)
return output

## Execute self.fun
with trace_nodes() as used_nodes:
# After exit, used_nodes contains the nodes whose data attribute is read in the operator fun.
async def async_call_fun(self, fun, *_args, **_kwargs):
oldtracer = sys.gettrace()
if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior
# Running a tracer would slow down the execution, so we only do this when necessary.
sys.settrace(self._get_tracer())

if self.catch_execution_error:
try:
output = await fun(*_args, **_kwargs)
except Exception as e:
output = self._construct_error_comment(e)
else:
output = await fun(*_args, **_kwargs)

# args, kwargs are nodes
# _args, _kwargs are the original inputs (_kwargs inlcudes the defaults)
sys.settrace(oldtracer)
return output

# Construct the inputs to call self.fun
if self._process_inputs:
if self.traceable_code:
_args, _kwargs = detach_inputs(args), detach_inputs(kwargs)
else:
_args, _kwargs = to_data(args), to_data(kwargs)
# else the inputs are passed directly to the function
# so we don't change _args and _kwargs

oldtracer = sys.gettrace()
if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior
sys.settrace(tracer)
# add an except here
if self.catch_execution_error:
try:
output = fun(*_args, **_kwargs)
except Exception as e:
# Construct the error comment on the source code and traceback
self.info['traceback'] = traceback.format_exc() # This is saved for user debugging
# Construct message to optimizer
error_class = e.__class__.__name__
detail = e.args[0]
cl, exc, tb = sys.exc_info()
n_fun_calls = len(traceback.extract_tb(tb))
# Step through the traceback stack
comments = []
base_message = f'({error_class}) {detail}.'
for i, (f, ln) in enumerate(traceback.walk_tb(tb)):
if i>0: # ignore the first one, since that is the try statement above
error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.'

if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here.
comment = self.generate_comment(self.parameter._data, error_message, ln, 1)
comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1)
else:
try:
f_source, f_source_ln = self.get_source(f, bug_mode=True)
except OSError: # OSError: could not get source code
# we reach the compiled C level, so the previous level is actually the bottom
comments[-1] = comment_backup # replace the previous comment
break # exit the loop
comment = self.generate_comment(f_source, error_message, ln, f_source_ln)
comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln)
comments.append(comment)
commented_code = '\n\n'.join(comments)
self.info['error_comment'] = commented_code + f"\n{base_message}"
output = e
else:
output = fun(*_args, **_kwargs)
sys.settrace(oldtracer)
def preprocess_inputs(self, args, kwargs, _args, _kwargs):
# NOTE This function must be put inside the used_nodes context manager.
""" Preprocess the inputs for the operator fun.
Args:
_args (list): the original positional arguments. This includes the default values.
_kwargs (dict): the original keyword arguments. This includes the default values.
args (list): the wrapped positional arguments.
kwargs (dict): the wrapped keyword arguments.
"""
# Construct the inputs to call self.fun
if self._process_inputs: # This is for handling hierarchical graph
if self.traceable_code:
_args, _kwargs = detach_inputs(args), detach_inputs(kwargs)
else: # NOTE Extract data from the nodes and pass them to the function; This line must be put inside the used_nodes context manager.
_args, _kwargs = to_data(args), to_data(kwargs) # read node.data; this ensures the inputs are treated as used nodes
# else the inputs are passed directly to the function
# so we don't change _args and _kwargs
return _args, _kwargs # this will be passed as the input to the function

def postprocess_output(self, output, fun, _args, _kwargs, used_nodes, inputs):
"""
Wrap the output as a MessageNode. Log the inputs and output of the function call.
Args:
output (Any): the output of the operator fun.
fun (callable): the operator fun.
_args (list): the original positional arguments. This includes the default values.
_kwargs (dict): the original keyword arguments. This includes the default values.
used_nodes (List[Node]): the nodes used in the operator fun.
inputs (Dict[str, Node]): the inputs of the operator fun.
"""

# logging inputs and output of the function call
# Log inputs and output of the function call
self.info["output"] = output
self.info['inputs']["args"] = _args
self.info['inputs']["kwargs"] = _kwargs
Expand All @@ -361,15 +415,51 @@ def tracer(frame, event, arg = None):
inputs = {} # We don't need to keep track of the inputs if we are not tracing.
# Wrap the output as a MessageNode or an ExceptionNode
nodes = self.wrap(output, inputs, external_dependencies)
return nodes

# If the output is a corountine, we return a coroutine.
if nodes._data is not None and inspect.iscoroutine(nodes._data):
async def _run_coro():
nodes._data = await nodes._data
return nodes
return _run_coro()
def forward(self, *args, **kwargs):
fun = self.fun # Define the function (only once)
self.info['fun'] = fun
if inspect.iscoroutinefunction(fun):
return self.async_forward(fun, *args, **kwargs) # Return a coroutine that returns a MessageNode
else:
return nodes
return self.sync_forward(fun, *args, **kwargs) # Return a MessageNode

def sync_forward(self, fun, *args, **kwargs):
"""
Call the operator fun and return a MessageNode. All nodes used in
the operator fun are added to used_nodes during the execution. If
the output is not a Node, we wrap it as a MessageNode, whose inputs
are nodes in used_nodes. Sync version.
"""
# Wrap the inputs as nodes
inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs)
## Execute fun
with trace_nodes() as used_nodes:
# After exit, used_nodes contains the nodes whose data attribute is read in the operator fun.
_args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs)
output = self.sync_call_fun(fun, *_args, **_kwargs)
# Wrap the output as a MessageNode or an ExceptionNode
nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs)
return nodes

async def async_forward(self, fun, *args, **kwargs):
"""
Call the operator fun and return a MessageNode. All nodes used in
the operator fun are added to used_nodes during the execution. If
the output is not a Node, we wrap it as a MessageNode, whose inputs
are nodes in used_nodes. Async version.
"""
# Wrap the inputs as nodes
inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs)
## Execute fun
with trace_nodes() as used_nodes:
# After exit, used_nodes contains the nodes whose data attribute is read in the operator fun.
_args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs)
output = await self.async_call_fun(fun, *_args, **_kwargs) # use await to call the async function
# Wrap the output as a MessageNode or an ExceptionNode
nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs)
return nodes

def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external_dependencies: List[Node]):
"""Wrap the output as a MessageNode of inputs as the parents."""
Expand Down
Empty file added opto/utils/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion opto/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.2.2"
__version__ = "0.1.3"
52 changes: 52 additions & 0 deletions tests/unit_tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ async def basic(a=0):
return 'basic'

async def main():
# single task
a = trace.node('a')
st = time.time()
x = await basic(a)
Expand All @@ -25,6 +26,7 @@ async def main():


async def main2():
# multiple tasks
a = trace.node('a')
st = time.time()
x, y, z = await asyncio.gather(basic(a), basic(a), basic(a)) # run in parallel
Expand All @@ -46,3 +48,53 @@ async def main2():


asyncio.run(main2())


@trace.bundle()
async def error(a=0):
raise ValueError('error')

async def main3():
# error handling
a = trace.node('a')
st = time.time()
try:
x = await error(a)
except trace.ExecutionError as e:
print(e)
x = e
ed = time.time()
print("Time taken: ", ed - st)
print(type(x), 'developer message:', x)
assert isinstance(x, trace.ExecutionError)
x = x.exception_node
print(type(x), 'optimizer message:', x.data)
assert isinstance(x, trace.nodes.MessageNode)
assert a in x.parents
assert len(x.parents) == 1

asyncio.run(main3())

async def main4():
# multiple error handling
a = trace.node('a')
b = trace.node('b')
c = trace.node('c')
st = time.time()
try:
x, y, z = await asyncio.gather(error(a), error(b), error(c)) # run in parallel
except trace.ExecutionError as e:
# print(e)
x = e # This will catch the first error
print(e.exception_node.parents)
ed = time.time()
print("Time taken: ", ed - st)
print(type(x), 'developer message:', x)
assert isinstance(x, trace.ExecutionError)
x = x.exception_node
print(type(x), 'optimizer message:', x.data)
assert isinstance(x, trace.nodes.MessageNode)
assert a in x.parents
assert len(x.parents) == 1

asyncio.run(main4())
14 changes: 10 additions & 4 deletions tests/unit_tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ def fun(x):

x = trace.node('x', trainable=True)
copy.deepcopy(x)
optimizer = OptoPrime([x])
optimizer2 = copy.deepcopy(optimizer)

llm = AutoGenLLM()
copy.deepcopy(llm)

try:
optimizer = OptoPrime([x])
optimizer2 = copy.deepcopy(optimizer)

llm = AutoGenLLM()
copy.deepcopy(llm)
except FileNotFoundError as e:
print(f'Error: {e}')
print('Omit the test.')
Loading

0 comments on commit 4347bd5

Please sign in to comment.