Skip to content

Commit

Permalink
fix custom tracer get attr
Browse files Browse the repository at this point in the history
  • Loading branch information
fpshuang committed Jun 3, 2024
1 parent 90c7af1 commit dbdec7a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
_base_ = ['mmpretrain::resnet/resnet18_8xb32_in1k.py']

resnet = _base_.model
float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501
Expand All @@ -18,7 +18,7 @@
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
type='mmpretrain.ClsDataPreprocessor',
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
Expand All @@ -33,8 +33,8 @@
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
'mmpretrain.models.heads.ClsHead._get_loss',
'mmpretrain.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/models/task_modules/tracer/fx/custom_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def _get_attrs(target, attrs):
for special_node in special_nodes:
if special_node in node.args or \
special_node in node.kwargs.values():
origin_module = getattr(model, special_node.target)
origin_module = _get_attrs(model, special_node.target)
setattr(module_dict[special_node.target], node.target,
getattr(origin_module, node.target))
_get_attrs(origin_module, node.target))

return module_dict

Expand Down

0 comments on commit dbdec7a

Please sign in to comment.