-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training_script_refactor #54
Conversation
I don't think this is totally ready yet, but at this point I think it's probably worth it to (1) merge this into the |
gptconf = Llama2ModelArgs(**model_args) | ||
model = Llama2Model(gptconf) | ||
state_dict = checkpoint["model"] | ||
# fix the keys of the state dictionary :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure this still happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea tbqh, that's just something I copied and never checked.
def set_iteration_params(config, train_ds, validation_ds) -> IterationParams: | ||
num_batches = len(train_ds) // config.batch_size | ||
num_steps = num_batches // config.gradient_accumulation_steps | ||
eval_iters = min(12, len(validation_ds) // config.batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the 12
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something that was already there that I didn't want to change without understanding it first: https://github.com/delphi-suite/delphi/pull/31/files#diff-c113425bb7a4b6c38858b09ca918bc89bd243d139cdca30ab5dd386d9690935bR94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as far as I can tell this is a constant set by Karpathy to log at least every 12 batches.
def _default_indices(self): | ||
return list(range(len(self.batched_tokens))) | ||
|
||
def shuffle(self, epoch: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this is correct. Shuffle alters the list in place, why are we changing the indices instead of the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idempotency. I want shuffle(x)
to result in the same state regardless of when it's called - so not dependent on the state of the shuffling before it's called. By making it a deterministic shuffle of the default indices we can guarantee that it doesn't matter when you call shuffle(x)
, you'll still get the same result.
Technically we don't need this for reproducibility if we commit to always calling shuffle with the same arguments in the same sequence, but I prefer to have it work this way to minimize the amount of state that needs to be kept track of when debugging - I don't want to have to repeat every shuffle in a given sequence to reproduce a problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it that way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just didnt see where you then apply the indices to the list, but if you do it somewhere I trust you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a couple comments, only noticeable thing is on the shuffle, the rest is minor points and then mamba stuff that I should do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems good for now, but in the future we should consider adding the dataset as a parameter to the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be both - a config param that defaults to a pre-defined constant.
def __init__(self, tokenized_docs, max_seq_len, device): | ||
self.device = device | ||
self.tokenized_docs = tokenized_docs | ||
self.doc_len = len(tokenized_docs[0]["tokens"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"doc_len" = "document_length" = context length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document length, but I don't think that's context length. max_seq_len
is model context length - or more accurately, how long training samples are.
|
||
# setup eval callbacks | ||
eval_callbacks = [save_checkpoint_if_needed] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I am missing something, but if eval_callbacks only includes saving the checkpoint, where does it log to wandb?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, found it in line 44 lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me. as discussed in Discord, I will start a training run tomorrow after the PR is merged to see whether the performance matches our current models
This is a WIP refactor of the training script. It's not in a good state and there's a high probability I broke something while refactoring, but it has one redeeming quality: it runs. Does it run correctly and produce usable models? Excellent question, I have no idea yet.
From
src/
:python delphi/train/training.py