diff --git a/src/delphi/train/gigaconfig.py b/src/delphi/train/gigaconfig.py index bfb1e487..538b18f7 100644 --- a/src/delphi/train/gigaconfig.py +++ b/src/delphi/train/gigaconfig.py @@ -36,6 +36,7 @@ class GigaConfig: batch_size: int = ( 64 # if gradient_accumulation_steps > 1, this is the micro-batch size ) + # TODO: delete this, use doc size always max_seq_len: int = 256 vocab_size: int = 32000 # the Llama 2 tokenizer has 32K tokens # model @@ -59,6 +60,7 @@ class GigaConfig: min_lr: float = 0.0 # should be ~learning_rate/10 per Chinchill # reproducibility seed = 1337 + # TODO: seeds for batch ordering and weight initialization # debugging train_sample_limit: int = -1 # -1 implies no limit val_sample_limit: int = -1 diff --git a/src/delphi/train/tokenized_chunks_dataset.py b/src/delphi/train/tokenized_chunks_dataset.py index 87d13b8a..92bf67aa 100644 --- a/src/delphi/train/tokenized_chunks_dataset.py +++ b/src/delphi/train/tokenized_chunks_dataset.py @@ -8,9 +8,7 @@ class TokenizedChunksDataset(Dataset): def __init__(self, tokenized_docs, max_seq_len, device): self.device = device self.tokenized_docs = tokenized_docs - self.doc_len = len(tokenized_docs[0]["tokens"]) self.max_len = max_seq_len - self._total_tokens = self.doc_len * len(self.tokenized_docs) self.batched_tokens = ( torch.Tensor() ) # will be initialized in initialize_samples @@ -39,7 +37,7 @@ def shuffle(self, epoch: int): shuffle_list(self.indices, seed=epoch) def __len__(self): - return self._total_tokens // self.max_len + return len(self.batched_tokens) def get_sample_window(self, idx): return self.batched_tokens[idx % len(self.batched_tokens), :] diff --git a/src/delphi/train/training.py b/src/delphi/train/training.py index 2d6b6198..ad6dddb8 100644 --- a/src/delphi/train/training.py +++ b/src/delphi/train/training.py @@ -72,7 +72,3 @@ def run_training(config: GigaConfig): ) if breaknow: break - - -if __name__ == "__main__": - run_training(debug_config) diff --git a/src/delphi/train/utils.py b/src/delphi/train/utils.py index 53a7cb9a..086ab90e 100644 --- a/src/delphi/train/utils.py +++ b/src/delphi/train/utils.py @@ -23,6 +23,8 @@ def load_config(config_path): return json.load(file) +# TODO: make this configurable. Set from config if available +# def get_device() -> torch.device def get_device() -> str: # cuda if available; else mps if apple silicon; else cpu if torch.cuda.is_available(): @@ -211,6 +213,7 @@ def load_model_training_state(config: GigaConfig, device: str) -> ModelTrainingS print("Initializing a new model from scratch") model = initialize_model(**model_args) checkpoint = None + # TODO: resume from huggingface model elif config.init_from == "resume": print(f"Resuming training from {config.out_dir}") model_mid_train = resume_model(Path(config.out_dir), device, **model_args)