Skip to content

Commit

Permalink
add grouper and aggregator op for system_prompt (#500)
Browse files Browse the repository at this point in the history
* chunk and extract events

* fix bugs

* fix tests

* refine tests

* extract nickname

* nickname test done

* lightRAG to OP

* doc done

* remove extra test

* relavant -> relevant

* fix minor error

* group by op done

* ValueError -> Exception

* fix config_all error

* fix prepare_api_model

* fix rank sample None

* constant fix key

* aggregator op

* init python_lambda_mapper

* set default arg

* fix init

* add python_file_mapper

* support text & most relavant entities

* coverage ignore_errors

* index sample

* role_playing_system_prompt_yaml

* system_prompt begin

* support batched

* remove unforkable

* support batched & add docs

* add docs

* fix docs

* update docs

* pre-commit done

* fix batch bug

* fix batch bug

* fix filter batch

* fix filter batch

* system prompt recipe done

* not rank for filter

* limit pyav version

* add test for op

* tmp

* doc done

* skip api test

* add env dependency

* install by recipe

* change to dj_install

* change to dj_install

* developer doc done

---------

Co-authored-by: null <[email protected]>
Co-authored-by: gece.gc <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent 46062f8 commit b4811a0
Show file tree
Hide file tree
Showing 44 changed files with 2,306 additions and 135 deletions.
83 changes: 81 additions & 2 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ process:
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_entity_attribute_mapper: # Extract attributes for given entities from the text.
api_model: 'gpt-4o' # API model name.
query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried.
query_attributes: ["人物性格"] # Attribute list to be queried.
api_model: 'gpt-4o' # API model name.
entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction.
entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted.
attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description.
Expand Down Expand Up @@ -153,6 +153,18 @@ process:
drop_text: false # If drop the text in the output.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- extract_support_text_mapper: # extract support sub text for a summary.
api_model: 'gpt-4o' # API model name.
summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len".
support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt: null # System prompt for the task.
input_template: null # Template for building the model input.
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
drop_text: false # If drop the text in the output.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- fix_unicode_mapper: # fix unicode errors in text.
- generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples.
hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs.
Expand Down Expand Up @@ -259,12 +271,27 @@ process:
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- python_python_mapper: # executing Python lambda function defined in a file.
- python_file_mapper: # executing Python lambda function defined in a file.
file_path: '' # The path to the Python file containing the function to be executed.
function_name: 'process_single' # The name of the function defined in the file to be executed.
- python_lambda_mapper: # executing Python lambda function on data samples.
lambda_str: '' # A string representation of the lambda function to be executed on data samples. If empty, the identity function is used.
batched: False # A boolean indicating whether to process input data in batches.
- relation_identity_mapper: # identify relation between two entity in the text.
api_model: 'gpt-4o' # API model name.
source_entity: '孙悟空' # The source entity of the relation to be dentified.
target_entity: '猪八戒' # The target entity of the relation to be identified.
input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2.
input_template: null # Template for building the model input.
output_pattern_template: null # Regular expression template for parsing model output.
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
drop_text: false # If drop the text in the output.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
doc_type: tex # comment type you want to remove. Only support 'tex' for now.
Expand Down Expand Up @@ -693,3 +720,55 @@ process:
top_ratio: # ratio of selected top samples
topk: # number of selected top sample
reverse: True # determine the sorting rule, if reverse=True, then sort in descending order

# Grouper ops.
- naive_grouper: # Group all samples to one batched sample.
- key_value_grouper: # Group samples to batched samples according values in given keys.
group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default.

# Aggregator ops.
- entity_attribute_aggregator: # Return conclusion of the given entity's attribute from some docs.
api_model: 'gpt-4o' # API model name.
entity: '孙悟空' # The given entity.
attribute: '人物经历' # The given attribute.
input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
word_limit: 100 # Prompt the output length.
max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute.
example_prompt: null # The example part in the system prompt.
input_template: null # The input template.
output_pattern_template: null # The output template.
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- most_relavant_entities_aggregator: # Extract entities closely related to a given entity from some texts, and sort them in descending order of importance.
api_model: 'gpt-4o' # API model name.
entity: '孙悟空' # The given entity.
query_entity_type: '人物' # The type of queried relavant entities.
input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and entity_type.
input_template: null # The input template.
output_pattern: null # The output pattern.
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
- nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples.
api_model: 'gpt-4o' # API model name.
input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default.
output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default.
max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None.
api_endpoint: null # URL endpoint for the API.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt: null # The system prompt.
sub_doc_template: null # The template for input text in each sample.
input_template: null # The input template.
try_num: 3 # The number of retry attempts when there is an API call error or output parsing error.
model_params: {} # Parameters for initializing the API model.
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
7 changes: 6 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,13 @@ def sort_op_by_types_and_names(op_name_classes):
if 'deduplicator' in name]
selector_ops = [(name, c) for (name, c) in op_name_classes
if 'selector' in name]
grouper_ops = [(name, c) for (name, c) in op_name_classes
if 'grouper' in name]
aggregator_ops = [(name, c) for (name, c) in op_name_classes
if 'aggregator' in name]
ops_sorted_by_types = sorted(mapper_ops) + sorted(filter_ops) + sorted(
deduplicator_ops) + sorted(selector_ops)
deduplicator_ops) + sorted(selector_ops) + sorted(grouper_ops) + \
sorted(aggregator_ops)
return ops_sorted_by_types


Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import deduplicator, filter, mapper, selector
from .base_op import (OPERATORS, UNFORKABLE, Deduplicator, Filter, Mapper,
Selector)
from . import aggregator, deduplicator, filter, grouper, mapper, selector
from .base_op import (OPERATORS, UNFORKABLE, Aggregator, Deduplicator, Filter,
Grouper, Mapper, Selector)
from .load import load_ops

__all__ = [
Expand All @@ -9,4 +9,6 @@
'Mapper',
'Deduplicator',
'Selector',
'Grouper',
'Aggregator',
]
8 changes: 8 additions & 0 deletions data_juicer/ops/aggregator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .entity_attribute_aggregator import EntityAttributeAggregator
from .most_relavant_entities_aggregator import MostRelavantEntitiesAggregator
from .nested_aggregator import NestedAggregator

__all__ = [
'NestedAggregator', 'EntityAttributeAggregator',
'MostRelavantEntitiesAggregator'
]
Loading

0 comments on commit b4811a0

Please sign in to comment.