Skip to content

Commit

Permalink
🔐 feat: Implement Allowed Action Domains (danny-avila#4964)
Browse files Browse the repository at this point in the history
* chore: RequestExecutor typing

* feat: allowed action domains

* fix: rename TAgentsEndpoint to TAssistantEndpoint in typedefs

* chore: update librechat-data-provider version to 0.7.62
  • Loading branch information
danny-avila authored and Tsounguinzo committed Dec 14, 2024
1 parent c384d76 commit d69d79f
Show file tree
Hide file tree
Showing 18 changed files with 364 additions and 97 deletions.
4 changes: 2 additions & 2 deletions api/server/middleware/checkDomainAllowed.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const { isDomainAllowed } = require('~/server/services/AuthService');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');

/**
Expand All @@ -14,7 +14,7 @@ const { logger } = require('~/config');
*/
const checkDomainAllowed = async (req, res, next = () => {}) => {
const email = req?.user?.email;
if (email && !(await isDomainAllowed(email))) {
if (email && !(await isEmailDomainAllowed(email))) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
return res.redirect('/login');
} else {
Expand Down
5 changes: 5 additions & 0 deletions api/server/routes/agents/actions.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const { nanoid } = require('nanoid');
const { actionDelimiter } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent');
const { logger } = require('~/config');

Expand Down Expand Up @@ -42,6 +43,10 @@ router.post('/:agent_id', async (req, res) => {
}

let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}

let { domain } = metadata;
domain = await domainParser(req, domain, true);
Expand Down
7 changes: 6 additions & 1 deletion api/server/routes/assistants/actions.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');

const router = express.Router();
Expand All @@ -29,6 +30,10 @@ router.post('/:assistant_id', async (req, res) => {
}

let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}

let { domain } = metadata;
domain = await domainParser(req, domain, true);
Expand Down
5 changes: 5 additions & 0 deletions api/server/services/ActionService.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { tool } = require('@langchain/core/tools');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
const { deleteAssistant } = require('~/models/Assistant');
Expand Down Expand Up @@ -122,6 +123,10 @@ async function loadActionSets(searchParams) {
*/
async function createActionTool({ action, requestBuilder, zodSchema, name, description }) {
action.metadata = await decryptMetadata(action.metadata);
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
/** @type {(toolInput: Object | string) => Promise<unknown>} */
const _call = async (toolInput) => {
try {
Expand Down
3 changes: 3 additions & 0 deletions api/server/services/ActionService.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat-
const { domainParser } = require('./ActionService');

jest.mock('keyv');
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));

const globalCache = {};
jest.mock('~/cache/getLogStores', () => {
Expand Down
5 changes: 2 additions & 3 deletions api/server/services/AuthService.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const {
} = require('~/models/userMethods');
const { createToken, findToken, deleteTokens, Session } = require('~/models');
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { registerSchema } = require('~/strategies/validators');
const { hashToken } = require('~/server/utils/crypto');
const isDomainAllowed = require('./isDomainAllowed');
const { logger } = require('~/config');

const domains = {
Expand Down Expand Up @@ -165,7 +165,7 @@ const registerUser = async (user, additionalData = {}) => {
return { status: 200, message: genericVerificationMessage };
}

if (!(await isDomainAllowed(email))) {
if (!(await isEmailDomainAllowed(email))) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
Expand Down Expand Up @@ -422,7 +422,6 @@ module.exports = {
registerUser,
setAuthTokens,
resetPassword,
isDomainAllowed,
requestPasswordReset,
resendVerificationEmail,
};
13 changes: 13 additions & 0 deletions api/server/services/ToolService.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { tool: toolFn, Tool } = require('@langchain/core/tools');
const { Calculator } = require('@langchain/community/tools/calculator');
const {
Tools,
ErrorTypes,
ContentTypes,
imageGenTools,
actionDelimiter,
Expand Down Expand Up @@ -327,6 +328,12 @@ async function processRequiredActions(client, requiredActions) {
}

tool = await createActionTool({ action: actionSet, requestBuilder });
if (!tool) {
logger.warn(
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
isActionTool = !!tool;
ActionToolMap[currentAction.tool] = tool;
}
Expand Down Expand Up @@ -464,6 +471,12 @@ async function loadAgentTools({ req, agent_id, tools, tool_resources, openAIApiK
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent_id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
Expand Down
109 changes: 109 additions & 0 deletions api/server/services/domains.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
const { getCustomConfig } = require('~/server/services/Config');

/**
* @param {string} email
* @returns {Promise<boolean>}
*/
async function isEmailDomainAllowed(email) {
if (!email) {
return false;
}

const domain = email.split('@')[1];

if (!domain) {
return false;
}

const customConfig = await getCustomConfig();
if (!customConfig) {
return true;
} else if (!customConfig?.registration?.allowedDomains) {
return true;
}

return customConfig.registration.allowedDomains.includes(domain);
}

/**
* Normalizes a domain string
* @param {string} domain
* @returns {string|null}
*/
/**
* Normalizes a domain string. If the domain is invalid, returns null.
* Normalized === lowercase, trimmed, and protocol added if missing.
* @param {string} domain
* @returns {string|null}
*/
function normalizeDomain(domain) {
try {
let normalizedDomain = domain.toLowerCase().trim();

// Early return for obviously invalid formats
if (normalizedDomain === 'http://' || normalizedDomain === 'https://') {
return null;
}

// If it's not already a URL, make it one
if (!normalizedDomain.startsWith('http://') && !normalizedDomain.startsWith('https://')) {
normalizedDomain = `https://${normalizedDomain}`;
}

const url = new URL(normalizedDomain);
// Additional validation that hostname isn't just protocol
if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') {
return null;
}

return url.hostname.replace(/^www\./i, '');
} catch {
return null;
}
}

/**
* Checks if the given domain is allowed. If no restrictions are set, allows all domains.
* @param {string} [domain]
* @returns {Promise<boolean>}
*/
async function isActionDomainAllowed(domain) {
if (!domain || typeof domain !== 'string') {
return false;
}

const customConfig = await getCustomConfig();
const allowedDomains = customConfig?.actions?.allowedDomains;

if (!Array.isArray(allowedDomains) || !allowedDomains.length) {
return true;
}

const normalizedInputDomain = normalizeDomain(domain);
if (!normalizedInputDomain) {
return false;
}

for (const allowedDomain of allowedDomains) {
const normalizedAllowedDomain = normalizeDomain(allowedDomain);
if (!normalizedAllowedDomain) {
continue;
}

if (normalizedAllowedDomain.startsWith('*.')) {
const baseDomain = normalizedAllowedDomain.slice(2);
if (
normalizedInputDomain === baseDomain ||
normalizedInputDomain.endsWith(`.${baseDomain}`)
) {
return true;
}
} else if (normalizedInputDomain === normalizedAllowedDomain) {
return true;
}
}

return false;
}

module.exports = { isEmailDomainAllowed, isActionDomainAllowed };
Loading

0 comments on commit d69d79f

Please sign in to comment.