From b6f89a90ce86978054e3445155223425eca31efb Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Fri, 20 Dec 2024 19:09:32 +0800 Subject: [PATCH 1/7] [Feat] OP-wise Insight Mining (#516) * + add auto mode for analyzer: load all filters that produce stats to analyze the target dataset * + add default mem_required for those model-based OPs * - support wordcloud drawing for str or str list fields in stats - support set the number of samples to be analyzed in auto mode. It's 1k in default. * - take the minimum one of dataset length and auto num * * update default export path * * set version limit for wandb to avoid exception * + add docs for auto mode * + support t-test for Measure * * fix some bugs * - support analyze a dataset object - optimize the logics of loading filters that produce stats and updating attributes of OPs * - support analysis on tags in meta * - support analysis with tagging OPs * - move tags into the meta field * - do not tell tags using their suffix - suppress the error/exceptions in Monitor due to the termination of the main process - exported stats file includes meta field in exporter * - add insight mining * * resolve the bugs when running insight mining in multiprocessing mode * * update unittests * * update unittests * * update unittests * * update readme for analyzer * * use more detailed key * + add reference --- README.md | 4 +- README_ZH.md | 4 +- data_juicer/analysis/column_wise_analysis.py | 24 +- data_juicer/analysis/measure.py | 111 +++++++++ data_juicer/analysis/overall_analysis.py | 16 +- data_juicer/config/config.py | 84 ++++--- data_juicer/core/adapter.py | 125 +++++++++- data_juicer/core/analyzer.py | 35 ++- data_juicer/core/data.py | 30 ++- data_juicer/core/executor.py | 1 + data_juicer/core/exporter.py | 11 +- data_juicer/core/monitor.py | 8 +- data_juicer/ops/__init__.py | 5 +- data_juicer/ops/base_op.py | 18 +- .../ops/filter/specified_field_filter.py | 7 +- .../filter/specified_numeric_field_filter.py | 8 +- data_juicer/ops/filter/suffix_filter.py | 7 +- .../video_tagging_from_frames_filter.py | 7 +- .../ops/mapper/image_tagging_mapper.py | 10 +- .../mapper/video_tagging_from_audio_mapper.py | 11 +- .../video_tagging_from_frames_mapper.py | 10 +- data_juicer/utils/cache_utils.py | 47 ++++ data_juicer/utils/constant.py | 6 +- .../test_video_tagging_from_frames_filter.py | 8 +- tests/ops/mapper/test_image_tagging_mapper.py | 102 ++++---- .../test_video_tagging_from_audio_mapper.py | 5 +- .../test_video_tagging_from_frames_mapper.py | 220 ++++++++++-------- 27 files changed, 680 insertions(+), 244 deletions(-) diff --git a/README.md b/README.md index d891ac332..586869b0a 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,9 @@ dj-analyze --config configs/demo/analyzer.yaml dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` -- **Note:** Analyzer only compute stats of Filter ops. So extra Mapper or Deduplicator ops will be ignored in the analysis process. +- **Note:** Analyzer only compute stats for Filters that produce stats or other OPs that produce tags/categories in meta. So other OPs will be ignored in the analysis process. We use the following registries to decorate OPs: + - `NON_STATS_FILTERS`: decorate Filters that **DO NOT** produce any stats. + - `TAGGING_OPS`: decorate OPs that **DO** produce tags/categories in meta field. ### Data Visualization diff --git a/README_ZH.md b/README_ZH.md index 01633731b..42612964a 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -316,7 +316,9 @@ dj-analyze --config configs/demo/analyzer.yaml dj-analyze --auto --dataset_path xx.jsonl [--auto_num 1000] ``` -* **注意**:Analyzer 只计算 Filter 算子的状态,其他的算子(例如 Mapper 和 Deduplicator)会在分析过程中被忽略。 +* **注意**:Analyzer 只用于能在 stats 字段里产出统计信息的 Filter 算子和能在 meta 字段里产出 tags 或类别标签的其他算子。除此之外的其他的算子会在分析过程中被忽略。我们使用以下两种注册器来装饰相关的算子: + * `NON_STATS_FILTERS`:装饰那些**不能**产出任何统计信息的 Filter 算子。 + * `TAGGING_OPS`:装饰那些能在 meta 字段中产出 tags 或类别标签的算子。 ### 数据可视化 diff --git a/data_juicer/analysis/column_wise_analysis.py b/data_juicer/analysis/column_wise_analysis.py index 825d9b4dd..ce5b3617d 100644 --- a/data_juicer/analysis/column_wise_analysis.py +++ b/data_juicer/analysis/column_wise_analysis.py @@ -6,7 +6,7 @@ from tqdm import tqdm from wordcloud import WordCloud -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import DEFAULT_PREFIX, Fields from .overall_analysis import OverallAnalysis @@ -70,6 +70,12 @@ def __init__(self, stats into one image file """ self.stats = pd.DataFrame(dataset[Fields.stats]) + self.meta = pd.DataFrame(dataset[Fields.meta]) + # remove non-tag columns + meta_columns = self.meta.columns + for col_name in meta_columns: + if not col_name.startswith(DEFAULT_PREFIX): + self.meta = self.meta.drop(col_name, axis=1) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) @@ -101,8 +107,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False): width_unit = 4 height_unit = 6 - columns = self.stats.columns - num = len(columns) + stats_and_meta = pd.concat([self.stats, self.meta], axis=1) + all_columns = stats_and_meta.columns + num = len(all_columns) # get the recommended "best" number of columns and rows rec_row, rec_col, grid_indexes = get_row_col(num, num_subcol) @@ -115,9 +122,9 @@ def analyze(self, show_percentiles=False, show=False, skip_export=False): fig = plt.figure(figsize=(rec_width, rec_height), layout='constrained') subfigs = fig.subfigures(rec_row, rec_col, wspace=0.01) - for i, column_name in enumerate(tqdm(columns.to_list(), - desc='Column')): - data = self.stats[column_name] + for i, column_name in enumerate( + tqdm(all_columns.to_list(), desc='Column')): + data = stats_and_meta[column_name] # explode data to flatten inner list data = data.explode().infer_objects() grid = grid_indexes[i] @@ -210,10 +217,7 @@ def draw_hist(self, ax, data, save_path, percentiles=None, show=False): """ # recommended number of bins data_num = len(data) - if data_num >= 100: - rec_bins = int(math.sqrt(len(data))) - else: - rec_bins = None + rec_bins = max(int(math.sqrt(data_num)), 10) # if ax is None, using plot method in pandas if ax is None: diff --git a/data_juicer/analysis/measure.py b/data_juicer/analysis/measure.py index fe54cdabd..bd97e811c 100644 --- a/data_juicer/analysis/measure.py +++ b/data_juicer/analysis/measure.py @@ -1,9 +1,13 @@ +import numpy as np + from data_juicer.utils.lazy_loader import LazyLoader torch = LazyLoader('torch', 'torch') td = LazyLoader('td', 'torch.distributions') F = LazyLoader('F', 'torch.nn.functional') +stats = LazyLoader('stats', 'scipy.stats') + class Measure(object): """Base class for Measure distribution. @@ -48,6 +52,15 @@ def _convert_to_categorical(self, p): else: return td.Categorical(torch.tensor(p)) + def _convert_to_ndarray(self, p): + """ + Convert input data to torch tensor. + :param p: input data, now support + [`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. + :return: torch tensor + """ + return self._convert_to_tensor(p).numpy() + class KLDivMeasure(Measure): """ @@ -108,3 +121,101 @@ class EntropyMeasure(Measure): def measure(self, p): p = self._convert_to_categorical(p) return p.entropy() + + +class RelatedTTestMeasure(Measure): + """ + Measure T-Test for two related distributions on their histogram of the same + bins. + + Ref: + https://en.wikipedia.org/wiki/Student%27s_t-test + + For continuous features or distributions, the input could be dataset stats + list. + For discrete features or distributions, the input could be the tags or the + categories list. + """ + name = 't-test' + + @staticmethod + def stats_to_hist(p, q): + p = np.array(p) + q = np.array(q) + + # get common maximum number of data samples, and max/min values + max_data_num = max(len(p), len(q)) + min_val = min(min(p), min(q)) + max_val = max(max(p), max(q)) + + # get a recommended number of bins + rec_bins = max(int(np.sqrt(max_data_num)), 10) + + # get the common bin edges + common_p = np.append(p, [min_val, max_val]) + hist_p, bin_edges = np.histogram(common_p, bins=rec_bins) + # restore the hist of the original p + hist_p[0] -= 1 + hist_p[-1] -= 1 + # get the hist of the original q using the common bin edges + hist_q, _ = np.histogram(q, bins=bin_edges) + return hist_p, hist_q, bin_edges + + @staticmethod + def category_to_hist(p, q): + + def flatten_list(lst): + res = [] + for s in lst: + if isinstance(s, list): + res.extend(flatten_list(s)) + else: + res.append(s) + return res + + # flatten the list + p = flatten_list(p) + q = flatten_list(q) + + # get the common categories + cat_p = set(p) + cat_q = set(q) + cat_common = cat_p.union(cat_q) + + # get category distributions + count_p = {cat: 0 for cat in cat_common} + count_q = {cat: 0 for cat in cat_common} + for cat in p: + count_p[cat] += 1 + for cat in q: + count_q[cat] += 1 + + # only keep distribution values sorted by counts + sorted_cat = list(count_p.items()) + sorted_cat.sort(key=lambda it: it[1], reverse=True) + sorted_cat = [it[0] for it in sorted_cat] + # get the value dist + hist_p = [count_p[cat] for cat in sorted_cat] + hist_q = [count_q[cat] for cat in sorted_cat] + + return hist_p, hist_q, count_p, count_q, sorted_cat + + def measure(self, p, q): + """ + :param p: the first feature or distribution. (stats/tags/categories) + :param q: the second feature or distribution. (stats/tags/categories) + :return: the T-Test results object -- ([ref](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats._result_classes.TtestResult.html#scipy.stats._result_classes.TtestResult)) # noqa: E501 + """ + ele = p[0] + while isinstance(ele, list): + ele = ele[0] + if isinstance(ele, str): + # discrete tags or categories + hist_p, hist_q = self.category_to_hist(p, q)[:2] + else: + # continuous stats + hist_p, hist_q = self.stats_to_hist(p, q)[:2] + + # compute the t-test and pval for hist_p and hist_q + ttest_res = stats.ttest_rel(hist_p, hist_q) + return ttest_res diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py index 04eefb178..696b25946 100644 --- a/data_juicer/analysis/overall_analysis.py +++ b/data_juicer/analysis/overall_analysis.py @@ -5,7 +5,7 @@ from loguru import logger from tqdm import tqdm -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import DEFAULT_PREFIX, Fields def _single_column_analysis(col, *args, **kwargs): @@ -25,6 +25,12 @@ def __init__(self, dataset, output_path): :param output_path: path to store the analysis results. """ self.stats = pd.DataFrame(dataset[Fields.stats]) + self.meta = pd.DataFrame(dataset[Fields.meta]) + # remove non-tag columns + meta_columns = self.meta.columns + for col_name in meta_columns: + if not col_name.startswith(DEFAULT_PREFIX): + self.meta = self.meta.drop(col_name, axis=1) self.output_path = output_path if not os.path.exists(self.output_path): os.makedirs(self.output_path) @@ -71,10 +77,14 @@ def analyze(self, percentiles=[], num_proc=1, skip_export=False): # merge default and customized percentiles and get overall information percentiles = list(set(percentiles + self.default_percentiles)) + # merge stats and meta + stats_and_meta = pd.concat([self.stats, self.meta], axis=1) + all_columns = stats_and_meta.columns + results = [] pool = Pool(num_proc) - for col_name in self.stats.columns: - this_col = self.refine_single_column(self.stats[col_name]) + for col_name in all_columns: + this_col = self.refine_single_column(stats_and_meta[col_name]) res = pool.apply_async(_single_column_analysis, kwds={ 'col': this_col, diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index c7f0aaf38..028f3cf79 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -290,6 +290,22 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None): help='Number of samples extracted by tracer to show the dataset ' 'difference before and after a op. Only available when ' 'open_tracer is true.') + parser.add_argument( + '--open_insight_mining', + type=bool, + default=False, + help='Whether to open insight mining to trace the OP-wise stats/tags ' + 'changes during process. It might take more time when opening ' + 'insight mining.') + parser.add_argument( + '--op_list_to_mine', + type=List[str], + default=[], + help='Which OPs will be applied on the dataset to mine the insights ' + 'in their stats changes. Only those OPs that produce stats or ' + 'meta are valid. If it\'s empty, all OPs that produce stats and ' + 'meta will be involved. Only available when filter_list_to_mine ' + 'is true.') parser.add_argument( '--op_fusion', type=bool, @@ -513,13 +529,7 @@ def init_setup_from_cfg(cfg: Namespace): # add all filters that produce stats if cfg.auto: - import pkgutil - - import data_juicer.ops.filter as djfilters - cfg.process = [{ - filter_name: {} - } for _, filter_name, _ in pkgutil.iter_modules(djfilters.__path__) - if filter_name not in djfilters.NON_STATS_FILTERS] + cfg.process = load_ops_with_stats_meta() # Apply text_key modification during initializing configs # users can freely specify text_key for different ops using `text_key` @@ -528,34 +538,48 @@ def init_setup_from_cfg(cfg: Namespace): text_key = cfg.text_keys[0] else: text_key = cfg.text_keys - for op in cfg.process: + op_attrs = { + 'text_key': text_key, + 'image_key': cfg.image_key, + 'audio_key': cfg.audio_key, + 'video_key': cfg.video_key, + 'num_proc': cfg.np, + 'turbo': cfg.turbo, + } + cfg.process = update_op_attr(cfg.process, op_attrs) + + return cfg + + +def load_ops_with_stats_meta(): + import pkgutil + + import data_juicer.ops.filter as djfilter + from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS + stats_filters = [{ + filter_name: {} + } for _, filter_name, _ in pkgutil.iter_modules(djfilter.__path__) + if filter_name not in NON_STATS_FILTERS.modules] + meta_ops = [{op_name: {}} for op_name in TAGGING_OPS.modules] + return stats_filters + meta_ops + + +def update_op_attr(op_list: list, attr_dict: dict = None): + if not attr_dict: + return op_list + updated_op_list = [] + for op in op_list: for op_name in op: args = op[op_name] if args is None: - args = { - 'text_key': text_key, - 'image_key': cfg.image_key, - 'audio_key': cfg.audio_key, - 'video_key': cfg.video_key, - 'num_proc': cfg.np, - 'turbo': cfg.turbo, - } + args = attr_dict else: - if 'text_key' not in args or args['text_key'] is None: - args['text_key'] = text_key - if 'image_key' not in args or args['image_key'] is None: - args['image_key'] = cfg.image_key - if 'audio_key' not in args or args['audio_key'] is None: - args['audio_key'] = cfg.audio_key - if 'video_key' not in args or args['video_key'] is None: - args['video_key'] = cfg.video_key - if 'num_proc' not in args or args['num_proc'] is None: - args['num_proc'] = cfg.np - if 'turbo' not in args or args['turbo'] is None: - args['turbo'] = cfg.turbo + for key in attr_dict: + if key not in args or args[key] is None: + args[key] = attr_dict[key] op[op_name] = args - - return cfg + updated_op_list.append(op) + return updated_op_list def _collect_config_info_from_class_docs(configurable_ops, parser): diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py index 5ab6e6ec8..64fd622f0 100644 --- a/data_juicer/core/adapter.py +++ b/data_juicer/core/adapter.py @@ -1,8 +1,15 @@ -from datasets import concatenate_datasets +import json +import os +from copy import deepcopy + +from datasets import Dataset, concatenate_datasets from datasets.config import DEFAULT_MAX_BATCH_SIZE +from data_juicer.analysis.measure import RelatedTTestMeasure from data_juicer.core.monitor import Monitor from data_juicer.ops import UNFORKABLE +from data_juicer.utils.cache_utils import dataset_cache_control +from data_juicer.utils.constant import Fields from data_juicer.utils.process_utils import setup_mp @@ -12,6 +19,11 @@ class Adapter: def __init__(self, cfg: dict): self.cfg = cfg + + # insight mining related + self.enable_insight_mining = self.cfg.open_insight_mining + + # resource probe related self.idle_resources = Monitor.monitor_current_resources() @staticmethod @@ -108,25 +120,21 @@ def adapt_workloads(self, dataset, operators): return bs_per_op + @dataset_cache_control(on=True) def probe_small_batch(self, dataset, operators): """ Perform small batch pre-execution to probe available resources, current load and estimated OP speed, returning load factors and speed ranks for each OP. - Notice: the probe should be run with cache enabled. + Notice: the probe should be run with cache enabled to avoid removing + the cache files of the input dataset. :param dataset: The dataset to pre-execute small batch on :param operators: The OP list to be pre-execution and probe :return: A list of probe results for each OP and the length of data batch to probe. """ - # record the cache state and enable the cache - from datasets import (disable_caching, enable_caching, - is_caching_enabled) - previous_state = is_caching_enabled() - if not previous_state: - enable_caching() # take a small batch data_batch = self.take_batch(dataset, self.cfg) @@ -135,10 +143,6 @@ def probe_small_batch(self, dataset, operators): # analyze resource utilization analysis_res = Monitor.analyze_resource_util_list(resource_util_list) - # if the cache is disabled before, disable it again - if not previous_state: - disable_caching() - return analysis_res, len(data_batch) def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): @@ -177,3 +181,100 @@ def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): batch_size_per_op.append(bs_this_op) return batch_size_per_op + + @dataset_cache_control(on=True) + def analyze_small_batch(self, dataset, current_state): + """ + Perform small batch analysis to probe the current OP-wise stats/meta + distributions. The analyzed results will be stored in the directory + `{work_dir}/insight_mining`. + + Notice: the probe should be run with cache enabled to avoid removing + the cache files of the input dataset. + + :param dataset: The dataset to analyze small batch on + :param current_state: A string to indicate the current state of the + input dataset. It usually consists of a number of the index of the + OP processed just now and the OP name, e.g. "1_text_length_filter". + """ + # prepare analyzer config + new_cfg = deepcopy(self.cfg) + # check ops to mine + new_cfg.auto = True + new_cfg.config = None + if len(new_cfg.op_list_to_mine) > 0: + new_cfg.process = [{ + op_name: {} + } for op_name in new_cfg.op_list_to_mine] + # update work dir + new_cfg.work_dir = os.path.join(new_cfg.work_dir, 'insight_mining', + current_state) + new_cfg.export_path = os.path.join(new_cfg.work_dir, + f'{current_state}.jsonl') + # close insight mining and monitor for inner analysis + new_cfg.open_insight_mining = False + new_cfg.open_monitor = False + + # init the analyzer + from data_juicer.core.analyzer import Analyzer + analyzer = Analyzer(new_cfg) + + # remove existing stats and meta in dataset + target_fields = {Fields.stats, Fields.meta} + target_fields = target_fields.intersection(set(dataset.features)) + if len(target_fields) > 0: + dataset = dataset.remove_columns(list(target_fields)) + analyzer.run(dataset, skip_return=True) + + def insight_mining(self, pval_th=0.05): + """ + Mining the insights from the OP-wise analysis results. For now, we use + T-Test to check the significance of stats/meta changes before and after + each OP processing. If the p-value is less than a given threshold + (usually 0.05), we think the stats/meta changes are significant. The + insight mining results will be stored in the file + `{work_dir}/insight_mining/insight_mining.json`. + + :param pval_th: the threshold of p-value. + """ + work_dir = os.path.join(self.cfg.work_dir, 'insight_mining') + res_order = [ + d for d in os.listdir(work_dir) + if os.path.isdir(os.path.join(work_dir, d)) + ] + res_order.sort() + + # collect analysis results + analysis_results = {} + for res_dir in res_order: + res = Dataset.from_json( + os.path.join(work_dir, res_dir, + f'{res_dir}_stats.jsonl')).flatten() + analysis_results[res_dir] = res + + # distribution change significance analysis + ttest_measure = RelatedTTestMeasure() + + sig_res = {} + # i = 0 is the original dataset + for i in range(1, len(res_order)): + prev_res = analysis_results[res_order[i - 1]] + curr_res = analysis_results[res_order[i]] + + # only consider common stats and meta + common_features = list( + set(prev_res.features).intersection(set(curr_res.features))) + curr_sig_res = {} + for feat in common_features: + ttest_res = ttest_measure(prev_res[feat], curr_res[feat]) + curr_sig_res[feat] = { + 't-statistic (standardized mean difference)': + ttest_res.statistic, + 'p-value': ttest_res.pvalue, + 'significant': + True if ttest_res.pvalue < pval_th else False, + } + sig_res[res_order[i]] = curr_sig_res + + with open(os.path.join(work_dir, 'insight_mining.json'), 'w') as out: + json.dump(sig_res, out) diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index 63e512d41..d9ac586e9 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -1,6 +1,7 @@ import os -from typing import Optional +from typing import Optional, Union +from datasets import Dataset from jsonargparse import Namespace from loguru import logger from pydantic import PositiveInt @@ -8,11 +9,12 @@ from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis from data_juicer.config import init_configs from data_juicer.format import load_formatter -from data_juicer.ops import Filter, load_ops +from data_juicer.ops import NON_STATS_FILTERS, TAGGING_OPS, Filter, load_ops from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils from .adapter import Adapter +from .data import NestedDataset from .exporter import Exporter @@ -71,22 +73,27 @@ def __init__(self, cfg: Optional[Namespace] = None): self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis') def run(self, + dataset: Union[Dataset, NestedDataset] = None, load_data_np: Optional[PositiveInt] = None, skip_export: bool = False, skip_return: bool = False): """ Running the dataset analysis pipeline. + :param dataset: a Dataset object to be analyzed. :param load_data_np: number of workers when loading the dataset. :param skip_export: whether export the results into disk :param skip_return: skip return for API called. :return: analyzed dataset. """ # 1. format data - logger.info('Loading dataset from data formatter...') if load_data_np is None: load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) + if dataset is None: + logger.info('Loading dataset from data formatter...') + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + else: + logger.info(f'Using existing dataset {dataset}') if self.cfg.auto: # if it's auto analysis, only analyze for a minor part of the input # dataset to save time and computing resource @@ -111,16 +118,26 @@ def run(self, logger.info('Computing the stats of dataset...') stats_collected = False for op in ops: - if isinstance(op, Filter): + if isinstance(op, Filter) \ + and op._name not in NON_STATS_FILTERS.modules: original_process = op.process op.process = None - dataset = dataset.process(op, work_dir=self.work_dir) + dataset = dataset.process(op, + work_dir=self.work_dir, + open_monitor=self.cfg.open_monitor) op.process = original_process stats_collected = True + elif op._name in TAGGING_OPS.modules: + dataset = dataset.process(op, + work_dir=self.work_dir, + open_monitor=self.cfg.open_monitor) + stats_collected = True if not stats_collected: - logger.warning('No stats collected. Please add some Filter ops to ' - 'the process list in configs.') - return dataset + logger.warning( + 'No stats/meta collected. Please add some Filter OPs or ' + 'Tagging OPs to the process list in configs.') + if not skip_return: + return dataset # 3. data export logger.info('Exporting dataset to disk...') diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 361f6e8a0..d0f8083e1 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -172,6 +172,7 @@ def process( exporter=None, checkpointer=None, tracer=None, + adapter=None, open_monitor=True, ): if operators is None: @@ -185,9 +186,19 @@ def process( if open_monitor: resource_util_list = [] + # whether to enable insight mining + enable_insight_mining = adapter.enable_insight_mining \ + if adapter else False + # record the analysis results of the original dataset + if enable_insight_mining: + logger.info('Analyze small batch for the original dataset for ' + 'insight mining...') + adapter.analyze_small_batch(self, '0_original') + dataset = self + op_num = len(operators) try: - for op in operators: + for idx, op in enumerate(operators, start=1): mp_context = ['forkserver', 'spawn'] if ( op.use_cuda() or op._name in unforkable_operators) else None @@ -211,8 +222,16 @@ def process( if open_monitor: resource_util_list.append(resource_util_per_op) end = time() - logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' - f'Left {len(dataset)} samples.') + logger.info( + f'[{idx}/{op_num}] OP [{op._name}] Done in ' + f'{end - start:.3f}s. Left {len(dataset)} samples.') + + # record the analysis results of the current dataset + if enable_insight_mining: + logger.info( + f'Analyze small batch for the current dataset after ' + f'OP [{op._name}] for insight mining...') + adapter.analyze_small_batch(dataset, f'{idx}_{op._name}') except: # noqa: E722 logger.error(f'An error occurred during Op [{op._name}].') traceback.print_exc() @@ -223,6 +242,7 @@ def process( 'last op...') dataset.cleanup_cache_files() checkpointer.save_ckpt(dataset) + # make summarization on the monitor results if work_dir and open_monitor: # get the analyzed version resource_util_list = Monitor.analyze_resource_util_list( @@ -234,6 +254,10 @@ def process( json.dump(resource_util_list, out) Monitor.draw_resource_util_graph(resource_util_list, monitor_dir) + # make summarization on the insight mining results + if work_dir and enable_insight_mining: + logger.info('Insight mining for each OP...') + adapter.insight_mining() return dataset def update_args(self, args, kargs, is_filter=False): diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index f78059247..7f0d93a66 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -199,6 +199,7 @@ def run(self, exporter=self.exporter, checkpointer=self.ckpt_manager, tracer=self.tracer, + adapter=self.adapter, open_monitor=self.cfg.open_monitor, ) tend = time() diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index 72b555d34..dbdb4fb9f 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -106,10 +106,15 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): :param export_stats: whether to export stats of dataset. :return: """ - if Fields.stats in dataset.features and export_stats: + if export_stats: # export stats of datasets into a single file. logger.info('Exporting computed stats into a single file...') - ds_stats = dataset.select_columns(Fields.stats) + export_columns = [] + if Fields.stats in dataset.features: + export_columns.append(Fields.stats) + if Fields.meta in dataset.features: + export_columns.append(Fields.meta) + ds_stats = dataset.select_columns(export_columns) stats_file = export_path.replace('.' + suffix, '_stats.jsonl') Exporter.to_jsonl( ds_stats, @@ -119,7 +124,7 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): if self.export_ds: # fetch the corresponding export method according to the suffix if not self.keep_stats_in_res_ds: - extra_fields = {Fields.stats} + extra_fields = {Fields.stats, Fields.meta} feature_fields = set(dataset.features.keys()) removed_fields = extra_fields.intersection(feature_fields) dataset = dataset.remove_columns(removed_fields) diff --git a/data_juicer/core/monitor.py b/data_juicer/core/monitor.py index 0210e3732..d5fdee241 100644 --- a/data_juicer/core/monitor.py +++ b/data_juicer/core/monitor.py @@ -15,7 +15,13 @@ def resource_monitor(mdict, interval): while True: this_states.append(Monitor.monitor_current_resources()) time.sleep(interval) - if mdict['stop']: + try: + stop_sign = mdict['stop'] + except (BrokenPipeError, FileNotFoundError): + # mdict crushes due to the main process is terminated already, + # which is not the fault here + return + if stop_sign: break mdict['resource'] = this_states diff --git a/data_juicer/ops/__init__.py b/data_juicer/ops/__init__.py index e02e10efa..2ab622266 100644 --- a/data_juicer/ops/__init__.py +++ b/data_juicer/ops/__init__.py @@ -1,6 +1,7 @@ from . import aggregator, deduplicator, filter, grouper, mapper, selector -from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter, - Grouper, Mapper, Selector) +from .base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE, + Aggregator, Deduplicator, Filter, Grouper, Mapper, + Selector) from .load import load_ops __all__ = [ diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 2091a867e..39e23d8e9 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -14,6 +14,8 @@ OPERATORS = Registry('Operators') UNFORKABLE = Registry('Unforkable') +NON_STATS_FILTERS = Registry('Non-stats Filters') +TAGGING_OPS = Registry('Tagging Operators') def convert_list_dict_to_dict_list(samples): @@ -223,6 +225,18 @@ def run(self, dataset): from data_juicer.core.data import NestedDataset if not isinstance(dataset, NestedDataset): dataset = NestedDataset(dataset) + # add meta field for OPs that produce tags + if self._name in TAGGING_OPS.modules \ + and Fields.meta not in dataset.features: + from data_juicer.core.data import add_same_content_to_new_column + dataset = dataset.map(add_same_content_to_new_column, + fn_kwargs={ + 'new_column_name': Fields.meta, + 'initial_value': {} + }, + num_proc=self.runtime_np(), + batch_size=self.batch_size, + desc='Adding new column for meta') if self.index_key is not None: def add_index(sample, idx): @@ -404,7 +418,9 @@ def process_single(self, sample): def run(self, dataset, *, exporter=None, tracer=None, reduce=True): dataset = super(Filter, self).run(dataset) - if Fields.stats not in dataset.features: + # add stats field for Filters that produce stats + if self._name not in NON_STATS_FILTERS.modules \ + and Fields.stats not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, fn_kwargs={ diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py index 86aff2426..41addf8da 100644 --- a/data_juicer/ops/filter/specified_field_filter.py +++ b/data_juicer/ops/filter/specified_field_filter.py @@ -1,9 +1,12 @@ from typing import List -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter +OP_NAME = 'specified_field_filter' -@OPERATORS.register_module('specified_field_filter') + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SpecifiedFieldFilter(Filter): """ Filter based on specified field information. diff --git a/data_juicer/ops/filter/specified_numeric_field_filter.py b/data_juicer/ops/filter/specified_numeric_field_filter.py index 693be3392..c7a1d301a 100644 --- a/data_juicer/ops/filter/specified_numeric_field_filter.py +++ b/data_juicer/ops/filter/specified_numeric_field_filter.py @@ -1,6 +1,6 @@ import sys -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter def is_number(s): @@ -13,7 +13,11 @@ def is_number(s): return False -@OPERATORS.register_module('specified_numeric_field_filter') +OP_NAME = 'specified_numeric_field_filter' + + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SpecifiedNumericFieldFilter(Filter): """ Filter based on specified numeric field information. diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py index ea7868399..7aaca53a7 100644 --- a/data_juicer/ops/filter/suffix_filter.py +++ b/data_juicer/ops/filter/suffix_filter.py @@ -2,10 +2,13 @@ from data_juicer.utils.constant import Fields -from ..base_op import OPERATORS, Filter +from ..base_op import NON_STATS_FILTERS, OPERATORS, Filter +OP_NAME = 'suffix_filter' -@OPERATORS.register_module('suffix_filter') + +@NON_STATS_FILTERS.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) class SuffixFilter(Filter): """Filter to keep samples with specified suffix.""" diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index 8872aab32..2436d886c 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -5,7 +5,8 @@ from data_juicer.utils.constant import Fields -from ..base_op import OPERATORS, UNFORKABLE, Filter +from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE, + Filter) from ..mapper.video_tagging_from_frames_mapper import \ VideoTaggingFromFramesMapper from ..op_fusion import LOADED_VIDEOS @@ -13,6 +14,8 @@ OP_NAME = 'video_tagging_from_frames_filter' +@NON_STATS_FILTERS.register_module(OP_NAME) +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) @@ -91,7 +94,7 @@ def compute_stats_single(self, sample, rank=None, context=False): return sample def process_single(self, sample, rank=None): - video_tags = sample[self.tag_field_name] + video_tags = sample[Fields.meta][self.tag_field_name] if len(video_tags) <= 0: return True diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index e3fc46f1b..dc2099b78 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -7,7 +7,7 @@ from data_juicer.utils.mm_utils import load_data_with_context, load_image from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper from ..op_fusion import LOADED_IMAGES torch = LazyLoader('torch', 'torch') @@ -16,6 +16,7 @@ OP_NAME = 'image_tagging_mapper' +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) @@ -47,12 +48,13 @@ def __init__(self, def process_single(self, sample, rank=None, context=False): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: - sample[self.tag_field_name] = np.array([[]], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([[]], + dtype=np.str_) return sample # load images @@ -75,5 +77,5 @@ def process_single(self, sample, rank=None, context=False): sorted_word_list = [item for item, _ in word_count.most_common()] image_tags.append(np.array(sorted_word_list, dtype=np.str_)) - sample[self.tag_field_name] = image_tags + sample[Fields.meta][self.tag_field_name] = image_tags return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 2c32093a5..7302953f2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -6,13 +6,14 @@ from data_juicer.utils.mm_utils import extract_audio_from_video from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper torch = LazyLoader('torch', 'torch') OP_NAME = 'video_tagging_from_audio_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class VideoTaggingFromAudioMapper(Mapper): """Mapper to generate video tags from audio streams extracted by video @@ -50,12 +51,13 @@ def __init__(self, def process_single(self, sample, rank=None): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[self.tag_field_name] = np.array([], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([], + dtype=np.str_) return sample # load video paths @@ -90,5 +92,6 @@ def process_single(self, sample, rank=None): predicted_tag_id = torch.argmax(logits, dim=-1).item() predicted_tag = model.config.id2label[predicted_tag_id] video_audio_tags.append(predicted_tag) - sample[self.tag_field_name] = np.array(video_audio_tags, dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array(video_audio_tags, + dtype=np.str_) return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index d4995d3f6..31927e1b2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -10,7 +10,7 @@ load_data_with_context, load_video) from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper from ..op_fusion import LOADED_VIDEOS ram = LazyLoader('ram', 'ram') @@ -19,6 +19,7 @@ OP_NAME = 'video_tagging_from_frames_mapper' +@TAGGING_OPS.register_module(OP_NAME) @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) @@ -73,12 +74,13 @@ def __init__(self, def process_single(self, sample, rank=None, context=False): # check if it's generated already - if self.tag_field_name in sample: + if self.tag_field_name in sample[Fields.meta]: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[self.tag_field_name] = np.array([[]], dtype=np.str_) + sample[Fields.meta][self.tag_field_name] = np.array([[]], + dtype=np.str_) return sample # load videos @@ -115,5 +117,5 @@ def process_single(self, sample, rank=None, context=False): for vid_key in videos: close_video(videos[vid_key]) - sample[self.tag_field_name] = video_tags + sample[Fields.meta][self.tag_field_name] = video_tags return sample diff --git a/data_juicer/utils/cache_utils.py b/data_juicer/utils/cache_utils.py index 7d815db2c..51138d7ed 100644 --- a/data_juicer/utils/cache_utils.py +++ b/data_juicer/utils/cache_utils.py @@ -1,4 +1,7 @@ import os +from functools import wraps + +from datasets import disable_caching, enable_caching, is_caching_enabled # Default cache location DEFAULT_CACHE_HOME = '~/.cache' @@ -21,3 +24,47 @@ DEFAULT_DATA_JUICER_MODELS_CACHE) CACHE_COMPRESS = None + + +class DatasetCacheControl: + """Define a range that change the cache state temporarily.""" + + def __init__(self, on: bool = False): + self.on = on + + def __enter__(self): + """ + Record the original cache state and turn it to the target state. + """ + self.previous_state = is_caching_enabled() + if self.on: + enable_caching() + else: + disable_caching() + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Restore the original cache state. + """ + if self.previous_state: + enable_caching() + else: + disable_caching() + + +def dataset_cache_control(on): + """ + A more easy-to-use decorator for functions that need to control the cache + state temporarily. + """ + + def dataset_cache_decorator(func): + + @wraps(func) + def wrapped_function(*args, **kwargs): + with DatasetCacheControl(on=on): + return func(*args, **kwargs) + + return wrapped_function + + return dataset_cache_decorator diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 922d44c8b..30686693e 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -16,13 +16,17 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' - video_frames = DEFAULT_PREFIX + 'video_frames__' + # tags in meta # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' + # video_audio_tags video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' # image_tags image_tags = DEFAULT_PREFIX + 'image_tags__' + # video_frames + video_frames = DEFAULT_PREFIX + 'video_frames__' + # the name of the original file from which this sample was derived. source_file = DEFAULT_PREFIX + 'source_file__' diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py index bc4f67fb4..4018136ec 100644 --- a/tests/ops/filter/test_video_tagging_from_frames_filter.py +++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py @@ -6,6 +6,7 @@ from data_juicer.ops.filter.video_tagging_from_frames_filter import \ VideoTaggingFromFramesFilter from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class VideoTaggingFromFramesFilterTest(DataJuicerTestCaseBase): @@ -21,8 +22,11 @@ def _run_video_tagging_from_frames_filter(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'videos']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py index 9ec3e4d22..d2bbddec2 100644 --- a/tests/ops/mapper/test_image_tagging_mapper.py +++ b/tests/ops/mapper/test_image_tagging_mapper.py @@ -24,6 +24,9 @@ def _run_image_tagging_mapper(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -38,23 +41,26 @@ def test(self): }] tgt_list = [{ 'images': [self.img1_path], - Fields.image_tags: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - Fields.image_tags: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -67,13 +73,15 @@ def test_no_images(self): }] tgt_list = [{ 'images': [], - Fields.image_tags: [[]], + Fields.meta: { + Fields.image_tags: [[]]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -90,23 +98,26 @@ def test_specified_tag_field_name(self): }] tgt_list = [{ 'images': [self.img1_path], - tag_field_name: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + tag_field_name: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - tag_field_name: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + tag_field_name: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - tag_field_name: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + tag_field_name: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper(tag_field_name=tag_field_name) self._run_image_tagging_mapper(op, ds_list, tgt_list) @@ -126,23 +137,26 @@ def test_multi_process(self): }] tgt_list = [{ 'images': [self.img1_path], - Fields.image_tags: [[ - 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', - 'chair', 'pillar', 'comfort', 'side table', 'floor', - 'hardwood floor', 'headboard', 'linen', 'mattress', - 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', - 'stool', 'white', 'window', 'wood floor']], + Fields.meta: { + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']]}, }, { 'images': [self.img2_path], - Fields.image_tags: [[ - 'advertisement', 'back', 'bus', 'car', 'city bus', - 'city street', 'curb', 'decker bus', 'drive', 'license plate', - 'road', 'street scene', 'tour bus', 'travel', 'white']], + Fields.meta: { + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], - Fields.image_tags: [[ - 'alley', 'black', 'building', 'catch', 'person', 'pavement', - 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + Fields.meta: { + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] op = ImageTaggingMapper() self._run_image_tagging_mapper(op, diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py index 8bbf05933..00a376170 100644 --- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -31,8 +31,11 @@ def _run_video_tagging_from_audio_mapper(self, tag_field_name=Fields.video_audio_tags, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc) - res_list = dataset.select_columns([tag_field_name])[tag_field_name] + res_list = dataset.flatten().select_columns([f'{Fields.meta}.{tag_field_name}'])[f'{Fields.meta}.{tag_field_name}'] self.assertEqual(res_list, target_list) def test(self): diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index 4484df754..31fc04c3b 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -25,6 +25,9 @@ def _run_video_tagging_from_frames_mapper(self, target_list, num_proc=1): dataset = Dataset.from_list(source_list) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) dataset = dataset.map(op.process, num_proc=num_proc) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -46,30 +49,33 @@ def test(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -87,16 +93,18 @@ def test_no_video(self): 'text': f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [], - Fields.video_frame_tags: [[]] + Fields.meta: { + Fields.video_frame_tags: [[]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -120,30 +128,33 @@ def test_specified_tag_field_name(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - tag_field_name: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + tag_field_name: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - tag_field_name: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + tag_field_name: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - tag_field_name: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + tag_field_name: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper(tag_field_name=tag_field_name) self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) @@ -165,30 +176,33 @@ def test_uniform(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'cartoon', 'animal', 'anime', 'game', 'screenshot', - 'video game', 'cartoon character', 'robe', 'ray', 'text', - 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']] + Fields.meta: { + Fields.video_frame_tags: [[ + 'cartoon', 'animal', 'anime', 'game', 'screenshot', + 'video game', 'cartoon character', 'robe', 'ray', 'text', + 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', - 'tennis racket', 'blind', 'game controller', 'remote', 'stand', - 'video game', 'Wii controller', 'play', 'baseball uniform', - 'toy', 'green']] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', + 'tennis racket', 'blind', 'game controller', 'remote', 'stand', + 'video game', 'Wii controller', 'play', 'baseball uniform', + 'toy', 'green']]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', - 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', - 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', - 'point' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', + 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', + 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', + 'point' + ]]} }] op = VideoTaggingFromFramesMapper(frame_sampling_method='uniform', frame_num=10) @@ -216,30 +230,33 @@ def test_multi_process(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], - Fields.video_frame_tags: [[ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, @@ -268,44 +285,47 @@ def test_multi_chunk(self): 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', 'videos': [self.vid1_path, self.vid2_path], - Fields.video_frame_tags: - [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ], [ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ]] + Fields.meta: { + Fields.video_frame_tags: + [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], [ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid2_path, self.vid3_path], - Fields.video_frame_tags: [[ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', - 'ball', 'person' - ], [ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }, { 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid1_path, self.vid3_path], - Fields.video_frame_tags: [[ - 'animal', 'ray', 'text', 'writing', 'yellow', 'game', - 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', - 'sky' - ], [ - 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', 'closet', 'computer', 'girl', - 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', - 'selfie', 'stand' - ]] + Fields.meta: { + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]]} }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) From a26dcc7051e42872d9d86a06a1625250757cbbd3 Mon Sep 17 00:00:00 2001 From: Daoyuan Chen <67475544+yxdyc@users.noreply.github.com> Date: Fri, 20 Dec 2024 20:14:34 +0800 Subject: [PATCH 2/7] Update __init__.py --- data_juicer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py index 91ce93bae..7b7173c37 100644 --- a/data_juicer/__init__.py +++ b/data_juicer/__init__.py @@ -1,4 +1,4 @@ -__version__ = '1.0.1' +__version__ = '1.0.2' import os import subprocess From 0125e1f3485de293878a0eded3f9a62f606c37de Mon Sep 17 00:00:00 2001 From: Cathy0908 <30484308+Cathy0908@users.noreply.github.com> Date: Wed, 25 Dec 2024 15:55:48 +0800 Subject: [PATCH 3/7] support ray actor (#511) * support ray actor --- data_juicer/config/config.py | 5 ++ data_juicer/core/ray_data.py | 42 +++++++++--- data_juicer/ops/base_op.py | 6 ++ data_juicer/utils/process_utils.py | 46 +++++++++---- tests/tools/test_process_data.py | 103 ++++++++++++++++++++++++++++- 5 files changed, 178 insertions(+), 24 deletions(-) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 028f3cf79..0585ac8c4 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -464,6 +464,11 @@ def init_setup_from_cfg(cfg: Namespace): # check number of processes np sys_cpu_count = os.cpu_count() + if not cfg.np: + cfg.np = sys_cpu_count + logger.warning( + f'Number of processes `np` is not set, ' + f'set it to cpu count [{sys_cpu_count}] as default value.') if cfg.np > sys_cpu_count: logger.warning(f'Number of processes `np` is set as [{cfg.np}], which ' f'is larger than the cpu count [{sys_cpu_count}]. Due ' diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 646d59a5d..568f88e41 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -122,15 +122,41 @@ def _run_single_op(self, op): batch_size = getattr(op, 'batch_size', 1) if op.is_batched_op() else 1 if isinstance(op, Mapper): - self.data = self.data.map_batches(op.process, - batch_size=batch_size, - batch_format='pyarrow', - num_gpus=num_gpus) + if op.use_cuda(): + op_kwargs = op._op_cfg[op._name] + self.data = self.data.map_batches( + op.__class__, + fn_args=None, + fn_kwargs=None, + fn_constructor_args=None, + fn_constructor_kwargs=op_kwargs, + batch_size=batch_size, + num_gpus=num_gpus, + concurrency=op_proc, + batch_format='pyarrow') + else: + self.data = self.data.map_batches(op.process, + batch_size=batch_size, + batch_format='pyarrow', + num_gpus=num_gpus) elif isinstance(op, Filter): - self.data = self.data.map_batches(op.compute_stats, - batch_size=batch_size, - batch_format='pyarrow', - num_gpus=num_gpus) + if op.use_cuda(): + op_kwargs = op._op_cfg[op._name] + self.data = self.data.map_batches( + op.__class__, + fn_args=None, + fn_kwargs=None, + fn_constructor_args=None, + fn_constructor_kwargs=op_kwargs, + batch_size=batch_size, + num_gpus=num_gpus, + concurrency=op_proc, + batch_format='pyarrow') + else: + self.data = self.data.map_batches(op.compute_stats, + batch_size=batch_size, + batch_format='pyarrow', + num_gpus=num_gpus) if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 39e23d8e9..9e39c50ab 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -288,6 +288,9 @@ def __init_subclass__(cls, **kwargs): f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.') + def __call__(self, *args, **kwargs): + return self.process(*args, **kwargs) + def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) @@ -378,6 +381,9 @@ def __init_subclass__(cls, **kwargs): f'{cls.__name__}. Please implement {method_name}_single ' f'or {method_name}_batched.') + def __call__(self, *args, **kwargs): + return self.compute_stats(*args, **kwargs) + def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() num_samples = len(samples[Fields.stats]) diff --git a/data_juicer/utils/process_utils.py b/data_juicer/utils/process_utils.py index 33d0a9f68..0ebb1c9fc 100644 --- a/data_juicer/utils/process_utils.py +++ b/data_juicer/utils/process_utils.py @@ -57,16 +57,10 @@ def calculate_np(name, """Calculate the optimum number of processes for the given OP""" eps = 1e-9 # about 1 byte - if num_proc is None: - num_proc = psutil.cpu_count() - if use_cuda: + auto_num_proc = None cuda_mem_available = get_min_cuda_memory() / 1024 - op_proc = min( - num_proc, - math.floor(cuda_mem_available / (mem_required + eps)) * - cuda_device_count()) - if use_cuda and mem_required == 0: + if mem_required == 0: logger.warning(f'The required cuda memory of Op[{name}] ' f'has not been specified. ' f'Please specify the mem_required field in the ' @@ -74,15 +68,39 @@ def calculate_np(name, f'out of memory error. You can reference ' f'the mem_required field in the ' f'config_all.yaml file.') - if op_proc < 1.0: - logger.warning(f'The required cuda memory:{mem_required}GB might ' - f'be more than the available cuda memory:' - f'{cuda_mem_available}GB.' - f'This Op[{name}] might ' - f'require more resource to run.') + else: + auto_num_proc = math.floor( + cuda_mem_available / mem_required) * cuda_device_count() + if cuda_mem_available / mem_required < 1.0: + logger.warning( + f'The required cuda memory:{mem_required}GB might ' + f'be more than the available cuda memory:' + f'{cuda_mem_available}GB.' + f'This Op[{name}] might ' + f'require more resource to run.') + + if auto_num_proc and num_proc: + op_proc = min(auto_num_proc, num_proc) + if num_proc > auto_num_proc: + logger.warning( + f'The given num_proc: {num_proc} is greater than ' + f'the value {auto_num_proc} auto calculated based ' + f'on the mem_required of Op[{name}]. ' + f'Set the `num_proc` to {auto_num_proc}.') + elif not auto_num_proc and not num_proc: + op_proc = cuda_device_count() + logger.warning( + f'Both mem_required and num_proc of Op[{name}] are not set.' + f'Set the `num_proc` to number of GPUs {op_proc}.') + else: + op_proc = auto_num_proc if auto_num_proc else num_proc + op_proc = max(op_proc, 1) return op_proc else: + if num_proc is None: + num_proc = psutil.cpu_count() + op_proc = num_proc cpu_available = psutil.cpu_count() mem_available = psutil.virtual_memory().available diff --git a/tests/tools/test_process_data.py b/tests/tools/test_process_data.py index 1c923a87b..27b3b290b 100644 --- a/tests/tools/test_process_data.py +++ b/tests/tools/test_process_data.py @@ -4,19 +4,49 @@ import subprocess import tempfile import unittest +import uuid import yaml from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +def run_in_subprocess(cmd): + try: + with subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) as return_info: + while True: + next_line = return_info.stdout.readline() + return_line = next_line.decode('utf-8', 'ignore').strip() + if return_line == '' and return_info.poll() != None: + break + if return_line != '': + print(return_line) + + err_lines = '' + while True: + next_line = return_info.stderr.readline() + return_line = next_line.decode('utf-8', 'ignore').strip() + if return_line == '' and return_info.poll() != None: + break + if return_line != '': + print(return_line) + err_lines += return_line + '\n' + + return_code = return_info.wait() + if return_code: + raise RuntimeError(err_lines) + except Exception as e: + raise e + + class ProcessDataTest(DataJuicerTestCaseBase): def setUp(self): super().setUp() self.tmp_dir = tempfile.TemporaryDirectory().name - if not osp.exists(self.tmp_dir): - os.makedirs(self.tmp_dir) + os.makedirs(self.tmp_dir, exist_ok=True) def tearDown(self): super().tearDown() @@ -66,5 +96,74 @@ def test_status_code_1(self): self.assertFalse(osp.exists(tmp_out_path)) +class ProcessDataRayTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + + cur_dir = osp.dirname(osp.abspath(__file__)) + self.tmp_dir = osp.join(cur_dir, f'tmp_{uuid.uuid4().hex}') + os.makedirs(self.tmp_dir, exist_ok=True) + + def tearDown(self): + super().tearDown() + + if osp.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + import ray + ray.shutdown() + + def test_ray_image(self): + tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml') + tmp_out_path = osp.join(self.tmp_dir, 'output_0.json') + text_keys = 'text' + + data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), + 'demos', 'data', 'demo-dataset-images.jsonl') + yaml_config = { + 'dataset_path': data_path, + 'executor_type': 'ray', + 'ray_address': 'auto', + 'text_keys': text_keys, + 'image_key': 'images', + 'export_path': tmp_out_path, + 'process': [ + { + 'image_nsfw_filter': { + 'hf_nsfw_model': 'Falconsai/nsfw_image_detection', + 'trust_remote_code': True, + 'score_threshold': 0.5, + 'any_or_all': 'any', + 'mem_required': '8GB' + }, + 'image_aspect_ratio_filter':{ + 'min_ratio': 0.5, + 'max_ratio': 2.0 + } + } + ] + } + + with open(tmp_yaml_file, 'w') as file: + yaml.dump(yaml_config, file) + + run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}') + + self.assertTrue(osp.exists(tmp_out_path)) + + from datasets import load_dataset + jsonl_files = [os.path.join(tmp_out_path, f) \ + for f in os.listdir(tmp_out_path) \ + if f.endswith('.json')] + dataset = load_dataset( + 'json', + data_files={'jsonl': jsonl_files}) + + self.assertEqual(len(dataset['jsonl']), 3) + for item in dataset['jsonl']: + self.assertIn('aspect_ratios', item['__dj__stats__']) + + if __name__ == '__main__': unittest.main() From 36af19321067b106b42d32d3015ccf1bbfd44a21 Mon Sep 17 00:00:00 2001 From: BeachWang <1400012807@pku.edu.cn> Date: Thu, 26 Dec 2024 11:00:12 +0800 Subject: [PATCH 4/7] fix bug in generate_qa_from_example_mapper (#517) * fix format bug * fix test --- data_juicer/ops/mapper/generate_qa_from_examples_mapper.py | 4 ++-- tests/ops/mapper/test_generate_qa_from_examples_mapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py index 0c0d084b3..b962aa51c 100644 --- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py @@ -194,10 +194,10 @@ def format_qa_pairs(qa_example): ]) formatted_examples = ''.join([ - self.example_template.format(qa_pairs=format_qa_pairs(qa_example)) + self.example_template.format(format_qa_pairs(qa_example)) for qa_example in qa_examples ]) - input_prompt = self.input_template.format(examples=formatted_examples) + input_prompt = self.input_template.format(formatted_examples) return input_prompt def parse_output(self, raw_output): diff --git a/tests/ops/mapper/test_generate_qa_from_examples_mapper.py b/tests/ops/mapper/test_generate_qa_from_examples_mapper.py index 2df4f09c0..023394f9b 100644 --- a/tests/ops/mapper/test_generate_qa_from_examples_mapper.py +++ b/tests/ops/mapper/test_generate_qa_from_examples_mapper.py @@ -38,7 +38,7 @@ def test(self): def test_multi_process(self): sampling_params = {'max_new_tokens': 200} - self._run_op(sampling_params=sampling_params, num_proc=3) + self._run_op(sampling_params=sampling_params, num_proc=2) def test_vllm(self): sampling_params = {'max_tokens': 200} From 1554138a78e92db882bb07fe68289c4d05f099c2 Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Thu, 26 Dec 2024 14:39:01 +0800 Subject: [PATCH 5/7] Format conversion tools for post tuning datasets (#514) * + add sharegpt <--> dj format conversion tools * - move multimodal into fmt_conversion * + add basic docs for format conversion tools and post tuning dialog format conversion tools * * rename tools * + add messages <--> dj conversion tools * + add messages <--> dj conversion tools * - reorganize the directory * * rename functions * + add conversion tools for ModelScope-Swift ShareGPT format * + add conversion tools for Alpaca format * * fix typos in doc strings * Update post_tuning_dialog/README.md * Update pos_tuning_dialog/README_ZH.md align with en version * clearly point out the DJ format * clearly point out the DJ format in zh * minor typo fix --------- Co-authored-by: Daoyuan Chen <67475544+yxdyc@users.noreply.github.com> --- README.md | 2 +- README_ZH.md | 2 +- tools/fmt_conversion/README.md | 54 +++++ tools/fmt_conversion/README_ZH.md | 54 +++++ .../{ => fmt_conversion}/multimodal/README.md | 4 +- .../multimodal/README_ZH.md | 4 +- .../absolute_path_to_relative_path.py | 0 .../dj_to_internvid.py | 2 +- .../dj_to_llava.py | 0 .../dj_to_mmc4.py | 0 .../dj_to_msrvtt.py | 2 +- .../dj_to_video_chatgpt.py | 2 +- .../dj_to_wavcaps.py | 0 .../dj_to_youku.py | 2 +- .../internvid_to_dj.py | 4 +- .../llava_to_dj.py | 0 .../mmc4_to_dj.py | 0 .../msrvtt_to_dj.py | 4 +- .../video_chatgpt_to_dj.py | 4 +- .../wavcaps_to_dj.py | 0 .../youku_to_dj.py | 4 +- .../{ => fmt_conversion}/multimodal/utils.py | 0 .../post_tuning_dialog/README.md | 96 ++++++++ .../post_tuning_dialog/README_ZH.md | 98 ++++++++ .../dj_to_alpaca.py | 110 +++++++++ .../dj_to_llama_factory_sharegpt.py | 185 +++++++++++++++ .../dj_to_messages.py | 110 +++++++++ .../dj_to_ms_swift_sharegpt.py | 143 ++++++++++++ .../alpaca_to_dj.py | 130 +++++++++++ .../llama_factory_sharegpt_to_dj.py | 216 ++++++++++++++++++ .../messages_to_dj.py | 108 +++++++++ .../ms_swift_sharegpt_to_dj.py | 168 ++++++++++++++ 32 files changed, 1490 insertions(+), 18 deletions(-) create mode 100644 tools/fmt_conversion/README.md create mode 100644 tools/fmt_conversion/README_ZH.md rename tools/{ => fmt_conversion}/multimodal/README.md (99%) rename tools/{ => fmt_conversion}/multimodal/README_ZH.md (99%) rename tools/{ => fmt_conversion}/multimodal/absolute_path_to_relative_path.py (100%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py (98%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_llava.py (100%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py (100%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py (98%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py (98%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py (100%) rename tools/{ => fmt_conversion}/multimodal/data_juicer_format_to_target_format/dj_to_youku.py (99%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py (97%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/llava_to_dj.py (100%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py (100%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py (96%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py (97%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py (100%) rename tools/{ => fmt_conversion}/multimodal/source_format_to_data_juicer_format/youku_to_dj.py (97%) rename tools/{ => fmt_conversion}/multimodal/utils.py (100%) create mode 100644 tools/fmt_conversion/post_tuning_dialog/README.md create mode 100644 tools/fmt_conversion/post_tuning_dialog/README_ZH.md create mode 100644 tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_alpaca.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_llama_factory_sharegpt.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_messages.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_ms_swift_sharegpt.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/alpaca_to_dj.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/llama_factory_sharegpt_to_dj.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/messages_to_dj.py create mode 100644 tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/ms_swift_sharegpt_to_dj.py diff --git a/README.md b/README.md index 586869b0a..95eba1da2 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ In this new version, we support more features for **multimodal data (including v - [2024-02-05] Our paper has been accepted by SIGMOD'24 industrial track! - [2024-01-10] Discover new horizons in "Data Mixture"—Our second data-centric LLM competition has kicked off! Please visit the competition's [official website](https://tianchi.aliyun.com/competition/entrance/532174) for more information. - [2024-01-05] We release **Data-Juicer v0.1.3** now! -In this new version, we support **more Python versions** (3.8-3.10), and support **multimodal** dataset [converting](tools/multimodal/README.md)/[processing](docs/Operators.md) (Including texts, images, and audios. More modalities will be supported in the future). +In this new version, we support **more Python versions** (3.8-3.10), and support **multimodal** dataset [converting](tools/fmt_conversion/multimodal/README.md)/[processing](docs/Operators.md) (Including texts, images, and audios. More modalities will be supported in the future). Besides, our paper is also updated to [v3](https://arxiv.org/abs/2309.02033). - [2023-10-13] Our first data-centric LLM competition begins! Please visit the competition's official websites, FT-Data Ranker ([1B Track](https://tianchi.aliyun.com/competition/entrance/532157), [7B Track](https://tianchi.aliyun.com/competition/entrance/532158)), for more information. diff --git a/README_ZH.md b/README_ZH.md index 42612964a..6ba358b37 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -47,7 +47,7 @@ Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多 - [2024-02-05] 我们的论文被SIGMOD'24 industrial track接收! - [2024-01-10] 开启“数据混合”新视界——第二届Data-Juicer大模型数据挑战赛已经正式启动!立即访问[竞赛官网](https://tianchi.aliyun.com/competition/entrance/532174),了解赛事详情。 - [2024-01-05] **Data-Juicer v0.1.3** 版本发布了。 -在这个新版本中,我们支持了**更多Python版本**(3.8-3.10),同时支持了**多模态**数据集的[转换](tools/multimodal/README_ZH.md)和[处理](docs/Operators_ZH.md)(包括文本、图像和音频。更多模态也将会在之后支持)! +在这个新版本中,我们支持了**更多Python版本**(3.8-3.10),同时支持了**多模态**数据集的[转换](tools/fmt_conversion/multimodal/README_ZH.md)和[处理](docs/Operators_ZH.md)(包括文本、图像和音频。更多模态也将会在之后支持)! 此外,我们的论文也更新到了[第三版](https://arxiv.org/abs/2309.02033) 。 - [2023-10-13] 我们的第一届以数据为中心的 LLM 竞赛开始了! 请访问大赛官网,FT-Data Ranker([1B赛道](https://tianchi.aliyun.com/competition/entrance/532157) 、[7B赛道](https://tianchi.aliyun.com/competition/entrance/532158) ) ,了解更多信息。 diff --git a/tools/fmt_conversion/README.md b/tools/fmt_conversion/README.md new file mode 100644 index 000000000..38629ef35 --- /dev/null +++ b/tools/fmt_conversion/README.md @@ -0,0 +1,54 @@ +# Format Conversion Tools + +Here Data-Juicer provides tens of format conversion tools for diverse datasets, including multimodal datasets, post tuning datasets, and so on. +These tools help to convert the dataset in the original format to a unified, intermediate format used in Data-Juicer, which we call it "DJ format". +An overview of DJ format is shown below: + +```python +{ + // >>> core contents: texts, dialogs, ... + "text": "xxx", + "query": "xxx", + "response": "xxx", + ...... + // <<< core contents + + // >>> extra data contents: multimodal data paths, ... + "images": [ + "path/to/the/image/of/antarctica_snowfield", + "path/to/the/image/of/antarctica_map", + "path/to/the/image/of/europe_map" + ], + "audios": [ + "path/to/the/audio/of/sound_of_waves_in_Antarctic_Ocean" + ], + "videos": [ + "path/to/the/video/of/remote_sensing_view_of_antarctica" + ], + // <<< extra data contents + + // >>> meta infos and stats, which could be primitive or produced by Data-Juicer + "meta": { + "src": "customized", + "version": "0.1", + "author": "xxx" + }, + "stats": { + "lang": "en", + "image_widths": [224, 336, 512], + ... + }, + // <<< meta infos and stats +} +``` + +There are about three parts in DJ format: +1. Core contents: such as texts in the pretraining dataset of LLMs, dialogs in the post tuning dataset, and so on. They are directly related to the training or fine-tuning procedures in the downstream usage of the dataset. +2. Extra data contents: such as the paths to the multimodal data in the multimodal datasets. They are organized as path lists. +3. Meta infos & Stats: such as version or source information of the dataset that are inherent from the original dataset, or category tags and stats produced by OPs of Data-Juicer. + +The 2nd and 3rd parts of them are common used and organized in nearly the same structures for diverse datasets. +As a contrast, the 1st part, which is the core contents, might be quite different for different kinds of datasets. +Here are the corresponding documents for different datasets that introduce more details about this part: +- [Multimodal datasets](multimodal/README.md) +- [Post Tuning](post_tuning_dialog/README.md) \ No newline at end of file diff --git a/tools/fmt_conversion/README_ZH.md b/tools/fmt_conversion/README_ZH.md new file mode 100644 index 000000000..5ab13fc9c --- /dev/null +++ b/tools/fmt_conversion/README_ZH.md @@ -0,0 +1,54 @@ +# 格式转换工具 + +在这里,Data-Juicer 为各式各样的数据集提供了十数种格式转换工具,包括多模态数据集,后微调数据集等等。 +这些工具帮助我们将原始格式的数据集转换为 Data-Juicer 使用的一种统一的、中间的格式表示,我们将其称为"DJ 格式"。 +DJ 格式的一个示例如下所示: + +```python +{ + // >>> 核心内容:文本,对话,...... + "text": "xxx", + "query": "xxx", + "response": "xxx", + ...... + // <<< 核心内容 + + // >>> 额外数据内容:多模态数据路径,...... + "images": [ + "path/to/the/image/of/antarctica_snowfield", + "path/to/the/image/of/antarctica_map", + "path/to/the/image/of/europe_map" + ], + "audios": [ + "path/to/the/audio/of/sound_of_waves_in_Antarctic_Ocean" + ], + "videos": [ + "path/to/the/video/of/remote_sensing_view_of_antarctica" + ], + // <<< 额外数据内容 + + // >>> meta 信息和 stats,它们可能是数据集原生的,也可以由 Data-Juicer 产出 + "meta": { + "src": "customized", + "version": "0.1", + "author": "xxx" + }, + "stats": { + "lang": "en", + "image_widths": [224, 336, 512], + ... + }, + // <<< meta 信息和 stats +} +``` + +在 DJ 格式中大概包括三个部分: +1. 核心内容:例如 LLM 的预训练数据集中的文本内容,后微调数据集中的对话内容等。它们与数据集的下游使用的训练或者微调过程直接相关。 +2. 额外数据内容:例如多模态数据集中的多模态数据路径。它们被组织为路径列表。 +3. Meta 信息和 Stats:例如从原始数据集中继承而来的数据集版本或来源信息,或者由 Data-Juicer 的算子产出的类别 tags 和 stats 信息。 + +其中,第 2 和第 3 部分对于不同的数据集来说是通用的,而且都会被组织为几乎相同的结构。 +作为对比,第 1 部分,也就是核心内容部分,对于各种数据集来说可能非常不同。 +这里列举了针对不同种类数据集介绍这个部分更多细节的对应的文档: +- [多模态数据集](multimodal/README_ZH.md) +- [后微调数据集](post_tuning_dialog/README_ZH.md) \ No newline at end of file diff --git a/tools/multimodal/README.md b/tools/fmt_conversion/multimodal/README.md similarity index 99% rename from tools/multimodal/README.md rename to tools/fmt_conversion/multimodal/README.md index 60ff084b8..a4a15aac6 100644 --- a/tools/multimodal/README.md +++ b/tools/fmt_conversion/multimodal/README.md @@ -10,7 +10,7 @@ Both input and output of this utility conform to Data-Juicer's data format. If y To learn more about the usage of the absolute to relative path conversion tool, you can execute the following command: ```shell -python tools/multimodal/absolute_path_to_relative_path.py --help +python tools/fmt_conversion/multimodal/absolute_path_to_relative_path.py --help ``` ## Dataset Format Conversion @@ -94,7 +94,7 @@ For all tools, you can run the following command to find out the usage of them: ```shell # e.g. llava_to_dj.py -python tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help +python tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help ``` Before using these tools, you might need to take a glance at the reference diff --git a/tools/multimodal/README_ZH.md b/tools/fmt_conversion/multimodal/README_ZH.md similarity index 99% rename from tools/multimodal/README_ZH.md rename to tools/fmt_conversion/multimodal/README_ZH.md index 07afd10cb..3d28633a4 100644 --- a/tools/multimodal/README_ZH.md +++ b/tools/fmt_conversion/multimodal/README_ZH.md @@ -10,7 +10,7 @@ 可以运行以下命令来了解绝对路径转化相对路径工具的详细用法: ```shell -python tools/multimodal/absolute_path_to_relative_path.py --help +python tools/fmt_conversion/multimodal/absolute_path_to_relative_path.py --help ``` ## 数据集格式转换 @@ -86,7 +86,7 @@ python tools/multimodal/absolute_path_to_relative_path.py --help ```shell # 例如:llava_to_dj.py -python tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help +python tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help ``` 在使用这些工具之前,您可能需要查看上表中每个格式的参考资料,以更好地了解详细的格式信息,并理解每个工具的参数含义。 diff --git a/tools/multimodal/absolute_path_to_relative_path.py b/tools/fmt_conversion/multimodal/absolute_path_to_relative_path.py similarity index 100% rename from tools/multimodal/absolute_path_to_relative_path.py rename to tools/fmt_conversion/multimodal/absolute_path_to_relative_path.py diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py similarity index 98% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py index 4c46f4676..434c31879 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py +++ b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py @@ -35,7 +35,7 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import remove_dj_special_tokens +from tools.fmt_conversion.multimodal.utils import remove_dj_special_tokens def main( diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_llava.py similarity index 100% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_llava.py diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py similarity index 100% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py similarity index 98% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py index 4e3e85e32..5cc8c0817 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py +++ b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_msrvtt.py @@ -44,7 +44,7 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import remove_dj_special_tokens +from tools.fmt_conversion.multimodal.utils import remove_dj_special_tokens def main( diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py similarity index 98% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py index aa3771c4c..18f1206db 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py +++ b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py @@ -38,7 +38,7 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import remove_dj_special_tokens +from tools.fmt_conversion.multimodal.utils import remove_dj_special_tokens def main( diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py similarity index 100% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_wavcaps.py diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_youku.py similarity index 99% rename from tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py rename to tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_youku.py index e3cb9671c..6b4831b52 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py +++ b/tools/fmt_conversion/multimodal/data_juicer_format_to_target_format/dj_to_youku.py @@ -59,7 +59,7 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import remove_dj_special_tokens +from tools.fmt_conversion.multimodal.utils import remove_dj_special_tokens def main( diff --git a/tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py similarity index 97% rename from tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py index 1b2ee2caa..5e52e5b02 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py +++ b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py @@ -42,8 +42,8 @@ from data_juicer.utils.file_utils import add_suffix_to_filename from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, timecode_string_to_seconds) -from tools.multimodal.utils import (check_args_load_to_dj_data, - convert_text_to_dj) +from tools.fmt_conversion.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) def main( diff --git a/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/llava_to_dj.py similarity index 100% rename from tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/llava_to_dj.py diff --git a/tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py similarity index 100% rename from tools/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/mmc4_to_dj.py diff --git a/tools/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py similarity index 96% rename from tools/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py index b42d8e608..0bc25f140 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py +++ b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/msrvtt_to_dj.py @@ -43,8 +43,8 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import (check_args_load_to_dj_data, - convert_text_to_dj) +from tools.fmt_conversion.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) def main( diff --git a/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py similarity index 97% rename from tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py index d05d64fc5..36f0e6473 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py +++ b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py @@ -37,8 +37,8 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import (check_args_load_to_dj_data, - convert_text_to_dj) +from tools.fmt_conversion.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) @logger.catch(reraise=True) diff --git a/tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py similarity index 100% rename from tools/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/wavcaps_to_dj.py diff --git a/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/youku_to_dj.py similarity index 97% rename from tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py rename to tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/youku_to_dj.py index 092a03958..15a570c55 100644 --- a/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py +++ b/tools/fmt_conversion/multimodal/source_format_to_data_juicer_format/youku_to_dj.py @@ -58,8 +58,8 @@ from tqdm import tqdm from data_juicer.utils.mm_utils import SpecialTokens -from tools.multimodal.utils import (check_args_load_to_dj_data, - convert_text_to_dj) +from tools.fmt_conversion.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) @logger.catch(reraise=True) diff --git a/tools/multimodal/utils.py b/tools/fmt_conversion/multimodal/utils.py similarity index 100% rename from tools/multimodal/utils.py rename to tools/fmt_conversion/multimodal/utils.py diff --git a/tools/fmt_conversion/post_tuning_dialog/README.md b/tools/fmt_conversion/post_tuning_dialog/README.md new file mode 100644 index 000000000..4d06a496f --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/README.md @@ -0,0 +1,96 @@ +# Post Tuning Tools + +For post tuning formats, we mainly consider 4 formats to support [ModelScope-Swift](https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md) and [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README.md). + +- Swift's Messages format (Very similar to the LLaMA-Factory's ShareGPT format, with different key names): + +```python +{ + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + }, + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] +} +``` + +- Swift's ShareGPT format: + +```python +{ + "system": "", + "conversation": [ + { + "human": "", + "assistant": "" + }, + { + "human": "", + "assistant": "" + } + ] +} +``` + +- Alpaca format (used in the same definition in Swift and LLaMA-Factory): + +```python +{ + "system": "", + "instruction": "", + "input": "", + "output": "" +} +``` + +- Swift's Query-Response format: + +```python +{ + "system": "", + "query": "", + "response": "", + "history": [ + [ + "", + "" + ] + ] +} +``` + +In Data-Juicer, we pre-set fields to align with the last two formats (Alpaca and Query-Response), which serves as our intermediate format for post-tuning dialog datasets. Correspondingly, we provide several tools to convert datasets in other formats to the following DJ format and vice versa. + +- DJ default format for post-tuning OPs: + +```python +{ + "system": "", + "instruction": "", + "query": "", + "response": "", + "history": [ + [ + "", + "" + ] + ] +} +``` diff --git a/tools/fmt_conversion/post_tuning_dialog/README_ZH.md b/tools/fmt_conversion/post_tuning_dialog/README_ZH.md new file mode 100644 index 000000000..ad73caba6 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/README_ZH.md @@ -0,0 +1,98 @@ +# 后微调工具 + +对于 后微调 数据格式,我们主要考虑 4 种格式来覆盖支持 [ModelScope-Swift](https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md) 和 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README.md) : + +- Swift的 Messages 格式(与LLaMA-Factory的 ShareGPT 格式几乎一致,采用了略微不同的key字段命名): + +```python +{ + "messages": [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + }, + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] +} +``` + +- Swift的 ShareGPT 格式: + +```python +{ + "system": "", + "conversation": [ + { + "human": "", + "assistant": "" + }, + { + "human": "", + "assistant": "" + } + ] +} +``` + +- Alpaca 格式 (在Swift和LLaMA-Factory中定义一致): + +```python +{ + "system": "", + "instruction": "", + "input": "", + "output": "" +} +``` + +- Swift的Query-Response 格式: + +```python +{ + "system": "", + "query": "", + "response": "", + "history": [ + [ + "", + "" + ] + ] +} +``` + +在 Data-Juicer 中,我们预设了一些字段来对齐最后两种格式(Alpaca和Query-Response),并将如下格式作为 后微调对话 数据集的统一中间表示。 +相应地,我们提供了若干内置工具将其他格式的数据集转换为 DJ 格式以及反向转换。 + + +- DJ的多轮对话缺省格式(DJ post-tuning算子实现时假设基于该格式进行字段解析和处理): + +```python +{ + "system": "", + "instruction": "", + "query": "", + "response": "", + "history": [ + [ + "", + "" + ] + ] +} +``` diff --git a/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_alpaca.py b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_alpaca.py new file mode 100644 index 000000000..f79fd0c43 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_alpaca.py @@ -0,0 +1,110 @@ +# This tool is used to convert dataset in Data-Juicer format to a +# target dataset in Alpaca-like format. +# +# Data-Juicer format (query-response format): +# [ +# { +# "system": "", +# "instruction": "", +# "query": "", +# "response": "", +# "history": [ +# ["human instruction in the first round (optional)", "model response in the first round (optional)"], # noqa: E501 +# ["human instruction in the second round (optional)", "model response in the second round (optional)"] # noqa: E501 +# ], +# }, +# ... +# ] +# +# Corresponding Alpaca format: +# [ +# { +# "system": "", +# "instruction": "", +# "input": "", +# "output": "", +# "history": [ +# ["human instruction in the first round (optional)", "model response in the first round (optional)"], # noqa: E501 +# ["human instruction in the second round (optional)", "model response in the second round (optional)"] # noqa: E501 +# ], +# }, +# ...... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md +# https://github.com/hiyouga/LLaMA-Factory/blob/v0.9.1/data/README.md#alpaca-format + +import json +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def dj_to_alpaca( + sample, + input_key: str = 'input', + output_key: str = 'output', +): + modified_keys = {'query', 'response'} + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys and sample[key] + } + + # key mapping + if 'query' in sample: + new_sample[input_key] = sample['query'] + if 'response' in sample: + new_sample[output_key] = sample['response'] + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + input_key: str = 'input', + output_key: str = 'output', +): + """ + Convert a Data-Juicer dataset to the Alpaca-like format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param input_key: the field key to store the query sentence from human. + :param output_key: the field key to store the response sentence from + assistant. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.json'): + raise ValueError('Only support "json" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + samples = [] + with jl.open(src_ds_path, 'r') as reader: + for sample in tqdm(reader): + converted_sample = dj_to_alpaca(sample, + input_key=input_key, + output_key=output_key) + samples.append(converted_sample) + + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + json.dump(samples, open(tgt_ds_path, 'w', encoding='utf-8')) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_llama_factory_sharegpt.py b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_llama_factory_sharegpt.py new file mode 100644 index 000000000..c72dcbb84 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_llama_factory_sharegpt.py @@ -0,0 +1,185 @@ +# This tool is used to convert dataset in Data-Juicer format to a +# target dataset in LLaMA-Factory ShareGPT-like format. +# +# Data-Juicer format (query-response format): +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "query": "Is the bus driving down the street or pulled off to the side?", +# "response": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# "history": [ +# [ +# "\nWhat are the colors of the bus in the image?", +# "The bus in the image is white and red." +# ], +# [ +# "What feature can be seen on the back of the bus?", +# "The back of the bus features an advertisement." +# ], +# ] +# }, +# ... +# ] +# +# Corresponding LLaMA-Factory ShareGPT format: +# - usually in json format +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "conversations": [ +# { +# "from": "human", +# "value": "\nWhat are the colors of the bus in the image?" +# }, +# { +# "from": "gpt", +# "value": "The bus in the image is white and red." +# }, +# { +# "from": "human", +# "value": "What feature can be seen on the back of the bus?" +# }, +# { +# "from": "gpt", +# "value": "The back of the bus features an advertisement." +# }, +# { +# "from": "human", +# "value": "Is the bus driving down the street or pulled off to the side?" # noqa: E501 +# }, +# { +# "from": "gpt", +# "value": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# } +# ] +# }, +# ... +# ] +# +# Reference: +# https://github.com/hiyouga/LLaMA-Factory/blob/v0.9.1/data/README.md#sharegpt-format + +import json +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def dj_to_llama_factory_sharegpt( + sample, + conversations_key: str = 'conversations', + from_key: str = 'from', + value_key: str = 'value', + human_role: str = 'user', + assistant_role: str = 'assistant', + system_role: str = 'system', + instruction_role: str = 'instruction', +): + modified_keys = {'query', 'response', 'history', 'system', 'instruction'} + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys and sample[key] + } + + # construct conversations + conversations = [] + # add system prompt and instruction + if 'system' in sample and sample['system'] != '': + conversations.append({ + from_key: system_role, + value_key: sample['system'] + }) + if 'instruction' in sample and sample['instruction'] != '': + conversations.append({ + from_key: instruction_role, + value_key: sample['instruction'] + }) + + # add dialogs + for query, response in sample['history']: + conversations.append({ + from_key: human_role, + value_key: query, + }) + conversations.append({ + from_key: assistant_role, + value_key: response, + }) + conversations.append({ + from_key: human_role, + value_key: sample['query'], + }) + if 'response' in sample and sample['response'] != '': + conversations.append({ + from_key: assistant_role, + value_key: sample['response'], + }) + + # get the result sample + new_sample[conversations_key] = conversations + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + conversations_key: str = 'conversations', + from_key: str = 'from', + value_key: str = 'value', + human_role: str = 'user', + assistant_role: str = 'assistant', + system_role: str = 'system', + instruction_role: str = 'instruction', +): + """ + Convert a Data-Juicer dataset to the LLaMA-Factory ShareGPT-like format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param conversations_key: the field key to store conversions. + :param from_key: the field key to store the sentence from. + :param value_key: the field key to store the sentence content. + :param human_role: the role to store the human prompt. + :param assistant_role: the role to store the instruction content. + :param system_role: the role to store the system prompt. + :param instruction_role: the role to store the instruction content. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.json'): + raise ValueError('Only support "json" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + samples = [] + with jl.open(src_ds_path, 'r') as reader: + for sample in tqdm(reader): + converted_sample = dj_to_llama_factory_sharegpt( + sample, + conversations_key=conversations_key, + from_key=from_key, + value_key=value_key, + human_role=human_role, + assistant_role=assistant_role, + system_role=system_role, + instruction_role=instruction_role) + samples.append(converted_sample) + + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + json.dump(samples, open(tgt_ds_path, 'w', encoding='utf-8')) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_messages.py b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_messages.py new file mode 100644 index 000000000..af52b2c87 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_messages.py @@ -0,0 +1,110 @@ +# This tool is used to convert dataset in Data-Juicer format to a +# target dataset in ModelScope-Swift Messages-like format. +# +# Data-Juicer format (query-response format): +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "query": "Is the bus driving down the street or pulled off to the side?", +# "response": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# "history": [ +# [ +# "\nWhat are the colors of the bus in the image?", +# "The bus in the image is white and red." +# ], +# [ +# "What feature can be seen on the back of the bus?", +# "The back of the bus features an advertisement." +# ], +# ] +# }, +# ... +# ] +# +# Corresponding ModelScope-Swift Messages format: +# - usually in json format +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "messages": [ +# { +# "role": "human", +# "content": "\nWhat are the colors of the bus in the image?" +# }, +# { +# "role": "gpt", +# "content": "The bus in the image is white and red." +# }, +# { +# "role": "human", +# "content": "What feature can be seen on the back of the bus?" +# }, +# { +# "role": "gpt", +# "content": "The back of the bus features an advertisement." +# }, +# { +# "role": "human", +# "content": "Is the bus driving down the street or pulled off to the side?" # noqa: E501 +# }, +# { +# "role": "gpt", +# "content": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# } +# ] +# }, +# ... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md +# +# This format is nearly the same as the LLaMA-Factory ShareGPT format, so we +# reuse the code in that conversion tools. + +import dj_to_llama_factory_sharegpt +import fire +from loguru import logger + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + messages_key: str = 'messages', + role_key: str = 'role', + content_key: str = 'content', + human_role: str = 'user', + assistant_role: str = 'assistant', + system_role: str = 'system', + instruction_role: str = 'instruction', +): + """ + Convert a Data-Juicer query-response dataset to the ModelScope-Swift + Message format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param messages_key: the field key to store messages. + :param role_key: the field key to store the sentence from. + :param content_key: the field key to store the sentence content. + :param human_role: the role to store the human prompt. + :param assistant_role: the role to store the instruction content. + :param system_role: the role to store the system prompt. + :param instruction_role: the role to store the instruction content. + """ + dj_to_llama_factory_sharegpt.main( + src_ds_path, + tgt_ds_path, + conversations_key=messages_key, + from_key=role_key, + value_key=content_key, + human_role=human_role, + assistant_role=assistant_role, + system_role=system_role, + instruction_role=instruction_role, + ) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_ms_swift_sharegpt.py b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_ms_swift_sharegpt.py new file mode 100644 index 000000000..d0d6b6b62 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/data_juicer_format_to_target_format/dj_to_ms_swift_sharegpt.py @@ -0,0 +1,143 @@ +# This tool is used to convert dataset in Data-Juicer format to a +# target dataset in ModelScope-Swift ShareGPT format. +# +# Data-Juicer format (query-response format): +# [ +# { +# "system": "", +# "query": "", +# "response": "" +# "history": [ +# [ +# "", +# "" +# ], +# ] +# }, +# ... +# ] +# +# Corresponding ModelScope-Swift ShareGPT format: +# [ +# { +# "system": "", +# "conversation": [ +# { +# "human": "", +# "assistant": "" +# }, +# { +# "human": "", +# "assistant": "" +# } +# ] +# }, +# ...... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md + +import json +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def dj_to_ms_swift_sharegpt( + sample, + conversation_key: str = 'conversation', + human_key: str = 'human', + assistant_key: str = 'assistant', + system_key: str = 'system', + instruction_key: str = 'instruction', +): + modified_keys = {'query', 'response', 'history', 'system', 'instruction'} + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys + } + + # find system prompt and instruction + if 'system' in sample: + new_sample[system_key] = sample['system'] + if 'instruction' in sample: + new_sample[instruction_key] = sample['instruction'] + + # construct conversation + conversation = [] + # add dialogs + for query, response in sample['history']: + conversation.append({ + human_key: query, + assistant_key: response, + }) + conversation.append({ + human_key: + sample['query'], + assistant_key: + sample['response'] if 'response' in sample else '' + }) + + new_sample[conversation_key] = conversation + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + conversation_key: str = 'conversation', + human_key: str = 'human', + assistant_key: str = 'assistant', + system_key: str = 'system', + instruction_key: str = 'instruction', +): + """ + Convert a Data-Juicer query-response dataset to the ModelScope-Swift + ShareGPT-like format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param conversation_key: the field key to store conversions. + :param human_key: the field key to store the sentence from human. + :param assistant_key: the field key to store the sentence from assistant. + :param system_key: the field key to store the system prompt. + :param instruction_key: the field key to store the instruction content. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.json'): + raise ValueError('Only support "json" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + # load dataset + samples = [] + with jl.open(src_ds_path, 'r') as reader: + for sample in tqdm(reader): + converted_sample = dj_to_ms_swift_sharegpt( + sample, + conversation_key=conversation_key, + human_key=human_key, + assistant_key=assistant_key, + system_key=system_key, + instruction_key=instruction_key) + samples.append(converted_sample) + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + json.dump(samples, open(tgt_ds_path, 'w', encoding='utf-8')) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/alpaca_to_dj.py b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/alpaca_to_dj.py new file mode 100644 index 000000000..cdbd64345 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/alpaca_to_dj.py @@ -0,0 +1,130 @@ +# This tool is used to convert dataset in Alpaca format to a +# target dataset in Data-Juicer query-response format. +# +# Alpaca format: +# [ +# { +# "system": "", +# "instruction": "", +# "input": "", +# "output": "", +# "history": [ +# ["human instruction in the first round (optional)", "model response in the first round (optional)"], # noqa: E501 +# ["human instruction in the second round (optional)", "model response in the second round (optional)"] # noqa: E501 +# ], +# }, +# ...... +# ] +# +# Corresponding Data-Juicer format (query-response format): +# [ +# { +# "system": "", +# "instruction": "", +# "query": "", +# "response": "", +# "history": [ +# ["human instruction in the first round (optional)", "model response in the first round (optional)"], # noqa: E501 +# ["human instruction in the second round (optional)", "model response in the second round (optional)"] # noqa: E501 +# ], +# }, +# ... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md +# https://github.com/hiyouga/LLaMA-Factory/blob/v0.9.1/data/README.md#alpaca-format + +import json +import os +from typing import List, Union + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def alpaca_to_dj( + sample, + input_key: str = 'input', + output_key: str = 'output', + multimodal_keys: Union[str, List[str]] = None, +): + modified_keys = {input_key, output_key} + if multimodal_keys: + modified_keys = modified_keys.union(set(multimodal_keys)) + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys + } + + # key mapping for input and output + if input_key in sample: + new_sample['query'] = sample[input_key] + if output_key in sample: + new_sample['response'] = sample[output_key] + + # update multimodal data + if multimodal_keys: + for mm_key in multimodal_keys: + if not isinstance(sample[mm_key], list): + new_sample[mm_key] = [sample[mm_key]] + else: + new_sample[mm_key] = sample[mm_key] + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + input_key: str = 'input', + output_key: str = 'output', + multimodal_keys: Union[str, List[str]] = None, +): + """ + Convert an Alpaca-like dataset to the Data-Juicer query-response format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param input_key: the field key to store the query sentence from human. + :param output_key: the field key to store the response sentence from + assistant. + :param multimodal_keys: optional keys to store multimodal data. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.jsonl'): + raise ValueError('Only support "jsonl" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + if isinstance(multimodal_keys, str): + multimodal_keys = [multimodal_keys] + + # load Alpaca dataset + logger.info('Loading original dataset.') + src_ds = json.load(open(src_ds_path, 'r', encoding='utf-8')) + logger.info(f'Load [{len(src_ds)}] samples.') + + with jl.open(tgt_ds_path, 'w') as writer: + for sample in tqdm(src_ds): + converted_sample = alpaca_to_dj(sample, + input_key=input_key, + output_key=output_key, + multimodal_keys=multimodal_keys) + writer.write(converted_sample) + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/llama_factory_sharegpt_to_dj.py b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/llama_factory_sharegpt_to_dj.py new file mode 100644 index 000000000..2f25ad7c8 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/llama_factory_sharegpt_to_dj.py @@ -0,0 +1,216 @@ +# This tool is used to convert dataset in LLaMA-Factory ShareGPT format to a +# target dataset in Data-Juicer query-response format. +# +# LLaMA-Factory ShareGPT format: +# - usually in json format +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "conversations": [ +# { +# "from": "human", +# "value": "\nWhat are the colors of the bus in the image?" +# }, +# { +# "from": "gpt", +# "value": "The bus in the image is white and red." +# }, +# { +# "from": "human", +# "value": "What feature can be seen on the back of the bus?" +# }, +# { +# "from": "gpt", +# "value": "The back of the bus features an advertisement." +# }, +# { +# "from": "human", +# "value": "Is the bus driving down the street or pulled off to the side?" # noqa: E501 +# }, +# { +# "from": "gpt", +# "value": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# } +# ] +# }, +# ... +# ] +# +# Corresponding Data-Juicer format (query-response format): +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "query": "Is the bus driving down the street or pulled off to the side?", +# "response": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# "history": [ +# [ +# "\nWhat are the colors of the bus in the image?", +# "The bus in the image is white and red." +# ], +# [ +# "What feature can be seen on the back of the bus?", +# "The back of the bus features an advertisement." +# ], +# ] +# }, +# ... +# ] +# +# Reference: +# https://github.com/hiyouga/LLaMA-Factory/blob/v0.9.1/data/README.md#sharegpt-format + +import json +import os +from typing import List, Union + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def llama_factory_sharegpt_to_dj( + sample, + conversations_key: str = 'conversations', + from_key: str = 'from', + value_key: str = 'value', + system_role: str = 'system', + instruction_role: str = 'instruction', + multimodal_keys: Union[str, List[str]] = None, +): + modified_keys = {conversations_key} + if multimodal_keys: + modified_keys = modified_keys.union(set(multimodal_keys)) + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys + } + + # conversations to query, response, history + conversations = sample[conversations_key] + # find system prompt and instruction + system_prompt = '' + instruction = '' + remove_idx = [] + for i, conv in enumerate(conversations): + if conv[from_key] == system_role: + if system_prompt != '': + raise NotImplementedError( + 'DO NOT support more than 1 system prompts in the ' + 'conversation for now.') + system_prompt = conv[value_key] + remove_idx.append(i) + elif conv[from_key] == instruction_role: + if instruction != '': + raise NotImplementedError( + 'DO NOT support more than 1 instructions in the ' + 'conversation for now.') + instruction = conv[value_key] + remove_idx.append(i) + if len(remove_idx) > 0: + for i in remove_idx: + conversations.pop(i) + + # reconstruct conversations + conv_num = len(conversations) + if conv_num == 0: + query = '' + response = '' + history = [] + elif conv_num % 2 == 0: + # the last 2 sentences are query and response + query = conversations[-2][value_key] + response = conversations[-1][value_key] + history = [[ + conversations[i][value_key], conversations[i + 1][value_key] + ] for i in range(0, conv_num - 2, 2)] + else: + # the last 1 sentence is query and response is empty + query = conversations[-1][value_key] + response = '' + history = [[ + conversations[i][value_key], conversations[i + 1][value_key] + ] for i in range(0, conv_num - 1, 2)] + + # get the result sample + new_sample.update({ + 'system': system_prompt, + 'instruction': instruction, + 'query': query, + 'response': response, + 'history': history, + }) + + # update multimodal data + if multimodal_keys: + for mm_key in multimodal_keys: + if not isinstance(sample[mm_key], list): + new_sample[mm_key] = [sample[mm_key]] + else: + new_sample[mm_key] = sample[mm_key] + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + conversations_key: str = 'conversations', + from_key: str = 'from', + value_key: str = 'value', + system_role: str = 'system', + instruction_role: str = 'instruction', + multimodal_keys: Union[str, List[str]] = None, +): + """ + Convert a LLaMA-Factory ShareGPT-like dataset to the Data-Juicer + query-response format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param conversations_key: the field key to store conversions. + :param from_key: the field key to store the sentence from. + :param value_key: the field key to store the sentence content. + :param system_role: the field key to store the system prompt. + :param instruction_role: the field key to store the instruction content. + :param multimodal_keys: optional keys to store multimodal data. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.jsonl'): + raise ValueError('Only support "jsonl" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + if isinstance(multimodal_keys, str): + multimodal_keys = [multimodal_keys] + + # load dataset + logger.info('Loading original dataset.') + src_ds = json.load(open(src_ds_path, 'r', encoding='utf-8')) + logger.info(f'Load [{len(src_ds)}] samples.') + + with jl.open(tgt_ds_path, 'w') as writer: + for sample in tqdm(src_ds): + converted_sample = llama_factory_sharegpt_to_dj( + sample, + conversations_key=conversations_key, + from_key=from_key, + value_key=value_key, + system_role=system_role, + instruction_role=instruction_role, + multimodal_keys=multimodal_keys) + writer.write(converted_sample) + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/messages_to_dj.py b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/messages_to_dj.py new file mode 100644 index 000000000..1f5e74071 --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/messages_to_dj.py @@ -0,0 +1,108 @@ +# This tool is used to convert dataset in ModelScope-Swift Messages format to a +# target dataset in Data-Juicer query-response format. +# +# ModelScope-Swift Messages format: +# - usually in json format +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "messages": [ +# { +# "role": "human", +# "content": "\nWhat are the colors of the bus in the image?" +# }, +# { +# "role": "gpt", +# "content": "The bus in the image is white and red." +# }, +# { +# "role": "human", +# "content": "What feature can be seen on the back of the bus?" +# }, +# { +# "role": "gpt", +# "content": "The back of the bus features an advertisement." +# }, +# { +# "role": "human", +# "content": "Is the bus driving down the street or pulled off to the side?" # noqa: E501 +# }, +# { +# "role": "gpt", +# "content": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# } +# ] +# }, +# ... +# ] +# +# Corresponding Data-Juicer format (query-response format): +# [ +# { +# "images": ["coco/train2017/000000033471.jpg"], +# "query": "Is the bus driving down the street or pulled off to the side?", +# "response": "The bus is driving down the street, which is crowded with people and other vehicles." # noqa: E501 +# "history": [ +# [ +# "\nWhat are the colors of the bus in the image?", +# "The bus in the image is white and red." +# ], +# [ +# "What feature can be seen on the back of the bus?", +# "The back of the bus features an advertisement." +# ], +# ] +# }, +# ... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md +# +# This format is nearly the same as the LLaMA-Factory ShareGPT format, so we +# reuse the code in that conversion tools. + +from typing import List, Union + +import fire +import llama_factory_sharegpt_to_dj +from loguru import logger + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + messages_key: str = 'messages', + role_key: str = 'role', + content_key: str = 'content', + system_role: str = 'system', + instruction_role: str = 'instruction', + multimodal_keys: Union[str, List[str]] = None, +): + """ + Convert a Messages-like dataset to the Data-Juicer query-response format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param messages_key: the field key to store messages. + :param role_key: the field key to store the sentence from. + :param content_key: the field key to store the sentence content. + :param system_role: the field key to store the system prompt. + :param instruction_role: the field key to store the instruction content. + :param multimodal_keys: optional keys to store multimodal data. + """ + llama_factory_sharegpt_to_dj.main( + src_ds_path, + tgt_ds_path, + conversations_key=messages_key, + from_key=role_key, + value_key=content_key, + system_role=system_role, + instruction_role=instruction_role, + multimodal_keys=multimodal_keys, + ) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/ms_swift_sharegpt_to_dj.py b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/ms_swift_sharegpt_to_dj.py new file mode 100644 index 000000000..8112c31bb --- /dev/null +++ b/tools/fmt_conversion/post_tuning_dialog/source_format_to_data_juicer_format/ms_swift_sharegpt_to_dj.py @@ -0,0 +1,168 @@ +# This tool is used to convert dataset in ModelScope-Swift ShareGPT format to a +# target dataset in Data-Juicer query-response format. +# +# ModelScope-Swift ShareGPT format: +# [ +# { +# "system": "", +# "conversation": [ +# { +# "human": "", +# "assistant": "" +# }, +# { +# "human": "", +# "assistant": "" +# } +# ] +# }, +# ...... +# ] +# +# Corresponding Data-Juicer format (query-response format): +# [ +# { +# "system": "", +# "query": "", +# "response": "" +# "history": [ +# [ +# "", +# "" +# ], +# ] +# }, +# ... +# ] +# +# Reference: +# https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Customization/Custom-dataset.md + +import json +import os +from typing import List, Union + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + + +def ms_swift_sharegpt_to_dj( + sample, + conversation_key: str = 'conversation', + human_key: str = 'human', + assistant_key: str = 'assistant', + system_key: str = 'system', + instruction_key: str = 'instruction', + multimodal_keys: Union[str, List[str]] = None, +): + modified_keys = {conversation_key, system_key, instruction_key} + if multimodal_keys: + modified_keys = modified_keys.union(set(multimodal_keys)) + new_sample = { + key: sample[key] + for key in sample if key not in modified_keys + } + + # find system prompt and instruction + if system_key in sample: + new_sample['system'] = sample[system_key] + if instruction_key in sample: + new_sample['instruction'] = sample[instruction_key] + + # conversations to query, response, history + conversation = sample[conversation_key] + # reconstruct conversations + conv_num = len(conversation) + if conv_num == 0: + query = '' + response = '' + history = [] + else: + # the last 1 sentence is query and response is empty + query = conversation[-1][human_key] + response = conversation[-1][assistant_key] + history = [[conv[human_key], conv[assistant_key]] + for conv in conversation[:-1]] + + # get the result sample + new_sample.update({ + 'query': query, + 'response': response, + 'history': history, + }) + + # update multimodal data + if multimodal_keys: + for mm_key in multimodal_keys: + if not isinstance(sample[mm_key], list): + new_sample[mm_key] = [sample[mm_key]] + else: + new_sample[mm_key] = sample[mm_key] + + return new_sample + + +@logger.catch(reraise=True) +def main( + src_ds_path: str, + tgt_ds_path: str, + conversation_key: str = 'conversation', + human_key: str = 'human', + assistant_key: str = 'assistant', + system_key: str = 'system', + instruction_key: str = 'instruction', + multimodal_keys: Union[str, List[str]] = None, +): + """ + Convert a ModelScope-Swift ShareGPT-like dataset to the Data-Juicer + query-response format. + + :param src_ds_path: the path to the source dataset. + :param tgt_ds_path: the path to store the converted target dataset. + :param conversation_key: the field key to store conversions. + :param human_key: the field key to store the sentence from human. + :param assistant_key: the field key to store the sentence from assistant. + :param system_key: the field key to store the system prompt. + :param instruction_key: the field key to store the instruction content. + :param multimodal_keys: optional keys to store multimodal data. + """ + + # check arguments + # check paths + if not os.path.exists(src_ds_path): + raise FileNotFoundError( + f'Input dataset [{src_ds_path}] can not be found.') + if not tgt_ds_path.endswith('.jsonl'): + raise ValueError('Only support "jsonl" target dataset file now.') + if os.path.dirname(tgt_ds_path) \ + and not os.path.exists(os.path.dirname(tgt_ds_path)): + logger.info(f'Create directory [{os.path.dirname(tgt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(tgt_ds_path)) + + if isinstance(multimodal_keys, str): + multimodal_keys = [multimodal_keys] + + # load dataset + logger.info('Loading original dataset.') + src_ds = json.load(open(src_ds_path, 'r', encoding='utf-8')) + logger.info(f'Load [{len(src_ds)}] samples.') + + with jl.open(tgt_ds_path, 'w') as writer: + for sample in tqdm(src_ds): + converted_sample = ms_swift_sharegpt_to_dj( + sample, + conversation_key=conversation_key, + human_key=human_key, + assistant_key=assistant_key, + system_key=system_key, + instruction_key=instruction_key, + multimodal_keys=multimodal_keys) + writer.write(converted_sample) + logger.info(f'Store the target dataset into [{tgt_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) From 7d5f37d6f7d5c41d135c7ff28ca5330c85cbbfec Mon Sep 17 00:00:00 2001 From: jackylee Date: Thu, 26 Dec 2024 14:48:43 +0800 Subject: [PATCH 6/7] Fix operators doc link for aggregators (#521) --- configs/config_all.yaml | 2 +- tests/ops/{Aggregator => aggregator}/__init__.py | 0 .../test_entity_attribute_aggregator.py | 0 .../test_most_relavant_entities_aggregator.py | 0 tests/ops/{Aggregator => aggregator}/test_nested_aggregator.py | 0 5 files changed, 1 insertion(+), 1 deletion(-) rename tests/ops/{Aggregator => aggregator}/__init__.py (100%) rename tests/ops/{Aggregator => aggregator}/test_entity_attribute_aggregator.py (100%) rename tests/ops/{Aggregator => aggregator}/test_most_relavant_entities_aggregator.py (100%) rename tests/ops/{Aggregator => aggregator}/test_nested_aggregator.py (100%) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 82cd6824e..5e12f5d8c 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -726,7 +726,7 @@ process: - key_value_grouper: # Group samples to batched samples according values in given keys. group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. -# Aggregator ops. +# aggregator ops. - entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs. api_model: 'gpt-4o' # API model name. entity: '孙悟空' # The given entity. diff --git a/tests/ops/Aggregator/__init__.py b/tests/ops/aggregator/__init__.py similarity index 100% rename from tests/ops/Aggregator/__init__.py rename to tests/ops/aggregator/__init__.py diff --git a/tests/ops/Aggregator/test_entity_attribute_aggregator.py b/tests/ops/aggregator/test_entity_attribute_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_entity_attribute_aggregator.py rename to tests/ops/aggregator/test_entity_attribute_aggregator.py diff --git a/tests/ops/Aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_most_relavant_entities_aggregator.py rename to tests/ops/aggregator/test_most_relavant_entities_aggregator.py diff --git a/tests/ops/Aggregator/test_nested_aggregator.py b/tests/ops/aggregator/test_nested_aggregator.py similarity index 100% rename from tests/ops/Aggregator/test_nested_aggregator.py rename to tests/ops/aggregator/test_nested_aggregator.py From 9466c7390a5cdb280eec6ebf6fe2e794b87dd582 Mon Sep 17 00:00:00 2001 From: BeachWang <1400012807@pku.edu.cn> Date: Thu, 26 Dec 2024 17:27:48 +0800 Subject: [PATCH 7/7] 10 more post-tuning OPs, regarding dialog data analysis from multiple aspects (#513) * add api call * add call_api ops * clean * minor update * more tests * update tests * update prompts * fix unittest * update tests * add docs * minor fix * add API processor * refine API processor * refine * chunk and extract events * fix bugs * fix tests * refine tests * extract nickname * nickname test done * lightRAG to OP * doc done * remove extra test * relavant -> relevant * fix minor error * group by op done * ValueError -> Exception * fix config_all error * fix prepare_api_model * fix rank sample None * constant fix key * aggregator op * init python_lambda_mapper * set default arg * fix init * add python_file_mapper * support text & most relavant entities * coverage ignore_errors * index sample * role_playing_system_prompt_yaml * system_prompt begin * support batched * remove unforkable * support batched & add docs * add docs * fix docs * update docs * pre-commit done * fix batch bug * fix batch bug * fix filter batch * fix filter batch * system prompt recipe done * not rank for filter * limit pyav version * add test for op * tmp * doc done * skip api test * add env dependency * install by recipe * dialog sent intensity * add query * change to dj_install * change to dj_install * developer doc done * query sent_int mapper * query sentiment test done * change meta pass * doc done * sentiment detection * diff label * sentiment * test done * dialog intent label * fix typo * prompt adjust * add more test * query intent detection * for test * for test * change model * fix typo * fix typo * for test * for test * doc done * dialog topic detection * dialog topic detection * dialog topic detection * dialog topic detection * dialog topic detection * dialog topic detection * query topic detection * query topic detection * query topic detection * query topic detection * query topic detection * doc done * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * meta tags aggregator * naive reverse grouper * naive reverse grouper * tags specified field * doc done * - rename tests/ops/Aggregator intotests/ops/aggregator for right linking; - minor fix for OP doc * rename for right doc linking in test dir * fix bad dingtalk link --------- Co-authored-by: null <3213204+drcege@users.noreply.github.com> Co-authored-by: gece.gc Co-authored-by: daoyuan --- README.md | 2 +- README_ZH.md | 2 +- configs/config_all.yaml | 95 ++++++++ data_juicer/ops/aggregator/__init__.py | 3 +- .../aggregator/entity_attribute_aggregator.py | 4 - .../ops/aggregator/meta_tags_aggregator.py | 222 ++++++++++++++++++ .../most_relavant_entities_aggregator.py | 4 - .../ops/aggregator/nested_aggregator.py | 4 - data_juicer/ops/grouper/__init__.py | 3 +- .../ops/grouper/naive_reverse_grouper.py | 26 ++ data_juicer/ops/mapper/__init__.py | 39 +-- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 + .../mapper/dialog_intent_detection_mapper.py | 216 +++++++++++++++++ .../dialog_sentiment_detection_mapper.py | 195 +++++++++++++++ .../dialog_sentiment_intensity_mapper.py | 207 ++++++++++++++++ .../mapper/dialog_topic_detection_mapper.py | 200 ++++++++++++++++ .../mapper/query_intent_detection_mapper.py | 84 +++++++ .../query_sentiment_detection_mapper.py | 85 +++++++ .../mapper/query_topic_detection_mapper.py | 84 +++++++ data_juicer/ops/selector/__init__.py | 4 +- .../selector/tags_specified_field_selector.py | 54 +++++ data_juicer/utils/auto_install_mapping.py | 8 + data_juicer/utils/common_utils.py | 18 +- data_juicer/utils/constant.py | 20 ++ docs/Operators.md | 23 +- docs/Operators_ZH.md | 22 +- .../aggregator/test_meta_tags_aggregator.py | 117 +++++++++ .../ops/grouper/test_naive_reverse_grouper.py | 83 +++++++ .../test_dialog_intent_detection_mapper.py | 170 ++++++++++++++ .../test_dialog_sentiment_detection_mapper.py | 141 +++++++++++ .../test_dialog_sentiment_intensity_mapper.py | 141 +++++++++++ .../test_dialog_topic_detection_mapper.py | 141 +++++++++++ .../test_extract_entity_attribute_mapper.py | 2 +- .../test_extract_entity_relation_mapper.py | 2 +- tests/ops/mapper/test_extract_event_mapper.py | 2 +- .../ops/mapper/test_extract_keyword_mapper.py | 2 +- .../mapper/test_extract_nickname_mapper.py | 2 +- .../test_extract_support_text_mapper.py | 2 +- .../test_query_intent_detection_mapper.py | 61 +++++ .../test_query_sentiment_detection_mapper.py | 62 +++++ .../test_query_topic_detection_mapper.py | 59 +++++ .../mapper/test_relation_identity_mapper.py | 2 +- .../selector/test_tags_specified_selector.py | 63 +++++ 43 files changed, 2621 insertions(+), 57 deletions(-) create mode 100644 data_juicer/ops/aggregator/meta_tags_aggregator.py create mode 100644 data_juicer/ops/grouper/naive_reverse_grouper.py create mode 100644 data_juicer/ops/mapper/dialog_intent_detection_mapper.py create mode 100644 data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py create mode 100644 data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py create mode 100644 data_juicer/ops/mapper/dialog_topic_detection_mapper.py create mode 100644 data_juicer/ops/mapper/query_intent_detection_mapper.py create mode 100644 data_juicer/ops/mapper/query_sentiment_detection_mapper.py create mode 100644 data_juicer/ops/mapper/query_topic_detection_mapper.py create mode 100644 data_juicer/ops/selector/tags_specified_field_selector.py create mode 100644 tests/ops/aggregator/test_meta_tags_aggregator.py create mode 100644 tests/ops/grouper/test_naive_reverse_grouper.py create mode 100644 tests/ops/mapper/test_dialog_intent_detection_mapper.py create mode 100644 tests/ops/mapper/test_dialog_sentiment_detection_mapper.py create mode 100644 tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py create mode 100644 tests/ops/mapper/test_dialog_topic_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_intent_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_sentiment_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_topic_detection_mapper.py create mode 100644 tests/ops/selector/test_tags_specified_selector.py diff --git a/README.md b/README.md index 95eba1da2..5e9ea1340 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ We provide a [playground](http://8.138.149.181/) with a managed JupyterLab. [Try [Platform for AI of Alibaba Cloud (PAI)](https://www.aliyun.com/product/bigdata/learn) has cited our work and integrated Data-Juicer into its data processing products. PAI is an AI Native large model and AIGC engineering platform that provides dataset management, computing power management, model tool chain, model development, model training, model deployment, and AI asset management. For documentation on data processing, please refer to: [PAI-Data Processing for Large Models](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX). Data-Juicer is being actively updated and maintained. We will periodically enhance and add more features, data recipes and datasets. -We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8253f30mgpjw&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs! +We welcome you to join us (via issues, PRs, [Slack](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw) channel, [DingDing](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11) group, ...), in promoting data-model co-development along with research and applications of (multimodal) LLMs! ---- diff --git a/README_ZH.md b/README_ZH.md index 6ba358b37..27bcb72f2 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -27,7 +27,7 @@ Data-Juicer 是一个一站式**多模态**数据处理系统,旨在为大语 [阿里云人工智能平台 PAI](https://www.aliyun.com/product/bigdata/learn) 已引用我们的工作,将Data-Juicer的能力集成到PAI的数据处理产品中。PAI提供包含数据集管理、算力管理、模型工具链、模型开发、模型训练、模型部署、AI资产管理在内的功能模块,为用户提供高性能、高稳定、企业级的大模型工程化能力。数据处理的使用文档请参考:[PAI-大模型数据处理](https://help.aliyun.com/zh/pai/user-guide/components-related-to-data-processing-for-foundation-models/?spm=a2c4g.11186623.0.0.3e9821a69kWdvX)。 -Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8275bc8g7ypp&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究! +Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们(issues/PRs/[Slack频道](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp) /[钉钉群](https://qr.dingtalk.com/action/joingroup?code=v1,k1,YFIXM2leDEk7gJP5aMC95AfYT+Oo/EP/ihnaIEhMyJM=&_dt_no_comment=1&origin=11)/...),一起推进LLM-数据的协同开发和研究! ---- diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 5e12f5d8c..1104a3a13 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -77,6 +77,68 @@ process: - clean_ip_mapper: # remove ip addresses from text. - clean_links_mapper: # remove web links from text. - clean_copyright_mapper: # remove copyright comments. + - dialog_intent_detection_mapper: # Mapper to generate user's intent labels in dialog. + api_model: 'gpt-4o' # API model name. + intent_candidates: null # The output intent candidates. Use the intent labels of the open domain if it is None. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + candidate_template: null # Template for intent candidates to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels to build the input prompt. + analysis_pattern: null # Pattern to parse the return intent analysis. + labels_pattern: null # Pattern to parse the return intent labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - dialog_sentiment_detection_mapper: # Mapper to generate user's sentiment labels in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels part to build the input prompt. + analysis_pattern: null # Pattern to parse the return sentiment analysis. + labels_pattern: null # Pattern to parse the return sentiment labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - dialog_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + intensity_template: null # Template for intensity part to build the input prompt. + analysis_pattern: null # Pattern to parse the return sentiment analysis. + intensity_pattern: null # Pattern to parse the return sentiment intensity. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - dialog_topic_detection_mapper: # Mapper to generate user's topic labels in dialog. + api_model: 'gpt-4o' # API model name. + max_round: 10 # The max num of round in the dialog to build the prompt. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + query_template: null # Template for query part to build the input prompt. + response_template: null # Template for response part to build the input prompt. + analysis_template: null # Template for analysis part to build the input prompt. + labels_template: null # Template for labels part to build the input prompt. + analysis_pattern: null # Pattern to parse the return topic analysis. + labels_pattern: null # Pattern to parse the return topic labels. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - expand_macro_mapper: # expand macro definitions in Latex text. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. api_model: 'gpt-4o' # API model name. @@ -277,6 +339,21 @@ process: - python_lambda_mapper: # executing Python lambda function on data samples. lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. batched: False # A boolean indicating whether to process input data in batches. + - query_intent_detection_mapper: # Mapper to predict user's Intent label in query. + hf_model: 'bespin-global/klue-roberta-small-3i4k-intent-classification' # Hugginface model ID to predict intent label. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. + - query_sentiment_detection_mapper: # Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. + hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment label. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. + - query_topic_detection_mapper: # Mapper to predict user's topic label in query. + hf_model: 'dstefa/roberta-base_topic_classification_nyt_news' # Hugginface model ID to predict topic label. + zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. + model_params: {} # model param for hf_model. + zh_to_en_model_params: {} # model param for zh_to_hf_model. - relation_identity_mapper: # identify relation between two entity in the text. api_model: 'gpt-4o' # API model name. source_entity: '孙悟空' # The source entity of the relation to be dentified. @@ -715,6 +792,9 @@ process: upper_percentile: # the upper bound of the percentile to be sampled lower_rank: # the lower rank of the percentile to be sampled upper_rank: # the upper rank of the percentile to be sampled + - tags_specified_field_selector: # Selector to select samples based on the tags of specified field. + field_key: '__dj__meta__.query_sentiment_label' # the target keys corresponding to multi-level field information need to be separated by '.' + target_tags: ['happy', 'sad'] # Target tags to be select. - topk_specified_field_selector: # selector to select top samples based on the sorted specified field field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.' top_ratio: # ratio of selected top samples @@ -723,6 +803,7 @@ process: # Grouper ops. - naive_grouper: # Group all samples to one batched sample. + - naive_reverse_grouper: # Split one batched sample to samples. - key_value_grouper: # Group samples to batched samples according values in given keys. group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. @@ -744,6 +825,20 @@ process: try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - meta_tags_aggregator: # Merge similar meta tags to one tag. + api_model: 'gpt-4o' # API model name. + meta_tag_key: '__dj__meta__.query_sentiment_label' # The key of the meta tag to be mapped. + target_tags: ['开心', '难过', '其他'] # The tags that is supposed to be mapped to. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # The system prompt. + input_template: null # The input template. + target_tag_template: null # The tap template for target tags. + tag_template: null # The tap template for each tag and its frequency. + output_pattern: null # The output pattern. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. api_model: 'gpt-4o' # API model name. entity: '孙悟空' # The given entity. diff --git a/data_juicer/ops/aggregator/__init__.py b/data_juicer/ops/aggregator/__init__.py index 4afe2974a..8aa87cbbd 100644 --- a/data_juicer/ops/aggregator/__init__.py +++ b/data_juicer/ops/aggregator/__init__.py @@ -1,8 +1,9 @@ from .entity_attribute_aggregator import EntityAttributeAggregator +from .meta_tags_aggregator import MetaTagsAggregator from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator from .nested_aggregator import NestedAggregator __all__ = [ - 'NestedAggregator', 'EntityAttributeAggregator', + 'NestedAggregator', 'MetaTagsAggregator', 'EntityAttributeAggregator', 'MostRelavantEntitiesAggregator' ] diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py index 96fbbb63f..16ec5fd07 100644 --- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -8,14 +8,10 @@ from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access, nested_set) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model from .nested_aggregator import NestedAggregator -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'entity_attribute_aggregator' diff --git a/data_juicer/ops/aggregator/meta_tags_aggregator.py b/data_juicer/ops/aggregator/meta_tags_aggregator.py new file mode 100644 index 000000000..808ef73da --- /dev/null +++ b/data_juicer/ops/aggregator/meta_tags_aggregator.py @@ -0,0 +1,222 @@ +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Aggregator +from data_juicer.utils.common_utils import is_string_list +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'meta_tags_aggregator' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class MetaTagsAggregator(Aggregator): + """ + Merge similar meta tags to one tag. + """ + + DEFAULT_SYSTEM_PROMPT = ('给定一些标签以及这些标签出现的频次,合并意思相近的标签。\n' + '要求:\n' + '- 任务分为两种情况,一种是给定合并后的标签,需要将合并前的标签映射到' + '这些标签。如果给定的合并后的标签中有类似“其他”这种标签,将无法归类的' + '标签合并到“其他”。以下是这种情况的一个样例:\n' + '合并后的标签应限定在[科技, 健康, 其他]中。\n' + '| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '| 医疗 | 20 |\n' + '| 信息技术 | 16 |\n' + '| 学习 | 19 |\n' + '| 气候变化 | 22 |\n' + '| 人工智能 | 11 |\n' + '| 养生 | 17 |\n' + '| 科学创新 | 10 |\n' + '\n' + '## 分析:“信息技术”、“人工智能”、“科学创新”都属于“科技”类别,“医疗' + '”和“养生”跟“健康”有关联,“学习”、“气候变化”和“科技”还有“健康”关' + '联不强,应该被归为“其他”。\n' + '## 标签合并:\n' + '** 医疗归类为健康 **\n' + '** 信息技术归类为科技 **\n' + '** 学习归类为其他 **\n' + '** 气候变化归类为其他 **\n' + '** 人工智能归类为科技 **\n' + '** 养生归类为健康 **\n' + '** 科学创新归类为科技 **\n' + '- 另外一种情况没有事先给定合并后的标签,需要生成合理的标签类别:' + '| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '| 医疗 | 20 |\n' + '| 信息技术 | 16 |\n' + '| 学习 | 2 |\n' + '| 气候变化 | 1 |\n' + '| 人工智能 | 11 |\n' + '| 养生 | 17 |\n' + '| 科学创新 | 10 |\n' + '\n' + '## 分析:“信息技术”、“人工智能”、“科学创新”这三个标签比较相近,归为' + '同一类,都属于“科技”类别,“医疗”和“养生”都跟“健康”有关系,可以归' + '类为“健康”,“学习”和“气候变化”跟其他标签关联度不强,且频次较低,' + '统一归类为“其他”。\n' + '## 标签合并:\n' + '** 医疗归类为健康 **\n' + '** 信息技术归类为科技 **\n' + '** 学习归类为其他 **\n' + '** 气候变化归类为其他 **\n' + '** 人工智能归类为科技 **\n' + '** 养生归类为健康 **\n' + '** 科学创新归类为科技 **\n') + + DEFAULT_INPUT_TEMPLATE = ('{target_tag_str}' + '| 合并前标签 | 频次 |\n' + '| ------ | ------ |\n' + '{tag_strs}') + DEFAULT_TARGET_TAG_TEMPLATE = '合并后的标签应限定在[{target_tags}]中。\n' + DEFAULT_TAG_TEMPLATE = '| {tag} | {cnt} |' + + DEFAULT_OUTPUT_PATTERN = r'\*\*\s*(\w+)归类为(\w+)\s*\*\*' + + def __init__(self, + api_model: str = 'gpt-4o', + meta_tag_key: str = MetaKeys.dialog_sentiment_labels, + target_tags: Optional[List[str]] = None, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + target_tag_template: Optional[str] = None, + tag_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param meta_tag_key: The key of the meta tag to be mapped. + :param target_tags: The tags that is supposed to be mapped to. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: The system prompt. + :param input_template: The input template. + :param target_tag_template: The tap template for target tags. + :param tag_template: The tap template for each tag and its + frequency. + :param output_pattern: The output pattern. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.meta_tag_key = meta_tag_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + target_tag_template = target_tag_template or \ + self.DEFAULT_TARGET_TAG_TEMPLATE + self.tag_template = tag_template or self.DEFAULT_TAG_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.target_tag_str = '' + if target_tags: + self.target_tag_str = target_tag_template.format( + target_tags=', '.join(target_tags)) + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + return_processor=True, + **model_params) + + self.try_num = try_num + + def parse_output(self, response): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(response) + tag_map = {tag1: tag2 for tag1, tag2 in matches} + return tag_map + + def meta_map(self, meta_cnts, rank=None): + + model, _ = get_model(self.model_key, rank, self.use_cuda()) + + tag_strs = [ + self.tag_template.format(tag=k, cnt=meta_cnts[k]) + for k in meta_cnts + ] + input_prompt = self.input_template.format( + target_tag_str=self.target_tag_str, tag_strs='\n'.join(tag_strs)) + + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + tag_map = {} + for i in range(self.try_num): + try: + response = model(messages, **self.sampling_params) + tag_map = self.parse_output(response) + if len(tag_map) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + return tag_map + + def process_single(self, sample=None, rank=None): + + if Fields.meta not in sample: + logger.warning('Not any meta in the sample!') + return sample + + metas = sample[Fields.meta] + # if not batched sample + if not isinstance(metas, list): + logger.warning('Not a batched sample!') + return sample + + meta_cnts = {} + + def update_dict(key): + if key in meta_cnts: + meta_cnts[key] += 1 + else: + meta_cnts[key] = 1 + + for meta in metas: + tag = meta[self.meta_tag_key] + if isinstance(tag, str): + update_dict(tag) + elif is_string_list(tag): + for t in tag: + update_dict(t) + else: + logger.warning('Meta tag must be string or list of string!') + return sample + + tag_map = self.meta_map(meta_cnts, rank=rank) + for i in range(len(metas)): + tag = metas[i][self.meta_tag_key] + if isinstance(tag, str) and tag in tag_map: + metas[i][self.meta_tag_key] = tag_map[tag] + elif is_string_list(tag): + metas[i][self.meta_tag_key] = [ + tag_map[t] if t in tag_map else t for t in tag + ] + + return sample diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index 69e1a209c..7ca49f505 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -7,14 +7,10 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (is_string_list, nested_access, nested_set) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'most_relavant_entities_aggregator' diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py index 124eb1470..ab25e057d 100644 --- a/data_juicer/ops/aggregator/nested_aggregator.py +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -6,12 +6,8 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, is_string_list, nested_access) -from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model -torch = LazyLoader('torch', 'torch') -vllm = LazyLoader('vllm', 'vllm') - OP_NAME = 'nested_aggregator' diff --git a/data_juicer/ops/grouper/__init__.py b/data_juicer/ops/grouper/__init__.py index 048b305e4..f81ba6aec 100644 --- a/data_juicer/ops/grouper/__init__.py +++ b/data_juicer/ops/grouper/__init__.py @@ -1,4 +1,5 @@ from .key_value_grouper import KeyValueGrouper from .naive_grouper import NaiveGrouper +from .naive_reverse_grouper import NaiveReverseGrouper -__all__ = ['NaiveGrouper', 'KeyValueGrouper'] +__all__ = ['KeyValueGrouper', 'NaiveGrouper', 'NaiveReverseGrouper'] diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py new file mode 100644 index 000000000..2535205b9 --- /dev/null +++ b/data_juicer/ops/grouper/naive_reverse_grouper.py @@ -0,0 +1,26 @@ +from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict + + +@OPERATORS.register_module('naive_reverse_grouper') +class NaiveReverseGrouper(Grouper): + """Split batched samples to samples. """ + + def __init__(self, *args, **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def process(self, dataset): + + if len(dataset) == 0: + return dataset + + samples = [] + for sample in dataset: + samples.extend(convert_dict_list_to_list_dict(sample)) + + return samples diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 9b86b83dc..8ffe7cc8e 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -8,6 +8,10 @@ from .clean_html_mapper import CleanHtmlMapper from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper +from .dialog_intent_detection_mapper import DialogIntentDetectionMapper +from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper +from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper +from .dialog_topic_detection_mapper import DialogTopicDetectionMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper from .extract_entity_relation_mapper import ExtractEntityRelationMapper @@ -33,6 +37,9 @@ from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper from .python_lambda_mapper import PythonLambdaMapper +from .query_intent_detection_mapper import QueryIntentDetectionMapper +from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper +from .query_topic_detection_mapper import QueryTopicDetectionMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper from .remove_comments_mapper import RemoveCommentsMapper @@ -71,6 +78,8 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', + 'DialogIntentDetectionMapper', 'DialogSentimentDetectionMapper', + 'DialogSentimentIntensityMapper', 'DialogTopicDetectionMapper', 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', @@ -81,18 +90,20 @@ 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', - 'PythonFileMapper', 'PythonLambdaMapper', 'RelationIdentityMapper', - 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', - 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', - 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', - 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', - 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', - 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', - 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', - 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper', + 'QueryIntentDetectionMapper', 'QueryTopicDetectionMapper', + 'RelationIdentityMapper', 'RemoveBibliographyMapper', + 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', + 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', + 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', + 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', + 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', + 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', + 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 8480ee899..bf9686409 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -55,6 +55,8 @@ def __init__(self, :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py new file mode 100644 index 000000000..7c8cba9ed --- /dev/null +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -0,0 +1,216 @@ +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_intent_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogIntentDetectionMapper(Mapper): + """ + Mapper to generate user's intent labels in dialog. Input from + history_key, query_key and response_key. Output lists of + labels and analysis for queries in the dialog, which is + store in 'dialog_intent_labels' and + 'dialog_intent_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ( + '请判断用户和LLM多轮对话中用户的意图。\n' + '要求:\n' + '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出' + '。\n' + '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n' + '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n' + '意图类别:信息查找\n' + 'LLM:你好!当然可以。机器学习是一种人工智能方法,允许计算机通过数据自动改进和学习。\n' + '用户:听起来很有趣,有没有推荐的入门书籍或资料?\n' + '意图分析:用户在请求建议,希望获取关于机器学习的入门资源。\n' + '意图类别:请求建议\n' + 'LLM:有很多不错的入门书籍和资源。一本常被推荐的书是《Python机器学习实践》(Python' + ' Machine Learning),它涵盖了基础知识和一些实际案例。此外,您还可以参考Coursera' + '或edX上的在线课程,这些课程提供了系统的学习路径。\n' + '用户:谢谢你的建议!我还想知道,学习机器学习需要什么样的数学基础?\n' + '意图分析:用户在寻求信息,希望了解学习机器学习所需的前提条件,特别是在数学方面。\n' + '意图类别:信息查找\n' + 'LLM:学习机器学习通常需要一定的数学基础,特别是线性代数、概率论和统计学。这些数学领' + '域帮助理解算法的工作原理和数据模式分析。如果您对这些主题不太熟悉,建议先从相关基础' + '书籍或在线资源开始学习。\n' + '用户:明白了,我会先补习这些基础知识。再次感谢你的帮助!\n' + '意图分析:用户表达感谢,并表示计划付诸行动来补充所需的基础知识。\n' + '意图类别:其他') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_CANDIDATES_TEMPLATE = '备选意图类别:[{candidate_str}]' + DEFAULT_ANALYSIS_TEMPLATE = '意图分析:{analysis}\n' + DEFAULT_LABELS_TEMPLATE = '意图类别:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '意图分析:(.*?)\n' + DEFAULT_LABELS_PATTERN = '意图类别:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + intent_candidates: Optional[List[str]] = None, + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + candidate_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param intent_candidates: The output intent candidates. Use the + intent labels of the open domain if it is None. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param candidate_template: Template for intent candidates to + build the input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels to build the + input prompt. + :param analysis_pattern: Pattern to parse the return intent + analysis. + :param labels_pattern: Pattern to parse the return intent + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.intent_candidates = intent_candidates + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.candidate_template = candidate_template or \ + self.DEFAULT_CANDIDATES_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_LABELS_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_LABELS_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + + if self.intent_candidates: + input_prompt = self.candidate_template.format( + candidate_str=','.join(self.intent_candidates)) + else: + input_prompt = '' + + if self.max_round > 0: + input_prompt += ''.join(history[-self.max_round * 4:]) + + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py new file mode 100644 index 000000000..33bccc5ce --- /dev/null +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -0,0 +1,195 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_sentiment_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogSentimentDetectionMapper(Mapper): + """ + Mapper to generate user's sentiment labels in dialog. Input from + history_key, query_key and response_key. Output lists of + labels and analysis for queries in the dialog, which is + store in 'dialog_sentiment_labels' and + 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所具有的情绪。\n' + '要求:\n' + '- 需要先进行分析,然后罗列用户所具有的情绪,下面是一个样例,请模仿样例格式输出' + '。\n' + '用户:最近工作压力好大,我觉得整个人都快被压垮了。\n' + '情感分析:用户的言语中透露出明显的压力和疲惫感,可能还夹杂着一些无助和焦虑。\n' + '情感:压力、疲惫、无助、焦虑\n' + 'LLM:听起来你真的承受了很多,面临这种情况确实不容易。有没有考虑过找一些放松的' + '方式,比如听音乐或者散步来减轻压力呢?\n' + '用户:试过了,但是好像没什么效果,每天的事情都堆积如山。\n' + '情感分析:用户感到无力解决现状,有挫败感,并且对尝试放松的方式失去信心。\n' + '情感:无力、挫败\n' + 'LLM:我理解你的感受,有时候压力积累到一定程度确实让人难以承受。或许你可以尝试' + '规划一下时间,把任务分成小块来完成,这样可能会减少一些压力感。\n' + '用户:这个主意不错,我会试着让自己更有条理一些,谢谢你的建议。\n' + '情感分析:用户对建议表现出认同和感激,同时展现出试图积极面对问题的态度。\n' + '情感:认同、感激、积极\n' + 'LLM:不用谢,我很高兴能帮到你。记得给自己一些时间去适应新的计划,有任何需要' + '随时可以跟我说哦!\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情感分析:{analysis}\n' + DEFAULT_LABELS_TEMPLATE = '情感:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '情感分析:(.*?)\n' + DEFAULT_LABELS_PATTERN = '情感:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param labels_pattern: Pattern to parse the return sentiment + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_LABELS_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_LABELS_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..198314ee3 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -0,0 +1,207 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_sentiment_intensity_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogSentimentIntensityMapper(Mapper): + """ + Mapper to predict user's sentiment intensity (from -5 to 5 in default + prompt) in dialog. Input from history_key, query_key and + response_key. Output lists of intensities and analysis for queries in + the dialog, which is store in 'dialog_sentiment_intensity' and + 'dialog_sentiment_intensity_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' + '要求:\n' + '- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,' + '-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n' + '- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n' + '用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n' + '情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n' + '情绪值:0\n' + 'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自' + '身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合' + '理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n' + '用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n' + '情绪分析:对回答感到满意,情绪正面。\n' + '情绪值:1\n' + 'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用' + '水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和' + '多用电子账单也是不错的选择。\n' + '用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促' + '进作用吗?\n' + '情绪分析:觉得回答实用且具体,情绪进一步转好。\n' + '情绪值:2\n' + 'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的' + '故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n' + '用户:什么吗,根本是答非所问。\n' + '情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n' + '情绪值:-2\n' + 'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加' + '环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n' + '用户:还行吧,算你能够掰回来。\n' + '情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n' + '情绪值:-1\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n' + DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n' + DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' + DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + intensity_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + intensity_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param intensity_template: Template for intensity part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param intensity_pattern: Pattern to parse the return sentiment + intensity. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.intensity_template = intensity_template or \ + self.DEFAULT_INTENSITY_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.intensity_pattern = intensity_pattern or \ + self.DEFAULT_INTENSITY_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + intensity = 0 + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.intensity_pattern, response) + if match: + intensity = int(match.group(1)) + + return analysis, intensity + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + intensities = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, intensity = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + intensities.append(intensity) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.intensity_template.format(intensity=intensity)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + intensity_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity}' + sample = nested_set(sample, intensity_key, intensities) + + return sample diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py new file mode 100644 index 000000000..7e8ee0b54 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py @@ -0,0 +1,200 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_topic_detection_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogTopicDetectionMapper(Mapper): + """ + Mapper to generate user's topic labels in dialog. Input from + history_key, query_key and response_key. Output lists of + labels and analysis for queries in the dialog, which is + store in 'dialog_sentiment_labels' and + 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所讨论的话题。\n' + '要求:\n' + '- 针对用户的每个query,需要先进行分析,然后列出用户正在讨论的话题,下面是' + '一个样例,请模仿样例格式输出。\n' + '用户:你好,今天我们来聊聊秦始皇吧。\n' + '话题分析:用户提到秦始皇,这是中国历史上第一位皇帝。\n' + '话题类别:历史\n' + 'LLM:当然可以,秦始皇是中国历史上第一个统一全国的皇帝,他在公元前221年建' + '立了秦朝,并采取了一系列重要的改革措施,如统一文字、度量衡和货币等。\n' + '用户:秦始皇修建的长城和现在的长城有什么区别?\n' + '话题分析:用户提到秦始皇修建的长城,并将其与现代长城进行比较,涉及建筑历史' + '和地理位置。\n' + '话题类别:历史' + 'LLM:秦始皇时期修建的长城主要是为了抵御北方游牧民族的入侵,它的规模和修建' + '技术相对较为简陋。现代人所看到的长城大部分是明朝时期修建和扩建的,明长城不' + '仅规模更大、结构更坚固,而且保存得比较完好。\n' + '用户:有意思,那么长城的具体位置在哪些省份呢?\n' + '话题分析:用户询问长城的具体位置,涉及到地理知识。\n' + '话题类别:地理\n' + 'LLM:长城横跨中国北方多个省份,主要包括河北、山西、内蒙古、宁夏、陕西、甘' + '肃和北京等。每一段长城都建在关键的战略位置,以便最大限度地发挥其防御作用' + '。\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '话题分析:{analysis}\n' + DEFAULT_LABELS_TEMPLATE = '话题类别:{labels}\n' + DEFAULT_ANALYSIS_PATTERN = '话题分析:(.*?)\n' + DEFAULT_LABELS_PATTERN = '话题类别:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: NonNegativeInt = 10, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + labels_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + labels_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param labels_template: Template for labels part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param labels_pattern: Pattern to parse the return sentiment + labels. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.labels_template = labels_template or \ + self.DEFAULT_LABELS_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.labels_pattern = labels_pattern or \ + self.DEFAULT_LABELS_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + + if self.max_round > 0: + input_prompt = ''.join(history[-self.max_round * 4:]) + else: + input_prompt = '' + + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + labels = '' + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.labels_pattern, response) + if match: + labels = match.group(1) + + return analysis, labels + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + labels_list = [] + history = [] + + dialog = sample[self.history_key] + if self.query_key in sample and sample[self.query_key]: + if self.response_key in sample and sample[self.response_key]: + dialog.append( + (sample[self.query_key], sample[self.response_key])) + else: + dialog.append((sample[self.query_key], '')) + + for qa in dialog: + input_prompt = self.build_input(history, qa) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, labels = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + labels_list.append(labels) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.labels_template.format(labels=labels)) + history.append(self.response_template.format(response=qa[1])) + + analysis_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels_analysis}' # noqa: E501 + sample = nested_set(sample, analysis_key, analysis_list) + labels_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels}' + sample = nested_set(sample, labels_key, labels_list) + + return sample diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py new file mode 100644 index 000000000..b0d240e2d --- /dev/null +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -0,0 +1,84 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_intent_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QueryIntentDetectionMapper(Mapper): + """ + Mapper to predict user's Intent label in query. Input from query_key. + Output intent label and corresponding score for the query, which is + store in 'query_intent_label' and 'query_intent_label_score' in + Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + hf_model: + str = 'bespin-global/klue-roberta-small-3i4k-intent-classification', # noqa: E501 E131 + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict intent label. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + labels = [r['label'] for r in results] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in labels] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_label, + labels[i]) + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_score, + scores[i]) + + return samples diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py new file mode 100644 index 000000000..634bdeab3 --- /dev/null +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -0,0 +1,85 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_sentiment_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QuerySentimentDetectionMapper(Mapper): + """ + Mapper to predict user's sentiment label ('negative', 'neutral' and + 'positive') in query. Input from query_key. + Output label and corresponding score for the query, which is + store in 'query_sentiment_label' and + 'query_sentiment_label_score' in Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + hf_model: + str = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis', # noqa: E501 E131 + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict sentiment label. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + labels = [r['label'] for r in results] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in labels] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set( + samples[Fields.meta][i], MetaKeys.query_sentiment_label, + labels[i]) + samples[Fields.meta][i] = nested_set( + samples[Fields.meta][i], MetaKeys.query_sentiment_score, + scores[i]) + + return samples diff --git a/data_juicer/ops/mapper/query_topic_detection_mapper.py b/data_juicer/ops/mapper/query_topic_detection_mapper.py new file mode 100644 index 000000000..8e5687ee3 --- /dev/null +++ b/data_juicer/ops/mapper/query_topic_detection_mapper.py @@ -0,0 +1,84 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_topic_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QueryTopicDetectionMapper(Mapper): + """ + Mapper to predict user's topic label in query. Input from query_key. + Output topic label and corresponding score for the query, which is + store in 'query_topic_label' and 'query_topic_label_score' in + Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + hf_model: + str = 'dstefa/roberta-base_topic_classification_nyt_news', # noqa: E501 E131 + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict topic label. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + labels = [r['label'] for r in results] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in labels] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_topic_label, + labels[i]) + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_topic_score, + scores[i]) + + return samples diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py index 22df12987..0339a2c5b 100644 --- a/data_juicer/ops/selector/__init__.py +++ b/data_juicer/ops/selector/__init__.py @@ -1,9 +1,11 @@ from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector from .random_selector import RandomSelector from .range_specified_field_selector import RangeSpecifiedFieldSelector +from .tags_specified_field_selector import TagsSpecifiedFieldSelector from .topk_specified_field_selector import TopkSpecifiedFieldSelector __all__ = [ 'FrequencySpecifiedFieldSelector', 'RandomSelector', - 'RangeSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector' + 'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector', + 'TopkSpecifiedFieldSelector' ] diff --git a/data_juicer/ops/selector/tags_specified_field_selector.py b/data_juicer/ops/selector/tags_specified_field_selector.py new file mode 100644 index 000000000..6fb32251a --- /dev/null +++ b/data_juicer/ops/selector/tags_specified_field_selector.py @@ -0,0 +1,54 @@ +import numbers +from typing import List + +from ..base_op import OPERATORS, Selector + + +@OPERATORS.register_module('tags_specified_field_selector') +class TagsSpecifiedFieldSelector(Selector): + """Selector to select samples based on the tags of specified + field.""" + + def __init__(self, + field_key: str = '', + target_tags: List[str] = None, + *args, + **kwargs): + """ + Initialization method. + + :param field_key: Selector based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param target_tags: Target tags to be select. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.field_key = field_key + self.target_tags = set(target_tags) + + def process(self, dataset): + if len(dataset) <= 1 or not self.field_key: + return dataset + + field_keys = self.field_key.split('.') + assert field_keys[0] in dataset.features.keys( + ), "'{}' not in {}".format(field_keys[0], dataset.features.keys()) + + selected_index = [] + for i, item in enumerate(dataset[field_keys[0]]): + field_value = item + for key in field_keys[1:]: + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + assert field_value is None or isinstance( + field_value, str) or isinstance( + field_value, numbers.Number + ), 'The {} item is not String, Numbers or NoneType'.format(i) + if field_value in self.target_tags: + selected_index.append(i) + + return dataset.select(selected_index) diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 5ea9091b0..3b8ec20aa 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -96,4 +96,12 @@ 'extract_support_text_mapper': ['openai'], 'pair_preference_mapper': ['openai'], 'relation_identity_mapper': ['openai'], + 'dialog_intent_detection_mapper': ['openai'], + 'dialog_sentiment_detection_mapper': ['openai'], + 'dialog_sentiment_intensity_mapper': ['openai'], + 'dialog_topic_intensity_mapper': ['openai'], + 'query_intent_detection_mapper': ['transformers'], + 'query_sentiment_detection_mapper': ['transformers'], + 'query_topic_detection_mapper': ['transformers'], + 'meta_tags_aggregator': ['openai'], } diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index bd649bb96..8a13ae361 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -69,17 +69,21 @@ def nested_set(data: dict, path: str, val): :param data: A dictionary with nested format. :param path: A dot-separated string representing the path to set. - This can include numeric indices when setting list - elements. :return: The nested data after the val set. """ keys = path.split('.') cur = data - for key in keys[:-1]: - if key not in cur: - cur[key] = {} - cur = cur[key] - cur[keys[-1]] = val + try: + for key in keys[:-1]: + if key not in cur: + cur[key] = {} + cur = cur[key] + if keys[-1] in cur: + logger.warning(f'Overwrite value in {path}!') + cur[keys[-1]] = val + except Exception: + logger.warning(f'Unvalid dot-separated path: {path}!') + return data return data diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 30686693e..83aa995ec 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -73,6 +73,26 @@ class Fields(object): support_text = DEFAULT_PREFIX + 'support_text__' +class MetaKeys(object): + + dialog_sentiment_intensity = 'dialog_sentiment_intensity' + dialog_sentiment_intensity_analysis = 'dialog_sentiment_intensity_analysis' + query_sentiment_label = 'query_sentiment_label' + query_sentiment_score = 'query_sentiment_label_score' + dialog_sentiment_labels = 'dialog_sentiment_labels' + dialog_sentiment_labels_analysis = 'dialog_sentiment_labels_analysis' + + dialog_intent_labels = 'dialog_intent_labels' + dialog_intent_labels_analysis = 'dialog_intent_labels_analysis' + query_intent_label = 'query_intent_label' + query_intent_score = 'query_intent_label_score' + + dialog_topic_labels = 'dialog_topic_labels' + dialog_topic_labels_analysis = 'dialog_topic_labels_analysis' + query_topic_label = 'query_topic_label' + query_topic_score = 'query_topic_label_score' + + class StatsKeysMeta(type): """ a helper class to track the mapping from OP's name to its used stats_keys diff --git a/docs/Operators.md b/docs/Operators.md index fe3c6d94d..ea84a360c 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -6,17 +6,17 @@ This page offers a basic description of the operators (OPs) in Data-Juicer. User ## Overview -The operators in Data-Juicer are categorized into 5 types. +The operators in Data-Juicer are categorized into 7 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 63 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 70 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | -| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | -| [ Grouper ]( #grouper ) | 2 | Group samples to batched samples | -| [ Aggregator ]( #aggregator ) | 3 | Aggregate for batched samples, such as summary or conclusion | +| [ Selector ]( #selector ) | 5 | Selects top samples based on ranking | +| [ Grouper ]( #grouper ) | 3 | Group samples to batched samples | +| [ Aggregator ]( #aggregator ) | 4 | Aggregate for batched samples, such as summary or conclusion | All the specific operators are listed below, each featured with several capability tags. @@ -68,6 +68,10 @@ All the specific operators are listed below, each featured with several capabili | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes HTML tags and returns plain text of all the nodes | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's intent labels in dialog. | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | +| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's sentiment labels in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | +| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | +| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to generate user's topic labels in dialog. | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -93,6 +97,9 @@ All the specific operators are listed below, each featured with several capabili | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Normalizes various Unicode punctuations to their ASCII equivalents | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python function defined in a file | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Executing Python lambda function on data samples | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's intent label in query. | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Mapper to predict user's topic label in query. | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Identify relation between two entity in the text. | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the bibliography of TeX documents | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes the comments of TeX documents | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | @@ -192,20 +199,24 @@ All the specific operators are listed below, each featured with several capabili | frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the frequency of the specified field | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) | | random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples randomly | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) | | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects samples within a specified range by comparing the values of the specified field | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | +| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Select samples based on the tags of specified + field. | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Selects top samples by comparing the values of the specified field | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | ## Grouper | Operator | Tags | Description | Source code | Unit tests | |------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| -| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | | naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group all samples to one batched sample. | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | +| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Split batched samples to samples. | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) | +| key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Group samples to batched samples according values in given keys. | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | ## Aggregator | Operator | Tags | Description | Source code | Unit tests | |------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------|---------------------------------------------------------------------------| | entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Return conclusion of the given entity's attribute from some docs. | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Merge similar meta tags to one tag. | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) | | most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities closely related to a given entity from some texts, and sort them in descending order of importance. | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | | nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Considering the limitation of input length, nested aggregate contents for each given number of samples. | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 61610a873..40710f68f 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -6,17 +6,17 @@ ## 概览 -Data-Juicer 中的算子分为以下 5 种类型。 +Data-Juicer 中的算子分为以下 7 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 63 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 70 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | -| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | -| [ Grouper ]( #grouper ) | 2 | 将样本分组,每一组组成一个批量样本 | -| [ Aggregator ]( #aggregator ) | 3 | 对批量样本进行汇总,如得出总结或结论 | +| [ Selector ]( #selector ) | 5 | 基于排序选取高质量样本 | +| [ Grouper ]( #grouper ) | 3 | 将样本分组,每一组组成一个批量样本 | +| [ Aggregator ]( #aggregator ) | 4 | 对批量样本进行汇总,如得出总结或结论 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -67,6 +67,10 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_html_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 HTML 标签并返回所有节点的纯文本 | [code](../data_juicer/ops/mapper/clean_html_mapper.py) | [tests](../tests/ops/mapper/test_clean_html_mapper.py) | | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | +| dialog_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户意图标签。 | [code](../data_juicer/ops/mapper/dialog_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_intent_detection_mapper.py) | +| dialog_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中用户的情感标签 | [code](../data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_detection_mapper.py) | +| dialog_sentiment_intensity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测对话中的情绪强度(默认从-5到5)。 | [code](../data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py) | [tests](../tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py) | +| dialog_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取对话中的用户的话题标签。 | [code](../data_juicer/ops/mapper/dialog_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_dialog_topic_detection_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | | extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | | extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | @@ -92,6 +96,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 | punctuation_normalization_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | [code](../data_juicer/ops/mapper/punctuation_normalization_mapper.py) | [tests](../tests/ops/mapper/test_punctuation_normalization_mapper.py) | | python_file_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行文件中定义的 Python 函数处理样本 | [code](../data_juicer/ops/mapper/python_file_mapper.py) | [tests](../tests/ops/mapper/test_python_file_mapper.py) | | python_lambda_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 执行 Python lambda 函数处理样本 | [code](../data_juicer/ops/mapper/python_lambda_mapper.py) | [tests](../tests/ops/mapper/test_python_lambda_mapper.py) | +| query_intent_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的意图标签。 | [code](../data_juicer/ops/mapper/query_intent_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_intent_detection_mapper.py) | +| query_sentiment_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的情感强度标签('negative'、'neutral'和'positive')。 | [code](../data_juicer/ops/mapper/query_sentiment_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_sentiment_detection_mapper.py) | +| query_topic_detection_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 预测用户查询中的话题标签。 | [code](../data_juicer/ops/mapper/query_topic_detection_mapper.py) | [tests](../tests/ops/mapper/test_query_topic_detection_mapper.py) | | relation_identity_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 识别一段文本中两个实体之间的关系 | [code](../data_juicer/ops/mapper/relation_identity_mapper.py) | [tests](../tests/ops/mapper/test_relation_identity_mapper.py) | | remove_bibliography_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档的参考文献 | [code](../data_juicer/ops/mapper/remove_bibliography_mapper.py) | [tests](../tests/ops/mapper/test_remove_bibliography_mapper.py) | | remove_comments_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 TeX 文档中的注释 | [code](../data_juicer/ops/mapper/remove_comments_mapper.py) | [tests](../tests/ops/mapper/test_remove_comments_mapper.py) | @@ -191,20 +198,23 @@ Data-Juicer 中的算子分为以下 5 种类型。 | frequency_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的频率选出前 k 个样本 | [code](../data_juicer/ops/selector/frequency_specified_field_selector.py) | [tests](../tests/ops/selector/test_frequency_specified_field_selector.py) | | random_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 随机筛选 k 个样本 | [code](../data_juicer/ops/selector/random_selector.py) | [tests](../tests/ops/selector/test_random_selector.py) | | range_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出指定范围的 k 个样本 | [code](../data_juicer/ops/selector/range_specified_field_selector.py) | [tests](../tests/ops/selector/test_range_specified_field_selector.py) | +| tags_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过指定字段的标签值筛选样例 | [code](../data_juicer/ops/selector/tags_specified_field_selector.py) | [tests](../tests/ops/selector/test_tags_specified_field_selector.py) | | topk_specified_field_selector | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 通过比较指定字段的值选出前 k 个样本 | [code](../data_juicer/ops/selector/topk_specified_field_selector.py) | [tests](../tests/ops/selector/test_topk_specified_field_selector.py) | ## Grouper | 算子 | 标签 | 描述 | 源码 | 单测样例 | |-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| +| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个batch化的样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | +| naive_reverse_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将batch化的样本拆分成普通的样本 | [code](../data_juicer/ops/grouper/naive_reverse_grouper.py) | [tests](../tests/ops/grouper/test_naive_reverse_grouper.py) | | key_value_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 根据给定键的值将样本分组,每一组组成一个批量样本。 | [code](../data_juicer/ops/grouper/key_value_grouper.py) | [tests](../tests/ops/grouper/test_key_value_grouper.py) | -| naive_grouper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将所有样本分为一个组,返回一个批量样本 | [code](../data_juicer/ops/grouper/naive_grouper.py) | [tests](../tests/ops/grouper/test_naive_grouper.py) | ## Aggregator | 算子 | 标签 | 描述 | 源码 | 单测样例 | |-------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------| | entity_attribute_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中总结出给定实体的属性 | [code](../data_juicer/ops/aggregator/entity_attribute_aggregator.py) | [tests](../tests/ops/aggregator/test_entity_attribute_aggregator.py) | +| meta_tags_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 将相似的标签合并成同一个标签。 | [code](../data_juicer/ops/aggregator/meta_tags_aggregator.py) | [tests](../tests/ops/aggregator/test_meta_tags_aggregator.py) | | most_relavant_entities_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从一些文本中抽取出与给定实体密切相关的实体,按重要性从高到低排序 | [code](../data_juicer/ops/aggregator/most_relavant_entities_aggregator.py) | [tests](../tests/ops/aggregator/test_most_relavant_entities_aggregator.py) | | nested_aggregator | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 考虑到输入长度的限制,对样本中的内容进行嵌套聚合。 | [code](../data_juicer/ops/aggregator/nested_aggregator.py) | [tests](../tests/ops/aggregator/test_nested_aggregator.py) | diff --git a/tests/ops/aggregator/test_meta_tags_aggregator.py b/tests/ops/aggregator/test_meta_tags_aggregator.py new file mode 100644 index 000000000..7aba225ae --- /dev/null +++ b/tests/ops/aggregator/test_meta_tags_aggregator.py @@ -0,0 +1,117 @@ +import unittest + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.aggregator import MetaTagsAggregator +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + + +@SKIPPED_TESTS.register_module() +class MetaTagsAggregatorTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples): + + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for data in new_dataset: + for k in data: + logger.info(f"{k}: {data[k]}") + + self.assertEqual(len(new_dataset), len(samples)) + + def test_default_aggregator(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '开心' + }, + { + MetaKeys.query_sentiment_label: '快乐' + }, + { + MetaKeys.query_sentiment_label: '难过' + }, + { + MetaKeys.query_sentiment_label: '不开心' + }, + { + MetaKeys.query_sentiment_label: '愤怒' + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.query_sentiment_label, + ) + self._run_helper(op, samples) + + + def test_target_tags(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.query_sentiment_label: '开心' + }, + { + MetaKeys.query_sentiment_label: '快乐' + }, + { + MetaKeys.query_sentiment_label: '难过' + }, + { + MetaKeys.query_sentiment_label: '不开心' + }, + { + MetaKeys.query_sentiment_label: '愤怒' + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.query_sentiment_label, + target_tags=['开心', '难过', '其他'] + ) + self._run_helper(op, samples) + + def test_tag_list(self): + samples = [ + { + Fields.meta: [ + { + MetaKeys.dialog_sentiment_labels: ['开心', '平静'] + }, + { + MetaKeys.dialog_sentiment_labels: ['快乐', '开心', '幸福'] + }, + { + MetaKeys.dialog_sentiment_labels: ['难过'] + }, + { + MetaKeys.dialog_sentiment_labels: ['不开心', '没头脑', '不高兴'] + }, + { + MetaKeys.dialog_sentiment_labels: ['愤怒', '愤慨'] + } + ] + }, + ] + op = MetaTagsAggregator( + api_model='qwen2.5-72b-instruct', + meta_tag_key=MetaKeys.dialog_sentiment_labels, + target_tags=['开心', '难过', '其他'] + ) + self._run_helper(op, samples) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/grouper/test_naive_reverse_grouper.py b/tests/ops/grouper/test_naive_reverse_grouper.py new file mode 100644 index 000000000..29c06451d --- /dev/null +++ b/tests/ops/grouper/test_naive_reverse_grouper.py @@ -0,0 +1,83 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.grouper.naive_reverse_grouper import NaiveReverseGrouper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class NaiveReverseGrouperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + new_dataset = op.run(dataset) + + for d, t in zip(new_dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_one_batched_sample(self): + + source = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.', + '欢迎来到阿里巴巴!' + ] + } + ] + + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + + op = NaiveReverseGrouper() + self._run_helper(op, source, target) + + + def test_two_batch_sample(self): + + source = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + ] + }, + { + 'text':[ + '欢迎来到阿里巴巴!' + ] + } + ] + + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + + op = NaiveReverseGrouper() + self._run_helper(op, source, target) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py new file mode 100644 index 000000000..bc3a18752 --- /dev/null +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py @@ -0,0 +1,170 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_intent_detection_mapper import DialogIntentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogIntentDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'意图:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(labels_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_intent_candidates(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogIntentDetectionMapper( + api_model='qwen2.5-72b-instruct', + intent_candidates=['评价', '讽刺', '表达困惑'] + ) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py new file mode 100644 index 000000000..b19bf6359 --- /dev/null +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -0,0 +1,141 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogSentimentDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(labels_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..a8953c3e4 --- /dev/null +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -0,0 +1,141 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity_analysis) + intensity_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity) + + for analysis, intensity in zip(analysis_list, intensity_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{intensity}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(intensity_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py new file mode 100644 index 000000000..887e96bad --- /dev/null +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -0,0 +1,141 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_topic_detection_mapper import DialogTopicDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogTopicDetectionMapper(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + + def _run_op(self, op, samples, target_len): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels_analysis) + labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels) + + for analysis, labels in zip(analysis_list, labels_list): + logger.info(f'分析:{analysis}') + logger.info(f'话题:{labels}') + + self.assertEqual(len(analysis_list), target_len) + self.assertEqual(len(labels_list), target_len) + + def test_default(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op, samples, 4) + + def test_max_round(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + def test_max_round_zero(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=0) + self._run_op(op, samples, 4) + + def test_query(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ) + ], + 'query': '你在说什么我听不懂。', + 'response': '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + max_round=1) + self._run_op(op, samples, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index f15b4ca3f..a2c156d48 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py index 40e3ca32d..0aed4fcee 100644 --- a/tests/ops/mapper/test_extract_entity_relation_mapper.py +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index aba40d73e..e936cb06c 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEventMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py index 5836f902a..2501a46ca 100644 --- a/tests/ops/mapper/test_extract_keyword_mapper.py +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractKeywordMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 2911a1002..457a7d53b 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractNicknameMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py index 0445d2526..080dfd672 100644 --- a/tests/ops/mapper/test_extract_support_text_mapper.py +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -10,7 +10,7 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.common_utils import nested_access -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py new file mode 100644 index 000000000..92d0346a4 --- /dev/null +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -0,0 +1,61 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): + + hf_model = 'bespin-global/klue-roberta-small-3i4k-intent-classification' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, label_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + label = nested_access(sample[Fields.meta], label_key) + self.assertEqual(label, target) + + def test_default(self): + + samples = [{ + 'query': '这样好吗?' + },{ + 'query': '站住!' + },{ + 'query': '今天阳光灿烂。' + } + ] + targets = ['question', 'command', 'statement'] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_intent_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '这样好吗?' + },{ + 'query': 'Is this okay?' + } + ] + targets = ['question', 'rhetorical question'] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_intent_label, targets) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py new file mode 100644 index 000000000..62ed0f380 --- /dev/null +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -0,0 +1,62 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_sentiment_detection_mapper import QuerySentimentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): + + hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, label_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + label = nested_access(sample[Fields.meta], label_key) + self.assertEqual(label, target) + + def test_default(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': '嗯嗯' + },{ + 'query': '没有希望。' + }, + ] + targets = ['positive', 'neutral', 'negative'] + + op = QuerySentimentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': 'That is great!' + } + ] + targets = ['neutral', 'positive'] + + op = QuerySentimentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py new file mode 100644 index 000000000..6304290c7 --- /dev/null +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -0,0 +1,59 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_topic_detection_mapper import QueryTopicDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase): + + hf_model = 'dstefa/roberta-base_topic_classification_nyt_news' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, label_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + label = nested_access(sample[Fields.meta], label_key) + self.assertEqual(label, target) + + def test_default(self): + + samples = [{ + 'query': '今天火箭和快船的比赛谁赢了。' + },{ + 'query': '你最近身体怎么样。' + } + ] + targets = ['Sports', 'Health and Wellness'] + + op = QueryTopicDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_topic_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '这样好吗?' + },{ + 'query': 'Is this okay?' + } + ] + targets = ['Lifestyle and Fashion', 'Health and Wellness'] + + op = QueryTopicDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_topic_label, targets) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py index d730cb79f..231b20ba1 100644 --- a/tests/ops/mapper/test_relation_identity_mapper.py +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class RelationIdentityMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/selector/test_tags_specified_selector.py b/tests/ops/selector/test_tags_specified_selector.py new file mode 100644 index 000000000..87c232a2b --- /dev/null +++ b/tests/ops/selector/test_tags_specified_selector.py @@ -0,0 +1,63 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.selector.tags_specified_field_selector import \ + TagsSpecifiedFieldSelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TagsSpecifiedFieldSelectorTest(DataJuicerTestCaseBase): + + def _run_tag_selector(self, dataset: Dataset, target_list, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_tag_select(self): + ds_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }, { + 'text': 'd', + 'meta': { + 'sentiment': 'angry', + } + }] + tgt_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }] + dataset = Dataset.from_list(ds_list) + op = TagsSpecifiedFieldSelector( + field_key='meta.sentiment', + target_tags=['happy', 'sad']) + self._run_tag_selector(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main()