Skip to content

Commit

Permalink
[Fix] Fix activation checkpointing bug (#159)
Browse files Browse the repository at this point in the history
* fix lora checkpoint bug

* rename

* fix pre-commit
  • Loading branch information
LZHgrla authored Oct 9, 2023
1 parent fcf5ffd commit d118ac4
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,14 @@ def __init__(self,
llm,
lora=None,
peft_model=None,
use_gradient_checkpointing=True):
use_activation_checkpointing=True):
super().__init__()
with LoadWoInit():
self.llm = self._build_from_cfg_or_module(llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm)

if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
lora, ConfigDict):
self.lora = BUILDER.build(lora)
else:
self.lora = lora
self.peft_model = peft_model
self.use_lora = lora is not None
if self.use_lora:
self._prepare_for_lora(peft_model, use_gradient_checkpointing)
elif use_gradient_checkpointing:
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
Expand All @@ -49,13 +40,23 @@ def make_inputs_require_grad(module, input, output):
# enable gradient checkpointing for memory efficiency
self.llm.gradient_checkpointing_enable()

if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
lora, ConfigDict):
self.lora = BUILDER.build(lora)
else:
self.lora = lora
self.peft_model = peft_model
self.use_lora = lora is not None
if self.use_lora:
self._prepare_for_lora(peft_model, use_activation_checkpointing)

self._is_init = True

def _prepare_for_lora(self,
peft_model=None,
use_gradient_checkpointing=True):
self.llm = prepare_model_for_kbit_training(self.llm,
use_gradient_checkpointing)
use_activation_checkpointing=True):
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if self.lora.target_modules is None:
modules = find_all_linear_names(self.llm)
self.lora.target_modules = modules
Expand Down

0 comments on commit d118ac4

Please sign in to comment.