diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index df5dc6923..a1e243b9b 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -52,11 +52,6 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; -import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.StoredField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; @@ -83,6 +78,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.CheckJoinIndex; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.MMapDirectory; @@ -112,6 +108,11 @@ public class KnnGraphTester { public static final String KNN_FIELD = "knn"; public static final String ID_FIELD = "id"; private static final String INDEX_DIR = "knnIndices"; + public static final String DOCTYPE_FIELD = "docType"; + public static final String DOCTYPE_PARENT = "_parent"; + public static final String DOCTYPE_CHILD = "_child"; + public static final String WIKI_ID_FIELD = "wikiID"; + public static final String WIKI_PARA_ID_FIELD = "wikiParaID"; private int numDocs; private int dim; @@ -140,6 +141,8 @@ public class KnnGraphTester { private float selectivity; private boolean prefilter; private boolean randomCommits; + private boolean parentJoin = false; + private Path parentJoinMetaFile; private KnnGraphTester() { // set defaults @@ -213,30 +216,35 @@ private void run(String... args) throws Exception { throw new IllegalArgumentException("-beamWidthIndex requires a following number"); } beamWidth = Integer.parseInt(args[++iarg]); + log("beamWidth = %d", beamWidth); break; case "-maxConn": if (iarg == args.length - 1) { throw new IllegalArgumentException("-maxConn requires a following number"); } maxConn = Integer.parseInt(args[++iarg]); + log("maxConn = %d", maxConn); break; case "-dim": if (iarg == args.length - 1) { throw new IllegalArgumentException("-dim requires a following number"); } dim = Integer.parseInt(args[++iarg]); + log("Vector Dimensions: %d", dim); break; case "-ndoc": if (iarg == args.length - 1) { throw new IllegalArgumentException("-ndoc requires a following number"); } numDocs = Integer.parseInt(args[++iarg]); + log("numDocs = %d", numDocs); break; case "-niter": if (iarg == args.length - 1) { throw new IllegalArgumentException("-niter requires a following number"); } numIters = Integer.parseInt(args[++iarg]); + log("numIters = %d", numIters); break; case "-reindex": reindex = true; @@ -302,6 +310,7 @@ private void run(String... args) throws Exception { default: throw new IllegalArgumentException("-metric can be 'mip', 'cosine', 'euclidean', 'angular' (or 'dot_product' -- same as 'angular') only; got: " + metric); } + log("similarity = %s", similarityFunction); break; case "-forceMerge": forceMerge = true; @@ -337,6 +346,13 @@ private void run(String... args) throws Exception { throw new IllegalArgumentException("-numMergeThread should be >= 1"); } break; + case "-parentJoin": + if (iarg == args.length - 1) { + throw new IllegalArgumentException("-parentJoin requires a following Path for parentJoinMetaFile"); + } + parentJoinMetaFile = Paths.get(args[++iarg]); + parentJoin = true; + break; default: throw new IllegalArgumentException("unknown argument " + arg); // usage(); @@ -350,6 +366,11 @@ private void run(String... args) throws Exception { } if (indexPath == null) { indexPath = Paths.get(formatIndexPath(docVectorsPath)); // derive index path + log("Index Path = %s", indexPath); + } + if (parentJoin && reindex == false && isParentJoinIndex(indexPath) == false) { + throw new IllegalArgumentException("Provided index: [" + indexPath + "] does not have parent-child " + + "document relationships. Rerun with -reindex or without -parentJoin argument"); } if (reindex) { if (docVectorsPath == null) { @@ -364,7 +385,9 @@ private void run(String... args) throws Exception { similarityFunction, numDocs, 0, - quiet + quiet, + parentJoin, + parentJoinMetaFile ).createIndex(); System.out.println(String.format("reindex takes %.2f sec", msToSec(reindexTimeMsec))); } @@ -409,11 +432,23 @@ private void run(String... args) throws Exception { } private String formatIndexPath(Path docsPath) { + List suffix = new ArrayList<>(); + suffix.add(Integer.toString(maxConn)); + suffix.add(Integer.toString(beamWidth)); if (quantize) { - return INDEX_DIR + "/" + docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + "-" - + quantizeBits + (quantizeCompress ? "-compressed" : "" ) + ".index"; + suffix.add(Integer.toString(quantizeBits)); + if (quantizeCompress == true) { + suffix.add("-compressed"); + } + } + if (parentJoin) { + suffix.add("parentJoin"); } - return INDEX_DIR + "/" + docsPath.getFileName() + "-" + maxConn + "-" + beamWidth + ".index"; + return INDEX_DIR + "/" + docsPath.getFileName() + "-" + String.join("-", suffix) + ".index"; + } + + private boolean isParentJoinIndex(Path indexPath) { + return indexPath.toString().contains("parentJoin"); } @SuppressForbidden(reason = "Prints stuff") @@ -543,6 +578,7 @@ private void printHist(int[] hist, int max, int count, int nbuckets) { private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] nn) throws IOException { TopDocs[] results = new TopDocs[numIters]; + int[][] resultIds = new int[numIters][]; long elapsed, totalCpuTimeMS, totalVisited = 0; ExecutorService executorService = Executors.newFixedThreadPool(8); try (FileChannel input = FileChannel.open(queryPath)) { @@ -551,16 +587,13 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] if (targetReader instanceof VectorReaderByte b) { targetReaderByte = b; } - if (quiet == false) { - System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); - } + log("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout); long start; ThreadMXBean bean = ManagementFactory.getThreadMXBean(); long cpuTimeStartNs; try (MMapDirectory dir = new MMapDirectory(indexPath)) { dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq")); - try ( - DirectoryReader reader = DirectoryReader.open(dir)) { + try (DirectoryReader reader = DirectoryReader.open(dir)) { IndexSearcher searcher = new IndexSearcher(reader); numDocs = reader.maxDoc(); Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null; @@ -576,9 +609,9 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] } else { float[] target = targetReader.next(); if (prefilter) { - doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); + doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery, parentJoin); } else { - doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); + doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null, parentJoin); } } } @@ -596,72 +629,60 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][] } else { float[] target = targetReader.next(); if (prefilter) { - results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery); + results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery, parentJoin); } else { results[i] = doKnnVectorQuery( - searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null); + searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null, parentJoin); + } + if (prefilter == false && matchDocs != null) { + results[i].scoreDocs = + Arrays.stream(results[i].scoreDocs) + .filter(scoreDoc -> matchDocs.get(scoreDoc.doc)) + .toArray(ScoreDoc[]::new); } - } - if (prefilter == false && matchDocs != null) { - results[i].scoreDocs = - Arrays.stream(results[i].scoreDocs) - .filter(scoreDoc -> matchDocs.get(scoreDoc.doc)) - .toArray(ScoreDoc[]::new); } } totalCpuTimeMS = TimeUnit.NANOSECONDS.toMillis(bean.getCurrentThreadCpuTime() - cpuTimeStartNs); elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // ns -> ms + + // Fetch, validate and write result document ids. StoredFields storedFields = reader.storedFields(); for (int i = 0; i < numIters; i++) { totalVisited += results[i].totalHits.value(); - for (ScoreDoc doc : results[i].scoreDocs) { - if (doc.doc != NO_MORE_DOCS) { - // there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens - // in some degenerate case (like input query has NaN in it?) that causes no results to - // be returned from HNSW search? - doc.doc = Integer.parseInt(storedFields.document(doc.doc).get("id")); - } else { - System.out.println("NO_MORE_DOCS!"); - } - } + resultIds[i] = KnnTesterUtils.getResultIds(results[i], storedFields); + } + if (quiet == false) { + System.out.println( + "completed " + + numIters + + " searches in " + + elapsed + + " ms: " + + ((1000 * numIters) / elapsed) + + " QPS " + + "CPU time=" + + totalCpuTimeMS + + "ms"); } - } - if (quiet == false) { - System.out.println( - "completed " - + numIters - + " searches in " - + elapsed - + " ms: " - + ((1000 * numIters) / elapsed) - + " QPS " - + "CPU time=" - + totalCpuTimeMS - + "ms"); } } } finally { executorService.shutdown(); } if (outputPath != null) { - ByteBuffer buf = ByteBuffer.allocate(4); - IntBuffer ibuf = buf.order(ByteOrder.LITTLE_ENDIAN).asIntBuffer(); + ByteBuffer tmp = + ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); try (OutputStream out = Files.newOutputStream(outputPath)) { for (int i = 0; i < numIters; i++) { - for (ScoreDoc doc : results[i].scoreDocs) { - ibuf.position(0); - ibuf.put(doc.doc); - out.write(buf.array()); - } + tmp.asIntBuffer().put(nn[i]); + out.write(tmp.array()); } } } else { - if (quiet == false) { - System.out.println("checking results"); - } - float recall = checkResults(results, nn); + log("checking results"); + float recall = checkResults(resultIds, nn); totalVisited /= numIters; String quantizeDesc; if (quantize) { @@ -707,14 +728,18 @@ private static TopDocs doKnnByteVectorQuery( } private static TopDocs doKnnVectorQuery( - IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter) + IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter, boolean isParentJoinQuery) throws IOException { + if (isParentJoinQuery) { + ParentJoinBenchmarkQuery parentJoinQuery = new ParentJoinBenchmarkQuery(vector, null, k); + return searcher.search(parentJoinQuery, k); + } ProfiledKnnFloatVectorQuery profiledQuery = new ProfiledKnnFloatVectorQuery(field, vector, k, fanout, filter); TopDocs docs = searcher.search(profiledQuery, k); return new TopDocs(new TotalHits(profiledQuery.totalVectorCount(), docs.totalHits.relation()), docs.scoreDocs); } - private float checkResults(TopDocs[] results, int[][] nn) { + private float checkResults(int[][] results, int[][] nn) { int totalMatches = 0; int totalResults = results.length * topK; for (int i = 0; i < results.length; i++) { @@ -725,23 +750,28 @@ private float checkResults(TopDocs[] results, int[][] nn) { return totalMatches / (float) totalResults; } - private int compareNN(int[] expected, TopDocs results) { + private int compareNN(int[] expected, int[] results) { int matched = 0; Set expectedSet = new HashSet<>(); for (int i = 0; i < topK; i++) { expectedSet.add(expected[i]); } - for (ScoreDoc scoreDoc : results.scoreDocs) { - if (expectedSet.contains(scoreDoc.doc)) { + for (int docId : results) { + if (expectedSet.contains(docId)) { ++matched; } } return matched; } - private int[][] getNN(Path docPath, Path queryPath) throws IOException { + /** Returns the topK nearest neighbors for each target query. + * + * The method runs "numIters" target queries and returns "topK" nearest neighbors + * for each of them. Nearest Neighbors are computed using exact match. + */ + private int[][] getNN(Path docPath, Path queryPath) throws IOException, InterruptedException { // look in working directory for cached nn file - String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK, similarityFunction.ordinal()), 36); + String hash = Integer.toString(Objects.hash(docPath, queryPath, numDocs, numIters, topK, similarityFunction.ordinal(), parentJoin), 36); String nnFileName = "nn-" + hash + ".bin"; Path nnPath = Paths.get(nnFileName); if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) { @@ -814,7 +844,7 @@ private static FixedBitSet generateRandomBitSet(int size, float selectivity) { return bitSet; } - private int[][] computeNNByte(Path docPath, Path queryPath) throws IOException { + private int[][] computeNNByte(Path docPath, Path queryPath) throws IOException, InterruptedException { int[][] result = new int[numIters][]; if (quiet == false) { System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); @@ -876,20 +906,36 @@ public Void call() { /** Brute force computation of "true" nearest neighhbors. */ private int[][] computeNN(Path docPath, Path queryPath) - throws IOException { + throws IOException, InterruptedException { int[][] result = new int[numIters][]; - if (quiet == false) { - System.out.println("computing true nearest neighbors of " + numIters + " target vectors"); - } - List tasks = new ArrayList<>(); - try (FileChannel qIn = FileChannel.open(queryPath)) { - VectorReader queryReader = (VectorReader) VectorReader.create(qIn, dim, VectorEncoding.FLOAT32); - for (int i = 0; i < numIters; i++) { - float[] query = queryReader.next().clone(); - tasks.add(new ComputeNNFloatTask(i, query, docPath, result)); + log("computing true nearest neighbors of " + numIters + " target vectors"); + log("parentJoin = %s", parentJoin); + if (parentJoin) { + try (Directory dir = FSDirectory.open(indexPath); + DirectoryReader reader = DirectoryReader.open(dir)) { + CheckJoinIndex.check(reader, ParentJoinBenchmarkQuery.parentsFilter); + List tasks = new ArrayList<>(); + try (FileChannel qIn = FileChannel.open(queryPath)) { + VectorReader queryReader = (VectorReader) VectorReader.create(qIn, dim, VectorEncoding.FLOAT32); + for (int i = 0; i < numIters; i++) { + float[] query = queryReader.next().clone(); + tasks.add(new ComputeExactSearchNNFloatTask(i, query, docPath, result, reader)); + } + } + ForkJoinPool.commonPool().invokeAll(tasks); } + } else { + // TODO: Use exactSearch here? + List tasks = new ArrayList<>(); + try (FileChannel qIn = FileChannel.open(queryPath)) { + VectorReader queryReader = (VectorReader) VectorReader.create(qIn, dim, VectorEncoding.FLOAT32); + for (int i = 0; i < numIters; i++) { + float[] query = queryReader.next().clone(); + tasks.add(new ComputeNNFloatTask(i, query, docPath, result)); + } + } + ForkJoinPool.commonPool().invokeAll(tasks); } - ForkJoinPool.commonPool().invokeAll(tasks); return result; } @@ -909,30 +955,69 @@ class ComputeNNFloatTask implements Callable { @Override public Void call() { - NeighborQueue queue = new NeighborQueue(topK, false); - try (FileChannel in = FileChannel.open(docPath)) { - VectorReader docReader = (VectorReader) VectorReader.create(in, dim, VectorEncoding.FLOAT32); - for (int j = 0; j < numDocs; j++) { - float[] doc = docReader.next(); - float d = similarityFunction.compare(query, doc); - if (matchDocs == null || matchDocs.get(j)) { - queue.insertWithOverflow(j, d); - } - } - docReader.reset(); - result[queryOrd] = new int[topK]; - for (int k = topK - 1; k >= 0; k--) { - result[queryOrd][k] = queue.topNode(); - queue.pop(); - } - if (quiet == false && (queryOrd + 1) % 10 == 0) { - System.out.print(" " + (queryOrd + 1)); - System.out.flush(); + NeighborQueue queue = new NeighborQueue(topK, false); + try (FileChannel in = FileChannel.open(docPath)) { + VectorReader docReader = (VectorReader) VectorReader.create(in, dim, VectorEncoding.FLOAT32); + for (int j = 0; j < numDocs; j++) { + float[] doc = docReader.next(); + float d = similarityFunction.compare(query, doc); + if (matchDocs == null || matchDocs.get(j)) { + queue.insertWithOverflow(j, d); } - } catch (IOException e) { - throw new RuntimeException(e); } - return null; + docReader.reset(); + result[queryOrd] = new int[topK]; + for (int k = topK - 1; k >= 0; k--) { + result[queryOrd][k] = queue.topNode(); + queue.pop(); + } + if ((queryOrd + 1) % 10 == 0) { + log(" " + (queryOrd + 1)); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + return null; + } + } + + /** Uses ExactSearch from Lucene queries to compute nearest neighbors. + */ + class ComputeExactSearchNNFloatTask implements Callable { + + private final int queryOrd; + private final float[] query; + private final Path docPath; + private final int[][] result; + private final IndexReader reader; + + ComputeExactSearchNNFloatTask(int queryOrd, float[] query, Path docPath, int[][] result, IndexReader reader) { + this.queryOrd = queryOrd; + this.query = query; + this.docPath = docPath; + this.result = result; + this.reader = reader; + } + + @Override + public Void call() { + // we only use this for ParentJoin benchmarks right now, TODO: extend for all computeNN needs. + try { + ParentJoinBenchmarkQuery parentJoinQuery = new ParentJoinBenchmarkQuery(query, null, topK); + TopDocs topHits = ParentJoinBenchmarkQuery.runExactSearch(reader, parentJoinQuery); + StoredFields storedFields = reader.storedFields(); + result[queryOrd] = KnnTesterUtils.getResultIds(topHits, storedFields); + } catch (IOException e) { + throw new RuntimeException(e); + } + return null; + } + } + + private void log(String msg, Object... args) { + if (quiet == false) { + System.out.printf((msg) + "%n", args); + System.out.flush(); } } diff --git a/src/main/knn/KnnIndexer.java b/src/main/knn/KnnIndexer.java index 02744df5f..336b72d15 100644 --- a/src/main/knn/KnnIndexer.java +++ b/src/main/knn/KnnIndexer.java @@ -17,8 +17,6 @@ package knn; -import knn.KnnGraphTester; -import knn.VectorReader; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.*; import org.apache.lucene.index.IndexWriter; @@ -26,12 +24,19 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.util.PrintStreamInfoStream; +import java.io.BufferedReader; import java.io.IOException; import java.nio.channels.FileChannel; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static knn.KnnGraphTester.DOCTYPE_CHILD; +import static knn.KnnGraphTester.DOCTYPE_PARENT; public class KnnIndexer { // use smaller ram buffer so we get to merging sooner, making better use of @@ -48,9 +53,12 @@ public class KnnIndexer { int numDocs; int docsStartIndex; boolean quiet; + boolean parentJoin; + Path parentJoinMetaPath; public KnnIndexer(Path docsPath, Path indexPath, Codec codec, VectorEncoding vectorEncoding, int dim, - VectorSimilarityFunction similarityFunction, int numDocs, int docsStartIndex, boolean quiet) { + VectorSimilarityFunction similarityFunction, int numDocs, int docsStartIndex, boolean quiet, + boolean parentJoin, Path parentJoinMetaPath) { this.docsPath = docsPath; this.indexPath = indexPath; this.codec = codec; @@ -60,6 +68,8 @@ public KnnIndexer(Path docsPath, Path indexPath, Codec codec, VectorEncoding vec this.numDocs = numDocs; this.docsStartIndex = docsStartIndex; this.quiet = quiet; + this.parentJoin = parentJoin; + this.parentJoinMetaPath = parentJoinMetaPath; } public int createIndex() throws IOException { @@ -76,7 +86,7 @@ public int createIndex() throws IOException { case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction); }; if (quiet == false) { - iwc.setInfoStream(new PrintStreamInfoStream(System.out)); +// iwc.setInfoStream(new PrintStreamInfoStream(System.out)); System.out.println("creating index in " + indexPath); } @@ -92,32 +102,93 @@ public int createIndex() throws IOException { seekToStartDoc(in, dim, vectorEncoding, docsStartIndex); } VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding); - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - switch (vectorEncoding) { - case BYTE -> doc.add( - new KnnByteVectorField( - KnnGraphTester.KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); - case FLOAT32 -> doc.add( - new KnnFloatVectorField(KnnGraphTester.KNN_FIELD, vectorReader.next(), fieldType)); + log("parentJoin=%s", parentJoin); + if (parentJoin == false) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + switch (vectorEncoding) { + case BYTE -> doc.add( + new KnnByteVectorField( + KnnGraphTester.KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); + case FLOAT32 -> doc.add( + new KnnFloatVectorField(KnnGraphTester.KNN_FIELD, vectorReader.next(), fieldType)); + } + doc.add(new StoredField(KnnGraphTester.ID_FIELD, i)); + iw.addDocument(doc); + + if ((i + 1) % 25000 == 0) { + System.out.println("Done indexing " + (i + 1) + " documents."); + } } - doc.add(new StoredField(KnnGraphTester.ID_FIELD, i)); - iw.addDocument(doc); + } else { + // create parent-block join documents + try (BufferedReader br = Files.newBufferedReader(parentJoinMetaPath)) { + String[] headers = br.readLine().trim().split(","); + if (headers.length != 2) { + throw new IllegalStateException("Expected two columns in parentJoinMetadata csv. Found: " + headers.length); + } + log("Parent join metaFile columns: %s | %s", headers[0], headers[1]); + int childDocs = 0; + int parentDocs = 0; + int docIds = 0; + String prevWikiId = "null"; + String currWikiId; + List block = new ArrayList<>(); + do { + String[] line = br.readLine().trim().split(","); + currWikiId = line[0]; + String currParaId = line[1]; + Document doc = new Document(); + switch (vectorEncoding) { + case BYTE -> doc.add( + new KnnByteVectorField( + KnnGraphTester.KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); + case FLOAT32 -> doc.add( + new KnnFloatVectorField(KnnGraphTester.KNN_FIELD, vectorReader.next(), fieldType)); + } + doc.add(new StoredField(KnnGraphTester.ID_FIELD, docIds++)); + doc.add(new StringField(KnnGraphTester.WIKI_ID_FIELD, currWikiId, Field.Store.YES)); + doc.add(new StringField(KnnGraphTester.WIKI_PARA_ID_FIELD, currParaId, Field.Store.YES)); + doc.add(new StringField(KnnGraphTester.DOCTYPE_FIELD, DOCTYPE_CHILD, Field.Store.NO)); + childDocs++; - if ((i+1) % 25000 == 0) { - System.out.println("Done indexing " + (i + 1) + " documents."); + // Close block and create a new one when wiki article changes. + if (!currWikiId.equals(prevWikiId) && !"null".equals(prevWikiId)) { + Document parent = new Document(); + parent.add(new StoredField(KnnGraphTester.ID_FIELD, docIds++)); + parent.add(new StringField(KnnGraphTester.DOCTYPE_FIELD, DOCTYPE_PARENT, Field.Store.NO)); + parent.add(new StringField(KnnGraphTester.WIKI_ID_FIELD, prevWikiId, Field.Store.YES)); + parent.add(new StringField(KnnGraphTester.WIKI_PARA_ID_FIELD, "_", Field.Store.YES)); + block.add(parent); + iw.addDocuments(block); + parentDocs++; + // create new block for the next article + block = new ArrayList<>(); + block.add(doc); + } else { + block.add(doc); + } + prevWikiId = currWikiId; + if (childDocs % 25000 == 0) { + log("indexed %d child documents, with %d parents", childDocs, parentDocs); + } + } while (childDocs < numDocs); + if (!block.isEmpty()) { + Document parent = new Document(); + parent.add(new StoredField(KnnGraphTester.ID_FIELD, docIds++)); + parent.add(new StringField(KnnGraphTester.DOCTYPE_FIELD, DOCTYPE_PARENT, Field.Store.NO)); + parent.add(new StringField(KnnGraphTester.WIKI_ID_FIELD, prevWikiId, Field.Store.YES)); + parent.add(new StringField(KnnGraphTester.WIKI_PARA_ID_FIELD, "_", Field.Store.YES)); + block.add(parent); + iw.addDocuments(block); + } + log("Indexed %d documents with %d parent docs. now flush", childDocs, parentDocs); } } - if (quiet == false) { - System.out.println("Done indexing " + numDocs + " documents; now flush"); - } } } long elapsed = System.nanoTime() - start; - if (quiet == false) { - System.out.println( - "Indexed " + numDocs + " documents in " + TimeUnit.NANOSECONDS.toSeconds(elapsed) + "s"); - } + log("Indexed %d docs in %d seconds", numDocs, TimeUnit.NANOSECONDS.toSeconds(elapsed)); return (int) TimeUnit.NANOSECONDS.toMillis(elapsed); } @@ -125,4 +196,10 @@ private void seekToStartDoc(FileChannel in, int dim, VectorEncoding vectorEncodi int startByte = docsStartIndex * dim * vectorEncoding.byteSize; in.position(startByte); } + + private void log(String msg, Object... args) { + if (quiet == false) { + System.out.printf((msg) + "%n", args); + } + } } diff --git a/src/main/knn/KnnIndexerMain.java b/src/main/knn/KnnIndexerMain.java index 66dc01ce6..3e136b744 100644 --- a/src/main/knn/KnnIndexerMain.java +++ b/src/main/knn/KnnIndexerMain.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -38,6 +39,8 @@ public class KnnIndexerMain { public int docStartIndex = 0; boolean quiet = false; + boolean parentJoin = false; + Path parentJoinMetaFile = null; @Override public String toString() { @@ -73,6 +76,10 @@ public static void main(String[] args) throws IOException { case "-docstartindex" -> inputs.docStartIndex = Integer.parseInt(args[++i]); case "-dimension" -> inputs.dimension = Integer.parseInt(args[++i]); case "-quiet" -> inputs.quiet = true; + case "-parentjoin" -> { + inputs.parentJoin = true; + inputs.parentJoinMetaFile = Paths.get(args[++i]); + } default -> throw new IllegalArgumentException("Cannot recognize the option " + args[i]); } i++; @@ -96,7 +103,8 @@ public static void main(String[] args) throws IOException { new KnnIndexer(inputs.docVectorsPath, inputs.indexPath, KnnGraphTester.getCodec(inputs.maxConn, inputs.beamWidth, exec, numMergeWorker, quantize, quantizeBits, quantizeCompress), inputs.vectorEncoding, - inputs.dimension, inputs.similarityFunction, inputs.numDocs, inputs.docStartIndex, inputs.quiet).createIndex(); + inputs.dimension, inputs.similarityFunction, inputs.numDocs, inputs.docStartIndex, inputs.quiet, + inputs.parentJoin, inputs.parentJoinMetaFile).createIndex(); if (!inputs.quiet) { System.out.println("Successfully created index."); @@ -114,6 +122,7 @@ public String usage() { "\t -similarityFunction : similarity function for vector comparison. One of ( EUCLIDEAN, DOT_PRODUCT, COSINE, MAXIMUM_INNER_PRODUCT )\n" + "\t -numDocs : number of document vectors to be used from the file\n" + "\t -docStartIndex : Start index of first document vector. This can be helpful when we want to run different with set of documents from within the same file.\n" + - "\t -quiet : don't print anything on console if mentioned.\n"; + "\t -quiet : don't print anything on console if mentioned.\n" + + "\t -parentJoin : create parentJoin index. Requires '*-metadata.csv'\n"; } } diff --git a/src/main/knn/KnnTesterUtils.java b/src/main/knn/KnnTesterUtils.java new file mode 100644 index 000000000..d43d25acc --- /dev/null +++ b/src/main/knn/KnnTesterUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package knn; + +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +public class KnnTesterUtils { + + /** Fetches values for the "id" field from search results + */ + public static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOException { + int[] resultIds = new int[topDocs.scoreDocs.length]; + int i = 0; + for (ScoreDoc doc : topDocs.scoreDocs) { + if (doc.doc != NO_MORE_DOCS) { + // there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens + // in some degenerate case (like input query has NaN in it?) that causes no results to + // be returned from HNSW search? + resultIds[i++] = Integer.parseInt(storedFields.document(doc.doc).get(KnnGraphTester.ID_FIELD)); + } else { + System.out.println("NO_MORE_DOCS!"); + } + } + return resultIds; + } +} diff --git a/src/main/knn/ParentJoinBenchmarkQuery.java b/src/main/knn/ParentJoinBenchmarkQuery.java new file mode 100644 index 000000000..65b1334c7 --- /dev/null +++ b/src/main/knn/ParentJoinBenchmarkQuery.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package knn; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.QueryBitSetProducer; + +import java.io.IOException; +import java.util.List; + +import static knn.KnnGraphTester.*; + +/** Exposes functions to directly invoke {@link DiversifyingChildrenFloatKnnVectorQuery#exactSearch} + */ +public class ParentJoinBenchmarkQuery extends DiversifyingChildrenFloatKnnVectorQuery { + + public static final BitSetProducer parentsFilter = + new QueryBitSetProducer(new TermQuery(new Term(DOCTYPE_FIELD, DOCTYPE_PARENT))); + + private static final TermQuery childDocQuery = new TermQuery(new Term(DOCTYPE_FIELD, DOCTYPE_CHILD)); + + ParentJoinBenchmarkQuery(float[] queryVector, Query childFilter, int k) throws IOException { + super(KNN_FIELD, queryVector, childFilter, k, parentsFilter); + } + + // expose for benchmarking + @Override + public TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException { + return super.exactSearch(context, acceptIterator, queryTimeout); + } + + public static TopDocs runExactSearch(IndexReader reader, ParentJoinBenchmarkQuery query) throws IOException { + IndexSearcher searcher = new IndexSearcher(reader); + List leafReaderContexts = reader.leaves(); + TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()]; + int leaf = 0; + for (LeafReaderContext ctx : leafReaderContexts) { + Weight childrenWeight = childDocQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f); + DocIdSetIterator acceptDocs = childrenWeight.scorer(ctx).iterator(); + perLeafResults[leaf] = query.exactSearch(ctx, acceptDocs, null); + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : perLeafResults[leaf].scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + leaf++; + } + return query.mergeLeafResults(perLeafResults); + } +} diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index 52aedcd58..ccd0b7244 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -103,8 +103,9 @@ def run_knn_benchmark(checkout, values): # Cohere dataset dim = 768 - doc_vectors = '%s/data/cohere-wikipedia-768.vec' % constants.BASE_DIR - query_vectors = '%s/data/cohere-wikipedia-queries-768.vec' % constants.BASE_DIR + doc_vectors = f"{constants.BASE_DIR}/data/{'cohere-wikipedia'}-docs-{dim}d.vec" + query_vectors = f"{constants.BASE_DIR}/data/{'cohere-wikipedia'}-queries-{dim}d.vec" + parentJoin_meta_file = f"{constants.BASE_DIR}/data/{'cohere-wikipedia'}-metadata.csv" cp = benchUtil.classPathToString(benchUtil.getClassPath(checkout)) cmd = constants.JAVA_EXE.split(' ') + ['-cp', cp, #'--add-modules', 'jdk.incubator.vector', # no need to add these flags -- they are on by default now? @@ -143,10 +144,12 @@ def run_knn_benchmark(checkout, values): '-reindex', '-search-and-stats', query_vectors, #'-metric', 'euclidean', + # '-parentJoin', parentJoin_meta_file, # '-numMergeThread', '8', '-numMergeWorker', '8', '-forceMerge', #'-stats', - '-quiet'] + '-quiet' + ] print(f' cmd: {this_cmd}') job = subprocess.Popen(this_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding='utf-8') re_summary = re.compile(r'^SUMMARY: (.*?)$', re.MULTILINE)