Skip to content

Commit

Permalink
Merge pull request #44 from PrivateAIM/build-url-util
Browse files Browse the repository at this point in the history
feat: add helper function for safely building urls
  • Loading branch information
mjugl authored Jun 17, 2024
2 parents d1bd544 + 4b0dddc commit 6af73a8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 23 deletions.
19 changes: 19 additions & 0 deletions project/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import urllib.parse


def build_url(
scheme="", netloc="", path="", query: dict[str, str] | None = None, fragment=""
):
if query is None:
query = {}

return urllib.parse.urlunsplit(
(
scheme,
netloc,
path,
# square brackets must not be encoded to support central filtering stuff
urllib.parse.urlencode(query, safe="[]"),
fragment,
),
)
79 changes: 56 additions & 23 deletions project/hub.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import time
import urllib.parse
from datetime import datetime
from io import BytesIO
from typing import TypeVar, Generic, Literal, Optional
from urllib.parse import urljoin
from uuid import UUID

import httpx
from pydantic import BaseModel
from starlette import status

from project.common import build_url

BucketType = Literal["CODE", "TEMP", "RESULT"]
ResourceT = TypeVar("ResourceT")

Expand Down Expand Up @@ -98,6 +100,13 @@ def __init__(
force_acquire_on_init=False,
):
self.base_url = base_url

base_url_parts = urllib.parse.urlsplit(base_url)

self._base_scheme = base_url_parts[0]
self._base_netloc = base_url_parts[1]
self._base_path = base_url_parts[2]

