Skip to content

Commit

Permalink
Merge branch 'master' into add-getters-for-pricing-plan-and-unit
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis authored Oct 18, 2023
2 parents 50c053d + faaf530 commit 9bed7a8
Show file tree
Hide file tree
Showing 54 changed files with 1,739 additions and 1,204 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""adding pricing and harware info to comp_tasks
Revision ID: 3810966d1534
Revises: 5c62b190e124
Create Date: 2023-10-17 14:35:21.032940+00:00
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "3810966d1534"
down_revision = "5c62b190e124"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"comp_tasks",
sa.Column(
"pricing_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
)
op.add_column(
"comp_tasks",
sa.Column(
"hardware_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("comp_tasks", "hardware_info")
op.drop_column("comp_tasks", "pricing_info")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ class NodeClass(enum.Enum):
),
column_created_datetime(timezone=True),
column_modified_datetime(timezone=True),
sa.Column(
"pricing_info",
postgresql.JSONB,
nullable=True,
doc="Billing information of this task",
),
sa.Column(
"hardware_info",
postgresql.JSONB,
nullable=True,
doc="Harware information of this task",
),
# ------
sa.UniqueConstraint("project_id", "node_id", name="project_node_uniqueness"),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import sqlalchemy
from aiopg.sa.connection import SAConnection
from simcore_postgres_database.models.projects_node_to_pricing_unit import (
projects_node_to_pricing_unit,
)
from sqlalchemy.dialects.postgresql import insert as pg_insert

from .errors import ForeignKeyViolation, UniqueViolation
from .models.projects_nodes import projects_nodes
Expand Down Expand Up @@ -193,3 +197,67 @@ async def delete(self, connection: SAConnection, *, node_id: uuid.UUID) -> None:
& (projects_nodes.c.node_id == f"{node_id}")
)
await connection.execute(delete_stmt)

async def get_project_node_pricing_unit_id(
self, connection: SAConnection, *, node_uuid: uuid.UUID
) -> tuple | None:
"""get a pricing unit that is connected to the project node or None if there is non connected
NOTE: Do not use this in an asyncio.gather call as this will fail!
"""
result = await connection.execute(
sqlalchemy.select(
projects_node_to_pricing_unit.c.pricing_plan_id,
projects_node_to_pricing_unit.c.pricing_unit_id,
)
.select_from(
projects_nodes.join(
projects_node_to_pricing_unit,
projects_nodes.c.project_node_id
== projects_node_to_pricing_unit.c.project_node_id,
)
)
.where(
(projects_nodes.c.project_uuid == f"{self.project_uuid}")
& (projects_nodes.c.node_id == f"{node_uuid}")
)
)
row = await result.fetchone()
if row:
return (row[0], row[1])
return None

async def connect_pricing_unit_to_project_node(
self,
connection: SAConnection,
*,
node_uuid: uuid.UUID,
pricing_plan_id: int,
pricing_unit_id: int,
) -> None:
result = await connection.scalar(
sqlalchemy.select(projects_nodes.c.project_node_id).where(
(projects_nodes.c.project_uuid == f"{self.project_uuid}")
& (projects_nodes.c.node_id == f"{node_uuid}")
)
)
project_node_id = int(result) if result else 0

