diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 95d0dba..e482de7 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -336,6 +336,8 @@ def _batch_search(self, query: list[str], k: int): def train(self, data_dir, training_config: ColBERTConfig): training_config = ColBERTConfig.from_existing(self.config, training_config) training_config.nway = 2 + if training_config.nranks < 2: + training_config.avoid_fork_if_possible = True with Run().context(self.run_config): trainer = Trainer( triples=str(data_dir / "triples.train.colbert.jsonl"),