Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training_script_refactor #54

Merged
merged 49 commits into from
Mar 8, 2024
Merged

Conversation

jaidhyani
Copy link
Collaborator

@jaidhyani jaidhyani commented Mar 6, 2024

This is a WIP refactor of the training script. It's not in a good state and there's a high probability I broke something while refactoring, but it has one redeeming quality: it runs. Does it run correctly and produce usable models? Excellent question, I have no idea yet.

From src/: python delphi/train/training.py

@jaidhyani jaidhyani marked this pull request as ready for review March 7, 2024 16:27
@jaidhyani
Copy link
Collaborator Author

I don't think this is totally ready yet, but at this point I think it's probably worth it to (1) merge this into the training-script branch and (2) then merge the training_script branch into main.

gptconf = Llama2ModelArgs(**model_args)
model = Llama2Model(gptconf)
state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this still happens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea tbqh, that's just something I copied and never checked.

def set_iteration_params(config, train_ds, validation_ds) -> IterationParams:
num_batches = len(train_ds) // config.batch_size
num_steps = num_batches // config.gradient_accumulation_steps
eval_iters = min(12, len(validation_ds) // config.batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the 12

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something that was already there that I didn't want to change without understanding it first: https://github.com/delphi-suite/delphi/pull/31/files#diff-c113425bb7a4b6c38858b09ca918bc89bd243d139cdca30ab5dd386d9690935bR94

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I can tell this is a constant set by Karpathy to log at least every 12 batches.

def _default_indices(self):
return list(range(len(self.batched_tokens)))

def shuffle(self, epoch: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is correct. Shuffle alters the list in place, why are we changing the indices instead of the list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idempotency. I want shuffle(x) to result in the same state regardless of when it's called - so not dependent on the state of the shuffling before it's called. By making it a deterministic shuffle of the default indices we can guarantee that it doesn't matter when you call shuffle(x), you'll still get the same result.

Technically we don't need this for reproducibility if we commit to always calling shuffle with the same arguments in the same sequence, but I prefer to have it work this way to minimize the amount of state that needs to be kept track of when debugging - I don't want to have to repeat every shuffle in a given sequence to reproduce a problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it that way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just didnt see where you then apply the indices to the list, but if you do it somewhere I trust you

Copy link
Contributor

@SrGonao SrGonao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a couple comments, only noticeable thing is on the shuffle, the rest is minor points and then mamba stuff that I should do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good for now, but in the future we should consider adding the dataset as a parameter to the config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be both - a config param that defaults to a pre-defined constant.

def __init__(self, tokenized_docs, max_seq_len, device):
self.device = device
self.tokenized_docs = tokenized_docs
self.doc_len = len(tokenized_docs[0]["tokens"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"doc_len" = "document_length" = context length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document length, but I don't think that's context length. max_seq_len is model context length - or more accurately, how long training samples are.


# setup eval callbacks
eval_callbacks = [save_checkpoint_if_needed]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I am missing something, but if eval_callbacks only includes saving the checkpoint, where does it log to wandb?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, found it in line 44 lol

Copy link
Contributor

@jannik-brinkmann jannik-brinkmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me. as discussed in Discord, I will start a training run tomorrow after the PR is merged to see whether the performance matches our current models

@jaidhyani jaidhyani merged commit a00425b into training-script Mar 8, 2024
1 check passed
@jaidhyani jaidhyani deleted the training_script_refactor branch March 20, 2024 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants