Skip to content

Commit

Permalink
community[patch]: Add support for inner product and Euclidean distanc…
Browse files Browse the repository at this point in the history
…e to `PGVector` (langchain-ai#4781)

* Add support for inner product and euclidean distance to PGVector.

* Fix formatting errors.

* Rename SupportedVectorTypes to DistanceStrategy. Update PGVector docs.
  • Loading branch information
andrewnguonly authored Mar 15, 2024
1 parent 231e475 commit fdbc858
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { OpenAIEmbeddings } from "@langchain/openai";
import { PGVectorStore } from "@langchain/community/vectorstores/pgvector";
import {
DistanceStrategy,
PGVectorStore,
} from "@langchain/community/vectorstores/pgvector";
import { PoolConfig } from "pg";

// First, follow set-up instructions at
Expand All @@ -21,6 +24,8 @@ const config = {
contentColumnName: "content",
metadataColumnName: "metadata",
},
// supported distance strategies: cosine (default), innerProduct, or euclidean
distanceStrategy: "cosine" as DistanceStrategy,
};

const pgvectorStore = await PGVectorStore.initialize(
Expand Down
34 changes: 28 additions & 6 deletions libs/langchain-community/src/vectorstores/pgvector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { getEnvironmentVariable } from "@langchain/core/utils/env";

type Metadata = Record<string, unknown>;

export type DistanceStrategy = "cosine" | "innerProduct" | "euclidean";

/**
* Interface that defines the arguments required to create a
* `PGVectorStore` instance. It includes Postgres connection options,
Expand Down Expand Up @@ -35,6 +37,7 @@ export interface PGVectorStoreArgs {
*/
chunkSize?: number;
ids?: string[];
distanceStrategy?: DistanceStrategy;
}

/**
Expand Down Expand Up @@ -76,6 +79,8 @@ export class PGVectorStore extends VectorStore {

chunkSize = 500;

distanceStrategy?: DistanceStrategy = "cosine";

_vectorstoreType(): string {
return "pgvector";
}
Expand Down Expand Up @@ -112,6 +117,7 @@ export class PGVectorStore extends VectorStore {
const pool = config.pool ?? new pg.Pool(config.postgresConnectionOptions);
this.pool = pool;
this.chunkSize = config.chunkSize ?? 500;
this.distanceStrategy = config.distanceStrategy ?? this.distanceStrategy;

this._verbose =
getEnvironmentVariable("LANGCHAIN_VERBOSE") === "true" ??
Expand All @@ -130,6 +136,27 @@ export class PGVectorStore extends VectorStore {
: `"${this.schemaName}"."${this.collectionTableName}"`;
}

get computedOperatorString() {
let operator: string;
switch (this.distanceStrategy) {
case "cosine":
operator = "<=>";
break;
case "innerProduct":
operator = "<#>";
break;
case "euclidean":
operator = "<->";
break;
default:
throw new Error(`Unknown distance strategy: ${this.distanceStrategy}`);
}

return this.extensionSchemaName !== null
? `OPERATOR(${this.extensionSchemaName}.${operator})`
: operator;
}

/**
* Static method to create a new `PGVectorStore` instance from a
* connection. It creates a table if one does not exist, and calls
Expand Down Expand Up @@ -487,13 +514,8 @@ export class PGVectorStore extends VectorStore {
? `WHERE ${whereClauses.join(" AND ")}`
: "";

const operatorString =
this.extensionSchemaName !== null
? `OPERATOR(${this.extensionSchemaName}.<=>)`
: "<=>";

const queryString = `
SELECT *, "${this.vectorColumnName}" ${operatorString} $1 as "_distance"
SELECT *, "${this.vectorColumnName}" ${this.computedOperatorString} $1 as "_distance"
FROM ${this.computedTableName}
${whereClause}
ORDER BY "_distance" ASC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,37 @@ describe("PGVectorStore", () => {
throw e;
}
});

test("PGvector supports different vector types", async () => {
// verify by asserting different pgvector operators based on vector type
pgvectorVectorStore.distanceStrategy = "cosine";
expect(pgvectorVectorStore.computedOperatorString).toEqual("<=>");

pgvectorVectorStore.distanceStrategy = "innerProduct";
expect(pgvectorVectorStore.computedOperatorString).toEqual("<#>");

pgvectorVectorStore.distanceStrategy = "euclidean";
expect(pgvectorVectorStore.computedOperatorString).toEqual("<->");

// verify with extensionSchemaName
pgvectorVectorStore.distanceStrategy = "cosine";
pgvectorVectorStore.extensionSchemaName = "schema1";
expect(pgvectorVectorStore.computedOperatorString).toEqual(
"OPERATOR(schema1.<=>)"
);

pgvectorVectorStore.distanceStrategy = "innerProduct";
pgvectorVectorStore.extensionSchemaName = "schema2";
expect(pgvectorVectorStore.computedOperatorString).toEqual(
"OPERATOR(schema2.<#>)"
);

pgvectorVectorStore.distanceStrategy = "euclidean";
pgvectorVectorStore.extensionSchemaName = "schema3";
expect(pgvectorVectorStore.computedOperatorString).toEqual(
"OPERATOR(schema3.<->)"
);
});
});

describe.skip("PGVectorStore with collection", () => {
Expand Down

0 comments on commit fdbc858

Please sign in to comment.