diff --git a/slapo/primitives/replace.py b/slapo/primitives/replace.py index 159fe617..f79c3fc7 100644 --- a/slapo/primitives/replace.py +++ b/slapo/primitives/replace.py @@ -280,3 +280,31 @@ def apply(sch, new_mod_or_func, target_ops=None, name=None, concrete_args=None): sch, sch.group, ) + + +@register_primitive() +class ReplaceAllPrimitive(Primitive): + """Replace all the specified submodules with the new module. + + Parameters + ---------- + sch : Schedule + The schedule with the module/function to be replaced. + target_mod_type : Type + A target nn.Module type to be replaced. + make_mod_fn : FunctionType + A function that takes the original module and generate a new module. + """ + + @staticmethod + def name(): + return "replace_all" + + @staticmethod + def apply(sch, target_mod_type, make_mod_fn): + module_names = dict(sch.mod.named_modules()).keys() + for name in module_names: + subsch = sch[name] + if isinstance(subsch.mod, target_mod_type): + new_mod = make_mod_fn(name, subsch.mod) + subsch.replace(new_mod) diff --git a/tests/test_replace.py b/tests/test_replace.py index d691f30f..91754092 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -33,6 +33,62 @@ def forward(self, x): assert isinstance(sch["activation"].mod, nn.GELU) +def test_replace_all_module(): + class SubMod(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(1024, 1024) + self.activation = nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.activation(x) + return x + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(1024, 1024) + self.act1 = nn.ReLU() + self.fc2 = nn.Linear(1024, 1024) + self.act2 = nn.ReLU() + self.submod = SubMod() + + def forward(self, x): + x = self.fc1(x) + x = self.act1(x) + x = self.fc2(x) + x = self.act2(x) + x = self.submod(x) + return x + + model = Model() + sch = slapo.create_schedule(model) + + def make_gelu(name, mod): + return nn.GELU() + + sch.replace_all(nn.ReLU, make_gelu) + assert isinstance(sch["act1"].mod, nn.GELU) + assert isinstance(sch["act2"].mod, nn.GELU) + assert isinstance(sch["submod.activation"].mod, nn.GELU) + + # test giving different shape of parameters + def make_linear(name, mod): + if name == "fc1": + in_feat, out_feat = 1024, 1025 + elif name == "fc2": + in_feat, out_feat = 1025, 1026 + else: + in_feat, out_feat = 1026, 1027 + return nn.Linear(in_feat, out_feat) + + sch.replace_all(nn.Linear, make_linear) + assert sch["fc1"].mod.out_features == 1025 + assert sch["fc2"].mod.out_features == 1026 + assert sch["submod.linear"].mod.out_features == 1027 + + def test_vertical_replacement(): class Model(nn.Module): def __init__(self):