Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

feat: update fetch configuration logic #65

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
let isNode = false;

const VERSION = '0.0.3';
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';

/**
* Initialize fetch
* @return {Promise<void>}
*/
async function initializeFetch() {
if (typeof window === 'undefined' ||
typeof globalThis.fetch === 'undefined') {
const nodeFetch = await import('node-fetch');
fetch = nodeFetch.default;
isNode = true;
} else {
fetch = globalThis.fetch;
}
}

initializeFetch();
// We can't use a top level await if eventually this is to be converted
// to typescript and compiled to commonjs, or similarly using babel.
const configuredFetch = Promise.resolve(
globalThis.fetch ?? import('node-fetch').then((m) => m.default));

/**
* MistralAPIError
Expand Down Expand Up @@ -67,6 +53,17 @@ class MistralClient {
}
}

/**
* @return {Promise}
* @private
* @param {...*} args - fetch args
* hook point for non-global fetch override
*/
async _fetch(...args) {
const fetchFunc = await configuredFetch;
return fetchFunc(...args);
}

/**
*
* @param {*} method
Expand All @@ -90,11 +87,12 @@ class MistralClient {

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
try {
const response = await fetch(url, options);
const response = await this._fetch(url, options);

if (response.ok) {
if (request?.stream) {
if (isNode) {
// When using node-fetch or test mocks, getReader is not defined
if (typeof response.body.getReader === 'undefined') {
return response.body;
} else {
const reader = response.body.getReader();
Expand Down
18 changes: 9 additions & 9 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ describe('Mistral Client', () => {
it('should return a chat response object', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -37,7 +37,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safeMode is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -55,7 +55,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safePrompt is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -75,7 +75,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -98,7 +98,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safeMode', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -122,7 +122,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safePrompt', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -148,7 +148,7 @@ describe('Mistral Client', () => {
it('should return embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -159,7 +159,7 @@ describe('Mistral Client', () => {
it('should return batched embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload(10);
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -170,7 +170,7 @@ describe('Mistral Client', () => {
it('should return a list of models', async() => {
// Mock the fetch function
const mockResponse = mockListModels();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.listModels();
expect(response).toEqual(mockResponse);
Expand Down