diff --git a/modulus/models/sfno/sfnonet.py b/modulus/models/sfno/sfnonet.py index 7b0b75b9a5..5e272b1c99 100644 --- a/modulus/models/sfno/sfnonet.py +++ b/modulus/models/sfno/sfnonet.py @@ -731,10 +731,13 @@ def __init__( ) # self.pos_embed = nn.Parameter( torch.zeros(1, self.embed_dim, self.img_shape_eff[0], self.img_shape_eff[1]) ) self.pos_embed.is_shared_mp = ["matmul"] - trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_weights) + # doing weight init of pos_embed after other layers resolves a segfault + if isinstance(self.pos_embed, nn.Parameter): + trunc_normal_(self.pos_embed, std=0.02) + def _init_weights(self, m): """Helper routine for weight initialization""" if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):