Skip to content

Commit

Permalink
Merge pull request #92 from VikParuchuri/revert-72-pruning
Browse files Browse the repository at this point in the history
Revert "Prune decoder languages before loading the model to save memory"
  • Loading branch information
VikParuchuri authored May 8, 2024
2 parents 0137c35 + 98c617c commit 27ccb3d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions surya/model/recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)

# Prune moe experts that are not needed before loading the model
if langs:
config.decoder.langs = {lang_iso : lang_int for lang_iso, lang_int in config.decoder.langs.items() if lang_int in langs}

decoder_config = vars(config.decoder)
decoder = MBartMoEConfig(**decoder_config)
config.decoder = decoder
Expand All @@ -33,6 +29,10 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings
assert isinstance(model.decoder, MBartMoE)
assert isinstance(model.encoder, VariableDonutSwinModel)

# Prune moe experts that are not needed
if langs is not None:
model.decoder.prune_moe_experts(langs)

model = model.to(device)
model = model.eval()
print(f"Loading recognition model {checkpoint} on device {device} with dtype {dtype}")
Expand Down

0 comments on commit 27ccb3d

Please sign in to comment.