Skip to content

Commit

Permalink
Use instances of PointRangeQuery where possible
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Jul 19, 2024
1 parent 7e3945e commit 2a0f9dc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public abstract class ApproximatePointRangeQuery extends Query {

private long[] docCount = { 0 };

private final PointRangeQuery pointRangeQuery;

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
this(field, lowerPoint, upperPoint, numDims, 10_000, SortOrder.ASC);
}
Expand Down Expand Up @@ -84,6 +86,12 @@ protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upp
this.upperPoint = upperPoint;
this.size = size;
this.sortOrder = sortOrder;
this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) {
@Override
protected String toString(int dimension, byte[] value) {
return ApproximatePointRangeQuery.this.toString();
}
};
}

public int getSize() {
Expand All @@ -104,13 +112,12 @@ public void setSortOrder(SortOrder sortOrder) {

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
pointRangeQuery.visit(visitor);
}

@Override
public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost);

// We don't use RandomAccessWeight here: it's no good to approximate with "match all docs".
// This is an inverted structure and should be used in the first pass:
Expand Down Expand Up @@ -454,35 +461,12 @@ public long cost() {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScorerSupplier scorerSupplier = scorerSupplier(context);
if (scorerSupplier == null) {
return null;
}
return scorerSupplier.get(Long.MAX_VALUE);
return pointRangeQueryWeight.scorer(context);
}

@Override
public int count(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();

PointValues values = (PointValues) reader.getPointValues(field);
if (checkValidPointValues(values) == false) {
return 0;
}

if (reader.hasDeletions() == false) {
if (relate(values.getMinPackedValue(), values.getMaxPackedValue()) == PointValues.Relation.CELL_INSIDE_QUERY) {
return values.getDocCount();
}
// only 1D: we have the guarantee that it will actually run fast since there are at most 2
// crossing leaves.
// docCount == size : counting according number of points in leaf node, so must be
// single-valued.
if (numDims == 1 && values.getDocCount() == values.size()) {
return (int) pointCount((PointValues.PointTree) values.getPointTree(), this::relate, this::matches);
}
}
return super.count(context);
return pointRangeQueryWeight.count(context);
}

/**
Expand Down Expand Up @@ -565,7 +549,7 @@ private void pointCount(PointValues.IntersectVisitor visitor, PointValues.PointT

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
return pointRangeQueryWeight.isCacheable(ctx);
}
};
}
Expand All @@ -592,18 +576,12 @@ public byte[] getUpperPoint() {

@Override
public final int hashCode() {
int hash = classHash();
hash = 31 * hash + field.hashCode();
hash = 31 * hash + Arrays.hashCode(lowerPoint);
hash = 31 * hash + Arrays.hashCode(upperPoint);
hash = 31 * hash + numDims;
hash = 31 * hash + Objects.hashCode(bytesPerDim);
return hash;
return pointRangeQuery.hashCode();
}

@Override
public final boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
return pointRangeQuery.equals(o);
}

private boolean equalsTo(ApproximatePointRangeQuery other) {
Expand All @@ -616,29 +594,6 @@ private boolean equalsTo(ApproximatePointRangeQuery other) {

@Override
public final String toString(String field) {
final StringBuilder sb = new StringBuilder();
if (this.field.equals(field) == false) {
sb.append(this.field);
sb.append(':');
}

// print ourselves as "range per dimension"
for (int i = 0; i < numDims; i++) {
if (i > 0) {
sb.append(',');
}

int startOffset = bytesPerDim * i;

sb.append('[');
sb.append(toString(i, ArrayUtil.copyOfSubArray(lowerPoint, startOffset, startOffset + bytesPerDim)));
sb.append(" TO ");
sb.append(toString(i, ArrayUtil.copyOfSubArray(upperPoint, startOffset, startOffset + bytesPerDim)));
sb.append(']');
}

return sb.toString();
return pointRangeQuery.toString(field);
}

protected abstract String toString(int dimension, byte[] value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,6 @@ public ApproximateableQuery(
this.approximationQueryWeight = approximationQueryWeight;
}

public void setOriginalQueryWeight(Weight originalQueryWeight) {
this.originalQueryWeight = originalQueryWeight;
}

public void setApproximationQueryWeight(Weight approximationQueryWeight) {
this.approximationQueryWeight = approximationQueryWeight;
}

public Weight getOriginalQueryWeight() {
return originalQueryWeight;
}

public Weight getApproximationQueryWeight() {
return approximationQueryWeight;
Expand All @@ -73,10 +62,8 @@ public Query getApproximationQuery() {

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
final Weight originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost);
setOriginalQueryWeight(originalQueryWeight);
final Weight approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost);
setApproximationQueryWeight(approximationQueryWeight);
this.originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost);
this.approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost);

return new Weight(this) {
@Override
Expand Down

0 comments on commit 2a0f9dc

Please sign in to comment.