From 772a6d68b118a9e9f73ce7042ec13cc15ff6636d Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:26:07 +0800 Subject: [PATCH] [Fix] Fix activation checkpointing bug (#159) * fix lora checkpoint bug * rename * fix pre-commit --- xtuner/model/sft.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 55f21fe7e..66ba0340c 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -18,23 +18,14 @@ def __init__(self, llm, lora=None, peft_model=None, - use_gradient_checkpointing=True): + use_activation_checkpointing=True): super().__init__() with LoadWoInit(): self.llm = self._build_from_cfg_or_module(llm) self.llm.config.use_cache = False dispatch_modules(self.llm) - if isinstance(lora, dict) or isinstance(lora, Config) or isinstance( - lora, ConfigDict): - self.lora = BUILDER.build(lora) - else: - self.lora = lora - self.peft_model = peft_model - self.use_lora = lora is not None - if self.use_lora: - self._prepare_for_lora(peft_model, use_gradient_checkpointing) - elif use_gradient_checkpointing: + if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): self.llm.enable_input_require_grads() @@ -49,13 +40,23 @@ def make_inputs_require_grad(module, input, output): # enable gradient checkpointing for memory efficiency self.llm.gradient_checkpointing_enable() + if isinstance(lora, dict) or isinstance(lora, Config) or isinstance( + lora, ConfigDict): + self.lora = BUILDER.build(lora) + else: + self.lora = lora + self.peft_model = peft_model + self.use_lora = lora is not None + if self.use_lora: + self._prepare_for_lora(peft_model, use_activation_checkpointing) + self._is_init = True def _prepare_for_lora(self, peft_model=None, - use_gradient_checkpointing=True): - self.llm = prepare_model_for_kbit_training(self.llm, - use_gradient_checkpointing) + use_activation_checkpointing=True): + self.llm = prepare_model_for_kbit_training( + self.llm, use_activation_checkpointing) if self.lora.target_modules is None: modules = find_all_linear_names(self.llm) self.lora.target_modules = modules