Skip to content

Commit

Permalink
enhance resume for map_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
LZHgrla committed Nov 8, 2023
1 parent 6b99303 commit 00c5d0c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
10 changes: 9 additions & 1 deletion xtuner/dataset/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datasets import DatasetDict
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.utils.misc import get_object_from_string

from ..registry import BUILDER, MAP_FUNC
from .utils import Packer, encode_fn
Expand Down Expand Up @@ -76,7 +77,14 @@ def process_hf_dataset(dataset,
# Extract the useful data for training from the original dataset.
if dataset_map_fn is not None:
if isinstance(dataset_map_fn, str):
dataset_map_fn = MAP_FUNC.get(dataset_map_fn)
map_fn_obj = MAP_FUNC.get(
dataset_map_fn) or get_object_from_string(dataset_map_fn)
if map_fn_obj is not None:
dataset_map_fn = map_fn_obj
else:
raise TypeError('dataset_map_fn must be a function or a '
"registered function's string in MAP_FUNC, "
f"but got a string of '{dataset_map_fn}'")

dataset = dataset.map(dataset_map_fn)

Expand Down
4 changes: 4 additions & 0 deletions xtuner/dataset/map_fns/template_map_fn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

from mmengine.utils.misc import get_object_from_string


def template_map_fn(example, template):
conversation = example.get('conversation', [])
Expand All @@ -22,4 +24,6 @@ def template_map_fn(example, template):


def template_map_fn_factory(template):
if isinstance(template, str): # for resume
template = get_object_from_string(template)
return partial(template_map_fn, template=template)
3 changes: 3 additions & 0 deletions xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.utils.misc import get_object_from_string
from transformers import GenerationConfig, StoppingCriteriaList

from xtuner.dataset.utils import expand2square, load_image
Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(self,
if prompt_template is None:
instruction = '{input}'
else:
if isinstance(prompt_template, str): # for resume
prompt_template = get_object_from_string(prompt_template)
instruction = prompt_template.get('INSTRUCTION', '{input}')
if system != '':
system = prompt_template.get(
Expand Down

0 comments on commit 00c5d0c

Please sign in to comment.