Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #66 from sublimator/nd-update-fetch-configuration-…
Browse files Browse the repository at this point in the history
…logic-2024-04-30

feat: update fetch configuration logic
  • Loading branch information
lerela authored Apr 30, 2024
2 parents 0d5af64 + a6a0310 commit 06dabd3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
38 changes: 18 additions & 20 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
let isNode = false;

const VERSION = '0.0.3';
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';

/**
* Initialize fetch
* @return {Promise<void>}
*/
async function initializeFetch() {
if (typeof window === 'undefined' ||
typeof globalThis.fetch === 'undefined') {
const nodeFetch = await import('node-fetch');
fetch = nodeFetch.default;
isNode = true;
} else {
fetch = globalThis.fetch;
}
}

initializeFetch();
// We can't use a top level await if eventually this is to be converted
// to typescript and compiled to commonjs, or similarly using babel.
const configuredFetch = Promise.resolve(
globalThis.fetch ?? import('node-fetch').then((m) => m.default));

/**
* MistralAPIError
Expand Down Expand Up @@ -67,6 +53,17 @@ class MistralClient {
}
}

/**
* @return {Promise}
* @private
* @param {...*} args - fetch args
* hook point for non-global fetch override
*/
async _fetch(...args) {
const fetchFunc = await configuredFetch;
return fetchFunc(...args);
}

/**
*
* @param {*} method
Expand All @@ -90,11 +87,12 @@ class MistralClient {

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
try {
const response = await fetch(url, options);
const response = await this._fetch(url, options);

if (response.ok) {
if (request?.stream) {
if (isNode) {
// When using node-fetch or test mocks, getReader is not defined
if (typeof response.body.getReader === 'undefined') {
return response.body;
} else {
const reader = response.body.getReader();
Expand Down
18 changes: 9 additions & 9 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ describe('Mistral Client', () => {
it('should return a chat response object', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -37,7 +37,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safeMode is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -55,7 +55,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safePrompt is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -75,7 +75,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -98,7 +98,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safeMode', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -122,7 +122,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safePrompt', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -148,7 +148,7 @@ describe('Mistral Client', () => {
it('should return embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -159,7 +159,7 @@ describe('Mistral Client', () => {
it('should return batched embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload(10);
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -170,7 +170,7 @@ describe('Mistral Client', () => {
it('should return a list of models', async() => {
// Mock the fetch function
const mockResponse = mockListModels();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.listModels();
expect(response).toEqual(mockResponse);
Expand Down

0 comments on commit 06dabd3

Please sign in to comment.