From a4d6d6d7f8bddcea9458e99ebd58c9b8e6d0ce9d Mon Sep 17 00:00:00 2001 From: Bam4d Date: Mon, 15 Jan 2024 13:21:51 +0100 Subject: [PATCH] deprecating safeMode in favour of safePrompt --- src/client.d.ts | 14 +++++++- src/client.js | 16 ++++++--- tests/client.test.js | 84 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) diff --git a/src/client.d.ts b/src/client.d.ts index a0f9df2..5a82b60 100644 --- a/src/client.d.ts +++ b/src/client.d.ts @@ -101,7 +101,11 @@ declare module '@mistralai/mistralai' { topP?: number, randomSeed?: number, stream?: boolean, - safeMode?: boolean + /** + * @deprecated use safePrompt instead + */ + safeMode?: boolean, + safePrompt?: boolean ): object; listModels(): Promise; @@ -113,7 +117,11 @@ declare module '@mistralai/mistralai' { maxTokens?: number; topP?: number; randomSeed?: number; + /** + * @deprecated use safePrompt instead + */ safeMode?: boolean; + safePrompt?: boolean; }): Promise; chatStream(options: { @@ -123,7 +131,11 @@ declare module '@mistralai/mistralai' { maxTokens?: number; topP?: number; randomSeed?: number; + /** + * @deprecated use safePrompt instead + */ safeMode?: boolean; + safePrompt?: boolean; }): AsyncGenerator; embeddings(options: { diff --git a/src/client.js b/src/client.js index e8e9bd2..fa2a990 100644 --- a/src/client.js +++ b/src/client.js @@ -153,7 +153,8 @@ class MistralClient { * @param {*} topP * @param {*} randomSeed * @param {*} stream - * @param {*} safeMode + * @param {*} safeMode deprecated use safePrompt instead + * @param {*} safePrompt * @return {Promise} */ _makeChatCompletionRequest = function( @@ -165,6 +166,7 @@ class MistralClient { randomSeed, stream, safeMode, + safePrompt, ) { return { model: model, @@ -174,7 +176,7 @@ class MistralClient { top_p: topP ?? undefined, random_seed: randomSeed ?? undefined, stream: stream ?? undefined, - safe_prompt: safeMode ?? undefined, + safe_prompt: (safeMode || safePrompt) ?? undefined, }; }; @@ -196,7 +198,8 @@ class MistralClient { * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100 * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9 * @param {*} randomSeed the random seed to use for sampling, e.g. 42 - * @param {*} safeMode whether to use safe mode, e.g. true + * @param {*} safeMode deprecated use safePrompt instead + * @param {*} safePrompt whether to use safe mode, e.g. true * @return {Promise} */ chat = async function({ @@ -207,6 +210,7 @@ class MistralClient { topP, randomSeed, safeMode, + safePrompt, }) { const request = this._makeChatCompletionRequest( model, @@ -217,6 +221,7 @@ class MistralClient { randomSeed, false, safeMode, + safePrompt, ); const response = await this._request( 'post', @@ -235,7 +240,8 @@ class MistralClient { * @param {*} maxTokens the maximum number of tokens to generate, e.g. 100 * @param {*} topP the cumulative probability of tokens to generate, e.g. 0.9 * @param {*} randomSeed the random seed to use for sampling, e.g. 42 - * @param {*} safeMode whether to use safe mode, e.g. true + * @param {*} safeMode deprecated use safePrompt instead + * @param {*} safePrompt whether to use safe mode, e.g. true * @return {Promise} */ chatStream = async function* ({ @@ -246,6 +252,7 @@ class MistralClient { topP, randomSeed, safeMode, + safePrompt, }) { const request = this._makeChatCompletionRequest( model, @@ -256,6 +263,7 @@ class MistralClient { randomSeed, true, safeMode, + safePrompt, ); const response = await this._request( 'post', diff --git a/tests/client.test.js b/tests/client.test.js index 8c3ffb5..30d85ec 100644 --- a/tests/client.test.js +++ b/tests/client.test.js @@ -33,6 +33,42 @@ describe('Mistral Client', () => { }); expect(response).toEqual(mockResponse); }); + + it('should return a chat response object if safeMode is set', async() => { + // Mock the fetch function + const mockResponse = mockChatResponsePayload(); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.chat({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, + ], + safeMode: true, + }); + expect(response).toEqual(mockResponse); + }); + + it('should return a chat response object if safePrompt is set', async() => { + // Mock the fetch function + const mockResponse = mockChatResponsePayload(); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.chat({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, + ], + safePrompt: true, + }); + expect(response).toEqual(mockResponse); + }); }); describe('chatStream()', () => { @@ -58,6 +94,54 @@ describe('Mistral Client', () => { expect(parsedResponse.length).toEqual(11); }); + + it('should return parsed, streamed response with safeMode', async() => { + // Mock the fetch function + const mockResponse = mockChatResponseStreamingPayload(); + globalThis.fetch = mockFetchStream(200, mockResponse); + + const response = await client.chatStream({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, + ], + safeMode: true, + }); + + const parsedResponse = []; + for await (const r of response) { + parsedResponse.push(r); + } + + expect(parsedResponse.length).toEqual(11); + }); + + it('should return parsed, streamed response with safePrompt', async() => { + // Mock the fetch function + const mockResponse = mockChatResponseStreamingPayload(); + globalThis.fetch = mockFetchStream(200, mockResponse); + + const response = await client.chatStream({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, + ], + safePrompt: true, + }); + + const parsedResponse = []; + for await (const r of response) { + parsedResponse.push(r); + } + + expect(parsedResponse.length).toEqual(11); + }); }); describe('embeddings()', () => {