Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 feat: Support Google Agents, fix Various Provider Configurations #5126

Merged
merged 15 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,17 @@ class BaseClient {

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.savedMessageIds.add(responseMessage.messageId);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return responseMessage;
}
Expand Down
20 changes: 11 additions & 9 deletions api/app/clients/PluginsClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,17 @@ class PluginsClient extends OpenAIClient {
}

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}
Expand Down
3 changes: 3 additions & 0 deletions api/models/Transaction.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ transactionSchema.methods.calculateTokenValue = function () {
*/
transactionSchema.statics.create = async function (txData) {
const Transaction = this;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}

const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
Expand Down
26 changes: 26 additions & 0 deletions api/models/Transaction.spec.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getMultiplier, getCacheMultiplier } = require('./tx');
Expand Down Expand Up @@ -346,3 +347,28 @@ describe('Structured Token Spending Tests', () => {
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15
});
});

describe('NaN Handling Tests', () => {
test('should skip transaction creation when rawAmount is NaN', async () => {
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });

const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: NaN,
tokenType: 'prompt',
};

const result = await Transaction.create(txData);
expect(result).toBeUndefined();

const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
});
2 changes: 1 addition & 1 deletion api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"@langchain/google-genai": "^0.1.4",
"@langchain/google-vertexai": "^0.1.4",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^1.8.8",
"@librechat/agents": "^1.9.7",
"axios": "^1.7.7",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
Expand Down
21 changes: 16 additions & 5 deletions api/server/controllers/agents/callbacks.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider');
const {
EnvVar,
Providers,
GraphEvents,
ToolEndHandler,
handleToolCalls,
ChatModelStreamHandler,
} = require('@librechat/agents');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
Expand Down Expand Up @@ -57,13 +59,22 @@ class ModelEndHandler {
return;
}

const usage = data?.output?.usage_metadata;
if (metadata?.model) {
usage.model = metadata.model;
}
try {
if (metadata.provider === Providers.GOOGLE) {
handleToolCalls(data?.output?.tool_calls, metadata, graph);
}

const usage = data?.output?.usage_metadata;
if (!usage) {
return;
}
if (metadata?.model) {
usage.model = metadata.model;
}

if (usage) {
this.collectedUsage.push(usage);
} catch (error) {
logger.error('Error handling model end event:', error);
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions api/server/controllers/assistants/chatV2.js
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,17 @@ const chatV2 = async (req, res) => {
response = streamRunManager;
response.text = streamRunManager.intermediateText;

const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
if (response.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
}
};

await processRun();
Expand Down
6 changes: 6 additions & 0 deletions api/server/services/Endpoints/agents/initialize.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
Expand All @@ -24,6 +25,7 @@ const providerConfigMap = {
[EModelEndpoint.azureOpenAI]: initOpenAI,
[EModelEndpoint.anthropic]: initAnthropic,
[EModelEndpoint.bedrock]: getBedrockOptions,
[EModelEndpoint.google]: initGoogle,
[Providers.OLLAMA]: initCustom,
};

Expand Down Expand Up @@ -116,6 +118,10 @@ const initializeAgentOptions = async ({
endpointOption: _endpointOption,
});

if (options.provider != null) {
agent.provider = options.provider;
}

agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
if (options.configOptions) {
agent.model_parameters.configuration = options.configOptions;
Expand Down
8 changes: 4 additions & 4 deletions api/server/services/Endpoints/anthropic/initialize.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
}

const clientOptions = {};
let clientOptions = {};

/** @type {undefined | TBaseEndpoint} */
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
Expand All @@ -36,7 +36,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
}

if (optionsOnly) {
const requestOptions = Object.assign(
clientOptions = Object.assign(
{
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
Expand All @@ -45,9 +45,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
clientOptions,
);
if (overrideModel) {
requestOptions.modelOptions.model = overrideModel;
clientOptions.modelOptions.model = overrideModel;
}
return getLLMConfig(anthropicApiKey, requestOptions);
return getLLMConfig(anthropicApiKey, clientOptions);
}

const client = new AnthropicClient(anthropicApiKey, {
Expand Down
20 changes: 12 additions & 8 deletions api/server/services/Endpoints/anthropic/llm.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,32 @@ function getLLMConfig(apiKey, options = {}) {

const mergedOptions = Object.assign(defaultOptions, options.modelOptions);

/** @type {AnthropicClientOptions} */
const requestOptions = {
apiKey,
model: mergedOptions.model,
stream: mergedOptions.stream,
temperature: mergedOptions.temperature,
top_p: mergedOptions.topP,
top_k: mergedOptions.topK,
stop_sequences: mergedOptions.stop,
max_tokens:
topP: mergedOptions.topP,
topK: mergedOptions.topK,
stopSequences: mergedOptions.stop,
maxTokens:
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
clientOptions: {},
};

const configOptions = {};
if (options.proxy) {
configOptions.httpAgent = new HttpsProxyAgent(options.proxy);
requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy);
}

if (options.reverseProxyUrl) {
configOptions.baseURL = options.reverseProxyUrl;
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
}

return { llmConfig: removeNullishValues(requestOptions), configOptions };
return {
/** @type {AnthropicClientOptions} */
llmConfig: removeNullishValues(requestOptions),
};
}

module.exports = { getLLMConfig };
47 changes: 23 additions & 24 deletions api/server/services/Endpoints/bedrock/options.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,42 +60,41 @@ const getOptions = async ({ req, endpointOption }) => {
streamRate = allConfig.streamRate;
}

/** @type {import('@librechat/agents').BedrockConverseClientOptions} */
const requestOptions = Object.assign(
{
model: endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true,
streamUsage: true,
callbacks: [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
/** @type {BedrockClientOptions} */
const requestOptions = {
model: endpointOption.model,
region: BEDROCK_AWS_DEFAULT_REGION,
streaming: true,
streamUsage: true,
callbacks: [
{
handleLLMNewToken: async () => {
if (!streamRate) {
return;
}
await sleep(streamRate);
},
],
},
endpointOption.model_parameters,
);
},
],
};

if (credentials) {
requestOptions.credentials = credentials;
}

if (BEDROCK_REVERSE_PROXY) {
requestOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}

const configOptions = {};
if (PROXY) {
/** NOTE: NOT SUPPORTED BY BEDROCK */
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
}

if (BEDROCK_REVERSE_PROXY) {
configOptions.endpointHost = BEDROCK_REVERSE_PROXY;
}

return {
llmConfig: removeNullishValues(requestOptions),
/** @type {BedrockClientOptions} */
llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
configOptions,
};
};
Expand Down
6 changes: 3 additions & 3 deletions api/server/services/Endpoints/custom/initialize.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
customOptions.streamRate = allConfig.streamRate;
}

const clientOptions = {
let clientOptions = {
reverseProxyUrl: baseURL ?? null,
proxy: PROXY ?? null,
req,
Expand All @@ -135,13 +135,13 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
if (endpoint !== Providers.OLLAMA) {
const requestOptions = Object.assign(
clientOptions = Object.assign(
{
modelOptions,
},
clientOptions,
);
const options = getLLMConfig(apiKey, requestOptions);
const options = getLLMConfig(apiKey, clientOptions);
if (!customOptions.streamRate) {
return options;
}
Expand Down
Loading
Loading