diff --git a/nbs/src/nixtla_client.ipynb b/nbs/src/nixtla_client.ipynb index 8794da15..2e2f324b 100644 --- a/nbs/src/nixtla_client.ipynb +++ b/nbs/src/nixtla_client.ipynb @@ -58,6 +58,7 @@ "import orjson\n", "import pandas as pd\n", "import utilsforecast.processing as ufp\n", + "import zstandard as zstd\n", "from tenacity import (\n", " RetryCallState,\n", " retry,\n", @@ -727,7 +728,13 @@ " else:\n", " self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n", "\n", - " def _make_request(self, client: httpx.Client, endpoint: str, payload: dict[str, Any]) -> dict[str, Any]:\n", + " def _make_request(\n", + " self,\n", + " client: httpx.Client,\n", + " endpoint: str,\n", + " payload: dict[str, Any],\n", + " multithreaded_compress: bool,\n", + " ) -> dict[str, Any]:\n", " def ensure_contiguous_arrays(d: dict[str, Any]) -> None:\n", " for k, v in d.items():\n", " if isinstance(v, np.ndarray):\n", @@ -747,10 +754,15 @@ "\n", " ensure_contiguous_arrays(payload)\n", " content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)\n", - " content_size_mb = len(content) / (1024*1024)\n", + " content_size_mb = len(content) / 2**20\n", " if content_size_mb > 200:\n", " raise ValueError(f'The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}')\n", - " resp = client.post(url=endpoint, content=content)\n", + " headers = {}\n", + " if content_size_mb > 10:\n", + " threads = -1 if multithreaded_compress else 0\n", + " content = zstd.ZstdCompressor(level=1, threads=threads).compress(content)\n", + " headers['content-encoding'] = 'zstd'\n", + " resp = client.post(url=endpoint, content=content, headers=headers)\n", " try:\n", " resp_body = orjson.loads(resp.content)\n", " except orjson.JSONDecodeError:\n", @@ -769,11 +781,13 @@ " client: httpx.Client,\n", " endpoint: str,\n", " payload: dict[str, Any],\n", + " multithreaded_compress: bool = True,\n", " ) -> dict[str, Any]:\n", " return self._retry_strategy(self._make_request)(\n", " client=client,\n", " endpoint=endpoint,\n", " payload=payload,\n", + " multithreaded_compress=multithreaded_compress,\n", " )\n", "\n", " def _make_partitioned_requests(\n", @@ -790,7 +804,11 @@ " with ThreadPoolExecutor(max_workers) as executor:\n", " future2pos = {\n", " executor.submit(\n", - " self._make_request_with_retries, client, endpoint, payload\n", + " self._make_request_with_retries,\n", + " client=client,\n", + " endpoint=endpoint,\n", + " payload=payload,\n", + " multithreaded_compress=False,\n", " ): i\n", " for i, payload in enumerate(payloads)\n", " }\n", @@ -2249,6 +2267,48 @@ "nixtla_client.validate_api_key()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test compression\n", + "captured_request = None\n", + "\n", + "class CapturingClient(httpx.Client): \n", + " def post(self, *args, **kwargs):\n", + " request = self.build_request('POST', *args, **kwargs)\n", + " global captured_request\n", + " captured_request = {\n", + " 'headers': dict(request.headers),\n", + " 'content': request.content,\n", + " 'method': request.method,\n", + " 'url': str(request.url)\n", + " }\n", + " return super().post(*args, **kwargs)\n", + "\n", + "@contextmanager\n", + "def capture_request():\n", + " original_client = httpx.Client\n", + " httpx.Client = CapturingClient\n", + " try:\n", + " yield\n", + " finally:\n", + " httpx.Client = original_client\n", + "\n", + "# this produces a 12MB payload\n", + "series = generate_series(2_500, n_static_features=2)\n", + "with capture_request():\n", + " nixtla_client.forecast(df=series, freq='D', h=1, hist_exog_list=['static_0', 'static_1'])\n", + "\n", + "assert captured_request['headers']['content-encoding'] == 'zstd'\n", + "content = captured_request['content']\n", + "assert len(content) < 12 * 2**20\n", + "assert len(zstd.ZstdDecompressor().decompress(content)) > 12 * 2**20" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index ce064209..fe40e50c 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -28,6 +28,7 @@ import orjson import pandas as pd import utilsforecast.processing as ufp +import zstandard as zstd from tenacity import ( RetryCallState, retry, @@ -652,7 +653,11 @@ def __init__( self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"] def _make_request( - self, client: httpx.Client, endpoint: str, payload: dict[str, Any] + self, + client: httpx.Client, + endpoint: str, + payload: dict[str, Any], + multithreaded_compress: bool, ) -> dict[str, Any]: def ensure_contiguous_arrays(d: dict[str, Any]) -> None: for k, v in d.items(): @@ -673,12 +678,17 @@ def ensure_contiguous_arrays(d: dict[str, Any]) -> None: ensure_contiguous_arrays(payload) content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY) - content_size_mb = len(content) / (1024 * 1024) + content_size_mb = len(content) / 2**20 if content_size_mb > 200: raise ValueError( f"The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}" ) - resp = client.post(url=endpoint, content=content) + headers = {} + if content_size_mb > 10: + threads = -1 if multithreaded_compress else 0 + content = zstd.ZstdCompressor(level=1, threads=threads).compress(content) + headers["content-encoding"] = "zstd" + resp = client.post(url=endpoint, content=content, headers=headers) try: resp_body = orjson.loads(resp.content) except orjson.JSONDecodeError: @@ -697,11 +707,13 @@ def _make_request_with_retries( client: httpx.Client, endpoint: str, payload: dict[str, Any], + multithreaded_compress: bool = True, ) -> dict[str, Any]: return self._retry_strategy(self._make_request)( client=client, endpoint=endpoint, payload=payload, + multithreaded_compress=multithreaded_compress, ) def _make_partitioned_requests( @@ -718,7 +730,11 @@ def _make_partitioned_requests( with ThreadPoolExecutor(max_workers) as executor: future2pos = { executor.submit( - self._make_request_with_retries, client, endpoint, payload + self._make_request_with_retries, + client=client, + endpoint=endpoint, + payload=payload, + multithreaded_compress=False, ): i for i, payload in enumerate(payloads) } diff --git a/setup.py b/setup.py index adf0cdd6..bf18ab4d 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ "tenacity", "tqdm", "utilsforecast>=0.2.8", + "zstandard", ], extras_require={ "dev": dev + plotting + date_extras,