Skip to content

Commit

Permalink
store function recorder in a format of assign_name@index
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu302 committed Sep 4, 2023
1 parent 9a6ff6f commit 4c5d27b
Showing 1 changed file with 54 additions and 16 deletions.
70 changes: 54 additions & 16 deletions mmengine/hooks/recorder_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ast
import inspect
import logging
import os.path as osp
import pickle
import textwrap
import types
Expand All @@ -13,23 +14,34 @@
from . import Hook


def function_with_index(function, index):
return function + '@' + str(index)


class FunctionRecorderTransformer(ast.NodeTransformer):

def __init__(self, target):
def __init__(self, target, target_index):
super().__init__()
self._target = target
self._target_index = set(target_index)
self.count = 0

def visit_Assign(self, node):
if node.targets[0].id != self._target:
return node
self.count += 1
if self.count not in self._target_index:
return node
update_messagehub_node = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id='message_hub', ctx=ast.Load()),
attr='update_info',
ctx=ast.Load()),
args=[
ast.Constant(value=node.targets[0].id),
ast.Constant(
value=function_with_index(node.targets[0].id,
self.count)),
ast.Name(id=node.targets[0].id, ctx=ast.Load())
],
keywords=[]))
Expand All @@ -48,10 +60,6 @@ def _get_tensor_key(target, attribute=None):
return target


# def get_node_name(func_name):
# return 'tmp_func_' + func_name


class FuncCallVisitor(ast.NodeTransformer):

def __init__(self, func_name):
Expand Down Expand Up @@ -98,10 +106,13 @@ def _get_target_attribute(self):
func_chain = self._target.split('.')
func_chain.append(self._attribute)
assert len(func_chain) >= 2
attr = ast.Attribute(value=ast.Name(id=func_chain[0], ctx=ast.Load()), attr=func_chain[1], ctx=ast.Load())
attr = ast.Attribute(
value=ast.Name(id=func_chain[0], ctx=ast.Load()),
attr=func_chain[1],
ctx=ast.Load())
for ele in func_chain[2:]:
attr = ast.Attribute(value=attr, attr=ele, ctx=ast.Load())
return attr
return attr

def visit_Assign(self, node):
self.function_visitor.visit(node)
Expand All @@ -112,7 +123,8 @@ def visit_Assign(self, node):
targets=[ast.Name(id=assign_node_name, ctx=ast.Store())],
value=assign_right_node)
if self._attribute:
assign_node_name = _get_tensor_key(self._target, self._attribute)
assign_node_name = _get_tensor_key(self._target,
self._attribute)
ast_arg2 = self._get_target_attribute()
else:
ast_arg2 = ast.Name(id=assign_node_name, ctx=ast.Load())
Expand Down Expand Up @@ -142,12 +154,13 @@ def rewrite(self, ast_tree):
@RECORDERS.register_module()
class FunctionRecorder(Recorder):

def __init__(self, target: str):
def __init__(self, target: str, index: list):
super().__init__(target)
self.index = index
self.visit_assign = self._get_transformer_class()

def _get_transformer_class(self):
return FunctionRecorderTransformer(self._target)
return FunctionRecorderTransformer(self._target, self.index)

def rewrite(self, ast_tree):
return self.visit_assign.visit(ast_tree)
Expand Down Expand Up @@ -184,21 +197,33 @@ def __init__(
self._recorders: Dict[str, Recorder] = {}
self.print_modification = print_modification
self.save_dir = save_dir # type: ignore
if filename_tmpl is None:
self.filename_tmpl = 'record_epoch_{}.pth'

if recorders is None or len(recorders) == 0:
raise ValueError('recorders not initialized')
for recorder in recorders:
target = recorder.get('target')
attribute = recorder.get('attribute')
tensor_key = _get_tensor_key(target, attribute)

if target is None:
print_log(
'`RecorderHook` cannot be initialized '
'because recorder has no target',
logger='current',
level=logging.WARNING)
tensor_key = _get_tensor_key(target, attribute)
self.tensor_dict[tensor_key] = list()
if recorder.get('type') == 'FunctionRecorder':
index = recorder.get('index')
if isinstance(index, list):
for i in index:
self.tensor_dict[function_with_index(target,
i)] = list()
elif isinstance(index, int):
self.tensor_dict[function_with_index(target,
index)] = list()
elif recorder.get('type') == 'AttributeRecorder':
self.tensor_dict[tensor_key] = list()
self._recorders[tensor_key] = RECORDERS.build(recorder)

def _modify_func(self, func):
Expand Down Expand Up @@ -268,8 +293,21 @@ def after_train_iter(self,
for key in self.tensor_dict.keys():
self.tensor_dict[key].append(self.message_hub.get_info(key))

def _save_record(self, step):
recorder_file_name = self.filename_tmpl.format(step)
path = osp.join(self.save_dir, recorder_file_name)
with open(path, 'wb') as f:
pickle.dump(self.tensor_dict, f)

def _init_tensor_dict(self):
for k in self.tensor_dict.keys():
self.tensor_dict[k] = list()

def after_train_epoch(self, runner) -> None:
step = runner.epoch + 1
runner.logger.info(f'Saving record at {runner.epoch + 1} epochs')
self._save_record(step)
self._init_tensor_dict()

def after_train(self, runner) -> None:
data = pickle.dumps(self.tensor_dict)
print(data)
# use self.save_dir to save data
runner.model.forward = self.origin_forward

0 comments on commit 4c5d27b

Please sign in to comment.