Skip to content

Commit

Permalink
first attempt at adding pricing plan and unit when starting job
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis committed Oct 18, 2023
1 parent b02f9f8 commit f05c481
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Annotated, Final
from uuid import UUID

from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, Request, status
from fastapi.exceptions import HTTPException
from fastapi.responses import RedirectResponse
from fastapi_pagination.api import create_page
Expand All @@ -15,6 +15,7 @@
from models_library.api_schemas_webserver.wallets import WalletGet
from models_library.clusters import ClusterID
from models_library.projects_nodes_io import BaseFileLink
from pydantic import ValidationError, parse_obj_as
from pydantic.types import PositiveInt
from servicelib.logging_utils import log_context

Expand All @@ -31,6 +32,7 @@
JobMetadata,
JobMetadataUpdate,
JobOutputs,
JobPricingSpecification,
JobStatus,
)
from ...models.schemas.solvers import Solver, SolverKeyId
Expand Down Expand Up @@ -78,6 +80,13 @@ def _raise_if_job_not_associated_with_solver(
)


def _get_pricing_plan_and_unit(request: Request) -> JobPricingSpecification | None:
try:
return parse_obj_as(JobPricingSpecification, request.headers)
except ValidationError:
return None


# JOBS ---------------
#
# - Similar to docker container's API design (container = job and image = solver)
Expand Down Expand Up @@ -283,11 +292,13 @@ async def delete_job(
response_model=JobStatus,
)
async def start_job(
request: Request,
solver_key: SolverKeyId,
version: VersionStr,
job_id: JobID,
user_id: Annotated[PositiveInt, Depends(get_current_user_id)],
director2_api: Annotated[DirectorV2Api, Depends(get_api_client(DirectorV2Api))],
webserver_api: Annotated[AuthSession, Depends(get_webserver_session)],
product_name: Annotated[str, Depends(get_product_name)],
groups_extra_properties_repository: Annotated[
GroupsExtraPropertiesRepository,
Expand All @@ -303,15 +314,28 @@ async def start_job(
job_name = _compose_job_resource_name(solver_key, version, job_id)
_logger.debug("Start Job '%s'", job_name)

task = await director2_api.start_computation(
project_id=job_id,
user_id=user_id,
product_name=product_name,
cluster_id=cluster_id,
groups_extra_properties_repository=groups_extra_properties_repository,
)
job_status: JobStatus = create_jobstatus_from_task(task)
return job_status
if pricing_spec := _get_pricing_plan_and_unit(request):
with log_context(_logger, logging.DEBUG, "Set pricing plan and unit"):
project: ProjectGet = await webserver_api.get_project(project_id=job_id)
node_ids = list(project.workbench.keys())
assert len(node_ids) == 1 # nosec
await webserver_api.put_project_node_pricing_plan_and_unit(
project_id=job_id,
node_id=UUID(node_ids[0]),
pricing_plan=pricing_spec.pricing_plan,
pricing_unit=pricing_spec.pricing_unit,
)

with log_context(_logger, logging.DEBUG, "Starting job"):
task = await director2_api.start_computation(
project_id=job_id,
user_id=user_id,
product_name=product_name,
cluster_id=cluster_id,
groups_extra_properties_repository=groups_extra_properties_repository,
)
job_status: JobStatus = create_jobstatus_from_task(task)
return job_status


@router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from pydantic import (
BaseModel,
ConstrainedInt,
Extra,
Field,
HttpUrl,
PositiveInt,
StrictBool,
StrictFloat,
StrictInt,
Expand Down Expand Up @@ -276,3 +278,11 @@ class Config(BaseConfig):
"stopped_at": None,
}
}


class JobPricingSpecification(BaseModel):
pricing_plan: PositiveInt
pricing_unit: PositiveInt

class Config:
extra = Extra.ignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from models_library.projects import ProjectID
from models_library.rest_pagination import Page
from models_library.utils.fastapi_encoders import jsonable_encoder
from pydantic import ValidationError
from pydantic import PositiveInt, ValidationError
from pydantic.errors import PydanticErrorMixin
from servicelib.aiohttp.long_running_tasks.server import TaskStatus
from servicelib.error_codes import create_error_code
Expand Down Expand Up @@ -392,6 +392,20 @@ async def get_project_node_pricing_unit(
data = Envelope[PricingUnitGet].parse_raw(response.text).data
return data

async def put_project_node_pricing_plan_and_unit(
self,
project_id: UUID,
node_id: UUID,
pricing_plan: PositiveInt,
pricing_unit: PositiveInt,
) -> None:
with _handle_webserver_api_errors():
response = await self.client.put(
f"/projects/{project_id}/nodes/{node_id}/pricing-plans/{pricing_plan}/pricing-units/{pricing_unit}",
cookies=self.session_cookies,
)
response.raise_for_status()

# WALLETS -------------------------------------------------

async def get_wallet(self, wallet_id: int) -> WalletGet:
Expand Down

0 comments on commit f05c481

Please sign in to comment.