Skip to content

Commit

Permalink
Hack proto serde into base classes
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Carroll <[email protected]>
  • Loading branch information
finnegancarroll committed Aug 29, 2024
1 parent be1ad00 commit db6c296
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 175 deletions.
291 changes: 169 additions & 122 deletions server/src/main/java/org/opensearch/search/SearchHit.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.search;

import com.google.protobuf.ByteString;
import org.apache.lucene.search.Explanation;
import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
Expand Down Expand Up @@ -67,7 +68,9 @@
import org.opensearch.rest.action.search.RestSearchAction;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.search.lookup.SourceLookup;
import org.opensearch.serde.proto.SearchHitsTransportProto;
import org.opensearch.transport.RemoteClusterAware;
import org.opensearch.transport.protobuf.SearchHitsProtobuf;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -89,6 +92,16 @@
import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureFieldName;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesToProto;

/**
* A single search hit.
Expand Down Expand Up @@ -140,6 +153,162 @@ public class SearchHit implements Writeable, ToXContentObject, Iterable<Document

protected Map<String, SearchHits> innerHits;

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
/////////////// DIRTY HACK ///////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////////////////

public SearchHit(StreamInput in) throws IOException {
fromProtobufStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
toProtobufStream(out);
}

public void toProtobufStream(StreamOutput out) throws IOException {
toProto().writeTo(out);
}

public void fromProtobufStream(StreamInput in) throws IOException {
SearchHitsTransportProto.SearchHitProto proto = SearchHitsTransportProto.SearchHitProto.parseFrom(in);
fromProto(proto);
}

private SearchHitsTransportProto.SearchHitProto toProto() {
SearchHitsTransportProto.SearchHitProto.Builder builder = SearchHitsTransportProto.SearchHitProto.newBuilder()
.setScore(score)
.setId(id.string())
.setVersion(version)
.setSeqNo(seqNo)
.setPrimaryTerm(primaryTerm);

if (nestedIdentity != null) {
builder.setNestedIdentity(nestedIdentityToProto(nestedIdentity));
}

if (source != null) {
builder.setSource(ByteString.copyFrom(source.toBytesRef().bytes));
}

if (explanation != null) {
builder.setExplanation(explanationToProto(explanation));
}

builder.setSortValues(searchSortValuesToProto(sortValues));

documentFields.forEach((key, value) -> builder.putDocumentFields(key, documentFieldToProto(value)));

metaFields.forEach((key, value) -> builder.putMetaFields(key, documentFieldToProto(value)));

if (highlightFields != null) {
highlightFields.forEach((key, value) -> builder.putHighlightFields(key, highlightFieldToProto(value)));
}

matchedQueries.forEach(builder::putMatchedQueries);

if (shard != null) {
builder.setShard(searchShardTargetToProto(shard));
}

if (innerHits != null) {
innerHits.forEach((key, value) -> builder.putInnerHits(key, new SearchHitsProtobuf(value).toProto()));
}

return builder.build();
}

private void fromProto(SearchHitsTransportProto.SearchHitProto proto) {
docId = -1;
score = proto.getScore();
seqNo = proto.getSeqNo();
version = proto.getVersion();
primaryTerm = proto.getPrimaryTerm();
id = new Text(proto.getId());
sortValues = searchSortValuesFromProto(proto.getSortValues());
matchedQueries = proto.getMatchedQueriesMap();

if (proto.hasNestedIdentity()) {
nestedIdentity = nestedIdentityFromProto(proto.getNestedIdentity());
} else {
nestedIdentity = null;
}

if (proto.hasSource()) {
source = BytesReference.fromByteBuffer(proto.getSource().asReadOnlyByteBuffer());
} else {
source = null;
}

if (proto.hasExplanation()) {
explanation = explanationFromProto(proto.getExplanation());
} else {
explanation = null;
}

if (proto.hasShard()) {
shard = searchShardTargetFromProto(proto.getShard());
index = shard.getIndex();
clusterAlias = shard.getClusterAlias();
} else {
shard = null;
index = null;
clusterAlias = null;
}

Map<String, SearchHitsTransportProto.SearchHitsProto> innerHitsProto = proto.getInnerHitsMap();
if (!innerHitsProto.isEmpty()) {
innerHits = new HashMap<>();
innerHitsProto.forEach((key, value) -> innerHits.put(key, new SearchHitsProtobuf(value)));
}

documentFields = new HashMap<>();
Map<String, SearchHitsTransportProto.DocumentFieldProto> documentFieldProtoMap = proto.getDocumentFieldsMap();
if (!documentFieldProtoMap.isEmpty()) {
documentFieldProtoMap.forEach((key, value) -> documentFields.put(key, documentFieldFromProto(value)));
}

metaFields = new HashMap<>();
Map<String, SearchHitsTransportProto.DocumentFieldProto> metaFieldProtoMap = proto.getMetaFieldsMap();
if (!metaFieldProtoMap.isEmpty()) {
metaFieldProtoMap.forEach((key, value) -> metaFields.put(key, documentFieldFromProto(value)));
}

highlightFields = new HashMap<>();
Map<String, SearchHitsTransportProto.HighlightFieldProto> highlightFieldProtoMap = proto.getHighlightFieldsMap();
if (!highlightFieldProtoMap.isEmpty()) {
highlightFieldProtoMap.forEach((key, value) -> highlightFields.put(key, highlightFieldFromProto(value)));
}
}

static SearchHitsTransportProto.NestedIdentityProto nestedIdentityToProto(SearchHit.NestedIdentity nestedIdentity) {
SearchHitsTransportProto.NestedIdentityProto.Builder builder = SearchHitsTransportProto.NestedIdentityProto.newBuilder()
.setField(nestedIdentity.getField().string())
.setOffset(nestedIdentity.getOffset());

if (nestedIdentity.getChild() != null) {
builder.setChild(nestedIdentityToProto(nestedIdentity.getChild()));
}

return builder.build();
}

static SearchHit.NestedIdentity nestedIdentityFromProto(SearchHitsTransportProto.NestedIdentityProto proto) {
String field = proto.getField();
int offset = proto.getOffset();

SearchHit.NestedIdentity child = null;
if (proto.hasChild()) {
child = nestedIdentityFromProto(proto.getChild());
}

return new SearchHit.NestedIdentity(field, offset, child);
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
/////////////// DIRTY HACK ///////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////////////////////

public SearchHit(SearchHit hit) {
this.docId = hit.docId;
this.id = hit.id;
Expand Down Expand Up @@ -189,132 +358,10 @@ public SearchHit(
this.metaFields = metaFields == null ? emptyMap() : metaFields;
}

public SearchHit(StreamInput in) throws IOException {
docId = -1;
score = in.readFloat();
id = in.readOptionalText();
if (in.getVersion().before(Version.V_2_0_0)) {
in.readOptionalText();
}
nestedIdentity = in.readOptionalWriteable(NestedIdentity::new);
version = in.readLong();
seqNo = in.readZLong();
primaryTerm = in.readVLong();
source = in.readBytesReference();
if (source.length() == 0) {
source = null;
}
if (in.readBoolean()) {
explanation = readExplanation(in);
}
documentFields = in.readMap(StreamInput::readString, DocumentField::new);
metaFields = in.readMap(StreamInput::readString, DocumentField::new);

int size = in.readVInt();
if (size == 0) {
highlightFields = emptyMap();
} else if (size == 1) {
HighlightField field = new HighlightField(in);
highlightFields = singletonMap(field.name(), field);
} else {
Map<String, HighlightField> highlightFields = new HashMap<>();
for (int i = 0; i < size; i++) {
HighlightField field = new HighlightField(in);
highlightFields.put(field.name(), field);
}
this.highlightFields = unmodifiableMap(highlightFields);
}

sortValues = new SearchSortValues(in);

size = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_2_13_0)) {
if (size > 0) {
Map<String, Float> tempMap = in.readMap(StreamInput::readString, StreamInput::readFloat);
matchedQueries = tempMap.entrySet()
.stream()
.sorted(Map.Entry.comparingByKey())
.collect(
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)
);
}
} else {
matchedQueries = new LinkedHashMap<>(size);
for (int i = 0; i < size; i++) {
matchedQueries.put(in.readString(), Float.NaN);
}
}
// we call the setter here because that also sets the local index parameter
shard(in.readOptionalWriteable(SearchShardTarget::new));
size = in.readVInt();
if (size > 0) {
innerHits = new HashMap<>(size);
for (int i = 0; i < size; i++) {
String key = in.readString();
SearchHits value = new SearchHits(in);
innerHits.put(key, value);
}
} else {
innerHits = null;
}
}

protected SearchHit() {}

protected static final Text SINGLE_MAPPING_TYPE = new Text(MapperService.SINGLE_MAPPING_NAME);

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeFloat(score);
out.writeOptionalText(id);
if (out.getVersion().before(Version.V_2_0_0)) {
out.writeOptionalText(SINGLE_MAPPING_TYPE);
}
out.writeOptionalWriteable(nestedIdentity);
out.writeLong(version);
out.writeZLong(seqNo);
out.writeVLong(primaryTerm);
out.writeBytesReference(source);
if (explanation == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
writeExplanation(out, explanation);
}
out.writeMap(documentFields, StreamOutput::writeString, (stream, documentField) -> documentField.writeTo(stream));
out.writeMap(metaFields, StreamOutput::writeString, (stream, documentField) -> documentField.writeTo(stream));
if (highlightFields == null) {
out.writeVInt(0);
} else {
out.writeVInt(highlightFields.size());
for (HighlightField highlightField : highlightFields.values()) {
highlightField.writeTo(out);
}
}
sortValues.writeTo(out);

out.writeVInt(matchedQueries.size());
if (out.getVersion().onOrAfter(Version.V_2_13_0)) {
if (!matchedQueries.isEmpty()) {
out.writeMap(matchedQueries, StreamOutput::writeString, StreamOutput::writeFloat);
}
} else {
for (String matchedFilter : matchedQueries.keySet()) {
out.writeString(matchedFilter);
}
}
out.writeOptionalWriteable(shard);
if (innerHits == null) {
out.writeVInt(0);
} else {
out.writeVInt(innerHits.size());
for (Map.Entry<String, SearchHits> entry : innerHits.entrySet()) {
out.writeString(entry.getKey());
entry.getValue().writeTo(out);
}
}
}

public int docId() {
return this.docId;
}
Expand Down
Loading

0 comments on commit db6c296

Please sign in to comment.