From 8695dd1fa0a6fc6d54624ef478e2952f0eb6f56f Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Mon, 2 Dec 2024 19:18:19 +1100 Subject: [PATCH] feat: bump accelerate to 1.0.0 Signed-off-by: jerryzhuang --- .github/workflows/preset-image-build.yml | 12 +++++++--- docker/presets/models/tfs/Dockerfile | 5 ++-- .../workspace/dependencies/requirements.txt | 4 ++-- .../text-generation/inference_api.py | 2 +- .../workspace/tuning/text-generation/cli.py | 9 +++++-- .../tuning/text-generation/fine_tuning.py | 21 +++++++++------- .../tuning/text-generation/parser.py | 24 ++++++++++++++++--- 7 files changed, 55 insertions(+), 22 deletions(-) diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index 8b130a5bf..92e68190e 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -23,7 +23,10 @@ on: type: boolean default: false description: "Run all models for build" - + force-run-all-public: + type: boolean + default: false + description: "Run all public models for build" env: GO_VERSION: "1.22" BRANCH_NAME: ${{ github.head_ref || github.ref_name }} @@ -49,14 +52,17 @@ jobs: - name: Set FORCE_RUN_ALL Flag id: set_force_run_all - run: echo "FORCE_RUN_ALL=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all == 'true' }}" >> $GITHUB_OUTPUT - + run: | + echo "FORCE_RUN_ALL=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all == 'true' }}" >> $GITHUB_OUTPUT + echo "FORCE_RUN_ALL_PUBLIC=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.force-run-all-public == 'true' }}" >> $GITHUB_OUTPUT + # This script should output a JSON array of model names - name: Determine Affected Models id: affected_models run: | PR_BRANCH=${{ env.BRANCH_NAME }} \ FORCE_RUN_ALL=${{ steps.set_force_run_all.outputs.FORCE_RUN_ALL }} \ + FORCE_RUN_ALL_PUBLIC=${{ steps.set_force_run_all.outputs.FORCE_RUN_ALL_PUBLIC }} \ python3 .github/workflows/kind-cluster/determine_models.py - name: Print Determined Models diff --git a/docker/presets/models/tfs/Dockerfile b/docker/presets/models/tfs/Dockerfile index fa369114c..63aa4612b 100644 --- a/docker/presets/models/tfs/Dockerfile +++ b/docker/presets/models/tfs/Dockerfile @@ -7,6 +7,9 @@ ARG VERSION # Set the working directory WORKDIR /workspace +# Model weights +COPY ${WEIGHTS_PATH} /workspace/weights + COPY kaito/presets/workspace/dependencies/requirements.txt /workspace/requirements.txt RUN pip install --no-cache-dir -r /workspace/requirements.txt @@ -26,8 +29,6 @@ COPY kaito/presets/workspace/inference/vllm/inference_api.py /workspace/vllm/inf # Chat template ADD kaito/presets/workspace/inference/chat_templates /workspace/chat_templates -# Model weights -COPY ${WEIGHTS_PATH} /workspace/weights RUN echo $VERSION > /workspace/version.txt && \ ln -s /workspace/weights /workspace/tfs/weights && \ ln -s /workspace/weights /workspace/vllm/weights diff --git a/presets/workspace/dependencies/requirements.txt b/presets/workspace/dependencies/requirements.txt index 2502fd63f..39a574e74 100644 --- a/presets/workspace/dependencies/requirements.txt +++ b/presets/workspace/dependencies/requirements.txt @@ -2,9 +2,9 @@ # Core Dependencies vllm==0.6.3 -transformers >= 4.45.0 +transformers == 4.45.0 torch==2.4.0 -accelerate==0.30.1 +accelerate==1.0.0 fastapi>=0.111.0,<0.112.0 # Allow patch updates pydantic>=2.9 uvicorn[standard]>=0.29.0,<0.30.0 # Allow patch updates diff --git a/presets/workspace/inference/text-generation/inference_api.py b/presets/workspace/inference/text-generation/inference_api.py index 871dfbbc9..23de55487 100644 --- a/presets/workspace/inference/text-generation/inference_api.py +++ b/presets/workspace/inference/text-generation/inference_api.py @@ -103,7 +103,7 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]: resolved_chat_template = Path(chat_template).read_text() logger.info("Chat template loaded successfully") - logger.info("Chat template: %s", resolved_chat_template) + logger.info("Chat template:\n%s", resolved_chat_template) return resolved_chat_template diff --git a/presets/workspace/tuning/text-generation/cli.py b/presets/workspace/tuning/text-generation/cli.py index d6c0b0680..9e848a8de 100644 --- a/presets/workspace/tuning/text-generation/cli.py +++ b/presets/workspace/tuning/text-generation/cli.py @@ -2,8 +2,6 @@ # Licensed under the MIT license. import os from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum, auto from typing import Any, Dict, List, Optional import torch @@ -78,6 +76,7 @@ class ModelConfig: load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"}) torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"}) device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"}) + chat_template: Optional[str] = field(default=None, metadata={"help": "The file path to the chat template, or the template in single-line form for the specified model"}) def __post_init__(self): """ @@ -89,6 +88,12 @@ def __post_init__(self): elif not isinstance(self.torch_dtype, torch.dtype): raise ValueError(f"Invalid torch dtype: {self.torch_dtype}") + def get_tokenizer_args(self): + return {k: v for k, v in self.__dict__.items() if k not in ["torch_dtype", "chat_template"]} + + def get_model_args(self): + return {k: v for k, v in self.__dict__.items() if k not in ["chat_template"]} + @dataclass class QuantizationConfig(BitsAndBytesConfig): """ diff --git a/presets/workspace/tuning/text-generation/fine_tuning.py b/presets/workspace/tuning/text-generation/fine_tuning.py index d334fbbf1..213670e57 100644 --- a/presets/workspace/tuning/text-generation/fine_tuning.py +++ b/presets/workspace/tuning/text-generation/fine_tuning.py @@ -2,26 +2,26 @@ # Licensed under the MIT license. import logging import os -import sys from dataclasses import asdict from datetime import datetime -from parser import parse_configs +from parser import parse_configs, load_chat_template import torch -import transformers from accelerate import Accelerator from dataset import DatasetManager from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import (AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig, HfArgumentParser, Trainer, - TrainerCallback, TrainerControl, TrainerState, - TrainingArguments) + BitsAndBytesConfig, + TrainerCallback, TrainerControl, TrainerState) from trl import SFTTrainer # Initialize logger logger = logging.getLogger(__name__) debug_mode = os.environ.get('DEBUG_MODE', 'false').lower() == 'true' -logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO) +logging.basicConfig( + level=logging.DEBUG if debug_mode else logging.INFO, + format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', + datefmt='%m-%d %H:%M:%S') CONFIG_YAML = os.environ.get('YAML_FILE_PATH', '/mnt/config/training_config.yaml') parsed_configs = parse_configs(CONFIG_YAML) @@ -36,7 +36,7 @@ accelerator = Accelerator() # Load Model Args -model_args = asdict(model_config) +model_args = model_config.get_model_args() if accelerator.distributed_type != "NO": # Meaning we require distributed training logger.debug("Setting device map for distributed training") model_args["device_map"] = {"": accelerator.process_index} @@ -47,10 +47,13 @@ enable_qlora = bnb_config.is_quantizable() # Load the Pre-Trained Tokenizer -tokenizer_args = {key: value for key, value in model_args.items() if key != "torch_dtype"} +tokenizer_args = model_config.get_tokenizer_args() +resovled_chat_template = load_chat_template(model_config.chat_template) tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token +if resovled_chat_template is not None: + tokenizer.chat_template = resovled_chat_template if dc_args.mlm and tokenizer.mask_token is None: logger.warning( "This tokenizer does not have a mask token which is necessary for masked language modeling. " diff --git a/presets/workspace/tuning/text-generation/parser.py b/presets/workspace/tuning/text-generation/parser.py index 619ebdfbb..4c07b68c8 100644 --- a/presets/workspace/tuning/text-generation/parser.py +++ b/presets/workspace/tuning/text-generation/parser.py @@ -1,14 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import os -import sys -from collections import defaultdict +import logging from dataclasses import asdict, fields +import codecs +from pathlib import Path +from typing import Optional import yaml from cli import (DatasetConfig, ExtDataCollator, ExtLoraConfig, ModelConfig, QuantizationConfig) from transformers import HfArgumentParser, TrainingArguments +logger = logging.getLogger(__name__) + # Mapping from config section names to data classes CONFIG_CLASS_MAP = { 'ModelConfig': ModelConfig, @@ -69,3 +72,18 @@ def parse_configs(config_yaml): parsed_configs[section_name] = CONFIG_CLASS_MAP[section_name](**filtered_config) return parsed_configs + +def load_chat_template(chat_template: Optional[str]) -> Optional[str]: + logger.info(chat_template) + if chat_template is None: + return None + + JINJA_CHARS = "{}\n" + if any(c in chat_template for c in JINJA_CHARS): + resolved_chat_template = codecs.decode(chat_template, "unicode_escape") + else: + resolved_chat_template = Path(chat_template).read_text() + + logger.info("Chat template loaded successfully") + logger.info("Chat template:\n%s", resolved_chat_template) + return resolved_chat_template