Skip to content

Commit

Permalink
fix: fixes embeddings generation via plugin-embeddings (#849) (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva authored Dec 10, 2024
1 parent 32f7f71 commit 53546f8
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 56 deletions.
1 change: 0 additions & 1 deletion packages/orama/src/components/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ export function runBeforeSearch<T extends AnyOrama>(
language: string | undefined
): Promise<void> | void {
const needAsync = hooks.some(isAsyncFunction)

if (needAsync) {
return (async () => {
for (const hook of hooks) {
Expand Down
1 change: 1 addition & 0 deletions packages/orama/src/methods/search-hybrid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export function hybridSearch<T extends AnyOrama, ResultDocument = TypedDocument<
}

const asyncNeeded = orama.beforeSearch?.length || orama.afterSearch?.length

if (asyncNeeded) {
return executeSearchAsync()
}
Expand Down
2 changes: 2 additions & 0 deletions packages/orama/src/methods/search-vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export function innerVectorSearch<T extends AnyOrama, ResultDocument = TypedDocu

const vectorIndex = orama.data.index.vectorIndexes[vector!.property]
const vectorSize = vectorIndex.node.size

if (vector?.value.length !== vectorSize) {
if (vector?.property === undefined || vector?.value.length === undefined) {
throw createError('INVALID_INPUT_VECTOR', 'undefined', vectorSize, 'undefined')
Expand Down Expand Up @@ -121,6 +122,7 @@ export function searchVector<T extends AnyOrama, ResultDocument = TypedDocument<
}

const asyncNeeded = orama.beforeSearch?.length || orama.afterSearch?.length

if (asyncNeeded) {
return executeSearchAsync()
}
Expand Down
1 change: 1 addition & 0 deletions packages/orama/src/trees/vector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export function findSimilarVectors(
const similarVectors: SimilarVector[] = []

const base = keys ? keys : vectors.keys()

for (const vectorId of base) {
const entry = vectors.get(vectorId)
if (!entry) {
Expand Down
2 changes: 1 addition & 1 deletion packages/orama/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ export function isAsyncFunction(func: any): boolean {
return func?.constructor?.name === 'AsyncFunction'
}


const withIntersection = 'intersection' in (new Set());

export function setIntersection<V>(...sets: Set<V>[]): Set<V> {
// Fast path 1
if (sets.length === 0) {
Expand Down
68 changes: 45 additions & 23 deletions packages/orama/tests/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import t from 'tap'
import { formatBytes, formatNanoseconds, getOwnProperty, getNested, flattenObject, setUnion, setIntersection } from '../src/utils.js'
import { formatBytes, formatNanoseconds, getOwnProperty, getNested, flattenObject, setUnion, setIntersection, isAsyncFunction } from '../src/utils.js'

t.test('utils', async (t) => {
t.test('should correctly format bytes', async (t) => {
t.equal(await formatBytes(0), '0 Bytes')
t.equal(await formatBytes(1), '1 Bytes')
t.equal(await formatBytes(1024), '1 KB')
t.equal(await formatBytes(1024 ** 2), '1 MB')
t.equal(await formatBytes(1024 ** 3), '1 GB')
t.equal(await formatBytes(1024 ** 4), '1 TB')
t.equal(await formatBytes(1024 ** 5), '1 PB')
t.equal(await formatBytes(1024 ** 6), '1 EB')
t.equal(await formatBytes(1024 ** 7), '1 ZB')
t.equal(formatBytes(0), '0 Bytes')
t.equal(formatBytes(1), '1 Bytes')
t.equal(formatBytes(1024), '1 KB')
t.equal(formatBytes(1024 ** 2), '1 MB')
t.equal(formatBytes(1024 ** 3), '1 GB')
t.equal(formatBytes(1024 ** 4), '1 TB')
t.equal(formatBytes(1024 ** 5), '1 PB')
t.equal(formatBytes(1024 ** 6), '1 EB')
t.equal(formatBytes(1024 ** 7), '1 ZB')
})

t.test('should correctly format nanoseconds', async (t) => {
t.equal(await formatNanoseconds(1n), '1ns')
t.equal(await formatNanoseconds(10n), '10ns')
t.equal(await formatNanoseconds(100n), '100ns')
t.equal(await formatNanoseconds(1_000n), '1μs')
t.equal(await formatNanoseconds(10_000n), '10μs')
t.equal(await formatNanoseconds(100_000n), '100μs')
t.equal(await formatNanoseconds(1_000_000n), '1ms')
t.equal(await formatNanoseconds(10_000_000n), '10ms')
t.equal(await formatNanoseconds(100_000_000n), '100ms')
t.equal(await formatNanoseconds(1000_000_000n), '1s')
t.equal(await formatNanoseconds(10_000_000_000n), '10s')
t.equal(await formatNanoseconds(100_000_000_000n), '100s')
t.equal(await formatNanoseconds(1000_000_000_000n), '1000s')
t.equal(formatNanoseconds(1n), '1ns')
t.equal(formatNanoseconds(10n), '10ns')
t.equal(formatNanoseconds(100n), '100ns')
t.equal(formatNanoseconds(1_000n), '1μs')
t.equal(formatNanoseconds(10_000n), '10μs')
t.equal(formatNanoseconds(100_000n), '100μs')
t.equal(formatNanoseconds(1_000_000n), '1ms')
t.equal(formatNanoseconds(10_000_000n), '10ms')
t.equal(formatNanoseconds(100_000_000n), '100ms')
t.equal(formatNanoseconds(1000_000_000n), '1s')
t.equal(formatNanoseconds(10_000_000_000n), '10s')
t.equal(formatNanoseconds(100_000_000_000n), '100s')
t.equal(formatNanoseconds(1000_000_000_000n), '1000s')
})

t.test('should check object properties', async (t) => {
Expand Down Expand Up @@ -95,6 +95,28 @@ t.test('utils', async (t) => {
t.equal((flattened as Record<string, string>).foo, 'bar')
t.equal(flattened['nested.nested2.nested3.bar'], 'baz')
})

// This test is skipped because the implementation of isAsyncFunction is temporary and will be
// removed in a future version of Orama.
t.skip('should correctly detect an async function', t => {
async function asyncFunction() {
return 'async'
}

function returnPromise() {
return new Promise((resolve) => {
resolve('promise')
})
}

function syncFunction() {
return 'sync'
}

t.equal(isAsyncFunction(asyncFunction), true)
t.equal(isAsyncFunction(returnPromise), false) // Returing a promise is not async, JS cannot detect it as async
t.equal(isAsyncFunction(syncFunction), false)
})
})

t.test('setUnion', async t => {
Expand Down
21 changes: 15 additions & 6 deletions packages/plugin-embeddings/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ function getPropertiesValues(schema: object, properties: string[]) {
.join('. ')
}

function normalizeVector(v: number[]): number[] {
const norm = Math.sqrt(v.reduce((sum, val) => sum + val * val, 0));
return v.map(val => val / norm);
}

export const embeddingsType = 'vector[512]'

export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Promise<OramaPluginAsync> {
Expand All @@ -49,9 +54,9 @@ export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Pr
console.log(`Generating embeddings for properties "${properties.join(', ')}": "${values}"`)
}

const embeddings = await model.embed(values)
const embeddings = Array.from(await (await model.embed(values)).data())

params[pluginParams.embeddings.defaultProperty] = (await embeddings.data()) as unknown as number[]
params[pluginParams.embeddings.defaultProperty] = normalizeVector(embeddings)
},

async beforeSearch<T extends AnyOrama>(_db: AnyOrama, params: SearchParams<T, TypedDocument<any>>) {
Expand All @@ -64,21 +69,25 @@ export async function pluginEmbeddings(pluginParams: PluginEmbeddingsParams): Pr
}

if (!params.term) {
throw new Error('Neither "term" nor "vector" parameters were provided')
throw new Error('No "term" or "vector" parameters were provided')
}

const embeddings = await model.embed(params.term) as unknown as number[]
const embeddings = Array.from(await (await model.embed(params.term)).data()) as unknown as number[]

if (!params.vector) {
params.vector = {
// eslint-disable-next-line
// @ts-ignore
property: params?.vector?.property ?? pluginParams.embeddings.defaultProperty,
value: embeddings
value: normalizeVector(embeddings)
}
}

console.log({
vector: normalizeVector(embeddings)
})

params.vector.value = embeddings
params.vector.value = normalizeVector(embeddings)
}
}
}
Loading

0 comments on commit 53546f8

Please sign in to comment.