Skip to content

Commit

Permalink
Fix returned status code (#287)
Browse files Browse the repository at this point in the history
* fix reture status code

* remove ray print rows

* update unittest
  • Loading branch information
Cathy0908 authored Apr 22, 2024
1 parent 4148016 commit 517efe1
Show file tree
Hide file tree
Showing 22 changed files with 97 additions and 26 deletions.
3 changes: 2 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,8 @@ def config_backup(cfg):
target_path = os.path.join(work_dir, os.path.basename(cfg_path))
logger.info(f'Back up the input config file [{cfg_path}] into the '
f'work_dir [{work_dir}]')
shutil.copyfile(cfg_path, target_path)
if not os.path.exists(target_path):
shutil.copyfile(cfg_path, target_path)


def display_config(cfg):
Expand Down
2 changes: 1 addition & 1 deletion demos/tool_quality_classifier/quality_classifier/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from qc_utils import eval, init_spark, load_datasets


@logger.catch
@logger.catch(reraise=True)
def main(positive_datasets=None,
negative_datasets=None,
model='my_quality_model',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
prepare_model)


@logger.catch
@logger.catch(reraise=True)
def main(dataset_path,
result_path,
model='gpt3',
Expand Down
2 changes: 1 addition & 1 deletion demos/tool_quality_classifier/quality_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qc_utils import eval, init_spark, load_datasets, shuffle, train


@logger.catch
@logger.catch(reraise=True)
def main(positive_datasets,
negative_datasets,
output_model_path='my_quality_model',
Expand Down
Empty file added tests/tools/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions tests/tools/test_process_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import os.path as osp
import shutil
import subprocess
import tempfile
import unittest
import yaml

from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class ProcessDataTest(DataJuicerTestCaseBase):

def setUp(self):
super().setUp()

self.tmp_dir = tempfile.TemporaryDirectory().name
if not osp.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)

def _test_status_code(self, yaml_file, output_path, text_keys):
data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))),
'demos', 'data', 'demo-dataset.jsonl')
yaml_config = {
'dataset_path': data_path,
'text_keys': text_keys,
'np': 2,
'export_path': output_path,
'process': [
{
'clean_copyright_mapper': None
}
]
}

with open(yaml_file, 'w') as file:
yaml.dump(yaml_config, file)

status_code = subprocess.call(
f'python tools/process_data.py --config {yaml_file}', shell=True)

return status_code

def test_status_code_0(self):
tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml')
tmp_out_path = osp.join(self.tmp_dir, 'output_0.json')
text_keys = 'text'

status_code = self._test_status_code(tmp_yaml_file, tmp_out_path, text_keys)

self.assertEqual(status_code, 0)
self.assertTrue(osp.exists(tmp_out_path))

def test_status_code_1(self):
tmp_yaml_file = osp.join(self.tmp_dir, 'config_1.yaml')
tmp_out_path = osp.join(self.tmp_dir, 'output_1.json')
text_keys = 'keys_not_exists'

status_code = self._test_status_code(tmp_yaml_file, tmp_out_path, text_keys)

self.assertEqual(status_code, 1)
self.assertFalse(osp.exists(tmp_out_path))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tools/analyze_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from data_juicer.core import Analyser


@logger.catch
@logger.catch(reraise=True)
def main():
analyser = Analyser()
analyser.run()
Expand Down
2 changes: 1 addition & 1 deletion tools/hpo/execute_hpo_3sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from data_juicer.utils.constant import StatsKeys


@logger.catch
@logger.catch(reraise=True)
def main():

path_k_sigma_recipe = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from data_juicer.utils.mm_utils import SpecialTokens


@logger.catch
@logger.catch(reraise=True)
def main(
dj_ds_path: str,
target_llava_ds_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from data_juicer.utils.mm_utils import SpecialTokens


@logger.catch
@logger.catch(reraise=True)
def main(
dj_ds_path: str,
target_mmc4_ds_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from data_juicer.utils.mm_utils import SpecialTokens


@logger.catch
@logger.catch(reraise=True)
def main(
dj_ds_path: str,
target_wavcaps_ds_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from data_juicer.utils.mm_utils import SpecialTokens


@logger.catch
@logger.catch(reraise=True)
def main(
llava_ds_path: str,
target_ds_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from data_juicer.utils.mm_utils import SpecialTokens


@logger.catch
@logger.catch(reraise=True)
def main(
mmc4_ds_path: str,
target_ds_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
convert_text_to_dj)


@logger.catch
@logger.catch(reraise=True)
def main(
video_chatgpt_ds_path: str,
target_ds_dj_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_all_files(dirname):
return result


@logger.catch
@logger.catch(reraise=True)
def main(
wavcaps_json_path: str,
wavcaps_audio_path: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
convert_text_to_dj)


@logger.catch
@logger.catch(reraise=True)
def main(
youku_ds_path: str,
target_ds_path: str,
Expand Down
4 changes: 2 additions & 2 deletions tools/preprocess/raw_arxiv_to_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from loguru import logger


@logger.catch
@logger.catch(reraise=True)
def tex_proj_loader(file_or_dir_path: pathlib.Path):
"""
Load the tex files from a tar file or a gzip file.
Expand Down Expand Up @@ -69,7 +69,7 @@ def tex_proj_loader(file_or_dir_path: pathlib.Path):
return files_and_content


@logger.catch
@logger.catch(reraise=True)
def convert_tar_to_jsonl(tar_fp, jsonl_fp, tmp_dir):
"""
Extract the contents of tex files from tar file, convert and
Expand Down
12 changes: 6 additions & 6 deletions tools/preprocess/raw_stackexchange_to_jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tqdm import tqdm


@logger.catch
@logger.catch(reraise=True)
def get_sites_count(path, topk=28):
"""
Take top-K sites(`.xml`) by its size of content
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_sites_count(path, topk=28):
return counts, sites


@logger.catch
@logger.catch(reraise=True)
def get_parents(site, counts):
"""
Find all answers's parent id, and groups by parent id
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_parents(site, counts):
return parents


@logger.catch
@logger.catch(reraise=True)
def get_qapairs(site, counts, parents):
"""
Find and group all matched pairs of question and answer in site file
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_qapairs(site, counts, parents):
return qa_pairs


@logger.catch
@logger.catch(reraise=True)
def process_qa_pair(pair, site_name, site_count):
"""
Sort answers by their score for question in qa pair sample,
Expand Down Expand Up @@ -171,7 +171,7 @@ def process_qa_pair(pair, site_name, site_count):
}


@logger.catch
@logger.catch(reraise=True)
def process_site(site, counts, src_dir, target_dir, num_proc=24):
"""
Convert one raw Stack Exchange site data to jsonl file.
Expand Down Expand Up @@ -207,7 +207,7 @@ def process_site(site, counts, src_dir, target_dir, num_proc=24):
f.write(json.dumps(result) + '\n')


@logger.catch
@logger.catch(reraise=True)
def main(src_dir, target_dir, topk=28, num_proc=1):
"""
Convert the raw Stack Exchange data downloaded from from Archive
Expand Down
2 changes: 1 addition & 1 deletion tools/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from data_juicer.core import Executor


@logger.catch
@logger.catch(reraise=True)
def main():
cfg = init_configs()
if cfg.executor_type == 'default':
Expand Down
2 changes: 1 addition & 1 deletion tools/quality_classifier/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tools.quality_classifier.qc_utils import eval, init_spark, load_datasets


@logger.catch
@logger.catch(reraise=True)
def main(positive_datasets=None,
negative_datasets=None,
model='my_quality_model',
Expand Down
2 changes: 1 addition & 1 deletion tools/quality_classifier/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
prepare_model)


@logger.catch
@logger.catch(reraise=True)
def predict_score(dataset_path,
result_path,
model='gpt3',
Expand Down
2 changes: 1 addition & 1 deletion tools/quality_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
shuffle, train)


@logger.catch
@logger.catch(reraise=True)
def main(positive_datasets,
negative_datasets,
output_model_path='my_quality_model',
Expand Down

0 comments on commit 517efe1

Please sign in to comment.