Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

feat: update fetch configuration logic #65

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/client.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ declare module '@mistralai/mistralai' {
message: {
role: string;
content: string;
tool_calls: null | ToolCalls[];
};
finish_reason: string;
}
Expand Down Expand Up @@ -105,6 +106,7 @@ declare module '@mistralai/mistralai' {
created: number;
model: string;
choices: ChatCompletionResponseChunkChoice[];
usage: TokenUsage | null;
}

export interface Embedding {
Expand Down Expand Up @@ -133,7 +135,7 @@ declare module '@mistralai/mistralai' {
private _makeChatCompletionRequest(
model: string,
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>,
tools?: Array<{ type: string; function:Function; }>,
tools?: Array<{ type: string; function: Function; }>,
temperature?: number,
maxTokens?: number,
topP?: number,
Expand All @@ -153,7 +155,7 @@ declare module '@mistralai/mistralai' {
chat(options: {
model: string;
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
tools?: Array<{ type: string; function:Function; }>;
tools?: Array<{ type: string; function: Function; }>;
temperature?: number;
maxTokens?: number;
topP?: number;
Expand All @@ -164,13 +166,13 @@ declare module '@mistralai/mistralai' {
safeMode?: boolean;
safePrompt?: boolean;
toolChoice?: ToolChoice;
responseFormat?: ResponseFormat;
responseFormat?: ResponseFormat;
}): Promise<ChatCompletionResponse>;

chatStream(options: {
model: string;
messages: Array<{ role: string; name?: string, content: string | string[], tool_calls?: ToolCalls[]; }>;
tools?: Array<{ type: string; function:Function; }>;
tools?: Array<{ type: string; function: Function; }>;
temperature?: number;
maxTokens?: number;
topP?: number;
Expand Down
39 changes: 20 additions & 19 deletions src/client.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
let isNode = false;

const VERSION = '0.0.3';
const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';

/**
* Initialize fetch
* @return {Promise<void>}
*/
async function initializeFetch() {
if (typeof window === 'undefined' ||
typeof globalThis.fetch === 'undefined') {
const nodeFetch = await import('node-fetch');
fetch = nodeFetch.default;
isNode = true;
} else {
fetch = globalThis.fetch;
}
}
const isNode = typeof process !== 'undefined' &&
process.versions != null &&
process.versions.node != null;
const haveNativeFetch = typeof globalThis.fetch !== 'undefined';

initializeFetch();
const configuredFetch = isNode && !haveNativeFetch ?
(await import('node-fetch')).default : globalThis.fetch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we should only use the haveNativeFetch condition no? Is there a case where isNode would be false and haveNativeFetch be false as well?

Copy link
Contributor Author

@sublimator sublimator Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back on laptop :) Thanks for the prodding

const configuredFetch = globalThis.fetch ??
  (await import('node-fetch')).default;

Went with this and killed off haveNativeFetch/isNode entirely by just using feature detection, and adding a comment

// When using node-fetch or test mocks, getReader is not defined
if (typeof response.body.getReader === 'undefined') {


/**
* MistralAPIError
Expand Down Expand Up @@ -67,6 +56,16 @@ class MistralClient {
}
}

/**
* @return {Promise}
* @private
* @param {...*} args - fetch args
* hook point for non-global fetch override
*/
async _fetch(...args) {
return configuredFetch(...args);
}

/**
*
* @param {*} method
Expand All @@ -90,11 +89,13 @@ class MistralClient {

for (let attempts = 0; attempts < this.maxRetries; attempts++) {
try {
const response = await fetch(url, options);
const response = await this._fetch(url, options);

if (response.ok) {
if (request?.stream) {
if (isNode) {
if (isNode && !haveNativeFetch ||
// The test mocks do not return a body with getReader
typeof response.body.getReader === 'undefined') {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we can remove the condition on isNode and haveNativeFetch as the second part of the condition will catch this anyway

return response.body;
} else {
const reader = response.body.getReader();
Expand Down
18 changes: 9 additions & 9 deletions tests/client.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ describe('Mistral Client', () => {
it('should return a chat response object', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -37,7 +37,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safeMode is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -55,7 +55,7 @@ describe('Mistral Client', () => {
it('should return a chat response object if safePrompt is set', async() => {
// Mock the fetch function
const mockResponse = mockChatResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.chat({
model: 'mistral-small',
Expand All @@ -75,7 +75,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -98,7 +98,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safeMode', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -122,7 +122,7 @@ describe('Mistral Client', () => {
it('should return parsed, streamed response with safePrompt', async() => {
// Mock the fetch function
const mockResponse = mockChatResponseStreamingPayload();
globalThis.fetch = mockFetchStream(200, mockResponse);
client._fetch = mockFetchStream(200, mockResponse);

const response = await client.chatStream({
model: 'mistral-small',
Expand All @@ -148,7 +148,7 @@ describe('Mistral Client', () => {
it('should return embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -159,7 +159,7 @@ describe('Mistral Client', () => {
it('should return batched embeddings', async() => {
// Mock the fetch function
const mockResponse = mockEmbeddingResponsePayload(10);
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.embeddings(mockEmbeddingRequest);
expect(response).toEqual(mockResponse);
Expand All @@ -170,7 +170,7 @@ describe('Mistral Client', () => {
it('should return a list of models', async() => {
// Mock the fetch function
const mockResponse = mockListModels();
globalThis.fetch = mockFetch(200, mockResponse);
client._fetch = mockFetch(200, mockResponse);

const response = await client.listModels();
expect(response).toEqual(mockResponse);
Expand Down