Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-context document support for Anthropic and Google models #5130

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
31 changes: 22 additions & 9 deletions api/app/clients/AnthropicClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ const {
getResponseSender,
validateVisionModel,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { encodeAndFormat } = require('~/server/services/Files/encode');
const {
truncateText,
formatMessage,
Expand Down Expand Up @@ -281,13 +281,13 @@ class AnthropicClient extends BaseClient {
return Math.ceil((width * height) / 750);
}

async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(
async addFileURLs(message, attachments) {
const { files, file_urls } = await encodeAndFormat(
this.options.req,
attachments,
EModelEndpoint.anthropic,
);
message.image_urls = image_urls.length ? image_urls : undefined;
message.image_urls = file_urls.length ? file_urls : undefined;
return files;
}

Expand Down Expand Up @@ -346,10 +346,16 @@ class AnthropicClient extends BaseClient {
if (this.options.attachments) {
const attachments = await this.options.attachments;
const images = attachments.filter((file) => file.type.includes('image'));
const documents = attachments.filter((file) => file.type == 'application/pdf');

if (images.length && !this.isVisionModel) {
throw new Error('Images are only supported with the Claude 3 family of models');
}
if (documents.length && !this.modelOptions.model.includes('3-5-sonnet')) {
throw new Error(
'PDF documents are only supported with the Claude 3.5 Sonnet family of models',
);
}

const latestMessage = orderedMessages[orderedMessages.length - 1];

Expand All @@ -361,7 +367,7 @@ class AnthropicClient extends BaseClient {
};
}

const files = await this.addImageURLs(latestMessage, attachments);
const files = await this.addFileURLs(latestMessage, attachments);

this.options.attachments = files;
}
Expand Down Expand Up @@ -399,10 +405,17 @@ class AnthropicClient extends BaseClient {
continue;
}

orderedMessages[i].tokenCount += this.calculateImageTokenCost({
width: file.width,
height: file.height,
});
if (file.type.includes('image')) {
orderedMessages[i].tokenCount += this.calculateImageTokenCost({
width: file.width,
height: file.height,
});
} else {
// File is a pdf.
// A reasonable estimate is 1500-3000 tokens per page
// without parsing the pdf to get the page count, assume it has one.
orderedMessages[i].tokenCount += 2000;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ class BaseClient {
file_id: { $in: fileIds },
});

await this.addImageURLs(message, files);
await this.addFileURLs(message, files);

this.message_file_map[message.messageId] = files;
return message;
Expand Down
2 changes: 1 addition & 1 deletion api/app/clients/ChatGPTClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ ${botMessage.message}
};
}

const files = await this.addImageURLs(lastMessage, attachments);
const files = await this.addFileURLs(lastMessage, attachments);
this.options.attachments = files;

this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
Expand Down
31 changes: 15 additions & 16 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const {
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { encodeAndFormat } = require('~/server/services/Files/encode');
const Tokenizer = require('~/server/services/Tokenizer');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
Expand Down Expand Up @@ -250,25 +250,24 @@ class GoogleClient extends BaseClient {
const formattedMessages = [];
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative);
const files = await this.addFileURLs(latestMessage, attachments, VisionModes.generative);
this.options.attachments = files;
messages[messages.length - 1] = latestMessage;

for (const _message of messages) {
const role = _message.isCreatedByUser ? this.userLabel : this.modelLabel;
const parts = [];
parts.push({ text: _message.text });
if (!_message.image_urls?.length) {
formattedMessages.push({ role, parts });
continue;
}

for (const images of _message.image_urls) {
if (images.inlineData) {
parts.push({ inlineData: images.inlineData });
if (_message.file_urls?.length) {
for (const fileParts of _message.file_urls) {
if (fileParts.inlineData) {
parts.push({ inlineData: fileParts.inlineData });
}
}
}

parts.push({ text: _message.text });

formattedMessages.push({ role, parts });
}

Expand All @@ -277,20 +276,20 @@ class GoogleClient extends BaseClient {

/**
*
* Adds image URLs to the message object and returns the files
* Adds file URLs to the message object and returns the files
*
* @param {TMessage[]} messages
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments, mode = '') {
const { files, image_urls } = await encodeAndFormat(
async addFileURLs(message, attachments, mode = '') {
const { files, file_urls } = await encodeAndFormat(
this.options.req,
attachments,
EModelEndpoint.google,
mode,
);
message.image_urls = image_urls.length ? image_urls : undefined;
message.file_urls = file_urls.length ? file_urls : undefined;
return files;
}

Expand Down Expand Up @@ -324,7 +323,7 @@ class GoogleClient extends BaseClient {

const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);

const files = await this.addImageURLs(latestMessage, attachments);
const files = await this.addFileURLs(latestMessage, attachments);

this.options.attachments = files;

Expand Down Expand Up @@ -833,7 +832,7 @@ class GoogleClient extends BaseClient {
text: `Please generate ${titleInstruction}

${convo}

||>Title:`,
isCreatedByUser: true,
author: this.userLabel,
Expand Down
10 changes: 5 additions & 5 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const {
titleInstruction,
createContextHandlers,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { encodeAndFormat } = require('~/server/services/Files/encode');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const { isEnabled, sleep } = require('~/server/utils');
Expand Down Expand Up @@ -369,13 +369,13 @@ class OpenAIClient extends BaseClient {
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(
async addFileURLs(message, attachments) {
const { files, file_urls } = await encodeAndFormat(
this.options.req,
attachments,
this.options.endpoint,
);
message.image_urls = image_urls.length ? image_urls : undefined;
message.image_urls = file_urls.length ? file_urls : undefined;
return files;
}

Expand Down Expand Up @@ -418,7 +418,7 @@ class OpenAIClient extends BaseClient {
};
}

const files = await this.addImageURLs(
const files = await this.addFileURLs(
orderedMessages[orderedMessages.length - 1],
attachments,
);
Expand Down
1 change: 1 addition & 0 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
userMessage.files = client.options.attachments;
conversation.model = endpointOption.modelOptions.model;
delete userMessage.image_urls;
delete userMessage.file_urls;
}

if (!abortController.signal.aborted) {
Expand Down
10 changes: 5 additions & 5 deletions api/server/controllers/agents/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const {
formatContentStrings,
createContextHandlers,
} = require('~/app/clients/prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { encodeAndFormat } = require('~/server/services/Files/encode');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
Expand Down Expand Up @@ -208,14 +208,14 @@ class AgentClient extends BaseClient {
};
}

async addImageURLs(message, attachments) {
const { files, image_urls } = await encodeAndFormat(
async addFileURLs(message, attachments) {
const { files, file_urls } = await encodeAndFormat(
this.options.req,
attachments,
this.options.agent.provider,
VisionModes.agents,
);
message.image_urls = image_urls.length ? image_urls : undefined;
message.image_urls = file_urls.length ? file_urls : undefined;
return files;
}

Expand Down Expand Up @@ -270,7 +270,7 @@ class AgentClient extends BaseClient {
};
}

const files = await this.addImageURLs(
const files = await this.addFileURLs(
orderedMessages[orderedMessages.length - 1],
attachments,
);
Expand Down
1 change: 1 addition & 0 deletions api/server/controllers/agents/request.js
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
}
}
delete userMessage.image_urls;
delete userMessage.file_urls;
}

if (!abortController.signal.aborted) {
Expand Down
2 changes: 1 addition & 1 deletion api/server/controllers/assistants/chatV1.js
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ const chatV1 = async (req, res) => {
role: 'user',
content: '',
};
const files = await client.addImageURLs(visionMessage, attachments);
const files = await client.addFileURLs(visionMessage, attachments);
if (!visionMessage.image_urls?.length) {
return;
}
Expand Down
24 changes: 14 additions & 10 deletions api/server/services/Config/loadAsyncEndpoints.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const { EModelEndpoint } = require('librechat-data-provider');
const { EModelEndpoint, BaseCapabilities } = require('librechat-data-provider');
const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs');
const { availableTools } = require('~/app/clients/tools');
const { isUserProvided } = require('~/server/utils');
Expand Down Expand Up @@ -37,21 +37,25 @@ async function loadAsyncEndpoints(req) {
}
const plugins = transformToolsToMap(tools);

const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
const google =
serviceKey || googleKey
? { userProvide: googleUserProvides, capabilities: [BaseCapabilities.file_search] }
: false;

const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
const gptPlugins =
useAzure || openAIApiKey || azureOpenAIApiKey
? {
plugins,
availableAgents: ['classic', 'functions'],
userProvide: useAzure ? false : userProvidedOpenAI,
userProvideURL: useAzure
? false
: config[EModelEndpoint.openAI]?.userProvideURL ||
plugins,
availableAgents: ['classic', 'functions'],
userProvide: useAzure ? false : userProvidedOpenAI,
userProvideURL: useAzure
? false
: config[EModelEndpoint.openAI]?.userProvideURL ||
config[EModelEndpoint.azureOpenAI]?.userProvideURL,
azure: useAzurePlugins || useAzure,
}
azure: useAzurePlugins || useAzure,
capabilities: [BaseCapabilities.file_search],
}
: false;

return { google, gptPlugins };
Expand Down
10 changes: 4 additions & 6 deletions api/server/services/Files/Local/images.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ function encodeImage(imagePath) {
* @returns {Promise<[MongoFile, string]>} - A promise that resolves to an array of results from updateFile and encodeImage.
*/
async function prepareImagesLocal(req, file) {
const { publicPath, imageOutput } = req.app.locals.paths;
const userPath = path.join(imageOutput, req.user.id);
const { publicPath, root } = req.app.locals.paths;

if (!fs.existsSync(userPath)) {
fs.mkdirSync(userPath, { recursive: true });
}
const filepath = path.join(publicPath, file.filepath);
const startPath = file.width ? publicPath : root;

const filepath = path.join(startPath, file.filepath);

const promises = [];
promises.push(updateFile({ file_id: file.file_id }));
Expand Down
Loading
Loading