diff --git a/.github/workflows/build_publish.yaml b/.github/workflows/build_publish.yaml index 089d06f..62a8b8c 100644 --- a/.github/workflows/build_publish.yaml +++ b/.github/workflows/build_publish.yaml @@ -15,7 +15,7 @@ on: jobs: - lint: + lint_and_test: runs-on: ubuntu-latest steps: @@ -34,13 +34,18 @@ jobs: run: | npm install - # Ruff + # Eslint - name: ESlint check run: | - ./node_modules/.bin/eslint . + npm lint + + # Run tests + - name: Run tests + run: | + npm test publish: - needs: lint + needs: lint_and_test runs-on: ubuntu-latest if: startsWith(github.ref, 'refs/tags') diff --git a/package.json b/package.json index 0109ab6..a7652c1 100644 --- a/package.json +++ b/package.json @@ -7,6 +7,7 @@ "type": "module", "main": "src/client.js", "scripts": { + "lint": "./node_modules/.bin/eslint .", "test": "node --experimental-vm-modules node_modules/.bin/jest" }, "jest": { diff --git a/tests/client.test.js b/tests/client.test.js index 563c351..6032c97 100644 --- a/tests/client.test.js +++ b/tests/client.test.js @@ -1,5 +1,12 @@ import MistralClient from '../src/client'; -import jest from 'jest-mock'; +import { + mockListModels, + mockFetch, + mockChatResponseStreamingPayload, + mockEmbeddingResponsePayload, + mockChatResponsePayload, + mockFetchStream, +} from './utils'; // Test the list models endpoint describe('Mistral Client', () => { @@ -8,37 +15,80 @@ describe('Mistral Client', () => { client = new MistralClient(); }); - const mockFetch = (status, payload) => { - return jest.fn(() => - Promise.resolve({ - json: () => Promise.resolve(payload), - text: () => Promise.resolve(JSON.stringify(payload)), - status, - }), - ); - }; - - describe('listModels()', () => { - it('should return a list of models', async() => { + describe('chat()', () => { + it('should return a chat response object', async() => { // Mock the fetch function - globalThis.fetch = mockFetch(200, { - models: [ - 'mistral-tiny', - 'mistral-small', - 'mistral-large', - 'mistral-mega', + const mockResponse = mockChatResponsePayload(); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.chat({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, ], }); + expect(response).toEqual(mockResponse); + }); + }); + + describe('chatStream()', () => { + it('should return parsed, streamed response', async() => { + // Mock the fetch function + const mockResponse = mockChatResponseStreamingPayload(); + globalThis.fetch = mockFetchStream(200, mockResponse); - const models = await client.listModels(); - expect(models).toEqual({ - models: [ - 'mistral-tiny', - 'mistral-small', - 'mistral-large', - 'mistral-mega', + const response = await client.chatStream({ + model: 'mistral-small', + messages: [ + { + role: 'user', + content: 'What is the best French cheese?', + }, ], }); + + const parsedResponse = []; + for await (const r of response) { + parsedResponse.push(r); + } + + expect(parsedResponse.length).toEqual(11); + }); + }); + + describe('embeddings()', () => { + it('should return embeddings', async() => { + // Mock the fetch function + const mockResponse = mockEmbeddingResponsePayload(); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.listModels(); + expect(response).toEqual(mockResponse); + }); + }); + + describe('embeddings() batched', () => { + it('should return batched embeddings', async() => { + // Mock the fetch function + const mockResponse = mockEmbeddingResponsePayload(10); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.listModels(); + expect(response).toEqual(mockResponse); + }); + }); + + describe('listModels()', () => { + it('should return a list of models', async() => { + // Mock the fetch function + const mockResponse = mockListModels(); + globalThis.fetch = mockFetch(200, mockResponse); + + const response = await client.listModels(); + expect(response).toEqual(mockResponse); }); }); }); diff --git a/tests/utils.js b/tests/utils.js new file mode 100644 index 0000000..28bb468 --- /dev/null +++ b/tests/utils.js @@ -0,0 +1,245 @@ +import jest from 'jest-mock'; + +/** + * Mock the fetch function + * @param {*} status + * @param {*} payload + * @return {Object} + */ +export function mockFetch(status, payload) { + return jest.fn(() => + Promise.resolve({ + json: () => Promise.resolve(payload), + text: () => Promise.resolve(JSON.stringify(payload)), + status, + ok: status >= 200 && status < 300, + }), + ); +} + +/** + * Mock fetch stream + * @param {*} status + * @param {*} payload + * @return {Object} + */ +export function mockFetchStream(status, payload) { + const asyncIterator = async function* () { + while (true) { + // Read from the stream + const value = payload.shift(); + // Exit if we're done + if (!value) return; + // Else yield the chunk + yield value; + } + }; + + return jest.fn(() => + Promise.resolve({ + // body is a ReadableStream of the objects in payload list + body: asyncIterator(), + status, + ok: status >= 200 && status < 300, + }), + ); +} + +/** + * Mock models list + * @return {Object} + */ +export function mockListModels() { + return { + object: 'list', + data: [ + { + id: 'mistral-medium', + object: 'model', + created: 1703186988, + owned_by: 'mistralai', + root: null, + parent: null, + permission: [ + { + id: 'modelperm-15bebaf316264adb84b891bf06a84933', + object: 'model_permission', + created: 1703186988, + allow_create_engine: false, + allow_sampling: true, + allow_logprobs: false, + allow_search_indices: false, + allow_view: true, + allow_fine_tuning: false, + organization: '*', + group: null, + is_blocking: false, + }, + ], + }, + { + id: 'mistral-small', + object: 'model', + created: 1703186988, + owned_by: 'mistralai', + root: null, + parent: null, + permission: [ + { + id: 'modelperm-d0dced5c703242fa862f4ca3f241c00e', + object: 'model_permission', + created: 1703186988, + allow_create_engine: false, + allow_sampling: true, + allow_logprobs: false, + allow_search_indices: false, + allow_view: true, + allow_fine_tuning: false, + organization: '*', + group: null, + is_blocking: false, + }, + ], + }, + { + id: 'mistral-tiny', + object: 'model', + created: 1703186988, + owned_by: 'mistralai', + root: null, + parent: null, + permission: [ + { + id: 'modelperm-0e64e727c3a94f17b29f8895d4be2910', + object: 'model_permission', + created: 1703186988, + allow_create_engine: false, + allow_sampling: true, + allow_logprobs: false, + allow_search_indices: false, + allow_view: true, + allow_fine_tuning: false, + organization: '*', + group: null, + is_blocking: false, + }, + ], + }, + { + id: 'mistral-embed', + object: 'model', + created: 1703186988, + owned_by: 'mistralai', + root: null, + parent: null, + permission: [ + { + id: 'modelperm-ebdff9046f524e628059447b5932e3ad', + object: 'model_permission', + created: 1703186988, + allow_create_engine: false, + allow_sampling: true, + allow_logprobs: false, + allow_search_indices: false, + allow_view: true, + allow_fine_tuning: false, + organization: '*', + group: null, + is_blocking: false, + }, + ], + }, + ], + }; +} + +/** + * Mock chat completion object + * @return {Object} + */ +export function mockChatResponsePayload() { + return { + id: 'chat-98c8c60e3fbf4fc49658eddaf447357c', + object: 'chat.completion', + created: 1703165682, + choices: [ + { + finish_reason: 'stop', + message: { + role: 'assistant', + content: 'What is the best French cheese?', + }, + index: 0, + }, + ], + model: 'mistral-small', + usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0}, + }; +} + +/** + * Mock chat completion stream + * @return {Object} + */ +export function mockChatResponseStreamingPayload() { + const firstMessage = + ['data: ' + + JSON.stringify({ + id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', + model: 'mistral-small', + choices: [ + { + index: 0, + delta: {role: 'assistant'}, + finish_reason: null, + }, + ], + }) + + '\n\n']; + const lastMessage = ['data: [DONE]\n\n']; + + const dataMessages = []; + for (let i = 0; i < 10; i++) { + dataMessages.push( + 'data: ' + + JSON.stringify({ + id: 'cmpl-8cd9019d21ba490aa6b9740f5d0a883e', + object: 'chat.completion.chunk', + created: 1703168544, + model: 'mistral-small', + choices: [ + { + index: i, + delta: {content: `stream response ${i}`}, + finish_reason: null, + }, + ], + }) + + '\n\n', + ); + } + + return firstMessage.concat(dataMessages).concat(lastMessage); +} + +/** + * Mock embeddings response + * @param {number} batchSize + * @return {Object} + */ +export function mockEmbeddingResponsePayload(batchSize = 1) { + return { + id: 'embd-98c8c60e3fbf4fc49658eddaf447357c', + object: 'list', + data: + [ + { + object: 'embedding', + embedding: [-0.018585205078125, 0.027099609375, 0.02587890625], + index: 0, + }, + ] * batchSize, + model: 'mistral-embed', + usage: {prompt_tokens: 90, total_tokens: 90, completion_tokens: 0}, + }; +}