From 4b0dddca355e17c32dec9b55f7d63dfb061ef40c Mon Sep 17 00:00:00 2001 From: Maximilian Jugl Date: Mon, 17 Jun 2024 14:12:35 +0200 Subject: [PATCH] feat: add helper function for safely building urls --- project/common.py | 19 +++++++++++ project/hub.py | 79 +++++++++++++++++++++++++++++++------------- tests/test_common.py | 8 +++++ 3 files changed, 83 insertions(+), 23 deletions(-) create mode 100644 project/common.py create mode 100644 tests/test_common.py diff --git a/project/common.py b/project/common.py new file mode 100644 index 0000000..a0cb4e5 --- /dev/null +++ b/project/common.py @@ -0,0 +1,19 @@ +import urllib.parse + + +def build_url( + scheme="", netloc="", path="", query: dict[str, str] | None = None, fragment="" +): + if query is None: + query = {} + + return urllib.parse.urlunsplit( + ( + scheme, + netloc, + path, + # square brackets must not be encoded to support central filtering stuff + urllib.parse.urlencode(query, safe="[]"), + fragment, + ), + ) diff --git a/project/hub.py b/project/hub.py index b2b0cf0..163e1af 100644 --- a/project/hub.py +++ b/project/hub.py @@ -1,14 +1,16 @@ import time +import urllib.parse from datetime import datetime from io import BytesIO from typing import TypeVar, Generic, Literal, Optional -from urllib.parse import urljoin from uuid import UUID import httpx from pydantic import BaseModel from starlette import status +from project.common import build_url + BucketType = Literal["CODE", "TEMP", "RESULT"] ResourceT = TypeVar("ResourceT") @@ -98,6 +100,13 @@ def __init__( force_acquire_on_init=False, ): self.base_url = base_url + + base_url_parts = urllib.parse.urlsplit(base_url) + + self._base_scheme = base_url_parts[0] + self._base_netloc = base_url_parts[1] + self._base_path = base_url_parts[2] + self._username = username self._password = password self._token_expiration_leeway_seconds = token_expiration_leeway_seconds @@ -107,9 +116,18 @@ def __init__( if force_acquire_on_init: self._acquire_token() + def _format_url(self, path: str, query: dict[str, str] = None): + return build_url( + self._base_scheme, + self._base_netloc, + urllib.parse.urljoin(self._base_path, path), + query, + "", + ) + def _acquire_token(self): r = httpx.post( - urljoin(self.base_url, "/token"), + self._format_url("/token"), json={ "grant_type": "password", "username": self._username, @@ -152,9 +170,24 @@ def __init__( self.base_url = base_url self.auth_client = auth_client + base_url_parts = urllib.parse.urlsplit(base_url) + + self._base_scheme = base_url_parts[0] + self._base_netloc = base_url_parts[1] + self._base_path = base_url_parts[2] + + def _format_url(self, path: str, query: dict[str, str] = None): + return build_url( + self._base_scheme, + self._base_netloc, + urllib.parse.urljoin(self._base_path, path), + query, + "", + ) + def create_project(self, name: str) -> Project: r = httpx.post( - urljoin(self.base_url, "/projects"), + self._format_url("/projects"), headers=self.auth_client.get_auth_bearer_header(), json={ "name": name, @@ -166,7 +199,7 @@ def create_project(self, name: str) -> Project: def delete_project(self, project_id: str | UUID): r = httpx.delete( - urljoin(self.base_url, f"/projects/{project_id}"), + self._format_url(f"/projects/{str(project_id)}"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -174,7 +207,7 @@ def delete_project(self, project_id: str | UUID): def get_project_list(self) -> ResourceList[Project]: r = httpx.get( - urljoin(self.base_url, "/projects"), + self._format_url("/projects"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -183,7 +216,7 @@ def get_project_list(self) -> ResourceList[Project]: def get_project_by_id(self, project_id: str | UUID) -> Project | None: r = httpx.get( - urljoin(self.base_url, f"/projects/{project_id}"), + self._format_url(f"/projects/{str(project_id)}"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -195,7 +228,7 @@ def get_project_by_id(self, project_id: str | UUID) -> Project | None: def create_analysis(self, name: str, project_id: str | UUID) -> Analysis: r = httpx.post( - urljoin(self.base_url, "/analyses"), + self._format_url("/analyses"), headers=self.auth_client.get_auth_bearer_header(), json={ "name": name, @@ -208,7 +241,7 @@ def create_analysis(self, name: str, project_id: str | UUID) -> Analysis: def delete_analysis(self, analysis_id: str | UUID): r = httpx.delete( - urljoin(self.base_url, f"/analyses/{analysis_id}"), + self._format_url(f"/analyses/{str(analysis_id)}"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -216,7 +249,7 @@ def delete_analysis(self, analysis_id: str | UUID): def get_analysis_list(self) -> ResourceList[Analysis]: r = httpx.get( - urljoin(self.base_url, "/analyses"), + self._format_url("/analyses"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -225,7 +258,7 @@ def get_analysis_list(self) -> ResourceList[Analysis]: def get_analysis_by_id(self, analysis_id: str | UUID) -> Analysis | None: r = httpx.get( - urljoin(self.base_url, f"/analyses/{analysis_id}"), + self._format_url(f"/analyses/{str(analysis_id)}"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -237,7 +270,7 @@ def get_analysis_by_id(self, analysis_id: str | UUID) -> Analysis | None: def get_bucket_list(self) -> ResourceList[Bucket]: r = httpx.get( - urljoin(self.base_url, "/storage/buckets"), + self._format_url("/storage/buckets"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -246,7 +279,7 @@ def get_bucket_list(self) -> ResourceList[Bucket]: def get_bucket_by_id(self, bucket_id: str | UUID) -> Bucket | None: r = httpx.get( - urljoin(self.base_url, f"/storage/buckets/{bucket_id}"), + self._format_url(f"/storage/buckets/{bucket_id}"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -258,7 +291,7 @@ def get_bucket_by_id(self, bucket_id: str | UUID) -> Bucket | None: def get_bucket_file_list(self) -> ResourceList[BucketFile]: r = httpx.get( - urljoin(self.base_url, "/storage/bucket-files"), + self._format_url("/storage/bucket-files"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -277,7 +310,7 @@ def upload_to_bucket( file = BytesIO(file) r = httpx.post( - urljoin(self.base_url, f"/storage/buckets/{bucket_id_or_name}/upload"), + self._format_url(f"/storage/buckets/{bucket_id_or_name}/upload"), headers=self.auth_client.get_auth_bearer_header(), files={"file": (file_name, file, content_type)}, ) @@ -287,7 +320,7 @@ def upload_to_bucket( def get_analysis_bucket_file_list(self) -> ResourceList[AnalysisBucketFile]: r = httpx.get( - urljoin(self.base_url, "/analysis-bucket-files"), + self._format_url("/analysis-bucket-files"), headers=self.auth_client.get_auth_bearer_header(), ) @@ -298,12 +331,12 @@ def get_analysis_bucket( self, analysis_id: str | UUID, bucket_type: BucketType ) -> AnalysisBucket: r = httpx.get( - urljoin( - self.base_url, - "/analysis-buckets?filter[analysis_id]=" - + str(analysis_id) - + "&filter[type]=" - + str(bucket_type), + self._format_url( + "/analysis-buckets", + query={ + "filter[analysis_id]": str(analysis_id), + "filter[type]": str(bucket_type), + }, ), headers=self.auth_client.get_auth_bearer_header(), ) @@ -323,7 +356,7 @@ def link_bucket_file_to_analysis( root=True, ) -> AnalysisBucketFile: r = httpx.post( - urljoin(self.base_url, "/analysis-bucket-files"), + self._format_url("/analysis-bucket-files"), headers=self.auth_client.get_auth_bearer_header(), json={ "bucket_id": str(analysis_bucket_id), @@ -339,7 +372,7 @@ def link_bucket_file_to_analysis( def stream_bucket_file(self, bucket_file_id: str | UUID, chunk_size=1024): with httpx.stream( "GET", - urljoin(self.base_url, f"/storage/bucket-files/{bucket_file_id}/stream"), + self._format_url(f"/storage/bucket-files/{bucket_file_id}/stream"), headers=self.auth_client.get_auth_bearer_header(), ) as r: for b in r.iter_bytes(chunk_size=chunk_size): diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..f1b8fb5 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,8 @@ +from project.common import build_url + + +def test_build_url(): + assert ( + build_url("http", "privateaim.de", "analysis", {"foo": "bar"}, "baz") + == "http://privateaim.de/analysis?foo=bar#baz" + )