Skip to content

Commit

Permalink
fix: re-enable training multi-task model with shared weights
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 13, 2024
1 parent 696a410 commit f41456d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 29 deletions.
74 changes: 48 additions & 26 deletions edsnlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,14 @@ def train(
+ "".join(f"\n - {i + 1}: {', '.join(n)}" for i, n in enumerate(phases))
)

optim_base = optimizer
all_params = set(nlp.parameters())
if optim_base is None:
optim = optimizer
del optimizer
if optim is None:
warnings.warn(
"No optimizer provided, using default optimizer with default " "parameters"
)
optimizer = default_optim(
optim = default_optim(
[nlp.get_pipe(name) for name in trainable_pipe_names],
max_steps=max_steps,
**{
Expand All @@ -435,56 +436,66 @@ def train(

for phase_i, pipe_names in enumerate(phases):
trained_pipes = [nlp.get_pipe(name) for name in pipe_names]
trained_pipes_params = {p for pipe in trained_pipes for p in pipe.parameters()}
phase_training_data = [
td
for td in train_data
if td.pipe_names is None or set(td.pipe_names) & set(pipe_names)
]

with nlp.select_pipes(disable=trainable_pipe_names - set(pipe_names)):
accelerator.print(f"Phase {phase_i + 1}: training {', '.join(pipe_names)}")
set_seed(seed)

grad_params = {p for g in optimizer.param_groups for p in g["params"]}
optim_params = {p for g in optim.param_groups for p in g["params"]}
grad_params = set()
for param in all_params:
has_grad_param = param in optim_params and param in trained_pipes_params
if has_grad_param:
grad_params.add(param)
param.requires_grad_(has_grad_param)

accelerator.print(
"Optimizing groups:"
+ "".join(
f"\n - {g.get('selector', '*') + ':' if 'selector' in g else ''} "
f"{len(g['params'])} weight tensors "
f"({sum(p.numel() for p in g['params']):,} parameters)"
for g in optimizer.param_groups
"\n - {} {} weight tensors ({:,} parameters)".format(
g.get("selector", "*") + ":" if "selector" in g else "",
len([p for p in g["params"] if p in grad_params]),
sum([p.numel() for p in g["params"] if p in grad_params]),
)
for g in optim.param_groups
)
)
accelerator.print(
f"Keeping frozen {len(all_params - grad_params):} weight tensors "
f"({sum(p.numel() for p in all_params - grad_params):,} parameters)"
)
for param in all_params:
param.requires_grad_(param in grad_params)

nlp.train(True)

cumulated_data = defaultdict(lambda: 0.0, count=0)

iterator = iter(
zip(
*(
td(nlp, device).set_processing(
num_cpu_workers=num_workers,
process_start_method="spawn",
)
for td in train_data
if td.pipe_names is None or set(td.pipe_names) & set(pipe_names)
for td in phase_training_data
)
)
)
(optimizer, *trained_pipes) = accelerator.prepare(optimizer, *trained_pipes)
if hasattr(optimizer, "initialize"):
optimizer.initialize()
if hasattr(optim, "initialize"):
optim.initialize()
(accel_optim, *trained_pipes) = accelerator.prepare(optim, *trained_pipes)

cumulated_data = defaultdict(lambda: 0.0, count=0)
all_metrics = []
set_seed(seed)

logger = (
with (
RichTablePrinter(LOGGER_FIELDS, auto_refresh=False)
if is_main_process
else nullcontext()
)
with logger:
) as logger:
# Training loop
for step in trange(
max_steps + 1,
Expand All @@ -500,7 +511,7 @@ def train(
all_metrics.append(
{
"step": step,
"lr": optimizer.param_groups[0]["lr"],
"lr": accel_optim.param_groups[0]["lr"],
**cumulated_data,
**scores,
}
Expand All @@ -515,9 +526,18 @@ def train(
if step == max_steps:
break

optimizer.zero_grad()
accel_optim.zero_grad()

batches = list(flatten(list(next(iterator))))
batches = list(next(iterator))
batches_pipe_names = list(
flatten(
[
[td.pipe_names or pipe_names] * len(b)
for td, b in zip(phase_training_data, batches)
]
)
)
batches = list(flatten(batches))

# Synchronize stats between sub-batches across workers
input_stats = {}
Expand All @@ -529,10 +549,12 @@ def train(
set_flat_stats(b, input_stats)

output_stats = defaultdict(lambda: 0.0)
for batch in batches:
for batch, batch_pipe_names in zip(batches, batches_pipe_names):
loss = torch.zeros((), device=accelerator.device)
with nlp.cache():
for name, pipe in zip(pipe_names, trained_pipes):
if name not in batch_pipe_names:
continue
res = dict(pipe(batch[name]))
if "loss" in res:
res["loss"] = res["loss"] * loss_scales.get(name, 1)
Expand Down Expand Up @@ -564,7 +586,7 @@ def train(

del input_stats, output_stats
accelerator.clip_grad_norm_(grad_params, max_grad_norm)
optimizer.step()
accel_optim.step()

del iterator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ optimizer:
groups:
"^transformer": false
".*":
lr: 1e-3
lr:
"@schedules": linear
start_value: 1e-3
max_value: 2e-3
warmup_rate: 0.5
total_steps: ${ train.max_steps }

# 📚 DATA
train_data:
Expand Down
21 changes: 19 additions & 2 deletions tests/training/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,26 @@ def __call__(self, obj):
return doc


def test_ner_qualif_train(run_in_test_dir, tmp_path):
def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("ner_qlf_config.yml")
config = Config.from_disk("ner_qlf_diff_bert_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
scorer = GenericScorer(**kwargs["scorer"])
val_data = kwargs["val_data"]
last_scores = scorer(nlp, val_data)

# Check empty doc
nlp("")

assert last_scores["ner"]["micro"]["f"] > 0.4
assert last_scores["qual"]["micro"]["f"] > 0.4


def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("ner_qlf_same_bert_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
Expand Down

0 comments on commit f41456d

Please sign in to comment.