-
Notifications
You must be signed in to change notification settings - Fork 189
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from eosphoros-ai/lora
update: Updates the readme document and optimizes the code structure
- Loading branch information
Showing
18 changed files
with
726 additions
and
852 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
from .data_args import DataArguments | ||
from .gen_args import GenerationArguments | ||
from .lora_args import LoraArguments | ||
from .model_args import ModelArguments | ||
from .quant_args import QuantArguments | ||
from .train_args import TrainingArguments | ||
|
||
__all__ = ['DataArguments', 'ModelArguments','TrainingArguments'] | ||
__all__ = [ | ||
'DataArguments', 'GenerationArguments', 'ModelArguments', | ||
'TrainingArguments', 'LoraArguments','QuantArguments' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from dataclasses import asdict, dataclass, field | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
@dataclass | ||
class GenerationArguments: | ||
# For more hyperparameters check: | ||
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig | ||
# Length arguments | ||
max_new_tokens: Optional[int] = field( | ||
default=256, | ||
metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" | ||
"if predict_with_generate is set."} | ||
) | ||
min_new_tokens : Optional[int] = field( | ||
default=None, | ||
metadata={"help": "Minimum number of new tokens to generate."} | ||
) | ||
|
||
# Generation strategy | ||
do_sample: Optional[bool] = field(default=False) | ||
num_beams: Optional[int] = field(default=1) | ||
num_beam_groups: Optional[int] = field(default=1) | ||
penalty_alpha: Optional[float] = field(default=None) | ||
use_cache: Optional[bool] = field(default=False) | ||
|
||
# Hyperparameters for logit manipulation | ||
temperature: Optional[float] = field(default=1.0) | ||
top_k: Optional[int] = field(default=50) | ||
top_p: Optional[float] = field(default=1.0) | ||
typical_p: Optional[float] = field(default=1.0) | ||
diversity_penalty: Optional[float] = field(default=0.0) | ||
repetition_penalty: Optional[float] = field(default=1.0) | ||
length_penalty: Optional[float] = field(default=1.0) | ||
no_repeat_ngram_size: Optional[int] = field(default=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class LoraArguments: | ||
# Number of columns of matrix A and number of rows of matrix B in Lora | ||
lora_r: int = field(default=64, metadata={'help': 'Lora R dimension.'}) | ||
# Scaling factor | ||
lora_alpha: float = field(default=16, metadata={'help': ' Lora alpha.'}) | ||
lora_dropout: float = field(default=0.0, | ||
metadata={'help': 'Lora dropout.'}) | ||
# Size of memory available on each GPU, in MB. The default is 80GB1 for the high-end version of the A100 | ||
max_memory_MB: int = field(default=8000, | ||
metadata={'help': 'Free memory per gpu.'}) | ||
lora_weight_path: str = '' | ||
bias: str = 'none' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class QuantArguments: | ||
# With 8-bit adam, can you adjust to LION or Sophia, and even deepspeed offers multiple 1-bit optimizer options0 | ||
adam8bit: bool = field(default=False, metadata={'help': 'Use 8-bit adam.'}) | ||
# Whether to use quadratic quantization | ||
double_quant: bool = field( | ||
default=True, | ||
metadata={ | ||
'help': | ||
'Compress the quantization statistics through double quantization.' | ||
}) | ||
# Quantization type, you can choose fp4 or nf4 | ||
quant_type: str = field( | ||
default='nf4', | ||
metadata={ | ||
'help': | ||
'Quantization data type to use. Should be one of `fp4` or `nf4`.' | ||
}) | ||
# Bit width used, default is 4. | ||
bits: int = field(default=4, metadata={'help': 'How many bits to use.'}) | ||
|
||
def __post_init__(self): | ||
if self.bits is not None: | ||
assert self.bits in [ | ||
4, 8 | ||
], 'We only accept 4-bit or 8-bit quantization.' | ||
|
||
if self.quant_type is not None: | ||
assert self.quant_type in [ | ||
'nf4', 'fp4' | ||
], 'We only accept `nf4` or `fp4` quantization type.' |
Oops, something went wrong.