From dbdec7a2d823637d153cfa6ba43f53e53e3d95b6 Mon Sep 17 00:00:00 2001 From: fpshuang Date: Mon, 3 Jun 2024 12:55:14 +0800 Subject: [PATCH] fix custom tracer get attr --- .../qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py | 8 ++++---- mmrazor/models/task_modules/tracer/fx/custom_tracer.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py index 261af7abb..18ce35471 100644 --- a/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py +++ b/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py @@ -1,4 +1,4 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = ['mmpretrain::resnet/resnet18_8xb32_in1k.py'] resnet = _base_.model float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 @@ -18,7 +18,7 @@ _scope_='mmrazor', type='MMArchitectureQuant', data_preprocessor=dict( - type='mmcls.ClsDataPreprocessor', + type='mmpretrain.ClsDataPreprocessor', num_classes=1000, # RGB format normalization parameters mean=[123.675, 116.28, 103.53], @@ -33,8 +33,8 @@ tracer=dict( type='mmrazor.CustomTracer', skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' + 'mmpretrain.models.heads.ClsHead._get_loss', + 'mmpretrain.models.heads.ClsHead._get_predictions' ]))) optim_wrapper = dict( diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 68d5f0809..17e717cae 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -146,9 +146,9 @@ def _get_attrs(target, attrs): for special_node in special_nodes: if special_node in node.args or \ special_node in node.kwargs.values(): - origin_module = getattr(model, special_node.target) + origin_module = _get_attrs(model, special_node.target) setattr(module_dict[special_node.target], node.target, - getattr(origin_module, node.target)) + _get_attrs(origin_module, node.target)) return module_dict