Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError when using peft on FSDPTrainer #90

Open
AragornHorse opened this issue Nov 5, 2024 · 0 comments
Open

ValueError when using peft on FSDPTrainer #90

AragornHorse opened this issue Nov 5, 2024 · 0 comments

Comments

@AragornHorse
Copy link

ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried running this code on two 80GB A100 and added PEFT's Lora in train.py:

peft_config = LoraConfig(                                                                                                                  
     r=config.lora.r, 
    lora_alpha=config.lora.alpha,
    lora_dropout=config.lora.dropout                                                    
)

policy = get_peft_model(policy, peft_config)

It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:

Traceback (most recent call last):
  File "/direct-preference-optimization/train.py", line 127, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
    while not context.join():
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 189, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
    fn(i, *args)
  File "direct-preference-optimization/train.py", line 43, in worker_main
    trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
  File "direct-preference-optimization/trainers.py", line 469, in __init__
    self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
    _auto_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
    return wrapper_cls(module, **kwargs)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
    _init_param_handle_from_module(
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
    self._init_flat_param_and_metadata(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
    raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried to use bfloat16 in Lora modules, but other ValueErrors occurs.
I tried use_orig_params=True, it doesn't work.

How to solve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant