From a6a0310e839c7841478a5098110fbdb5fc2aa4b7 Mon Sep 17 00:00:00 2001 From: Nicholas Dudfield Date: Tue, 30 Apr 2024 11:53:06 +0700 Subject: [PATCH] feat: update fetch configuration logic --- src/client.js | 38 ++++++++++++++++++-------------------- tests/client.test.js | 18 +++++++++--------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/client.js b/src/client.js index 10def72..6828820 100644 --- a/src/client.js +++ b/src/client.js @@ -1,25 +1,11 @@ -let isNode = false; - const VERSION = '0.0.3'; const RETRY_STATUS_CODES = [429, 500, 502, 503, 504]; const ENDPOINT = 'https://api.mistral.ai'; -/** - * Initialize fetch - * @return {Promise} - */ -async function initializeFetch() { - if (typeof window === 'undefined' || - typeof globalThis.fetch === 'undefined') { - const nodeFetch = await import('node-fetch'); - fetch = nodeFetch.default; - isNode = true; - } else { - fetch = globalThis.fetch; - } -} - -initializeFetch(); +// We can't use a top level await if eventually this is to be converted +// to typescript and compiled to commonjs, or similarly using babel. +const configuredFetch = Promise.resolve( + globalThis.fetch ?? import('node-fetch').then((m) => m.default)); /** * MistralAPIError @@ -67,6 +53,17 @@ class MistralClient { } } + /** + * @return {Promise} + * @private + * @param {...*} args - fetch args + * hook point for non-global fetch override + */ + async _fetch(...args) { + const fetchFunc = await configuredFetch; + return fetchFunc(...args); + } + /** * * @param {*} method @@ -90,11 +87,12 @@ class MistralClient { for (let attempts = 0; attempts < this.maxRetries; attempts++) { try { - const response = await fetch(url, options); + const response = await this._fetch(url, options); if (response.ok) { if (request?.stream) { - if (isNode) { + // When using node-fetch or test mocks, getReader is not defined + if (typeof response.body.getReader === 'undefined') { return response.body; } else { const reader = response.body.getReader(); diff --git a/tests/client.test.js b/tests/client.test.js index 30d85ec..54b52b6 100644 --- a/tests/client.test.js +++ b/tests/client.test.js @@ -20,7 +20,7 @@ describe('Mistral Client', () => { it('should return a chat response object', async() => { // Mock the fetch function const mockResponse = mockChatResponsePayload(); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ model: 'mistral-small', @@ -37,7 +37,7 @@ describe('Mistral Client', () => { it('should return a chat response object if safeMode is set', async() => { // Mock the fetch function const mockResponse = mockChatResponsePayload(); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ model: 'mistral-small', @@ -55,7 +55,7 @@ describe('Mistral Client', () => { it('should return a chat response object if safePrompt is set', async() => { // Mock the fetch function const mockResponse = mockChatResponsePayload(); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.chat({ model: 'mistral-small', @@ -75,7 +75,7 @@ describe('Mistral Client', () => { it('should return parsed, streamed response', async() => { // Mock the fetch function const mockResponse = mockChatResponseStreamingPayload(); - globalThis.fetch = mockFetchStream(200, mockResponse); + client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ model: 'mistral-small', @@ -98,7 +98,7 @@ describe('Mistral Client', () => { it('should return parsed, streamed response with safeMode', async() => { // Mock the fetch function const mockResponse = mockChatResponseStreamingPayload(); - globalThis.fetch = mockFetchStream(200, mockResponse); + client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ model: 'mistral-small', @@ -122,7 +122,7 @@ describe('Mistral Client', () => { it('should return parsed, streamed response with safePrompt', async() => { // Mock the fetch function const mockResponse = mockChatResponseStreamingPayload(); - globalThis.fetch = mockFetchStream(200, mockResponse); + client._fetch = mockFetchStream(200, mockResponse); const response = await client.chatStream({ model: 'mistral-small', @@ -148,7 +148,7 @@ describe('Mistral Client', () => { it('should return embeddings', async() => { // Mock the fetch function const mockResponse = mockEmbeddingResponsePayload(); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.embeddings(mockEmbeddingRequest); expect(response).toEqual(mockResponse); @@ -159,7 +159,7 @@ describe('Mistral Client', () => { it('should return batched embeddings', async() => { // Mock the fetch function const mockResponse = mockEmbeddingResponsePayload(10); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.embeddings(mockEmbeddingRequest); expect(response).toEqual(mockResponse); @@ -170,7 +170,7 @@ describe('Mistral Client', () => { it('should return a list of models', async() => { // Mock the fetch function const mockResponse = mockListModels(); - globalThis.fetch = mockFetch(200, mockResponse); + client._fetch = mockFetch(200, mockResponse); const response = await client.listModels(); expect(response).toEqual(mockResponse);