Skip to content

Commit

Permalink
correct load v3 (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
teslacool authored Jun 19, 2023
1 parent d3ff076 commit b1c89a2
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions scripts/translate_jax_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
swap: bool = False
swap: bool = False


def _process_translations_dict(d, top_layer=True):
Expand Down Expand Up @@ -101,6 +101,7 @@ def stacked(param_dict_list, out=None):
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
swap=v[0].swap
)

out[k] = stacked_param
Expand All @@ -122,7 +123,12 @@ def assign(translation_dict, orig_weights):
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
p.copy_(w)
if param.swap:
index = p.shape[0]//2
p[:index].copy_(w[index:])
p[index:].copy_(w[:index])
else:
p.copy_(w)
except:
print(k)
print(ref[0].shape)
Expand Down

0 comments on commit b1c89a2

Please sign in to comment.