diff --git a/api/server/middleware/checkDomainAllowed.js b/api/server/middleware/checkDomainAllowed.js index 895ce99a567..f9af7558cb2 100644 --- a/api/server/middleware/checkDomainAllowed.js +++ b/api/server/middleware/checkDomainAllowed.js @@ -1,4 +1,4 @@ -const { isDomainAllowed } = require('~/server/services/AuthService'); +const { isEmailDomainAllowed } = require('~/server/services/domains'); const { logger } = require('~/config'); /** @@ -14,7 +14,7 @@ const { logger } = require('~/config'); */ const checkDomainAllowed = async (req, res, next = () => {}) => { const email = req?.user?.email; - if (email && !(await isDomainAllowed(email))) { + if (email && !(await isEmailDomainAllowed(email))) { logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`); return res.redirect('/login'); } else { diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index dde3293b42a..398481b6aa2 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -3,6 +3,7 @@ const { nanoid } = require('nanoid'); const { actionDelimiter } = require('librechat-data-provider'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); +const { isActionDomainAllowed } = require('~/server/services/domains'); const { getAgent, updateAgent } = require('~/models/Agent'); const { logger } = require('~/config'); @@ -42,6 +43,10 @@ router.post('/:agent_id', async (req, res) => { } let metadata = await encryptMetadata(_metadata); + const isDomainAllowed = await isActionDomainAllowed(metadata.domain); + if (!isDomainAllowed) { + return res.status(400).json({ message: 'Domain not allowed' }); + } let { domain } = metadata; domain = await domainParser(req, domain, true); diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 1646ac0a965..c3941e91776 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -1,10 +1,11 @@ const express = require('express'); const { nanoid } = require('nanoid'); -const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider'); +const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAssistantDoc, getAssistant } = require('~/models/Assistant'); +const { isActionDomainAllowed } = require('~/server/services/domains'); const { logger } = require('~/config'); const router = express.Router(); @@ -29,6 +30,10 @@ router.post('/:assistant_id', async (req, res) => { } let metadata = await encryptMetadata(_metadata); + const isDomainAllowed = await isActionDomainAllowed(metadata.domain); + if (!isDomainAllowed) { + return res.status(400).json({ message: 'Domain not allowed' }); + } let { domain } = metadata; domain = await domainParser(req, domain, true); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index ea1bcc4d23c..068e96948a9 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -7,6 +7,7 @@ const { actionDomainSeparator, } = require('librechat-data-provider'); const { tool } = require('@langchain/core/tools'); +const { isActionDomainAllowed } = require('~/server/services/domains'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { getActions, deleteActions } = require('~/models/Action'); const { deleteAssistant } = require('~/models/Assistant'); @@ -122,6 +123,10 @@ async function loadActionSets(searchParams) { */ async function createActionTool({ action, requestBuilder, zodSchema, name, description }) { action.metadata = await decryptMetadata(action.metadata); + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain); + if (!isDomainAllowed) { + return null; + } /** @type {(toolInput: Object | string) => Promise} */ const _call = async (toolInput) => { try { diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index a9650d60302..8f9d67a9d18 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -2,6 +2,9 @@ const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat- const { domainParser } = require('./ActionService'); jest.mock('keyv'); +jest.mock('~/server/services/Config', () => ({ + getCustomConfig: jest.fn(), +})); const globalCache = {}; jest.mock('~/cache/getLogStores', () => { diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 5812dd26f99..383f00cde75 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -12,9 +12,9 @@ const { } = require('~/models/userMethods'); const { createToken, findToken, deleteTokens, Session } = require('~/models'); const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils'); +const { isEmailDomainAllowed } = require('~/server/services/domains'); const { registerSchema } = require('~/strategies/validators'); const { hashToken } = require('~/server/utils/crypto'); -const isDomainAllowed = require('./isDomainAllowed'); const { logger } = require('~/config'); const domains = { @@ -165,7 +165,7 @@ const registerUser = async (user, additionalData = {}) => { return { status: 200, message: genericVerificationMessage }; } - if (!(await isDomainAllowed(email))) { + if (!(await isEmailDomainAllowed(email))) { const errorMessage = 'The email address provided cannot be used. Please use a different email address.'; logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`); @@ -422,7 +422,6 @@ module.exports = { registerUser, setAuthTokens, resetPassword, - isDomainAllowed, requestPasswordReset, resendVerificationEmail, }; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 91a5e7a6cfe..52118662443 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -5,6 +5,7 @@ const { tool: toolFn, Tool } = require('@langchain/core/tools'); const { Calculator } = require('@langchain/community/tools/calculator'); const { Tools, + ErrorTypes, ContentTypes, imageGenTools, actionDelimiter, @@ -327,6 +328,12 @@ async function processRequiredActions(client, requiredActions) { } tool = await createActionTool({ action: actionSet, requestBuilder }); + if (!tool) { + logger.warn( + `Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`, + ); + throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`); + } isActionTool = !!tool; ActionToolMap[currentAction.tool] = tool; } @@ -464,6 +471,12 @@ async function loadAgentTools({ req, agent_id, tools, tool_resources, openAIApiK name: toolName, description: functionSig.description, }); + if (!tool) { + logger.warn( + `Invalid action: user: ${req.user.id} | agent_id: ${agent_id} | toolName: ${toolName}`, + ); + throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`); + } agentTools.push(tool); ActionToolMap[toolName] = tool; } diff --git a/api/server/services/domains.js b/api/server/services/domains.js new file mode 100644 index 00000000000..50e625c3d63 --- /dev/null +++ b/api/server/services/domains.js @@ -0,0 +1,109 @@ +const { getCustomConfig } = require('~/server/services/Config'); + +/** + * @param {string} email + * @returns {Promise} + */ +async function isEmailDomainAllowed(email) { + if (!email) { + return false; + } + + const domain = email.split('@')[1]; + + if (!domain) { + return false; + } + + const customConfig = await getCustomConfig(); + if (!customConfig) { + return true; + } else if (!customConfig?.registration?.allowedDomains) { + return true; + } + + return customConfig.registration.allowedDomains.includes(domain); +} + +/** + * Normalizes a domain string + * @param {string} domain + * @returns {string|null} + */ +/** + * Normalizes a domain string. If the domain is invalid, returns null. + * Normalized === lowercase, trimmed, and protocol added if missing. + * @param {string} domain + * @returns {string|null} + */ +function normalizeDomain(domain) { + try { + let normalizedDomain = domain.toLowerCase().trim(); + + // Early return for obviously invalid formats + if (normalizedDomain === 'http://' || normalizedDomain === 'https://') { + return null; + } + + // If it's not already a URL, make it one + if (!normalizedDomain.startsWith('http://') && !normalizedDomain.startsWith('https://')) { + normalizedDomain = `https://${normalizedDomain}`; + } + + const url = new URL(normalizedDomain); + // Additional validation that hostname isn't just protocol + if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') { + return null; + } + + return url.hostname.replace(/^www\./i, ''); + } catch { + return null; + } +} + +/** + * Checks if the given domain is allowed. If no restrictions are set, allows all domains. + * @param {string} [domain] + * @returns {Promise} + */ +async function isActionDomainAllowed(domain) { + if (!domain || typeof domain !== 'string') { + return false; + } + + const customConfig = await getCustomConfig(); + const allowedDomains = customConfig?.actions?.allowedDomains; + + if (!Array.isArray(allowedDomains) || !allowedDomains.length) { + return true; + } + + const normalizedInputDomain = normalizeDomain(domain); + if (!normalizedInputDomain) { + return false; + } + + for (const allowedDomain of allowedDomains) { + const normalizedAllowedDomain = normalizeDomain(allowedDomain); + if (!normalizedAllowedDomain) { + continue; + } + + if (normalizedAllowedDomain.startsWith('*.')) { + const baseDomain = normalizedAllowedDomain.slice(2); + if ( + normalizedInputDomain === baseDomain || + normalizedInputDomain.endsWith(`.${baseDomain}`) + ) { + return true; + } + } else if (normalizedInputDomain === normalizedAllowedDomain) { + return true; + } + } + + return false; +} + +module.exports = { isEmailDomainAllowed, isActionDomainAllowed }; diff --git a/api/server/services/domains.spec.js b/api/server/services/domains.spec.js new file mode 100644 index 00000000000..b4537dd3753 --- /dev/null +++ b/api/server/services/domains.spec.js @@ -0,0 +1,193 @@ +const { isEmailDomainAllowed, isActionDomainAllowed } = require('~/server/services/domains'); +const { getCustomConfig } = require('~/server/services/Config'); + +jest.mock('~/server/services/Config', () => ({ + getCustomConfig: jest.fn(), +})); + +describe('isEmailDomainAllowed', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should return false if email is falsy', async () => { + const email = ''; + const result = await isEmailDomainAllowed(email); + expect(result).toBe(false); + }); + + it('should return false if domain is not present in the email', async () => { + const email = 'test'; + const result = await isEmailDomainAllowed(email); + expect(result).toBe(false); + }); + + it('should return true if customConfig is not available', async () => { + const email = 'test@domain1.com'; + getCustomConfig.mockResolvedValue(null); + const result = await isEmailDomainAllowed(email); + expect(result).toBe(true); + }); + + it('should return true if allowedDomains is not defined in customConfig', async () => { + const email = 'test@domain1.com'; + getCustomConfig.mockResolvedValue({}); + const result = await isEmailDomainAllowed(email); + expect(result).toBe(true); + }); + + it('should return true if domain is included in the allowedDomains', async () => { + const email = 'user@domain1.com'; + getCustomConfig.mockResolvedValue({ + registration: { + allowedDomains: ['domain1.com', 'domain2.com'], + }, + }); + const result = await isEmailDomainAllowed(email); + expect(result).toBe(true); + }); + + it('should return false if domain is not included in the allowedDomains', async () => { + const email = 'user@domain3.com'; + getCustomConfig.mockResolvedValue({ + registration: { + allowedDomains: ['domain1.com', 'domain2.com'], + }, + }); + const result = await isEmailDomainAllowed(email); + expect(result).toBe(false); + }); +}); + +describe('isActionDomainAllowed', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + // Basic Input Validation Tests + describe('input validation', () => { + it('should return false for falsy values', async () => { + expect(await isActionDomainAllowed()).toBe(false); + expect(await isActionDomainAllowed(null)).toBe(false); + expect(await isActionDomainAllowed('')).toBe(false); + expect(await isActionDomainAllowed(undefined)).toBe(false); + }); + + it('should return false for non-string inputs', async () => { + expect(await isActionDomainAllowed(123)).toBe(false); + expect(await isActionDomainAllowed({})).toBe(false); + expect(await isActionDomainAllowed([])).toBe(false); + }); + + it('should return false for invalid domain formats', async () => { + getCustomConfig.mockResolvedValue({ + actions: { allowedDomains: ['http://', 'https://'] }, + }); + expect(await isActionDomainAllowed('http://')).toBe(false); + expect(await isActionDomainAllowed('https://')).toBe(false); + }); + }); + + // Configuration Tests + describe('configuration handling', () => { + it('should return true if customConfig is null', async () => { + getCustomConfig.mockResolvedValue(null); + expect(await isActionDomainAllowed('example.com')).toBe(true); + }); + + it('should return true if actions.allowedDomains is not defined', async () => { + getCustomConfig.mockResolvedValue({}); + expect(await isActionDomainAllowed('example.com')).toBe(true); + }); + + it('should return true if allowedDomains is empty array', async () => { + getCustomConfig.mockResolvedValue({ + actions: { allowedDomains: [] }, + }); + expect(await isActionDomainAllowed('example.com')).toBe(true); + }); + }); + + // Domain Matching Tests + describe('domain matching', () => { + beforeEach(() => { + getCustomConfig.mockResolvedValue({ + actions: { + allowedDomains: [ + 'example.com', + '*.subdomain.com', + 'specific.domain.com', + 'www.withprefix.com', + 'swapi.dev', + ], + }, + }); + }); + + it('should match exact domains', async () => { + expect(await isActionDomainAllowed('example.com')).toBe(true); + expect(await isActionDomainAllowed('other.com')).toBe(false); + expect(await isActionDomainAllowed('swapi.dev')).toBe(true); + }); + + it('should handle domains with www prefix', async () => { + expect(await isActionDomainAllowed('www.example.com')).toBe(true); + expect(await isActionDomainAllowed('www.withprefix.com')).toBe(true); + }); + + it('should handle full URLs', async () => { + expect(await isActionDomainAllowed('https://example.com')).toBe(true); + expect(await isActionDomainAllowed('http://example.com')).toBe(true); + expect(await isActionDomainAllowed('https://example.com/path')).toBe(true); + }); + + it('should handle wildcard subdomains', async () => { + expect(await isActionDomainAllowed('test.subdomain.com')).toBe(true); + expect(await isActionDomainAllowed('any.subdomain.com')).toBe(true); + expect(await isActionDomainAllowed('subdomain.com')).toBe(true); + }); + + it('should handle specific subdomains', async () => { + expect(await isActionDomainAllowed('specific.domain.com')).toBe(true); + expect(await isActionDomainAllowed('other.domain.com')).toBe(false); + }); + }); + + // Edge Cases + describe('edge cases', () => { + beforeEach(() => { + getCustomConfig.mockResolvedValue({ + actions: { + allowedDomains: ['example.com', '*.test.com'], + }, + }); + }); + + it('should handle domains with query parameters', async () => { + expect(await isActionDomainAllowed('example.com?param=value')).toBe(true); + }); + + it('should handle domains with ports', async () => { + expect(await isActionDomainAllowed('example.com:8080')).toBe(true); + }); + + it('should handle domains with trailing slashes', async () => { + expect(await isActionDomainAllowed('example.com/')).toBe(true); + }); + + it('should handle case insensitivity', async () => { + expect(await isActionDomainAllowed('EXAMPLE.COM')).toBe(true); + expect(await isActionDomainAllowed('Example.Com')).toBe(true); + }); + + it('should handle invalid entries in allowedDomains', async () => { + getCustomConfig.mockResolvedValue({ + actions: { + allowedDomains: ['example.com', null, undefined, '', 'test.com'], + }, + }); + expect(await isActionDomainAllowed('example.com')).toBe(true); + expect(await isActionDomainAllowed('test.com')).toBe(true); + }); + }); +}); diff --git a/api/server/services/isDomainAllowed.js b/api/server/services/isDomainAllowed.js deleted file mode 100644 index 2eb6c0db247..00000000000 --- a/api/server/services/isDomainAllowed.js +++ /dev/null @@ -1,24 +0,0 @@ -const { getCustomConfig } = require('~/server/services/Config'); - -async function isDomainAllowed(email) { - if (!email) { - return false; - } - - const domain = email.split('@')[1]; - - if (!domain) { - return false; - } - - const customConfig = await getCustomConfig(); - if (!customConfig) { - return true; - } else if (!customConfig?.registration?.allowedDomains) { - return true; - } - - return customConfig.registration.allowedDomains.includes(domain); -} - -module.exports = isDomainAllowed; diff --git a/api/server/services/isDomainAllowed.spec.js b/api/server/services/isDomainAllowed.spec.js deleted file mode 100644 index 216b7d58113..00000000000 --- a/api/server/services/isDomainAllowed.spec.js +++ /dev/null @@ -1,60 +0,0 @@ -const { getCustomConfig } = require('~/server/services/Config'); -const isDomainAllowed = require('./isDomainAllowed'); - -jest.mock('~/server/services/Config', () => ({ - getCustomConfig: jest.fn(), -})); - -describe('isDomainAllowed', () => { - afterEach(() => { - jest.clearAllMocks(); - }); - - it('should return false if email is falsy', async () => { - const email = ''; - const result = await isDomainAllowed(email); - expect(result).toBe(false); - }); - - it('should return false if domain is not present in the email', async () => { - const email = 'test'; - const result = await isDomainAllowed(email); - expect(result).toBe(false); - }); - - it('should return true if customConfig is not available', async () => { - const email = 'test@domain1.com'; - getCustomConfig.mockResolvedValue(null); - const result = await isDomainAllowed(email); - expect(result).toBe(true); - }); - - it('should return true if allowedDomains is not defined in customConfig', async () => { - const email = 'test@domain1.com'; - getCustomConfig.mockResolvedValue({}); - const result = await isDomainAllowed(email); - expect(result).toBe(true); - }); - - it('should return true if domain is included in the allowedDomains', async () => { - const email = 'user@domain1.com'; - getCustomConfig.mockResolvedValue({ - registration: { - allowedDomains: ['domain1.com', 'domain2.com'], - }, - }); - const result = await isDomainAllowed(email); - expect(result).toBe(true); - }); - - it('should return false if domain is not included in the allowedDomains', async () => { - const email = 'user@domain3.com'; - getCustomConfig.mockResolvedValue({ - registration: { - allowedDomains: ['domain1.com', 'domain2.com'], - }, - }); - const result = await isDomainAllowed(email); - expect(result).toBe(false); - }); -}); diff --git a/api/typedefs.js b/api/typedefs.js index 2c799585ae2..907568f5f2d 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -819,7 +819,7 @@ */ /** - * @exports TAgentsEndpoint + * @exports TAssistantEndpoint * @typedef {import('librechat-data-provider').TAssistantEndpoint} TAssistantEndpoint * @memberof typedefs */ diff --git a/client/src/components/Messages/Content/Error.tsx b/client/src/components/Messages/Content/Error.tsx index b1ab7980404..b33169813f2 100644 --- a/client/src/components/Messages/Content/Error.tsx +++ b/client/src/components/Messages/Content/Error.tsx @@ -42,6 +42,7 @@ const errorMessages = { [ErrorTypes.NO_USER_KEY]: 'com_error_no_user_key', [ErrorTypes.INVALID_USER_KEY]: 'com_error_invalid_user_key', [ErrorTypes.NO_BASE_URL]: 'com_error_no_base_url', + [ErrorTypes.INVALID_ACTION]: `com_error_${ErrorTypes.INVALID_ACTION}`, [ErrorTypes.INVALID_REQUEST]: `com_error_${ErrorTypes.INVALID_REQUEST}`, [ErrorTypes.NO_SYSTEM_MESSAGES]: `com_error_${ErrorTypes.NO_SYSTEM_MESSAGES}`, [ErrorTypes.EXPIRED_USER_KEY]: (json: TExpiredKey, localize: LocalizeFunction) => { diff --git a/client/src/localization/languages/Eng.ts b/client/src/localization/languages/Eng.ts index d29d8b6a75b..545e926ada1 100644 --- a/client/src/localization/languages/Eng.ts +++ b/client/src/localization/languages/Eng.ts @@ -30,6 +30,7 @@ export default { 'Resubmitting the AI message is not supported for this endpoint.', com_error_invalid_request_error: 'The AI service rejected the request due to an error. This could be caused by an invalid API key or an improperly formatted request.', + com_error_invalid_action_error: 'Request denied: The specified action domain is not allowed.', com_error_no_system_messages: 'The selected AI service or model does not support system messages. Try using prompts instead of custom instructions.', com_error_invalid_user_key: 'Invalid key provided. Please provide a valid key and try again.', diff --git a/package-lock.json b/package-lock.json index 7beb44c32e2..758ec39823c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -36153,7 +36153,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.61", + "version": "0.7.62", "license": "ISC", "dependencies": { "@types/js-yaml": "^4.0.9", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 65c45ebfcc6..ba379268445 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.61", + "version": "0.7.62", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/src/actions.ts b/packages/data-provider/src/actions.ts index 386ca34e74f..d4fafdb718e 100644 --- a/packages/data-provider/src/actions.ts +++ b/packages/data-provider/src/actions.ts @@ -201,15 +201,21 @@ class RequestExecutor { oauth_client_secret, } = metadata; - const isApiKey = api_key && type === AuthTypeEnum.ServiceHttp; - const isOAuth = + const isApiKey = api_key != null && api_key.length > 0 && type === AuthTypeEnum.ServiceHttp; + const isOAuth = !!( + oauth_client_id != null && oauth_client_id && + oauth_client_secret != null && oauth_client_secret && type === AuthTypeEnum.OAuth && + authorization_url != null && authorization_url && + client_url != null && client_url && + scope != null && scope && - token_exchange_method; + token_exchange_method + ); if (isApiKey && authorization_type === AuthorizationTypeEnum.Basic) { const basicToken = Buffer.from(api_key).toString('base64'); @@ -219,11 +225,13 @@ class RequestExecutor { } else if ( isApiKey && authorization_type === AuthorizationTypeEnum.Custom && + custom_auth_header != null && custom_auth_header ) { this.authHeaders[custom_auth_header] = api_key; } else if (isOAuth) { - if (!this.authToken) { + const authToken = this.authToken ?? ''; + if (!authToken) { const tokenResponse = await axios.post( client_url, { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index b3c2cb136f2..506fe6b3fd6 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -471,6 +471,11 @@ export const configSchema = z.object({ agents: true, }), fileStrategy: fileSourceSchema.default(FileSources.local), + actions: z + .object({ + allowedDomains: z.array(z.string()).optional(), + }) + .optional(), registration: z .object({ socialLogins: z.array(z.string()).optional(), @@ -962,6 +967,10 @@ export enum ErrorTypes { * Invalid request error, API rejected request */ INVALID_REQUEST = 'invalid_request_error', + /** + * Invalid action request error, likely not on list of allowed domains + */ + INVALID_ACTION = 'invalid_action_error', /** * Invalid request error, API rejected request */