You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:
Traceback (most recent call last):
File "/direct-preference-optimization/train.py", line 127, in main
mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
while not context.join():
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 189, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
fn(i, *args)
File "direct-preference-optimization/train.py", line 43, in worker_main
trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
File "direct-preference-optimization/trainers.py", line 469, in __init__
self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
_auto_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
return wrapper_cls(module, **kwargs)
File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
_init_param_handle_from_module(
File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
handle = FlatParamHandle(
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
self._init_flat_param_and_metadata(
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I tried to use bfloat16 in Lora modules, but other ValueErrors occurs.
I tried use_orig_params=True, it doesn't work.
How to solve it?
The text was updated successfully, but these errors were encountered:
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I tried running this code on two 80GB A100 and added PEFT's Lora in train.py:
It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:
I tried to use bfloat16 in Lora modules, but other ValueErrors occurs.
I tried use_orig_params=True, it doesn't work.
How to solve it?
The text was updated successfully, but these errors were encountered: