diff --git a/toloka_provider/example_dags/text_classification.ipynb b/toloka_provider/example_dags/text_classification.ipynb index 4e4add6..9ebc421 100644 --- a/toloka_provider/example_dags/text_classification.ipynb +++ b/toloka_provider/example_dags/text_classification.ipynb @@ -222,8 +222,8 @@ " 'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json')\n", "\n", " project = tlk_tasks.create_project(project_conf)\n", - " exam = tlk_tasks.create_exam_pool(exam_conf, project=project)\n", - " pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1))\n", + " exam = tlk_tasks.create_training_pool(exam_conf, project=project)\n", + " pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1))\n", "\n", " dataset = prepare_datasets(\n", " unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv',\n", @@ -417,8 +417,8 @@ " 'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json')\n", "\n", "project = tlk_tasks.create_project(project_conf)\n", - "exam = tlk_tasks.create_exam_pool(exam_conf, project=project)\n", - "pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1))\n", + "exam = tlk_tasks.create_training_pool(exam_conf, project=project)\n", + "pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1))\n", "\n", "dataset = prepare_datasets(\n", " unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv',\n", diff --git a/toloka_provider/example_dags/text_classification.py b/toloka_provider/example_dags/text_classification.py index 195f55f..9d866f6 100644 --- a/toloka_provider/example_dags/text_classification.py +++ b/toloka_provider/example_dags/text_classification.py @@ -114,8 +114,8 @@ def aggregate_assignments(assignments): 'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json') project = tlk_tasks.create_project(project_conf) - exam = tlk_tasks.create_exam_pool(exam_conf, project=project) - pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1)) + exam = tlk_tasks.create_training_pool(exam_conf, project=project) + pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1)) dataset = prepare_datasets( unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv', diff --git a/toloka_provider/tasks/toloka.py b/toloka_provider/tasks/toloka.py index cace0a3..4bfa4ef 100644 --- a/toloka_provider/tasks/toloka.py +++ b/toloka_provider/tasks/toloka.py @@ -4,6 +4,7 @@ import logging from datetime import datetime, timedelta from typing import Optional, Dict, List, Union +import warnings from airflow.decorators import task @@ -38,6 +39,32 @@ def create_project( return toloka_client.create_project(obj) +@task +@serialize_if_default_xcom_backend +@add_headers('airflow') +def create_training_pool( + obj: Union[Training, Dict, str, bytes], + *, + project: Union[Project, str, None] = None, + toloka_conn_id: str = 'toloka_default', +) -> Union[Training, str]: + """Create a Training pool object from given config. + + :param obj: Either a `Training` object itself or a config to make a `Training`. + :param project: Project to assign a training pool to. May pass either an object, config or project_id. + :param toloka_conn_id: Airflow connection with toloka credentials. + :returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise. + + """ + toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id) + toloka_client = toloka_hook.get_conn() + + obj = structure_from_conf(obj, Training) + if project is not None: + obj.project_id = extract_id(project, Project) + return toloka_client.create_training(obj) + + @task @serialize_if_default_xcom_backend @add_headers('airflow') @@ -55,6 +82,11 @@ def create_exam_pool( :returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise. """ + warnings.warn( + """This function is deprecated. It will be deleted in v1.0.0. + Please use `toloka_provider.tasks.toloka.create_training_pool` instead""", + ) + toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id) toloka_client = toloka_hook.get_conn() @@ -72,6 +104,7 @@ def create_pool( *, project: Union[Project, str, None] = None, exam_pool: Union[Training, str, None] = None, + training_pool: Union[Training, str, None] = None, expiration: Union[datetime, timedelta, None] = None, toloka_conn_id: str = 'toloka_default', ) -> Union[Pool, str]: @@ -79,7 +112,8 @@ def create_pool( :param obj: Either a `Pool` object itself or a config to make a `Pool`. :param project: Project to assign a pool to. May pass either an object, config or project_id. - :param exam_pool: Related training pool. May pass either an object, config or pool_id. + :param exam_pool: Deprecated param, use `training_pool` instead. + :param training_pool: Related training pool. May pass either an object, config or pool_id. :param expiration: Expiration setting. May pass any of: * `None` if this setting is already present; * `datetime` object to set exact datetime; @@ -88,16 +122,21 @@ def create_pool( :returns: Pool object if custom XCom backend is configured or its JSON serialized version otherwise. """ + warnings.warn( + """`exam_pool` is deprecated param. It will be deleted in v1.0.0. + Please use `training_pool` instead""" + ) toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id) toloka_client = toloka_hook.get_conn() obj = structure_from_conf(obj, Pool) if project is not None: obj.project_id = extract_id(project, Project) - if exam_pool: + training_pool = training_pool or exam_pool + if training_pool: if obj.quality_control.training_requirement is None: - raise ValueError('pool.quality_control.training_requirement should be set before exam_pool assignment') - obj.quality_control.training_requirement.training_pool_id = extract_id(exam_pool, Training) + raise ValueError('pool.quality_control.training_requirement should be set before training_pool assignment') + obj.quality_control.training_requirement.training_pool_id = extract_id(training_pool, Training) if expiration: obj.will_expire = datetime.now() + expiration if isinstance(expiration, timedelta) else expiration return toloka_client.create_pool(obj) @@ -162,6 +201,29 @@ def open_pool( return toloka_client.open_pool(pool_id) +@task +@serialize_if_default_xcom_backend +@add_headers('airflow') +def open_training_pool( + obj: Union[Training, str], + *, + toloka_conn_id: str = 'toloka_default', +) -> Union[Pool, str]: + """Open given training pool. + + :param obj: Training pool_id or `Training` object of it's config. + :param toloka_conn_id: Airflow connection with toloka credentials. + :returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise. + + """ + toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id) + toloka_client = toloka_hook.get_conn() + + training = structure_from_conf(obj, Training) + training_id = extract_id(training, Training) + return toloka_client.open_training(training_id) + + @task @serialize_if_default_xcom_backend @add_headers('airflow') @@ -177,6 +239,11 @@ def open_exam_pool( :returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise. """ + warnings.warn( + """This function is deprecated. It will be deleted in v1.0.0. + Please use `toloka_provider.tasks.toloka.open_training_pool` instead""", + ) + toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id) toloka_client = toloka_hook.get_conn()