Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace exam with training #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions toloka_provider/example_dags/text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@
" 'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json')\n",
"\n",
" project = tlk_tasks.create_project(project_conf)\n",
" exam = tlk_tasks.create_exam_pool(exam_conf, project=project)\n",
" pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1))\n",
" exam = tlk_tasks.create_training_pool(exam_conf, project=project)\n",
" pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1))\n",
"\n",
" dataset = prepare_datasets(\n",
" unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv',\n",
Expand Down Expand Up @@ -417,8 +417,8 @@
" 'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json')\n",
"\n",
"project = tlk_tasks.create_project(project_conf)\n",
"exam = tlk_tasks.create_exam_pool(exam_conf, project=project)\n",
"pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1))\n",
"exam = tlk_tasks.create_training_pool(exam_conf, project=project)\n",
"pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1))\n",
"\n",
"dataset = prepare_datasets(\n",
" unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv',\n",
Expand Down
4 changes: 2 additions & 2 deletions toloka_provider/example_dags/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def aggregate_assignments(assignments):
'https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/configs/pool.json')

project = tlk_tasks.create_project(project_conf)
exam = tlk_tasks.create_exam_pool(exam_conf, project=project)
pool = tlk_tasks.create_pool(pool_conf, project=project, exam_pool=exam, expiration=timedelta(days=1))
exam = tlk_tasks.create_training_pool(exam_conf, project=project)
pool = tlk_tasks.create_pool(pool_conf, project=project, training_pool=exam, expiration=timedelta(days=1))

dataset = prepare_datasets(
unlabeled_url='https://raw.githubusercontent.com/Toloka/airflow-provider-toloka/main/toloka_provider/example_dags/data/not_known.csv',
Expand Down
75 changes: 71 additions & 4 deletions toloka_provider/tasks/toloka.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, List, Union
import warnings

from airflow.decorators import task

Expand Down Expand Up @@ -38,6 +39,32 @@ def create_project(
return toloka_client.create_project(obj)


@task
@serialize_if_default_xcom_backend
@add_headers('airflow')
def create_training_pool(
obj: Union[Training, Dict, str, bytes],
*,
project: Union[Project, str, None] = None,
toloka_conn_id: str = 'toloka_default',
) -> Union[Training, str]:
"""Create a Training pool object from given config.

:param obj: Either a `Training` object itself or a config to make a `Training`.
:param project: Project to assign a training pool to. May pass either an object, config or project_id.
:param toloka_conn_id: Airflow connection with toloka credentials.
:returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise.

"""
toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id)
toloka_client = toloka_hook.get_conn()

obj = structure_from_conf(obj, Training)
if project is not None:
obj.project_id = extract_id(project, Project)
return toloka_client.create_training(obj)


@task
@serialize_if_default_xcom_backend
@add_headers('airflow')
Expand All @@ -55,6 +82,11 @@ def create_exam_pool(
:returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise.

"""
warnings.warn(
"""This function is deprecated. It will be deleted in v1.0.0.
Please use `toloka_provider.tasks.toloka.create_training_pool` instead""",
)

toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id)
toloka_client = toloka_hook.get_conn()

Expand All @@ -72,14 +104,16 @@ def create_pool(
*,
project: Union[Project, str, None] = None,
exam_pool: Union[Training, str, None] = None,
training_pool: Union[Training, str, None] = None,
expiration: Union[datetime, timedelta, None] = None,
toloka_conn_id: str = 'toloka_default',
) -> Union[Pool, str]:
"""Create a Pool object from given config.

:param obj: Either a `Pool` object itself or a config to make a `Pool`.
:param project: Project to assign a pool to. May pass either an object, config or project_id.
:param exam_pool: Related training pool. May pass either an object, config or pool_id.
:param exam_pool: Deprecated param, use `training_pool` instead.
:param training_pool: Related training pool. May pass either an object, config or pool_id.
:param expiration: Expiration setting. May pass any of:
* `None` if this setting is already present;
* `datetime` object to set exact datetime;
Expand All @@ -88,16 +122,21 @@ def create_pool(
:returns: Pool object if custom XCom backend is configured or its JSON serialized version otherwise.

"""
warnings.warn(
"""`exam_pool` is deprecated param. It will be deleted in v1.0.0.
Please use `training_pool` instead"""
)
toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id)
toloka_client = toloka_hook.get_conn()

obj = structure_from_conf(obj, Pool)
if project is not None:
obj.project_id = extract_id(project, Project)
if exam_pool:
training_pool = training_pool or exam_pool
if training_pool:
if obj.quality_control.training_requirement is None:
raise ValueError('pool.quality_control.training_requirement should be set before exam_pool assignment')
obj.quality_control.training_requirement.training_pool_id = extract_id(exam_pool, Training)
raise ValueError('pool.quality_control.training_requirement should be set before training_pool assignment')
obj.quality_control.training_requirement.training_pool_id = extract_id(training_pool, Training)
if expiration:
obj.will_expire = datetime.now() + expiration if isinstance(expiration, timedelta) else expiration
return toloka_client.create_pool(obj)
Expand Down Expand Up @@ -162,6 +201,29 @@ def open_pool(
return toloka_client.open_pool(pool_id)


@task
@serialize_if_default_xcom_backend
@add_headers('airflow')
def open_training_pool(
obj: Union[Training, str],
*,
toloka_conn_id: str = 'toloka_default',
) -> Union[Pool, str]:
"""Open given training pool.

:param obj: Training pool_id or `Training` object of it's config.
:param toloka_conn_id: Airflow connection with toloka credentials.
:returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise.

"""
toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id)
toloka_client = toloka_hook.get_conn()

training = structure_from_conf(obj, Training)
training_id = extract_id(training, Training)
return toloka_client.open_training(training_id)


@task
@serialize_if_default_xcom_backend
@add_headers('airflow')
Expand All @@ -177,6 +239,11 @@ def open_exam_pool(
:returns: Training object if custom XCom backend is configured or its JSON serialized version otherwise.

"""
warnings.warn(
"""This function is deprecated. It will be deleted in v1.0.0.
Please use `toloka_provider.tasks.toloka.open_training_pool` instead""",
)

toloka_hook = TolokaHook(toloka_conn_id=toloka_conn_id)
toloka_client = toloka_hook.get_conn()

Expand Down