Skip to content

Commit

Permalink
feat: zstd-compress requests over 10MB
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Dec 5, 2024
1 parent 3045ac1 commit 0ab46bb
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
68 changes: 64 additions & 4 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"import orjson\n",
"import pandas as pd\n",
"import utilsforecast.processing as ufp\n",
"import zstandard as zstd\n",
"from tenacity import (\n",
" RetryCallState,\n",
" retry,\n",
Expand Down Expand Up @@ -727,7 +728,13 @@
" else:\n",
" self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
"\n",
" def _make_request(self, client: httpx.Client, endpoint: str, payload: dict[str, Any]) -> dict[str, Any]:\n",
" def _make_request(\n",
" self,\n",
" client: httpx.Client,\n",
" endpoint: str,\n",
" payload: dict[str, Any],\n",
" multithreaded_compress: bool,\n",
" ) -> dict[str, Any]:\n",
" def ensure_contiguous_arrays(d: dict[str, Any]) -> None:\n",
" for k, v in d.items():\n",
" if isinstance(v, np.ndarray):\n",
Expand All @@ -747,10 +754,15 @@
"\n",
" ensure_contiguous_arrays(payload)\n",
" content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)\n",
" content_size_mb = len(content) / (1024*1024)\n",
" content_size_mb = len(content) / 2**20\n",
" if content_size_mb > 200:\n",
" raise ValueError(f'The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}')\n",
" resp = client.post(url=endpoint, content=content)\n",
" headers = {}\n",
" if content_size_mb > 10:\n",
" threads = -1 if multithreaded_compress else 0\n",
" content = zstd.ZstdCompressor(level=1, threads=threads).compress(content)\n",
" headers['content-encoding'] = 'zstd'\n",
" resp = client.post(url=endpoint, content=content, headers=headers)\n",
" try:\n",
" resp_body = orjson.loads(resp.content)\n",
" except orjson.JSONDecodeError:\n",
Expand All @@ -769,11 +781,13 @@
" client: httpx.Client,\n",
" endpoint: str,\n",
" payload: dict[str, Any],\n",
" multithreaded_compress: bool = True,\n",
" ) -> dict[str, Any]:\n",
" return self._retry_strategy(self._make_request)(\n",
" client=client,\n",
" endpoint=endpoint,\n",
" payload=payload,\n",
" multithreaded_compress=multithreaded_compress,\n",
" )\n",
"\n",
" def _make_partitioned_requests(\n",
Expand All @@ -790,7 +804,11 @@
" with ThreadPoolExecutor(max_workers) as executor:\n",
" future2pos = {\n",
" executor.submit(\n",
" self._make_request_with_retries, client, endpoint, payload\n",
" self._make_request_with_retries,\n",
" client=client,\n",
" endpoint=endpoint,\n",
" payload=payload,\n",
" multithreaded_compress=False,\n",
" ): i\n",
" for i, payload in enumerate(payloads)\n",
" }\n",
Expand Down Expand Up @@ -2249,6 +2267,48 @@
"nixtla_client.validate_api_key()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test compression\n",
"captured_request = None\n",
"\n",
"class CapturingClient(httpx.Client): \n",
" def post(self, *args, **kwargs):\n",
" request = self.build_request('POST', *args, **kwargs)\n",
" global captured_request\n",
" captured_request = {\n",
" 'headers': dict(request.headers),\n",
" 'content': request.content,\n",
" 'method': request.method,\n",
" 'url': str(request.url)\n",
" }\n",
" return super().post(*args, **kwargs)\n",
"\n",
"@contextmanager\n",
"def capture_request():\n",
" original_client = httpx.Client\n",
" httpx.Client = CapturingClient\n",
" try:\n",
" yield\n",
" finally:\n",
" httpx.Client = original_client\n",
"\n",
"# this produces a 12MB payload\n",
"series = generate_series(2_500, n_static_features=2)\n",
"with capture_request():\n",
" nixtla_client.forecast(df=series, freq='D', h=1, hist_exog_list=['static_0', 'static_1'])\n",
"\n",
"assert captured_request['headers']['content-encoding'] == 'zstd'\n",
"content = captured_request['content']\n",
"assert len(content) < 12 * 2**20\n",
"assert len(zstd.ZstdDecompressor().decompress(content)) > 12 * 2**20"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
24 changes: 20 additions & 4 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import orjson
import pandas as pd
import utilsforecast.processing as ufp
import zstandard as zstd
from tenacity import (
RetryCallState,
retry,
Expand Down Expand Up @@ -652,7 +653,11 @@ def __init__(
self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]

def _make_request(
self, client: httpx.Client, endpoint: str, payload: dict[str, Any]
self,
client: httpx.Client,
endpoint: str,
payload: dict[str, Any],
multithreaded_compress: bool,
) -> dict[str, Any]:
def ensure_contiguous_arrays(d: dict[str, Any]) -> None:
for k, v in d.items():
Expand All @@ -673,12 +678,17 @@ def ensure_contiguous_arrays(d: dict[str, Any]) -> None:

ensure_contiguous_arrays(payload)
content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
content_size_mb = len(content) / (1024 * 1024)
content_size_mb = len(content) / 2**20
if content_size_mb > 200:
raise ValueError(
f"The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}"
)
resp = client.post(url=endpoint, content=content)
headers = {}
if content_size_mb > 10:
threads = -1 if multithreaded_compress else 0
content = zstd.ZstdCompressor(level=1, threads=threads).compress(content)
headers["content-encoding"] = "zstd"
resp = client.post(url=endpoint, content=content, headers=headers)
try:
resp_body = orjson.loads(resp.content)
except orjson.JSONDecodeError:
Expand All @@ -697,11 +707,13 @@ def _make_request_with_retries(
client: httpx.Client,
endpoint: str,
payload: dict[str, Any],
multithreaded_compress: bool = True,
) -> dict[str, Any]:
return self._retry_strategy(self._make_request)(
client=client,
endpoint=endpoint,
payload=payload,
multithreaded_compress=multithreaded_compress,
)

def _make_partitioned_requests(
Expand All @@ -718,7 +730,11 @@ def _make_partitioned_requests(
with ThreadPoolExecutor(max_workers) as executor:
future2pos = {
executor.submit(
self._make_request_with_retries, client, endpoint, payload
self._make_request_with_retries,
client=client,
endpoint=endpoint,
payload=payload,
multithreaded_compress=False,
): i
for i, payload in enumerate(payloads)
}
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"tenacity",
"tqdm",
"utilsforecast>=0.2.8",
"zstandard",
],
extras_require={
"dev": dev + plotting + date_extras,
Expand Down

0 comments on commit 0ab46bb

Please sign in to comment.