diff --git a/electron/main/vector-database/ipcHandlers.ts b/electron/main/vector-database/ipcHandlers.ts index b8154ceb..b0619441 100644 --- a/electron/main/vector-database/ipcHandlers.ts +++ b/electron/main/vector-database/ipcHandlers.ts @@ -10,7 +10,7 @@ import { StoreSchema } from '../electron-store/storeConfig' import { startWatchingDirectory, updateFileListForRenderer } from '../filesystem/filesystem' import { rerankSearchedEmbeddings } from './embeddings' -import { DBEntry, DatabaseFields } from './schema' +import { DBEntry, DatabaseFields, DBQueryResult } from './schema' import { RepopulateTableWithMissingItems } from './tableHelperFunctions' export interface PromptWithRagResults { @@ -37,6 +37,24 @@ export const registerDBSessionHandlers = (store: Store, _windowMana return searchResults }) + ipcMain.handle( + 'multi-modal-search', + async ( + event, + query: string, + limit: number, + searchType: 'vector' | 'text' | 'hybrid', + filter?: string, + ): Promise<{ vectorResults: DBQueryResult[]; textResults: DBQueryResult[] }> => { + const windowInfo = windowManager.getWindowInfoForContents(event.sender) + if (!windowInfo) { + throw new Error('Window info not found.') + } + const searchResults = await windowInfo.dbTableClient.multiModalSearch(query, limit, searchType, filter) + return searchResults + }, + ) + ipcMain.handle('index-files-in-directory', async (event) => { const windowInfo = windowManager.getWindowInfoForContents(event.sender) if (!windowInfo) { diff --git a/electron/main/vector-database/lanceTableWrapper.ts b/electron/main/vector-database/lanceTableWrapper.ts index b57d17e2..c18b78cb 100644 --- a/electron/main/vector-database/lanceTableWrapper.ts +++ b/electron/main/vector-database/lanceTableWrapper.ts @@ -118,6 +118,40 @@ class LanceDBTableWrapper { const mapped = rawResults.map(convertRecordToDBType) return mapped as DBEntry[] } + + async multiModalSearch( + query: string, + limit: number, + searchType: 'vector' | 'text' | 'hybrid' = 'vector', + filter?: string, + ): Promise<{ vectorResults: DBQueryResult[]; textResults: DBQueryResult[] }> { + let vectorResults: DBQueryResult[] = [] + let textResults: DBQueryResult[] = [] + + if (searchType === 'vector' || searchType === 'hybrid') { + const vectorQuery = await this.lanceTable.search(query).metricType(MetricType.Cosine).limit(limit) + if (filter) { + vectorQuery.prefilter(true).filter(filter) + } + const rawVectorResults = await vectorQuery.execute() + vectorResults = rawVectorResults + .map(convertRecordToDBType) + .filter((r): r is DBQueryResult => r !== null) + } + + if (searchType === 'text' || searchType === 'hybrid') { + const sanitizedTextQuery = sanitizePathForDatabase(query) + const textFilter = filter + ? `${filter} AND content LIKE '%${sanitizedTextQuery}%'` + : `content LIKE '%${sanitizedTextQuery}%'` + const rawTextResults = await this.lanceTable.filter(textFilter).limit(limit).execute() + textResults = rawTextResults + .map(convertRecordToDBType) + .filter((r): r is DBQueryResult => r !== null) + } + + return { vectorResults, textResults } + } } export default LanceDBTableWrapper diff --git a/electron/preload/index.ts b/electron/preload/index.ts index 383eb77c..e43a1265 100644 --- a/electron/preload/index.ts +++ b/electron/preload/index.ts @@ -22,6 +22,15 @@ function createIPCHandler any>(channel: string): I const database = { search: createIPCHandler<(query: string, limit: number, filter?: string) => Promise>('search'), + multiModalSearch: + createIPCHandler< + ( + query: string, + limit: number, + searchType: 'vector' | 'text' | 'hybrid', + filter?: string, + ) => Promise<{ vectorResults: DBQueryResult[]; textResults: DBQueryResult[] }> + >('multi-modal-search'), deleteLanceDBEntriesByFilePath: createIPCHandler<(filePath: string) => Promise>( 'delete-lance-db-entries-by-filepath', ), diff --git a/src/components/File/DBResultPreview.tsx b/src/components/File/DBResultPreview.tsx index 0f3f5d78..edba8602 100644 --- a/src/components/File/DBResultPreview.tsx +++ b/src/components/File/DBResultPreview.tsx @@ -62,10 +62,15 @@ export const DBSearchPreview: React.FC = ({ dbResult: entr
- {fileName && {fileName} } | Similarity:{' '} + {fileName && {fileName} } {/* eslint-disable-next-line no-underscore-dangle */} - {cosineDistanceToPercentage(entry._distance)}% |{' '} - {modified && Modified {modified}} + {entry._distance != null && ( + <> + {/* eslint-disable-next-line no-underscore-dangle */}| Similarity:{' '} + {cosineDistanceToPercentage(entry._distance)}%{' '} + + )}{' '} + | {modified && Modified {modified}}
) diff --git a/src/components/MainPage.tsx b/src/components/MainPage.tsx index 64bf019e..dae977cd 100644 --- a/src/components/MainPage.tsx +++ b/src/components/MainPage.tsx @@ -52,7 +52,7 @@ const MainPageContent: React.FC = () => { /> - +
diff --git a/src/components/Sidebars/MainSidebar.tsx b/src/components/Sidebars/MainSidebar.tsx index b919c428..544f7fc4 100644 --- a/src/components/Sidebars/MainSidebar.tsx +++ b/src/components/Sidebars/MainSidebar.tsx @@ -14,7 +14,10 @@ const SidebarManager: React.FC = () => { const { sidebarShowing } = useChatContext() const [searchQuery, setSearchQuery] = useState('') - const [searchResults, setSearchResults] = useState([]) + const [searchResults, setSearchResults] = useState<{ vectorResults: DBQueryResult[]; textResults: DBQueryResult[] }>({ + vectorResults: [], + textResults: [], + }) return (
diff --git a/src/components/Sidebars/SearchComponent.tsx b/src/components/Sidebars/SearchComponent.tsx index 5dfed0d1..c0b4896d 100644 --- a/src/components/Sidebars/SearchComponent.tsx +++ b/src/components/Sidebars/SearchComponent.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useRef, useCallback } from 'react' +import React, { useEffect, useRef, useCallback, useState } from 'react' import { DBQueryResult } from 'electron/main/vector-database/schema' import posthog from 'posthog-js' import { FaSearch } from 'react-icons/fa' @@ -6,11 +6,13 @@ import { debounce } from 'lodash' import { DBSearchPreview } from '../File/DBResultPreview' import { useContentContext } from '@/contexts/ContentContext' +type SearchType = 'vector' | 'text' | 'hybrid' + interface SearchComponentProps { searchQuery: string setSearchQuery: (query: string) => void - searchResults: DBQueryResult[] - setSearchResults: (results: DBQueryResult[]) => void + searchResults: { vectorResults: DBQueryResult[]; textResults: DBQueryResult[] } + setSearchResults: (results: { vectorResults: DBQueryResult[]; textResults: DBQueryResult[] }) => void } const SearchComponent: React.FC = ({ @@ -21,13 +23,14 @@ const SearchComponent: React.FC = ({ }) => { const { openContent: openTabContent } = useContentContext() const searchInputRef = useRef(null) + const [searchType, setSearchType] = useState('vector') const handleSearch = useCallback( async (query: string) => { - const results: DBQueryResult[] = await window.database.search(query, 50) + const results = await window.database.multiModalSearch(query, 50, searchType) setSearchResults(results) }, - [setSearchResults], + [setSearchResults, searchType], ) const debouncedSearch = useCallback( @@ -46,7 +49,7 @@ const SearchComponent: React.FC = ({ if (searchQuery) { debouncedSearch(searchQuery) } - }, [searchQuery, debouncedSearch]) + }, [searchQuery, debouncedSearch, searchType]) const openFileSelectSearch = useCallback( (path: string) => { @@ -70,13 +73,32 @@ const SearchComponent: React.FC = ({ onChange={(e) => setSearchQuery(e.target.value)} placeholder="Semantic search..." /> +
- {searchResults.length > 0 && ( + {searchResults?.textResults?.length > 0 && ( +
+

Text Search Results

+ {searchResults.textResults.map((result, index) => ( + // eslint-disable-next-line react/no-array-index-key + + ))} +
+ )} + {searchResults?.vectorResults?.length > 0 && (
- {searchResults.map((result, index) => ( +

Vector Search Results

+ {searchResults.vectorResults.map((result, index) => ( // eslint-disable-next-line react/no-array-index-key - + ))}
)}