Skip to content

Commit

Permalink
Move SearchHitS serialization to SerDe class
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Carroll <[email protected]>
  • Loading branch information
finnegancarroll committed Aug 22, 2024
1 parent ea630c8 commit a4dfcdf
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 38 deletions.
4 changes: 2 additions & 2 deletions server/src/main/java/org/opensearch/search/SearchHit.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public SearchHit(SearchHit hit) throws IOException {

/**
* Preserving this constructor for compatibility.
* Going forward deserialize with dedicated SearchHitSerDe object.
* Going forward deserialize with dedicated SearchHitSerDe.
*/
public SearchHit(StreamInput in) throws IOException {
this(new SearchHitSerDe().deserialize(in));
Expand Down Expand Up @@ -335,7 +335,7 @@ public Map<String, SearchHits> getInnerHits() {

/**
* Preserving for compatibility.
* Going forward serialize with dedicated SearchHitSerDe object.
* Going forward serialize with dedicated SearchHitSerDe.
*/
@Override
public void writeTo(StreamOutput out) throws IOException {
Expand Down
74 changes: 40 additions & 34 deletions server/src/main/java/org/opensearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.action.search.RestSearchAction;
import org.opensearch.search.fetch.serde.SearchHitSerDe;
import org.opensearch.search.fetch.serde.SearchHitsSerDe;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -103,6 +105,15 @@ public SearchHits(
this.collapseValues = collapseValues;
}

public SearchHits(SearchHits SrchHits) {
this(SrchHits.hits,
SrchHits.totalHits,
SrchHits.maxScore,
SrchHits.sortFields,
SrchHits.collapseField,
SrchHits.collapseValues);
}

/**
* Internal access for serialization interface.
* @opensearch.api
Expand All @@ -114,6 +125,12 @@ public interface SerializationAccess {
float getMaxScore();

SearchHit[] getHits();

SortField[] getSortFields();

String getCollapseField();

Object[] getCollapseValues();
}

public SearchHits.SerializationAccess getSerAccess() {
Expand All @@ -129,48 +146,37 @@ public float getMaxScore() {
public SearchHit[] getHits() {
return hits;
}

public SortField[] getSortFields() {
return sortFields;
}

public String getCollapseField() {
return collapseField;
}

public Object[] getCollapseValues() {
return collapseValues;
}
};
}

/**
* Preserving for compatibility.
* Going forward serialize with dedicated SearchHitsSerDe.
*/
public SearchHits(StreamInput in) throws IOException {
if (in.readBoolean()) {
totalHits = Lucene.readTotalHits(in);
} else {
// track_total_hits is false
totalHits = null;
}
maxScore = in.readFloat();
int size = in.readVInt();
if (size == 0) {
hits = EMPTY;
} else {
hits = new SearchHit[size];
for (int i = 0; i < hits.length; i++) {
hits[i] = new SearchHit(in);
}
}
sortFields = in.readOptionalArray(Lucene::readSortField, SortField[]::new);
collapseField = in.readOptionalString();
collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new);
this(new SearchHitsSerDe().deserialize(in));
}

/**
* Preserving for compatibility.
* Going forward deserialize with dedicated SearchHitSerDe.
*/
@Override
public void writeTo(StreamOutput out) throws IOException {
final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
out.writeVInt(hits.length);
if (hits.length > 0) {
for (SearchHit hit : hits) {
hit.writeTo(out);
}
}
out.writeOptionalArray(Lucene::writeSortField, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
SearchHitsSerDe serDe = new SearchHitsSerDe();
serDe.serialize(this, out);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@

package org.opensearch.search.fetch.serde;

import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;

import java.io.IOException;

import static org.opensearch.search.SearchHits.EMPTY;

public class SearchHitsSerDe implements SerDe.StreamSerializer<SearchHits>, SerDe.StreamDeserializer<SearchHits> {
SearchHitSerDe searchHitSerDe;

@Override
public SearchHits deserialize(StreamInput in) {
try {
return new SearchHits(in);
return fromStream(in);
} catch (IOException e) {
throw new SerDe.SerializationException("Failed to deserialize FetchSearchResult", e);
}
Expand All @@ -29,9 +35,66 @@ public SearchHits deserialize(StreamInput in) {
@Override
public void serialize(SearchHits object, StreamOutput out) throws SerDe.SerializationException {
try {
object.writeTo(out);
toStream(object, out);
} catch (IOException e) {
throw new SerDe.SerializationException("Failed to serialize FetchSearchResult", e);
}
}

private SearchHits fromStream(StreamInput in) throws IOException {
SearchHit[] hits;
TotalHits totalHits;
float maxScore;
SortField[] sortFields;
String collapseField;
Object[] collapseValues;

if (in.readBoolean()) {
totalHits = Lucene.readTotalHits(in);
} else {
// track_total_hits is false
totalHits = null;
}
maxScore = in.readFloat();
int size = in.readVInt();
if (size == 0) {
hits = EMPTY;
} else {
hits = new SearchHit[size];
for (int i = 0; i < hits.length; i++) {
hits[i] = new SearchHit(in);
}
}
sortFields = in.readOptionalArray(Lucene::readSortField, SortField[]::new);
collapseField = in.readOptionalString();
collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new);

return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues);
}

private void toStream(SearchHits object, StreamOutput out) throws IOException {
SearchHits.SerializationAccess serI = object.getSerAccess();
SearchHit[] hits = serI.getHits();
TotalHits totalHits = serI.getTotalHits();
float maxScore = serI.getMaxScore();
SortField[] sortFields = serI.getSortFields();
String collapseField = serI.getCollapseField();
Object[] collapseValues = serI.getCollapseValues();

final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
out.writeVInt(hits.length);
if (hits.length > 0) {
for (SearchHit hit : hits) {
hit.writeTo(out);
}
}
out.writeOptionalArray(Lucene::writeSortField, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
}
}

0 comments on commit a4dfcdf

Please sign in to comment.