self._username = username
self._password = password
self._token_expiration_leeway_seconds = token_expiration_leeway_seconds
Expand All @@ -107,9 +116,18 @@ def __init__(
if force_acquire_on_init:
self._acquire_token()

def _format_url(self, path: str, query: dict[str, str] = None):
return build_url(
self._base_scheme,
self._base_netloc,
urllib.parse.urljoin(self._base_path, path),
query,
"",
)

def _acquire_token(self):
r = httpx.post(
urljoin(self.base_url, "/token"),
self._format_url("/token"),
json={
"grant_type": "password",
"username": self._username,
Expand Down Expand Up @@ -152,9 +170,24 @@ def __init__(
self.base_url = base_url
self.auth_client = auth_client

base_url_parts = urllib.parse.urlsplit(base_url)

self._base_scheme = base_url_parts[0]
self._base_netloc = base_url_parts[1]
self._base_path = base_url_parts[2]

def _format_url(self, path: str, query: dict[str, str] = None):
return build_url(
self._base_scheme,
self._base_netloc,
urllib.parse.urljoin(self._base_path, path),
query,
"",
)

def create_project(self, name: str) -> Project:
r = httpx.post(
urljoin(self.base_url, "/projects"),
self._format_url("/projects"),
headers=self.auth_client.get_auth_bearer_header(),
json={
"name": name,
Expand All @@ -166,15 +199,15 @@ def create_project(self, name: str) -> Project:

def delete_project(self, project_id: str | UUID):
r = httpx.delete(
urljoin(self.base_url, f"/projects/{project_id}"),
self._format_url(f"/projects/{str(project_id)}"),
headers=self.auth_client.get_auth_bearer_header(),
)

r.raise_for_status()

def get_project_list(self) -> ResourceList[Project]:
r = httpx.get(
urljoin(self.base_url, "/projects"),
self._format_url("/projects"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -183,7 +216,7 @@ def get_project_list(self) -> ResourceList[Project]:

def get_project_by_id(self, project_id: str | UUID) -> Project | None:
r = httpx.get(
urljoin(self.base_url, f"/projects/{project_id}"),
self._format_url(f"/projects/{str(project_id)}"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -195,7 +228,7 @@ def get_project_by_id(self, project_id: str | UUID) -> Project | None:

def create_analysis(self, name: str, project_id: str | UUID) -> Analysis:
r = httpx.post(
urljoin(self.base_url, "/analyses"),
self._format_url("/analyses"),
headers=self.auth_client.get_auth_bearer_header(),
json={
"name": name,
Expand All @@ -208,15 +241,15 @@ def create_analysis(self, name: str, project_id: str | UUID) -> Analysis:

def delete_analysis(self, analysis_id: str | UUID):
r = httpx.delete(
urljoin(self.base_url, f"/analyses/{analysis_id}"),
self._format_url(f"/analyses/{str(analysis_id)}"),
headers=self.auth_client.get_auth_bearer_header(),
)

r.raise_for_status()

def get_analysis_list(self) -> ResourceList[Analysis]:
r = httpx.get(
urljoin(self.base_url, "/analyses"),
self._format_url("/analyses"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -225,7 +258,7 @@ def get_analysis_list(self) -> ResourceList[Analysis]:

def get_analysis_by_id(self, analysis_id: str | UUID) -> Analysis | None:
r = httpx.get(
urljoin(self.base_url, f"/analyses/{analysis_id}"),
self._format_url(f"/analyses/{str(analysis_id)}"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -237,7 +270,7 @@ def get_analysis_by_id(self, analysis_id: str | UUID) -> Analysis | None:

def get_bucket_list(self) -> ResourceList[Bucket]:
r = httpx.get(
urljoin(self.base_url, "/storage/buckets"),
self._format_url("/storage/buckets"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -246,7 +279,7 @@ def get_bucket_list(self) -> ResourceList[Bucket]:

def get_bucket_by_id(self, bucket_id: str | UUID) -> Bucket | None:
r = httpx.get(
urljoin(self.base_url, f"/storage/buckets/{bucket_id}"),
self._format_url(f"/storage/buckets/{bucket_id}"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -258,7 +291,7 @@ def get_bucket_by_id(self, bucket_id: str | UUID) -> Bucket | None:

def get_bucket_file_list(self) -> ResourceList[BucketFile]:
r = httpx.get(
urljoin(self.base_url, "/storage/bucket-files"),
self._format_url("/storage/bucket-files"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -277,7 +310,7 @@ def upload_to_bucket(
file = BytesIO(file)

r = httpx.post(
urljoin(self.base_url, f"/storage/buckets/{bucket_id_or_name}/upload"),
self._format_url(f"/storage/buckets/{bucket_id_or_name}/upload"),
headers=self.auth_client.get_auth_bearer_header(),
files={"file": (file_name, file, content_type)},
)
Expand All @@ -287,7 +320,7 @@ def upload_to_bucket(

def get_analysis_bucket_file_list(self) -> ResourceList[AnalysisBucketFile]:
r = httpx.get(
urljoin(self.base_url, "/analysis-bucket-files"),
self._format_url("/analysis-bucket-files"),
headers=self.auth_client.get_auth_bearer_header(),
)

Expand All @@ -298,12 +331,12 @@ def get_analysis_bucket(
self, analysis_id: str | UUID, bucket_type: BucketType
) -> AnalysisBucket:
r = httpx.get(
urljoin(
self.base_url,
"/analysis-buckets?filter[analysis_id]="
+ str(analysis_id)
+ "&filter[type]="
+ str(bucket_type),
self._format_url(
"/analysis-buckets",
query={
"filter[analysis_id]": str(analysis_id),
"filter[type]": str(bucket_type),
},
),
headers=self.auth_client.get_auth_bearer_header(),
)
Expand All @@ -323,7 +356,7 @@ def link_bucket_file_to_analysis(
root=True,
) -> AnalysisBucketFile:
r = httpx.post(
urljoin(self.base_url, "/analysis-bucket-files"),
self._format_url("/analysis-bucket-files"),
headers=self.auth_client.get_auth_bearer_header(),
json={
"bucket_id": str(analysis_bucket_id),
Expand All @@ -339,7 +372,7 @@ def link_bucket_file_to_analysis(
def stream_bucket_file(self, bucket_file_id: str | UUID, chunk_size=1024):
with httpx.stream(
"GET",
urljoin(self.base_url, f"/storage/bucket-files/{bucket_file_id}/stream"),
self._format_url(f"/storage/bucket-files/{bucket_file_id}/stream"),
headers=self.auth_client.get_auth_bearer_header(),
) as r:
for b in r.iter_bytes(chunk_size=chunk_size):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from project.common import build_url


def test_build_url():
assert (
build_url("http", "privateaim.de", "analysis", {"foo": "bar"}, "baz")
== "http://privateaim.de/analysis?foo=bar#baz"
)

0 comments on commit 6af73a8

Please sign in to comment.