Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔐 feat: Implement Allowed Action Domains #4964

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/server/middleware/checkDomainAllowed.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const { isDomainAllowed } = require('~/server/services/AuthService');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');

/**
Expand All @@ -14,7 +14,7 @@ const { logger } = require('~/config');
*/
const checkDomainAllowed = async (req, res, next = () => {}) => {
const email = req?.user?.email;
if (email && !(await isDomainAllowed(email))) {
if (email && !(await isEmailDomainAllowed(email))) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
return res.redirect('/login');
} else {
Expand Down
5 changes: 5 additions & 0 deletions api/server/routes/agents/actions.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const { nanoid } = require('nanoid');
const { actionDelimiter } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent');
const { logger } = require('~/config');

Expand Down Expand Up @@ -42,6 +43,10 @@ router.post('/:agent_id', async (req, res) => {
}

let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}

let { domain } = metadata;
domain = await domainParser(req, domain, true);
Expand Down
7 changes: 6 additions & 1 deletion api/server/routes/assistants/actions.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger } = require('~/config');

const router = express.Router();
Expand All @@ -29,6 +30,10 @@ router.post('/:assistant_id', async (req, res) => {
}

let metadata = await encryptMetadata(_metadata);
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}

let { domain } = metadata;
domain = await domainParser(req, domain, true);
Expand Down
5 changes: 5 additions & 0 deletions api/server/services/ActionService.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { tool } = require('@langchain/core/tools');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
const { deleteAssistant } = require('~/models/Assistant');
Expand Down Expand Up @@ -122,6 +123,10 @@ async function loadActionSets(searchParams) {
*/
async function createActionTool({ action, requestBuilder, zodSchema, name, description }) {
action.metadata = await decryptMetadata(action.metadata);
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
/** @type {(toolInput: Object | string) => Promise<unknown>} */
const _call = async (toolInput) => {
try {
Expand Down
3 changes: 3 additions & 0 deletions api/server/services/ActionService.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ const { Constants, EModelEndpoint, actionDomainSeparator } = require('librechat-
const { domainParser } = require('./ActionService');

jest.mock('keyv');
jest.mock('~/server/services/Config', () => ({
getCustomConfig: jest.fn(),
}));

const globalCache = {};
jest.mock('~/cache/getLogStores', () => {
Expand Down
5 changes: 2 additions & 3 deletions api/server/services/AuthService.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const {
} = require('~/models/userMethods');
const { createToken, findToken, deleteTokens, Session } = require('~/models');
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { registerSchema } = require('~/strategies/validators');
const { hashToken } = require('~/server/utils/crypto');
const isDomainAllowed = require('./isDomainAllowed');
const { logger } = require('~/config');

const domains = {
Expand Down Expand Up @@ -165,7 +165,7 @@ const registerUser = async (user, additionalData = {}) => {
return { status: 200, message: genericVerificationMessage };
}

if (!(await isDomainAllowed(email))) {
if (!(await isEmailDomainAllowed(email))) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
Expand Down Expand Up @@ -422,7 +422,6 @@ module.exports = {
registerUser,
setAuthTokens,
resetPassword,
isDomainAllowed,
requestPasswordReset,
resendVerificationEmail,
};
13 changes: 13 additions & 0 deletions api/server/services/ToolService.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { tool: toolFn, Tool } = require('@langchain/core/tools');
const { Calculator } = require('@langchain/community/tools/calculator');
const {
Tools,
ErrorTypes,
ContentTypes,
imageGenTools,
actionDelimiter,
Expand Down Expand Up @@ -327,6 +328,12 @@ async function processRequiredActions(client, requiredActions) {
}

tool = await createActionTool({ action: actionSet, requestBuilder });
if (!tool) {
logger.warn(
`Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
isActionTool = !!tool;
ActionToolMap[currentAction.tool] = tool;
}
Expand Down Expand Up @@ -464,6 +471,12 @@ async function loadAgentTools({ req, agent_id, tools, tool_resources, openAIApiK
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent_id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
Expand Down
109 changes: 109 additions & 0 deletions api/server/services/domains.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
const { getCustomConfig } = require('~/server/services/Config');

/**
* @param {string} email
* @returns {Promise<boolean>}
*/
async function isEmailDomainAllowed(email) {
if (!email) {
return false;
}

const domain = email.split('@')[1];

if (!domain) {
return false;
}

const customConfig = await getCustomConfig();
if (!customConfig) {
return true;
} else if (!customConfig?.registration?.allowedDomains) {
return true;
}

return customConfig.registration.allowedDomains.includes(domain);
}

/**
* Normalizes a domain string
* @param {string} domain
* @returns {string|null}
*/
/**
* Normalizes a domain string. If the domain is invalid, returns null.
* Normalized === lowercase, trimmed, and protocol added if missing.
* @param {string} domain
* @returns {string|null}
*/
function normalizeDomain(domain) {
try {
let normalizedDomain = domain.toLowerCase().trim();

// Early return for obviously invalid formats
if (normalizedDomain === 'http://' || normalizedDomain === 'https://') {
return null;
}

// If it's not already a URL, make it one
if (!normalizedDomain.startsWith('http://') && !normalizedDomain.startsWith('https://')) {
normalizedDomain = `https://${normalizedDomain}`;
}

const url = new URL(normalizedDomain);
// Additional validation that hostname isn't just protocol
if (!url.hostname || url.hostname === 'http:' || url.hostname === 'https:') {
return null;
}

return url.hostname.replace(/^www\./i, '');
} catch {
return null;
}
}

/**
* Checks if the given domain is allowed. If no restrictions are set, allows all domains.
* @param {string} [domain]
* @returns {Promise<boolean>}
*/
async function isActionDomainAllowed(domain) {
if (!domain || typeof domain !== 'string') {
return false;
}

const customConfig = await getCustomConfig();
const allowedDomains = customConfig?.actions?.allowedDomains;

if (!Array.isArray(allowedDomains) || !allowedDomains.length) {
return true;
}

const normalizedInputDomain = normalizeDomain(domain);
if (!normalizedInputDomain) {
return false;
}

for (const allowedDomain of allowedDomains) {
const normalizedAllowedDomain = normalizeDomain(allowedDomain);
if (!normalizedAllowedDomain) {
continue;
}

if (normalizedAllowedDomain.startsWith('*.')) {
const baseDomain = normalizedAllowedDomain.slice(2);
if (
normalizedInputDomain === baseDomain ||
normalizedInputDomain.endsWith(`.${baseDomain}`)
) {
return true;
}
} else if (normalizedInputDomain === normalizedAllowedDomain) {
return true;
}
}

return false;
}

module.exports = { isEmailDomainAllowed, isActionDomainAllowed };
Loading
Loading