Skip to content

Commit

Permalink
Create specformer.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
lhparker1 authored Feb 28, 2024
1 parent 843a09f commit df08771
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions notebooks/specformer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fillm.run.model import *
from datasets import load_dataset

CACHE_DIR = '/mnt/ceph/users/lparker/datasets_astroclip'
dataset = load_dataset('/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP/astroclip_datasets/legacy_survey.py', cache_dir=CACHE_DIR)
dataset.set_format(type='torch', columns=['spectrum'])

def load_model_from_ckpt(ckpt_path: str):
"""
Load a model from a checkpoint.
"""
if Path(ckpt_path).is_dir():
ckpt_path = Path(ckpt_path) / "ckpt.pt"

chkpt = torch.load(ckpt_path)
config = chkpt["config"]
state_dict = chkpt["model"]
model_name = config["model"]['kind']
model_keys = get_model_keys(model_name)

model_args = {k: config['model'][k] for k in model_keys}

model_ctr, config_cls = model_registry[model_name]
model_config = config_cls(**model_args)
model_ = model_ctr(model_config)
model_.load_state_dict(state_dict)

return {"model": model_, "config": config}

model_path = "/mnt/home/sgolkar/ceph/saves/fillm/run-seqformer-2708117"
out = load_model_from_ckpt(model_path)

config = out['config']
spec_model = out['model']

0 comments on commit df08771

Please sign in to comment.