diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index b1385f10878..9a7282e25ae 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -25,6 +25,10 @@ const config = isEnabled(USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.CONFIG_STORE }); +const roles = isEnabled(USE_REDIS) + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.ROLES }); + const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes ? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES }) : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES }); @@ -46,6 +50,7 @@ const abortKeys = isEnabled(USE_REDIS) : new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 }); const namespaces = { + [CacheKeys.ROLES]: roles, [CacheKeys.CONFIG_STORE]: config, pending_req, [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), diff --git a/api/models/Categories.js b/api/models/Categories.js new file mode 100644 index 00000000000..fc2cbdd98b0 --- /dev/null +++ b/api/models/Categories.js @@ -0,0 +1,61 @@ +const { logger } = require('~/config'); +// const { Categories } = require('./schema/categories'); +const options = [ + { + label: '', + value: '', + }, + { + label: 'idea', + value: 'idea', + }, + { + label: 'travel', + value: 'travel', + }, + { + label: 'teach_or_explain', + value: 'teach_or_explain', + }, + { + label: 'write', + value: 'write', + }, + { + label: 'shop', + value: 'shop', + }, + { + label: 'code', + value: 'code', + }, + { + label: 'misc', + value: 'misc', + }, + { + label: 'roleplay', + value: 'roleplay', + }, + { + label: 'finance', + value: 'finance', + }, +]; + +module.exports = { + /** + * Retrieves the categories asynchronously. + * @returns {Promise} An array of category objects. + * @throws {Error} If there is an error retrieving the categories. + */ + getCategories: async () => { + try { + // const categories = await Categories.find(); + return options; + } catch (error) { + logger.error('Error getting categories', error); + return []; + } + }, +}; diff --git a/api/models/Project.js b/api/models/Project.js new file mode 100644 index 00000000000..e982e34b5d6 --- /dev/null +++ b/api/models/Project.js @@ -0,0 +1,90 @@ +const { model } = require('mongoose'); +const projectSchema = require('~/models/schema/projectSchema'); + +const Project = model('Project', projectSchema); + +/** + * Retrieve a project by ID and convert the found project document to a plain object. + * + * @param {string} projectId - The ID of the project to find and return as a plain object. + * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. + * @returns {Promise} A plain object representing the project document, or `null` if no project is found. + */ +const getProjectById = async function (projectId, fieldsToSelect = null) { + const query = Project.findById(projectId); + + if (fieldsToSelect) { + query.select(fieldsToSelect); + } + + return await query.lean(); +}; + +/** + * Retrieve a project by name and convert the found project document to a plain object. + * If the project with the given name doesn't exist and the name is "instance", create it and return the lean version. + * + * @param {string} projectName - The name of the project to find or create. + * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. + * @returns {Promise} A plain object representing the project document. + */ +const getProjectByName = async function (projectName, fieldsToSelect = null) { + const query = { name: projectName }; + const update = { $setOnInsert: { name: projectName } }; + const options = { + new: true, + upsert: projectName === 'instance', + lean: true, + select: fieldsToSelect, + }; + + return await Project.findOneAndUpdate(query, update, options); +}; + +/** + * Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project. + * @returns {Promise} The updated project document. + */ +const addGroupIdsToProject = async function (projectId, promptGroupIds) { + return await Project.findByIdAndUpdate( + projectId, + { $addToSet: { promptGroupIds: { $each: promptGroupIds } } }, + { new: true }, + ); +}; + +/** + * Remove an array of prompt group IDs from a project's promptGroupIds array. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project. + * @returns {Promise} The updated project document. + */ +const removeGroupIdsFromProject = async function (projectId, promptGroupIds) { + return await Project.findByIdAndUpdate( + projectId, + { $pull: { promptGroupIds: { $in: promptGroupIds } } }, + { new: true }, + ); +}; + +/** + * Remove a prompt group ID from all projects. + * + * @param {string} promptGroupId - The ID of the prompt group to remove from projects. + * @returns {Promise} + */ +const removeGroupFromAllProjects = async (promptGroupId) => { + await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } }); +}; + +module.exports = { + getProjectById, + getProjectByName, + addGroupIdsToProject, + removeGroupIdsFromProject, + removeGroupFromAllProjects, +}; diff --git a/api/models/Prompt.js b/api/models/Prompt.js index f2759472b66..fbc661addf2 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,52 +1,435 @@ -const mongoose = require('mongoose'); +const { ObjectId } = require('mongodb'); +const { SystemRoles, SystemCategories } = require('librechat-data-provider'); +const { + getProjectByName, + addGroupIdsToProject, + removeGroupIdsFromProject, + removeGroupFromAllProjects, +} = require('./Project'); +const { Prompt, PromptGroup } = require('./schema/promptSchema'); const { logger } = require('~/config'); -const promptSchema = mongoose.Schema( - { - title: { - type: String, - required: true, +/** + * Create a pipeline for the aggregation to get prompt groups + * @param {Object} query + * @param {number} skip + * @param {number} limit + * @returns {[Object]} - The pipeline for the aggregation + */ +const createGroupPipeline = (query, skip, limit) => { + return [ + { $match: query }, + { $sort: { createdAt: -1 } }, + { $skip: skip }, + { $limit: limit }, + { + $lookup: { + from: 'prompts', + localField: 'productionId', + foreignField: '_id', + as: 'productionPrompt', + }, }, - prompt: { - type: String, - required: true, + { $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } }, + { + $project: { + name: 1, + numberOfGenerations: 1, + oneliner: 1, + category: 1, + projectIds: 1, + productionId: 1, + author: 1, + authorName: 1, + createdAt: 1, + updatedAt: 1, + 'productionPrompt.prompt': 1, + // 'productionPrompt._id': 1, + // 'productionPrompt.type': 1, + }, }, - category: { - type: String, - }, - }, - { timestamps: true }, -); + ]; +}; + +/** + * Get prompt groups with filters + * @param {Object} req + * @param {TPromptGroupsWithFilterRequest} filter + * @returns {Promise} + */ +const getPromptGroups = async (req, filter) => { + try { + const { pageNumber = 1, pageSize = 10, name, ...query } = filter; + if (!query.author) { + throw new Error('Author is required'); + } -const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema); + let searchShared = true; + let searchSharedOnly = false; + if (name) { + query.name = new RegExp(name, 'i'); + } + if (!query.category) { + delete query.category; + } else if (query.category === SystemCategories.MY_PROMPTS) { + searchShared = false; + delete query.category; + } else if (query.category === SystemCategories.NO_CATEGORY) { + query.category = ''; + } else if (query.category === SystemCategories.SHARED_PROMPTS) { + searchSharedOnly = true; + delete query.category; + } + + let combinedQuery = query; + + if (searchShared) { + // const projects = req.user.projects || []; // TODO: handle multiple projects + const project = await getProjectByName('instance', 'promptGroupIds'); + if (project && project.promptGroupIds.length > 0) { + const projectQuery = { _id: { $in: project.promptGroupIds }, ...query }; + delete projectQuery.author; + combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] }; + } + } + + const skip = (parseInt(pageNumber, 10) - 1) * parseInt(pageSize, 10); + const limit = parseInt(pageSize, 10); + + const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit); + const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }]; + + const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([ + PromptGroup.aggregate(promptGroupsPipeline).exec(), + PromptGroup.aggregate(totalPromptGroupsPipeline).exec(), + ]); + + const promptGroups = promptGroupsResults; + const totalPromptGroups = + totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0; + + return { + promptGroups, + pageNumber: pageNumber.toString(), + pageSize: pageSize.toString(), + pages: Math.ceil(totalPromptGroups / pageSize).toString(), + }; + } catch (error) { + console.error('Error getting prompt groups', error); + return { message: 'Error getting prompt groups' }; + } +}; module.exports = { - savePrompt: async ({ title, prompt }) => { + getPromptGroups, + /** + * Create a prompt and its respective group + * @param {TCreatePromptRecord} saveData + * @returns {Promise} + */ + createPromptGroup: async (saveData) => { try { - await Prompt.create({ - title, - prompt, - }); - return { title, prompt }; + const { prompt, group, author, authorName } = saveData; + + let newPromptGroup = await PromptGroup.findOneAndUpdate( + { ...group, author, authorName, productionId: null }, + { $setOnInsert: { ...group, author, authorName, productionId: null } }, + { new: true, upsert: true }, + ) + .lean() + .select('-__v') + .exec(); + + const newPrompt = await Prompt.findOneAndUpdate( + { ...prompt, author, groupId: newPromptGroup._id }, + { $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } }, + { new: true, upsert: true }, + ) + .lean() + .select('-__v') + .exec(); + + newPromptGroup = await PromptGroup.findByIdAndUpdate( + newPromptGroup._id, + { productionId: newPrompt._id }, + { new: true }, + ) + .lean() + .select('-__v') + .exec(); + + return { + prompt: newPrompt, + group: { + ...newPromptGroup, + productionPrompt: { prompt: newPrompt.prompt }, + }, + }; + } catch (error) { + logger.error('Error saving prompt group', error); + throw new Error('Error saving prompt group'); + } + }, + /** + * Save a prompt + * @param {TCreatePromptRecord} saveData + * @returns {Promise} + */ + savePrompt: async (saveData) => { + try { + const { prompt, author } = saveData; + const newPromptData = { + ...prompt, + author, + }; + + /** @type {TPrompt} */ + let newPrompt; + try { + newPrompt = await Prompt.create(newPromptData); + } catch (error) { + if (error?.message?.includes('groupId_1_version_1')) { + await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1'); + } else { + throw error; + } + newPrompt = await Prompt.create(newPromptData); + } + + return { prompt: newPrompt }; } catch (error) { logger.error('Error saving prompt', error); - return { prompt: 'Error saving prompt' }; + return { message: 'Error saving prompt' }; } }, getPrompts: async (filter) => { try { - return await Prompt.find(filter).lean(); + return await Prompt.find(filter).sort({ createdAt: -1 }).lean(); } catch (error) { logger.error('Error getting prompts', error); - return { prompt: 'Error getting prompts' }; + return { message: 'Error getting prompts' }; } }, - deletePrompts: async (filter) => { + getPrompt: async (filter) => { try { - return await Prompt.deleteMany(filter); + if (filter.groupId) { + filter.groupId = new ObjectId(filter.groupId); + } + return await Prompt.findOne(filter).lean(); + } catch (error) { + logger.error('Error getting prompt', error); + return { message: 'Error getting prompt' }; + } + }, + /** + * Get prompt groups with filters + * @param {TGetRandomPromptsRequest} filter + * @returns {Promise} + */ + getRandomPromptGroups: async (filter) => { + try { + const result = await PromptGroup.aggregate([ + { + $match: { + category: { $ne: '' }, + }, + }, + { + $group: { + _id: '$category', + promptGroup: { $first: '$$ROOT' }, + }, + }, + { + $replaceRoot: { newRoot: '$promptGroup' }, + }, + { + $sample: { size: +filter.limit + +filter.skip }, + }, + { + $skip: +filter.skip, + }, + { + $limit: +filter.limit, + }, + ]); + return { prompts: result }; + } catch (error) { + logger.error('Error getting prompt groups', error); + return { message: 'Error getting prompt groups' }; + } + }, + getPromptGroupsWithPrompts: async (filter) => { + try { + return await PromptGroup.findOne(filter) + .populate({ + path: 'prompts', + select: '-_id -__v -user', + }) + .select('-_id -__v -user') + .lean(); + } catch (error) { + logger.error('Error getting prompt groups', error); + return { message: 'Error getting prompt groups' }; + } + }, + getPromptGroup: async (filter) => { + try { + return await PromptGroup.findOne(filter).lean(); + } catch (error) { + logger.error('Error getting prompt group', error); + return { message: 'Error getting prompt group' }; + } + }, + /** + * Deletes a prompt and its corresponding prompt group if it is the last prompt in the group. + * + * @param {Object} options - The options for deleting the prompt. + * @param {ObjectId|string} options.promptId - The ID of the prompt to delete. + * @param {ObjectId|string} options.groupId - The ID of the prompt's group. + * @param {ObjectId|string} options.author - The ID of the prompt's author. + * @param {string} options.role - The role of the prompt's author. + * @return {Promise} An object containing the result of the deletion. + * If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'. + * If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group. + * If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'. + */ + deletePrompt: async ({ promptId, groupId, author, role }) => { + const query = { _id: promptId, groupId, author }; + if (role === SystemRoles.ADMIN) { + delete query.author; + } + const { deletedCount } = await Prompt.deleteOne(query); + if (deletedCount === 0) { + throw new Error('Failed to delete the prompt'); + } + + const remainingPrompts = await Prompt.find({ groupId }) + .select('_id') + .sort({ createdAt: 1 }) + .lean(); + + if (remainingPrompts.length === 0) { + await PromptGroup.deleteOne({ _id: groupId }); + await removeGroupFromAllProjects(groupId); + + return { + prompt: 'Prompt deleted successfully', + promptGroup: { + message: 'Prompt group deleted successfully', + id: groupId, + }, + }; + } else { + const promptGroup = await PromptGroup.findById(groupId).lean(); + if (promptGroup.productionId.toString() === promptId.toString()) { + await PromptGroup.updateOne( + { _id: groupId }, + { productionId: remainingPrompts[remainingPrompts.length - 1]._id }, + ); + } + + return { prompt: 'Prompt deleted successfully' }; + } + }, + /** + * Update prompt group + * @param {Partial} filter - Filter to find prompt group + * @param {Partial} data - Data to update + * @returns {Promise} + */ + updatePromptGroup: async (filter, data) => { + try { + const updateOps = {}; + if (data.removeProjectIds) { + for (const projectId of data.removeProjectIds) { + await removeGroupIdsFromProject(projectId, [filter._id]); + } + + updateOps.$pull = { projectIds: { $in: data.removeProjectIds } }; + delete data.removeProjectIds; + } + + if (data.projectIds) { + for (const projectId of data.projectIds) { + await addGroupIdsToProject(projectId, [filter._id]); + } + + updateOps.$addToSet = { projectIds: { $each: data.projectIds } }; + delete data.projectIds; + } + + const updateData = { ...data, ...updateOps }; + const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, { + new: true, + upsert: false, + }); + + if (!updatedDoc) { + throw new Error('Prompt group not found'); + } + + return updatedDoc; + } catch (error) { + logger.error('Error updating prompt group', error); + return { message: 'Error updating prompt group' }; + } + }, + /** + * Function to make a prompt production based on its ID. + * @param {String} promptId - The ID of the prompt to make production. + * @returns {Object} The result of the production operation. + */ + makePromptProduction: async (promptId) => { + try { + const prompt = await Prompt.findById(promptId).lean(); + + if (!prompt) { + throw new Error('Prompt not found'); + } + + await PromptGroup.findByIdAndUpdate( + prompt.groupId, + { productionId: prompt._id }, + { new: true }, + ) + .lean() + .exec(); + + return { + message: 'Prompt production made successfully', + }; + } catch (error) { + logger.error('Error making prompt production', error); + return { message: 'Error making prompt production' }; + } + }, + updatePromptLabels: async (_id, labels) => { + try { + const response = await Prompt.updateOne({ _id }, { $set: { labels } }); + if (response.matchedCount === 0) { + return { message: 'Prompt not found' }; + } + return { message: 'Prompt labels updated successfully' }; + } catch (error) { + logger.error('Error updating prompt labels', error); + return { message: 'Error updating prompt labels' }; + } + }, + deletePromptGroup: async (_id) => { + try { + const response = await PromptGroup.deleteOne({ _id }); + + if (response.deletedCount === 0) { + return { promptGroup: 'Prompt group not found' }; + } + + await Prompt.deleteMany({ groupId: new ObjectId(_id) }); + await removeGroupFromAllProjects(_id); + return { promptGroup: 'Prompt group deleted successfully' }; } catch (error) { - logger.error('Error deleting prompts', error); - return { prompt: 'Error deleting prompts' }; + logger.error('Error deleting prompt group', error); + return { message: 'Error deleting prompt group' }; } }, }; diff --git a/api/models/Role.js b/api/models/Role.js new file mode 100644 index 00000000000..af02e5cac40 --- /dev/null +++ b/api/models/Role.js @@ -0,0 +1,86 @@ +const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const Role = require('~/models/schema/roleSchema'); + +/** + * Retrieve a role by name and convert the found role document to a plain object. + * If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version. + * + * @param {string} roleName - The name of the role to find or create. + * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. + * @returns {Promise} A plain object representing the role document. + */ +const getRoleByName = async function (roleName, fieldsToSelect = null) { + try { + const cache = getLogStores(CacheKeys.ROLES); + const cachedRole = await cache.get(roleName); + if (cachedRole) { + return cachedRole; + } + let query = Role.findOne({ name: roleName }); + if (fieldsToSelect) { + query = query.select(fieldsToSelect); + } + let role = await query.lean().exec(); + + if (!role && SystemRoles[roleName]) { + role = roleDefaults[roleName]; + role = await new Role(role).save(); + await cache.set(roleName, role); + return role.toObject(); + } + await cache.set(roleName, role); + return role; + } catch (error) { + throw new Error(`Failed to retrieve or create role: ${error.message}`); + } +}; + +/** + * Update role values by name. + * + * @param {string} roleName - The name of the role to update. + * @param {Partial} updates - The fields to update. + * @returns {Promise} Updated role document. + */ +const updateRoleByName = async function (roleName, updates) { + try { + const cache = getLogStores(CacheKeys.ROLES); + const role = await Role.findOneAndUpdate( + { name: roleName }, + { $set: updates }, + { new: true, lean: true }, + ) + .select('-__v') + .lean() + .exec(); + await cache.set(roleName, role); + return role; + } catch (error) { + throw new Error(`Failed to update role: ${error.message}`); + } +}; + +/** + * Initialize default roles in the system. + * Creates the default roles (ADMIN, USER) if they don't exist in the database. + * + * @returns {Promise} + */ +const initializeRoles = async function () { + const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER]; + + for (const roleName of defaultRoles) { + let role = await Role.findOne({ name: roleName }).select('name').lean(); + if (!role) { + role = new Role(roleDefaults[roleName]); + await role.save(); + } + } +}; + +module.exports = { + getRoleByName, + initializeRoles, + updateRoleByName, +}; diff --git a/api/models/schema/categories.js b/api/models/schema/categories.js new file mode 100644 index 00000000000..31676856670 --- /dev/null +++ b/api/models/schema/categories.js @@ -0,0 +1,19 @@ +const mongoose = require('mongoose'); +const Schema = mongoose.Schema; + +const categoriesSchema = new Schema({ + label: { + type: String, + required: true, + unique: true, + }, + value: { + type: String, + required: true, + unique: true, + }, +}); + +const categories = mongoose.model('categories', categoriesSchema); + +module.exports = { Categories: categories }; diff --git a/api/models/schema/projectSchema.js b/api/models/schema/projectSchema.js new file mode 100644 index 00000000000..0e27c6a8f9f --- /dev/null +++ b/api/models/schema/projectSchema.js @@ -0,0 +1,30 @@ +const { Schema } = require('mongoose'); + +/** + * @typedef {Object} MongoProject + * @property {ObjectId} [_id] - MongoDB Document ID + * @property {string} name - The name of the project + * @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project + * @property {Date} [createdAt] - Date when the project was created (added by timestamps) + * @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps) + */ + +const projectSchema = new Schema( + { + name: { + type: String, + required: true, + index: true, + }, + promptGroupIds: { + type: [Schema.Types.ObjectId], + ref: 'PromptGroup', + default: [], + }, + }, + { + timestamps: true, + }, +); + +module.exports = projectSchema; diff --git a/api/models/schema/promptSchema.js b/api/models/schema/promptSchema.js new file mode 100644 index 00000000000..4aeb1deb280 --- /dev/null +++ b/api/models/schema/promptSchema.js @@ -0,0 +1,101 @@ +const mongoose = require('mongoose'); +const Schema = mongoose.Schema; + +/** + * @typedef {Object} MongoPromptGroup + * @property {ObjectId} [_id] - MongoDB Document ID + * @property {string} name - The name of the prompt group + * @property {ObjectId} author - The author of the prompt group + * @property {ObjectId} [projectId=null] - The project ID of the prompt group + * @property {ObjectId} [productionId=null] - The project ID of the prompt group + * @property {string} authorName - The name of the author of the prompt group + * @property {number} [numberOfGenerations=0] - Number of generations the prompt group has + * @property {string} [oneliner=''] - Oneliner description of the prompt group + * @property {string} [category=''] - Category of the prompt group + * @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps) + * @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps) + */ + +const promptGroupSchema = new Schema( + { + name: { + type: String, + required: true, + index: true, + }, + numberOfGenerations: { + type: Number, + default: 0, + }, + oneliner: { + type: String, + default: '', + }, + category: { + type: String, + default: '', + index: true, + }, + projectIds: { + type: [Schema.Types.ObjectId], + ref: 'Project', + index: true, + }, + productionId: { + type: Schema.Types.ObjectId, + ref: 'Prompt', + required: true, + index: true, + }, + author: { + type: Schema.Types.ObjectId, + ref: 'User', + required: true, + index: true, + }, + authorName: { + type: String, + required: true, + }, + }, + { + timestamps: true, + }, +); + +const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); + +const promptSchema = new Schema( + { + groupId: { + type: Schema.Types.ObjectId, + ref: 'PromptGroup', + required: true, + index: true, + }, + author: { + type: Schema.Types.ObjectId, + ref: 'User', + required: true, + }, + prompt: { + type: String, + required: true, + }, + type: { + type: String, + enum: ['text', 'chat'], + required: true, + }, + }, + { + timestamps: true, + }, +); + +const Prompt = mongoose.model('Prompt', promptSchema); + +promptSchema.index({ createdAt: 1, updatedAt: 1 }); +promptGroupSchema.index({ createdAt: 1, updatedAt: 1 }); + +module.exports = { Prompt, PromptGroup }; diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js new file mode 100644 index 00000000000..0387f44ad36 --- /dev/null +++ b/api/models/schema/roleSchema.js @@ -0,0 +1,29 @@ +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const mongoose = require('mongoose'); + +const roleSchema = new mongoose.Schema({ + name: { + type: String, + required: true, + unique: true, + index: true, + }, + [PermissionTypes.PROMPTS]: { + [Permissions.SHARED_GLOBAL]: { + type: Boolean, + default: false, + }, + [Permissions.USE]: { + type: Boolean, + default: true, + }, + [Permissions.CREATE]: { + type: Boolean, + default: true, + }, + }, +}); + +const Role = mongoose.model('Role', roleSchema); + +module.exports = Role; diff --git a/api/models/schema/userSchema.js b/api/models/schema/userSchema.js index f32da48cc97..715d8235164 100644 --- a/api/models/schema/userSchema.js +++ b/api/models/schema/userSchema.js @@ -1,4 +1,5 @@ const mongoose = require('mongoose'); +const { SystemRoles } = require('librechat-data-provider'); /** * @typedef {Object} MongoSession @@ -78,7 +79,7 @@ const userSchema = mongoose.Schema( }, role: { type: String, - default: 'USER', + default: SystemRoles.USER, }, googleId: { type: String, diff --git a/api/package.json b/api/package.json index f06384e5942..52413f7ac26 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "0.7.3", + "version": "0.7.4-rc1", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", diff --git a/api/server/controllers/assistants/helpers.js b/api/server/controllers/assistants/helpers.js index e91a3dc528c..715bb02ed20 100644 --- a/api/server/controllers/assistants/helpers.js +++ b/api/server/controllers/assistants/helpers.js @@ -1,8 +1,9 @@ const { - EModelEndpoint, CacheKeys, - defaultAssistantsVersion, + SystemRoles, + EModelEndpoint, defaultOrderQuery, + defaultAssistantsVersion, } = require('librechat-data-provider'); const { initializeClient: initAzureClient, @@ -227,7 +228,7 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => { body = await listAssistantsForAzure({ req, res, version, azureConfig, query }); } - if (req.user.role === 'ADMIN') { + if (req.user.role === SystemRoles.ADMIN) { return body; } else if (!req.app.locals[endpoint]) { return body; diff --git a/api/server/index.js b/api/server/index.js index 0aa8cb27607..d99340a1fbc 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -81,6 +81,7 @@ const startServer = async () => { app.use('/api/convos', routes.convos); app.use('/api/presets', routes.presets); app.use('/api/prompts', routes.prompts); + app.use('/api/categories', routes.categories); app.use('/api/tokenizer', routes.tokenizer); app.use('/api/endpoints', routes.endpoints); app.use('/api/balance', routes.balance); @@ -91,6 +92,7 @@ const startServer = async () => { app.use('/api/files', await routes.files.initialize()); app.use('/images/', validateImageRequest, routes.staticRoute); app.use('/api/share', routes.share); + app.use('/api/roles', routes.roles); app.use((req, res) => { res.sendFile(path.join(app.locals.paths.dist, 'index.html')); diff --git a/api/server/middleware/assistants/validateAuthor.js b/api/server/middleware/assistants/validateAuthor.js index 749b309cbe2..a17448211e7 100644 --- a/api/server/middleware/assistants/validateAuthor.js +++ b/api/server/middleware/assistants/validateAuthor.js @@ -1,3 +1,4 @@ +const { SystemRoles } = require('librechat-data-provider'); const { getAssistant } = require('~/models/Assistant'); /** @@ -11,7 +12,7 @@ const { getAssistant } = require('~/models/Assistant'); * @returns {Promise} */ const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => { - if (req.user.role === 'ADMIN') { + if (req.user.role === SystemRoles.ADMIN) { return; } diff --git a/api/server/middleware/canDeleteAccount.js b/api/server/middleware/canDeleteAccount.js index 1abfbc9f8c7..5f2479fb542 100644 --- a/api/server/middleware/canDeleteAccount.js +++ b/api/server/middleware/canDeleteAccount.js @@ -1,3 +1,4 @@ +const { SystemRoles } = require('librechat-data-provider'); const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); @@ -16,7 +17,7 @@ const { logger } = require('~/config'); const canDeleteAccount = async (req, res, next = () => {}) => { const { user } = req; const { ALLOW_ACCOUNT_DELETION = true } = process.env; - if (user?.role === 'ADMIN' || isEnabled(ALLOW_ACCOUNT_DELETION)) { + if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) { return next(); } else { logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`); diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 8d3455af341..75aab961b59 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -18,10 +18,12 @@ const limiters = require('./limiters'); const uaParser = require('./uaParser'); const checkBan = require('./checkBan'); const noIndex = require('./noIndex'); +const roles = require('./roles'); module.exports = { ...abortMiddleware, ...limiters, + ...roles, noIndex, checkBan, uaParser, diff --git a/api/server/middleware/roles/checkAdmin.js b/api/server/middleware/roles/checkAdmin.js new file mode 100644 index 00000000000..3cb93fab536 --- /dev/null +++ b/api/server/middleware/roles/checkAdmin.js @@ -0,0 +1,14 @@ +const { SystemRoles } = require('librechat-data-provider'); + +function checkAdmin(req, res, next) { + try { + if (req.user.role !== SystemRoles.ADMIN) { + return res.status(403).json({ message: 'Forbidden' }); + } + next(); + } catch (error) { + res.status(500).json({ message: 'Internal Server Error' }); + } +} + +module.exports = checkAdmin; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js new file mode 100644 index 00000000000..900921ef80d --- /dev/null +++ b/api/server/middleware/roles/generateCheckAccess.js @@ -0,0 +1,52 @@ +const { SystemRoles } = require('librechat-data-provider'); +const { getRoleByName } = require('~/models/Role'); + +/** + * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. + * + * @param {PermissionTypes} permissionType - The type of permission to check. + * @param {Permissions[]} permissions - The list of specific permissions to check. + * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. + * @returns {Function} Express middleware function. + */ +const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { + return async (req, res, next) => { + try { + const { user } = req; + if (!user) { + return res.status(401).json({ message: 'Authorization required' }); + } + + if (user.role === SystemRoles.ADMIN) { + return next(); + } + + const role = await getRoleByName(user.role); + if (role && role[permissionType]) { + const hasAnyPermission = permissions.some((permission) => { + if (role[permissionType][permission]) { + return true; + } + + if (bodyProps[permission] && req.body) { + return bodyProps[permission].some((prop) => + Object.prototype.hasOwnProperty.call(req.body, prop), + ); + } + + return false; + }); + + if (hasAnyPermission) { + return next(); + } + } + + return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); + } catch (error) { + return res.status(500).json({ message: `Server error: ${error.message}` }); + } + }; +}; + +module.exports = generateCheckAccess; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js new file mode 100644 index 00000000000..999c36481e0 --- /dev/null +++ b/api/server/middleware/roles/index.js @@ -0,0 +1,7 @@ +const checkAdmin = require('./checkAdmin'); +const generateCheckAccess = require('./generateCheckAccess'); + +module.exports = { + checkAdmin, + generateCheckAccess, +}; diff --git a/api/server/routes/categories.js b/api/server/routes/categories.js new file mode 100644 index 00000000000..da1828b3ce7 --- /dev/null +++ b/api/server/routes/categories.js @@ -0,0 +1,15 @@ +const express = require('express'); +const router = express.Router(); +const { requireJwtAuth } = require('~/server/middleware'); +const { getCategories } = require('~/models/Categories'); + +router.get('/', requireJwtAuth, async (req, res) => { + try { + const categories = await getCategories(); + res.status(200).send(categories); + } catch (error) { + res.status(500).send({ message: 'Failed to retrieve categories', error: error.message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index de3c0d89c9d..4c2fd28d242 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,6 +1,8 @@ const express = require('express'); -const { defaultSocialLogins } = require('librechat-data-provider'); +const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { getProjectByName } = require('~/models/Project'); const { isEnabled } = require('~/server/utils'); +const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); const router = express.Router(); @@ -17,11 +19,20 @@ const publicSharedLinksEnabled = isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC)); router.get('/', async function (req, res) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); + if (cachedStartupConfig) { + res.send(cachedStartupConfig); + return; + } + const isBirthday = () => { const today = new Date(); return today.getMonth() === 1 && today.getDate() === 11; }; + const instanceProject = await getProjectByName('instance', '_id'); + const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE; try { @@ -63,12 +74,14 @@ router.get('/', async function (req, res) { sharedLinksEnabled, publicSharedLinksEnabled, analyticsGtmId: process.env.ANALYTICS_GTM_ID, + instanceProjectId: instanceProject._id.toString(), }; if (typeof process.env.CUSTOM_FOOTER === 'string') { payload.customFooter = process.env.CUSTOM_FOOTER; } + await cache.set(CacheKeys.STARTUP_CONFIG, payload); return res.status(200).send(payload); } catch (err) { logger.error('Error in startup config', err); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 958cf0aed52..f8a3d258485 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -19,6 +19,8 @@ const assistants = require('./assistants'); const files = require('./files'); const staticRoute = require('./static'); const share = require('./share'); +const categories = require('./categories'); +const roles = require('./roles'); module.exports = { search, @@ -42,4 +44,6 @@ module.exports = { files, staticRoute, share, + categories, + roles, }; diff --git a/api/server/routes/prompts.js b/api/server/routes/prompts.js index 753feb262a3..38a9e51ba10 100644 --- a/api/server/routes/prompts.js +++ b/api/server/routes/prompts.js @@ -1,14 +1,218 @@ const express = require('express'); +const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider'); +const { + getPrompt, + getPrompts, + savePrompt, + deletePrompt, + getPromptGroup, + getPromptGroups, + updatePromptGroup, + deletePromptGroup, + createPromptGroup, + // updatePromptLabels, + makePromptProduction, +} = require('~/models/Prompt'); +const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); +const { logger } = require('~/config'); + const router = express.Router(); -const { getPrompts } = require('../../models/Prompt'); + +const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]); +const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [ + Permissions.USE, + Permissions.CREATE, +]); +const checkGlobalPromptShare = generateCheckAccess( + PermissionTypes.PROMPTS, + [Permissions.USE, Permissions.CREATE], + { + [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], + }, +); + +router.use(requireJwtAuth); +router.use(checkPromptAccess); + +/** + * Route to get single prompt group by its ID + * GET /groups/:groupId + */ +router.get('/groups/:groupId', async (req, res) => { + let groupId = req.params.groupId; + const author = req.user.id; + + const query = { + _id: groupId, + $or: [{ projectIds: { $exists: true, $ne: [], $not: { $size: 0 } } }, { author }], + }; + + if (req.user.role === SystemRoles.ADMIN) { + delete query.$or; + } + + try { + const group = await getPromptGroup(query); + + if (!group) { + return res.status(404).send({ message: 'Prompt group not found' }); + } + + res.status(200).send(group); + } catch (error) { + logger.error('Error getting prompt group', error); + res.status(500).send({ message: 'Error getting prompt group' }); + } +}); + +/** + * Route to fetch paginated prompt groups with filters + * GET /groups + */ +router.get('/groups', async (req, res) => { + try { + const filter = req.query; + /* Note: The aggregation requires an ObjectId */ + filter.author = req.user._id; + const groups = await getPromptGroups(req, filter); + res.status(200).send(groups); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error getting prompt groups' }); + } +}); + +/** + * Updates or creates a prompt + promptGroup + * @param {object} req + * @param {TCreatePrompt} req.body + * @param {Express.Response} res + */ +const createPrompt = async (req, res) => { + try { + const { prompt, group } = req.body; + if (!prompt) { + return res.status(400).send({ error: 'Prompt is required' }); + } + + const saveData = { + prompt, + group, + author: req.user.id, + authorName: req.user.name, + }; + + /** @type {TCreatePromptResponse} */ + let result; + if (group && group.name) { + result = await createPromptGroup(saveData); + } else { + result = await savePrompt(saveData); + } + res.status(200).send(result); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error saving prompt' }); + } +}; + +router.post('/', createPrompt); + +/** + * Updates a prompt group + * @param {object} req + * @param {object} req.params - The request parameters + * @param {string} req.params.groupId - The group ID + * @param {TUpdatePromptGroupPayload} req.body - The request body + * @param {Express.Response} res + */ +const patchPromptGroup = async (req, res) => { + try { + const { groupId } = req.params; + const author = req.user.id; + const filter = { _id: groupId, author }; + if (req.user.role === SystemRoles.ADMIN) { + delete filter.author; + } + const promptGroup = await updatePromptGroup(filter, req.body); + res.status(200).send(promptGroup); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error updating prompt group' }); + } +}; + +router.patch('/groups/:groupId', checkGlobalPromptShare, patchPromptGroup); + +router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) => { + try { + const { promptId } = req.params; + const result = await makePromptProduction(promptId); + res.status(200).send(result); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error updating prompt production' }); + } +}); + +router.get('/:promptId', async (req, res) => { + const { promptId } = req.params; + const author = req.user.id; + const query = { _id: promptId, author }; + if (req.user.role === SystemRoles.ADMIN) { + delete query.author; + } + const prompt = await getPrompt(query); + res.status(200).send(prompt); +}); router.get('/', async (req, res) => { - let filter = {}; - // const { search } = req.body.arg; - // if (!!search) { - // filter = { conversationId }; - // } - res.status(200).send(await getPrompts(filter)); + try { + const author = req.user.id; + const { groupId } = req.query; + const query = { groupId, author }; + if (req.user.role === SystemRoles.ADMIN) { + delete query.author; + } + const prompts = await getPrompts(query); + res.status(200).send(prompts); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error getting prompts' }); + } +}); + +/** + * Deletes a prompt + * + * @param {Express.Request} req - The request object. + * @param {TDeletePromptVariables} req.params - The request parameters + * @param {import('mongoose').ObjectId} req.params.promptId - The prompt ID + * @param {Express.Response} res - The response object. + * @return {TDeletePromptResponse} A promise that resolves when the prompt is deleted. + */ +const deletePromptController = async (req, res) => { + try { + const { promptId } = req.params; + const { groupId } = req.query; + const author = req.user.id; + const query = { promptId, groupId, author, role: req.user.role }; + if (req.user.role === SystemRoles.ADMIN) { + delete query.author; + } + const result = await deletePrompt(query); + res.status(200).send(result); + } catch (error) { + logger.error(error); + res.status(500).send({ error: 'Error deleting prompt' }); + } +}; + +router.delete('/:promptId', checkPromptCreate, deletePromptController); + +router.delete('/groups/:groupId', checkPromptCreate, async (req, res) => { + const { groupId } = req.params; + res.status(200).send(await deletePromptGroup(groupId)); }); module.exports = router; diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js new file mode 100644 index 00000000000..06005ad40e8 --- /dev/null +++ b/api/server/routes/roles.js @@ -0,0 +1,72 @@ +const express = require('express'); +const { + promptPermissionsSchema, + PermissionTypes, + roleDefaults, + SystemRoles, +} = require('librechat-data-provider'); +const { checkAdmin, requireJwtAuth } = require('~/server/middleware'); +const { updateRoleByName, getRoleByName } = require('~/models/Role'); + +const router = express.Router(); +router.use(requireJwtAuth); + +/** + * GET /api/roles/:roleName + * Get a specific role by name + */ +router.get('/:roleName', async (req, res) => { + const { roleName: _r } = req.params; + // TODO: TEMP, use a better parsing for roleName + const roleName = _r.toUpperCase(); + + if (req.user.role !== SystemRoles.ADMIN && !roleDefaults[roleName]) { + return res.status(403).send({ message: 'Unauthorized' }); + } + + try { + const role = await getRoleByName(roleName, '-_id -__v'); + if (!role) { + return res.status(404).send({ message: 'Role not found' }); + } + + res.status(200).send(role); + } catch (error) { + return res.status(500).send({ message: 'Failed to retrieve role', error: error.message }); + } +}); + +/** + * PUT /api/roles/:roleName/prompts + * Update prompt permissions for a specific role + */ +router.put('/:roleName/prompts', checkAdmin, async (req, res) => { + const { roleName: _r } = req.params; + // TODO: TEMP, use a better parsing for roleName + const roleName = _r.toUpperCase(); + /** @type {TRole['PROMPTS']} */ + const updates = req.body; + + try { + const parsedUpdates = promptPermissionsSchema.partial().parse(updates); + + const role = await getRoleByName(roleName); + if (!role) { + return res.status(404).send({ message: 'Role not found' }); + } + + const mergedUpdates = { + [PermissionTypes.PROMPTS]: { + ...role[PermissionTypes.PROMPTS], + ...parsedUpdates, + }, + }; + + const updatedRole = await updateRoleByName(roleName, mergedUpdates); + res.status(200).send(updatedRole); + } catch (error) { + return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors }); + } +}); + +module.exports = router; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index bbb3880b313..e416d5f6e70 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -7,6 +7,7 @@ const handleRateLimits = require('./Config/handleRateLimits'); const { loadDefaultInterface } = require('./start/interface'); const { azureConfigSetup } = require('./start/azureOpenAI'); const { loadAndFormatTools } = require('./ToolService'); +const { initializeRoles } = require('~/models/Role'); const paths = require('~/config/paths'); /** @@ -16,6 +17,7 @@ const paths = require('~/config/paths'); * @param {Express.Application} app - The Express application object. */ const AppService = async (app) => { + await initializeRoles(); /** @type {TCustomConfig}*/ const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 6bcf4754e15..cab6d8e2a46 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -21,6 +21,9 @@ jest.mock('./Config/loadCustomConfig', () => { jest.mock('./Files/Firebase/initialize', () => ({ initializeFirebase: jest.fn(), })); +jest.mock('~/models/Role', () => ({ + initializeRoles: jest.fn(), +})); jest.mock('./ToolService', () => ({ loadAndFormatTools: jest.fn().mockReturnValue({ ExampleTool: { diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 06dd0d0e729..9efc42b5ce9 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,6 +1,6 @@ const crypto = require('crypto'); const bcrypt = require('bcryptjs'); -const { errorsToString } = require('librechat-data-provider'); +const { SystemRoles, errorsToString } = require('librechat-data-provider'); const { findUser, countUsers, @@ -169,7 +169,7 @@ const registerUser = async (user) => { username, name, avatar: null, - role: isFirstRegisteredUser ? 'ADMIN' : 'USER', + role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER, password: bcrypt.hashSync(password, salt), }; diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index 01eb8da2ca7..e65b2849501 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -1,5 +1,6 @@ +const { SystemRoles } = require('librechat-data-provider'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); -const { getUserById } = require('~/models'); +const { getUserById, updateUser } = require('~/models'); const { logger } = require('~/config'); // JWT strategy @@ -14,6 +15,10 @@ const jwtLogin = async () => const user = await getUserById(payload?.id, '-password -__v'); if (user) { user.id = user._id.toString(); + if (!user.role) { + user.role = SystemRoles.USER; + await updateUser(user.id, { role: user.role }); + } done(null, user); } else { logger.warn('[jwtLogin] JwtStrategy => no user found: ' + payload?.id); diff --git a/api/typedefs.js b/api/typedefs.js index 445ab4f9040..cdb2c531f2d 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -248,6 +248,110 @@ * @memberof typedefs */ +/** Prompts */ +/** + * @exports TPrompt + * @typedef {import('librechat-data-provider').TPrompt} TPrompt + * @memberof typedefs + */ + +/** + * @exports TPromptGroup + * @typedef {import('librechat-data-provider').TPromptGroup} TPromptGroup + * @memberof typedefs + */ + +/** + * @exports TCreatePrompt + * @typedef {import('librechat-data-provider').TCreatePrompt} TCreatePrompt + * @memberof typedefs + */ + +/** + * @exports TCreatePromptRecord + * @typedef {import('librechat-data-provider').TCreatePromptRecord} TCreatePromptRecord + * @memberof typedefs + */ +/** + * @exports TCreatePromptResponse + * @typedef {import('librechat-data-provider').TCreatePromptResponse} TCreatePromptResponse + * @memberof typedefs + */ +/** + * @exports TUpdatePromptGroupResponse + * @typedef {import('librechat-data-provider').TUpdatePromptGroupResponse} TUpdatePromptGroupResponse + * @memberof typedefs + */ + +/** + * @exports TPromptGroupsWithFilterRequest + * @typedef {import('librechat-data-provider').TPromptGroupsWithFilterRequest } TPromptGroupsWithFilterRequest + * @memberof typedefs + */ + +/** + * @exports PromptGroupListResponse + * @typedef {import('librechat-data-provider').PromptGroupListResponse } PromptGroupListResponse + * @memberof typedefs + */ + +/** + * @exports TGetCategoriesResponse + * @typedef {import('librechat-data-provider').TGetCategoriesResponse } TGetCategoriesResponse + * @memberof typedefs + */ + +/** + * @exports TGetRandomPromptsResponse + * @typedef {import('librechat-data-provider').TGetRandomPromptsResponse } TGetRandomPromptsResponse + * @memberof typedefs + */ + +/** + * @exports TGetRandomPromptsRequest + * @typedef {import('librechat-data-provider').TGetRandomPromptsRequest } TGetRandomPromptsRequest + * @memberof typedefs + */ + +/** + * @exports TUpdatePromptGroupPayload + * @typedef {import('librechat-data-provider').TUpdatePromptGroupPayload } TUpdatePromptGroupPayload + * @memberof typedefs + */ + +/** + * @exports TDeletePromptVariables + * @typedef {import('librechat-data-provider').TDeletePromptVariables } TDeletePromptVariables + * @memberof typedefs + */ + +/** + * @exports TDeletePromptResponse + * @typedef {import('librechat-data-provider').TDeletePromptResponse } TDeletePromptResponse + * @memberof typedefs + */ + +/* Roles */ + +/** + * @exports TRole + * @typedef {import('librechat-data-provider').TRole } TRole + * @memberof typedefs + */ + +/** + * @exports PermissionTypes + * @typedef {import('librechat-data-provider').PermissionTypes } PermissionTypes + * @memberof typedefs + */ + +/** + * @exports Permissions + * @typedef {import('librechat-data-provider').Permissions } Permissions + * @memberof typedefs + */ + +/** Assistants */ /** * @exports Assistant * @typedef {import('librechat-data-provider').Assistant} Assistant @@ -500,6 +604,18 @@ * @memberof typedefs */ +/** + * @exports MongoProject + * @typedef {import('~/models/schema/projectSchema.js').MongoProject} MongoProject + * @memberof typedefs + */ + +/** + * @exports MongoPromptGroup + * @typedef {import('~/models/schema/promptSchema.js').MongoPromptGroup} MongoPromptGroup + * @memberof typedefs + */ + /** * @exports uploadImageBuffer * @typedef {import('~/server/services/Files/process').uploadImageBuffer} uploadImageBuffer diff --git a/client/package.json b/client/package.json index 3571d94bf8e..80c33d85b06 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "0.7.3", + "version": "0.7.4-rc1", "description": "", "type": "module", "scripts": { @@ -66,7 +66,7 @@ "image-blob-reduce": "^4.1.0", "librechat-data-provider": "*", "lodash": "^4.17.21", - "lucide-react": "^0.220.0", + "lucide-react": "^0.394.0", "match-sorter": "^6.3.4", "rc-input-number": "^7.4.2", "react": "^18.2.0", diff --git a/client/src/Providers/ChatFormContext.tsx b/client/src/Providers/ChatFormContext.tsx new file mode 100644 index 00000000000..33940077bb8 --- /dev/null +++ b/client/src/Providers/ChatFormContext.tsx @@ -0,0 +1,6 @@ +import { createFormContext } from './CustomFormContext'; +import type { ChatFormValues } from '~/common'; + +const { CustomFormProvider, useCustomFormContext } = createFormContext(); + +export { CustomFormProvider as ChatFormProvider, useCustomFormContext as useChatFormContext }; diff --git a/client/src/Providers/CustomFormContext.tsx b/client/src/Providers/CustomFormContext.tsx new file mode 100644 index 00000000000..cb62b0d4021 --- /dev/null +++ b/client/src/Providers/CustomFormContext.tsx @@ -0,0 +1,56 @@ +import React, { createContext, PropsWithChildren, ReactElement, useContext, useMemo } from 'react'; +import type { + Control, + // FieldErrors, + FieldValues, + UseFormReset, + UseFormRegister, + UseFormGetValues, + UseFormHandleSubmit, + UseFormSetValue, +} from 'react-hook-form'; + +interface FormContextValue { + register: UseFormRegister; + control: Control; + // errors: FieldErrors; + getValues: UseFormGetValues; + setValue: UseFormSetValue; + handleSubmit: UseFormHandleSubmit; + reset: UseFormReset; +} + +function createFormContext() { + const context = createContext | undefined>(undefined); + + const useCustomFormContext = (): FormContextValue => { + const value = useContext(context); + if (!value) { + throw new Error('useCustomFormContext must be used within a CustomFormProvider'); + } + return value; + }; + + const CustomFormProvider = ({ + register, + control, + setValue, + // errors, + getValues, + handleSubmit, + reset, + children, + }: PropsWithChildren>): ReactElement => { + const value = useMemo( + () => ({ register, control, getValues, setValue, handleSubmit, reset }), + [register, control, setValue, getValues, handleSubmit, reset], + ); + + return {children}; + }; + + return { CustomFormProvider, useCustomFormContext }; +} + +export type { FormContextValue }; +export { createFormContext }; diff --git a/client/src/Providers/DashboardContext.tsx b/client/src/Providers/DashboardContext.tsx new file mode 100644 index 00000000000..f33a240d001 --- /dev/null +++ b/client/src/Providers/DashboardContext.tsx @@ -0,0 +1,7 @@ +import { createContext, useContext } from 'react'; +type TDashboardContext = { + prevLocationPath: string; +}; + +export const DashboardContext = createContext({} as TDashboardContext); +export const useDashboardContext = () => useContext(DashboardContext); diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 4085113e1a8..836bbde90b4 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -5,5 +5,7 @@ export * from './ShareContext'; export * from './ToastContext'; export * from './SearchContext'; export * from './FileMapContext'; +export * from './ChatFormContext'; +export * from './DashboardContext'; export * from './AssistantsContext'; export * from './AssistantsMapContext'; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index d0cc345615a..fc515f138a3 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,8 +1,10 @@ -import { FileSources } from 'librechat-data-provider'; +import React from 'react'; +import { FileSources, SystemRoles } from 'librechat-data-provider'; import type * as InputNumberPrimitive from 'rc-input-number'; import type { ColumnDef } from '@tanstack/react-table'; import type { SetterOrUpdater } from 'recoil'; import type { + TRole, TUser, Action, TPreset, @@ -52,6 +54,8 @@ export type LastSelectedModels = Record; export type LocalizeFunction = (phraseKey: string, ...values: string[]) => string; +export type ChatFormValues = { text: string }; + export const mainTextareaId = 'prompt-textarea'; export const globalAudioId = 'global-audio'; @@ -75,7 +79,7 @@ export type IconMapProps = { export type NavLink = { title: string; label?: string; - icon: LucideIcon; + icon: LucideIcon | React.FC; Component?: React.ComponentType; onClick?: () => void; variant?: 'default' | 'ghost'; @@ -325,6 +329,7 @@ export type TAuthContext = { login: (data: TLoginUser) => void; logout: () => void; setError: React.Dispatch>; + roles?: Record; }; export type TUserContext = { @@ -394,7 +399,6 @@ export interface SwitcherProps { endpointKeyProvided: boolean; isCollapsed: boolean; } - export type TLoginLayoutContext = { startupConfig: TStartupConfig | null; startupConfigError: unknown; @@ -404,3 +408,19 @@ export type TLoginLayoutContext = { headerText: string; setHeaderText: React.Dispatch>; }; +export type TVectorStore = { + _id: string; + object: 'vector_store'; + created_at: string | Date; + name: string; + bytes?: number; + file_counts?: { + in_progress: number; + completed: number; + failed: number; + cancelled: number; + total: number; + }; +}; + +export type TThread = { id: string; createdAt: string }; diff --git a/client/src/components/Auth/AuthLayout.tsx b/client/src/components/Auth/AuthLayout.tsx index c2d4a6303c2..6d99a004899 100644 --- a/client/src/components/Auth/AuthLayout.tsx +++ b/client/src/components/Auth/AuthLayout.tsx @@ -1,8 +1,8 @@ -import { ThemeSelector } from '~/components/ui'; import { useLocalize } from '~/hooks'; import { BlinkAnimation } from './BlinkAnimation'; import { TStartupConfig } from 'librechat-data-provider'; import SocialLoginRender from './SocialLoginRender'; +import { ThemeSelector } from '~/components/ui'; import Footer from './Footer'; const ErrorRender = ({ children }: { children: React.ReactNode }) => ( diff --git a/client/src/components/Chat/ChatView.tsx b/client/src/components/Chat/ChatView.tsx index 604c8f1e78a..dfae014ea02 100644 --- a/client/src/components/Chat/ChatView.tsx +++ b/client/src/components/Chat/ChatView.tsx @@ -1,8 +1,10 @@ import { memo } from 'react'; import { useRecoilValue } from 'recoil'; +import { useForm } from 'react-hook-form'; import { useParams } from 'react-router-dom'; import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query'; -import { ChatContext, useFileMapContext } from '~/Providers'; +import type { ChatFormValues } from '~/common'; +import { ChatContext, useFileMapContext, ChatFormProvider } from '~/Providers'; import MessagesView from './Messages/MessagesView'; import { useChatHelpers, useSSE } from '~/hooks'; import { Spinner } from '~/components/svg'; @@ -30,25 +32,37 @@ function ChatView({ index = 0 }: { index?: number }) { }); const chatHelpers = useChatHelpers(index, conversationId); + const methods = useForm({ + defaultValues: { text: '' }, + }); return ( - - - {isLoading && conversationId !== 'new' ? ( -
- + + + + {isLoading && conversationId !== 'new' ? ( +
+ +
+ ) : messagesTree && messagesTree.length !== 0 ? ( + } /> + ) : ( + } /> + )} +
+ +
- ) : messagesTree && messagesTree.length !== 0 ? ( - } /> - ) : ( - } /> - )} -
- -
-
-
-
+ + +
); } diff --git a/client/src/components/Chat/Input/AudioRecorder.tsx b/client/src/components/Chat/Input/AudioRecorder.tsx index d4ea2c4a8e8..48d89c2c3fa 100644 --- a/client/src/components/Chat/Input/AudioRecorder.tsx +++ b/client/src/components/Chat/Input/AudioRecorder.tsx @@ -1,8 +1,8 @@ import { useEffect } from 'react'; -import type { UseFormReturn } from 'react-hook-form'; -import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui/'; +import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui'; import { ListeningIcon, Spinner } from '~/components/svg'; import { useLocalize, useSpeechToText } from '~/hooks'; +import { useChatFormContext } from '~/Providers'; import { globalAudioId } from '~/common'; export default function AudioRecorder({ @@ -12,7 +12,7 @@ export default function AudioRecorder({ disabled, }: { textAreaRef: React.RefObject; - methods: UseFormReturn<{ text: string }>; + methods: ReturnType; ask: (data: { text: string }) => void; disabled: boolean; }) { diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index 4cdd0f0c2c5..ea88457ce21 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -1,15 +1,14 @@ -import { useForm } from 'react-hook-form'; +import { memo, useRef, useMemo } from 'react'; import { useRecoilState, useRecoilValue } from 'recoil'; -import { memo, useCallback, useRef, useMemo, useState, useEffect } from 'react'; import { supportsFiles, mergeFileConfig, isAssistantsEndpoint, fileConfig as defaultFileConfig, } from 'librechat-data-provider'; -import { useChatContext, useAssistantsMapContext } from '~/Providers'; +import { useChatContext, useAssistantsMapContext, useChatFormContext } from '~/Providers'; +import { useRequiresKey, useTextarea, useSubmitMessage } from '~/hooks'; import { useAutoSave } from '~/hooks/Input/useAutoSave'; -import { useRequiresKey, useTextarea } from '~/hooks'; import { TextareaAutosize } from '~/components/ui'; import { useGetFileConfig } from '~/data-provider'; import { cn, removeFocusRings } from '~/utils'; @@ -35,10 +34,6 @@ const ChatForm = ({ index = 0 }) => { ); const { requiresKey } = useRequiresKey(); - const methods = useForm<{ text: string }>({ - defaultValues: { text: '' }, - }); - const { handlePaste, handleKeyDown, handleKeyUp, handleCompositionStart, handleCompositionEnd } = useTextarea({ textAreaRef, @@ -47,7 +42,6 @@ const ChatForm = ({ index = 0 }) => { }); const { - ask, files, setFiles, conversation, @@ -56,28 +50,17 @@ const ChatForm = ({ index = 0 }) => { setFilesLoading, handleStopGenerating, } = useChatContext(); + const methods = useChatFormContext(); const { clearDraft } = useAutoSave({ conversationId: useMemo(() => conversation?.conversationId, [conversation]), textAreaRef, - setValue: methods.setValue, files, setFiles, }); const assistantMap = useAssistantsMapContext(); - - const submitMessage = useCallback( - (data?: { text: string }) => { - if (!data) { - return console.warn('No data provided to submitMessage'); - } - ask({ text: data.text }); - methods.reset(); - clearDraft(); - }, - [ask, methods, clearDraft], - ); + const { submitMessage } = useSubmitMessage({ clearDraft }); const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null }; const endpoint = endpointType ?? _endpoint; diff --git a/client/src/components/Chat/Input/Files/FileUpload.tsx b/client/src/components/Chat/Input/Files/FileUpload.tsx index a16d953ece9..506f50c01de 100644 --- a/client/src/components/Chat/Input/Files/FileUpload.tsx +++ b/client/src/components/Chat/Input/Files/FileUpload.tsx @@ -66,7 +66,7 @@ const FileUpload: React.FC = ({