insert_stmt = pg_insert(projects_node_to_pricing_unit).values(
project_node_id=project_node_id,
pricing_plan_id=pricing_plan_id,
pricing_unit_id=pricing_unit_id,
created=sqlalchemy.func.now(),
modified=sqlalchemy.func.now(),
)
on_update_stmt = insert_stmt.on_conflict_do_update(
index_elements=[
projects_node_to_pricing_unit.c.project_node_id,
],
set_={
"pricing_plan_id": insert_stmt.excluded.pricing_plan_id,
"pricing_unit_id": insert_stmt.excluded.pricing_unit_id,
"modified": sqlalchemy.func.now(),
},
)
await connection.execute(on_update_stmt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fastapi import Request

from ...modules.resource_usage_tracker_client import ResourceUsageTrackerClient


def get_rut_client(request: Request) -> ResourceUsageTrackerClient:
return ResourceUsageTrackerClient.get_from_state(request.app)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from models_library.clusters import DEFAULT_CLUSTER_ID
from models_library.projects import ProjectAtDB, ProjectID
from models_library.projects_nodes_io import NodeID
from models_library.services import ServiceKey, ServiceKeyVersion, ServiceVersion
from models_library.services import ServiceKeyVersion
from models_library.users import UserID
from models_library.utils.fastapi_encoders import jsonable_encoder
from pydantic import AnyHttpUrl, parse_obj_as
Expand All @@ -47,6 +47,7 @@
ClusterAccessForbiddenError,
ClusterNotFoundError,
ComputationalRunNotFoundError,
PricingPlanUnitNotFoundError,
ProjectNotFoundError,
SchedulerError,
)
Expand All @@ -62,7 +63,7 @@
from ...modules.db.repositories.projects import ProjectsRepository
from ...modules.db.repositories.users import UsersRepository
from ...modules.director_v0 import DirectorV0Client
from ...modules.resource_usage_client import ResourceUsageApi
from ...modules.resource_usage_tracker_client import ResourceUsageTrackerClient
from ...utils.computations import (
find_deprecated_tasks,
get_pipeline_state_from_task_states,
Expand All @@ -82,6 +83,7 @@
from ..dependencies.catalog import get_catalog_client
from ..dependencies.database import get_repository
from ..dependencies.director_v0 import get_director_v0_client
from ..dependencies.rut_client import get_rut_client
from ..dependencies.scheduler import get_scheduler
from .computations_tasks import analyze_pipeline

Expand Down Expand Up @@ -122,6 +124,7 @@ async def create_computation( # noqa: C901, PLR0912
scheduler: Annotated[BaseCompScheduler, Depends(get_scheduler)],
catalog_client: Annotated[CatalogClient, Depends(get_catalog_client)],
users_repo: Annotated[UsersRepository, Depends(get_repository(UsersRepository))],
rut_client: Annotated[ResourceUsageTrackerClient, Depends(get_rut_client)],
) -> ComputationGet:
log.debug(
"User %s is creating a new computation from project %s",
Expand Down Expand Up @@ -205,6 +208,8 @@ async def create_computation( # noqa: C901, PLR0912
published_nodes=min_computation_nodes if computation.start_pipeline else [],
user_id=computation.user_id,
product_name=computation.product_name,
rut_client=rut_client,
is_wallet=bool(computation.wallet_info),
)

if computation.start_pipeline:
Expand All @@ -225,25 +230,10 @@ async def create_computation( # noqa: C901, PLR0912
# Billing info
wallet_id = None
wallet_name = None
pricing_plan_id = None
pricing_unit_id = None
pricing_unit_cost_id = None
if computation.wallet_info:
wallet_id = computation.wallet_info.wallet_id
wallet_name = computation.wallet_info.wallet_name

resource_usage_api = ResourceUsageApi.get_from_state(request.app)
# NOTE: MD/SAN -> add real service version/key and store in DB, issue: https://github.com/ITISFoundation/osparc-issues/issues/1131
(
pricing_plan_id,
pricing_unit_id,
pricing_unit_cost_id,
) = await resource_usage_api.get_default_service_pricing_plan_and_pricing_unit(
computation.product_name,
ServiceKey("simcore/services/comp/itis/sleeper"),
ServiceVersion("2.1.6"),
)

await scheduler.run_new_pipeline(
computation.user_id,
computation.project_id,
Expand All @@ -259,9 +249,6 @@ async def create_computation( # noqa: C901, PLR0912
user_email=await users_repo.get_user_email(computation.user_id),
wallet_id=wallet_id,
wallet_name=wallet_name,
pricing_plan_id=pricing_plan_id,
pricing_unit_id=pricing_unit_id,
pricing_unit_cost_id=pricing_unit_cost_id,
),
use_on_demand_clusters=computation.use_on_demand_clusters,
)
Expand Down Expand Up @@ -317,6 +304,8 @@ async def create_computation( # noqa: C901, PLR0912
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE, detail=f"{e}"
) from e
except PricingPlanUnitNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{e}") from e


@router.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
osparc_variables_substitutions,
rabbitmq,
remote_debug,
resource_usage_client,
resource_usage_tracker_client,
storage,
)
from .errors import (
Expand Down Expand Up @@ -176,7 +176,7 @@ def init_app(settings: AppSettings | None = None) -> FastAPI:
comp_scheduler.setup(app)

if settings.DIRECTOR_V2_RESOURCE_USAGE_TRACKER:
resource_usage_client.setup(app)
resource_usage_tracker_client.setup(app)

node_rights.setup(app)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def __init__(self, project_id: ProjectID):
super().__init__(f"project {project_id} not found")


class PricingPlanUnitNotFoundError(DirectorException):
"""Pricing plan unit not found error"""

def __init__(self, msg: str):
super().__init__(msg)


class PipelineNotFoundError(DirectorException):
"""Pipeline not found error"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class RunMetadataDict(TypedDict, total=False):
user_email: str
wallet_id: int | None
wallet_name: str | None
pricing_plan_id: int | None
pricing_unit_id: int | None
pricing_unit_cost_id: int | None


class CompRunsAtDB(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class CompTaskAtDB(BaseModel):
)
created: datetime.datetime
modified: datetime.datetime
# Additional information about price and hardware (ex. AWS EC2 instance type)
pricing_info: dict | None
hardware_info: dict | None

@validator("state", pre=True)
@classmethod
Expand Down Expand Up @@ -214,6 +217,12 @@ class Config:
"last_heartbeat": None,
"created": "2022-05-20 13:28:31.139+00",
"modified": "2023-06-23 15:58:32.833081+00",
"pricing_info": {
"pricing_plan_id": 1,
"pricing_unit_id": 1,
"pricing_unit_cost_id": 1,
},
"hardware_info": {"aws_ec2_instance": ["aws-specific-instance"]},
}
for image_example in Image.Config.schema_extra["examples"]
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,15 @@ async def _process_started_tasks(
),
wallet_id=run_metadata.get("wallet_id"),
wallet_name=run_metadata.get("wallet_name"),
pricing_plan_id=run_metadata.get("pricing_plan_id"),
pricing_unit_id=run_metadata.get("pricing_unit_id"),
pricing_unit_cost_id=run_metadata.get("pricing_unit_cost_id"),
pricing_plan_id=t.pricing_info.get("pricing_plan_id")
if t.pricing_info
else None,
pricing_unit_id=t.pricing_info.get("pricing_unit_id")
if t.pricing_info
else None,
pricing_unit_cost_id=t.pricing_info.get("pricing_unit_cost_id")
if t.pricing_info
else None,
product_name=run_metadata.get(
"product_name", UNDEFINED_STR_METADATA
),
Expand Down
Loading

0 comments on commit 9bed7a8

Please sign in to comment.