Skip to content

Commit

Permalink
add query support
Browse files Browse the repository at this point in the history
  • Loading branch information
marevol committed Jun 22, 2024
1 parent febbcfe commit 849027b
Show file tree
Hide file tree
Showing 14 changed files with 688 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@
package org.codelibs.fess.multimodal;

public class MultiModalConstants {

private static final String PREFIX = "fess.multimodal.";

public static final String MIN_SCORE = PREFIX + "min_score";

public static final String CONTENT_VECTOR_FIELD = System.getProperty(PREFIX + "content.field", "content_vector");

public static final String X_FESS_EMBEDDING = "X-FESS-Embedding";

public static final String SEARCHER = "multiModalSearcher";

public static final String CAS_CLIENT = "casClient";

private MultiModalConstants() {
// nothing
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/org/codelibs/fess/multimodal/client/CasClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,25 @@ protected String encodeImage(final InputStream in) {
throw new CasAccessException("Failed to read an image.", e);
}
}

public float[] getTextEmbedding(final String query) {
final String body = "{\"data\":[{\"text\":\"" + StringEscapeUtils.escapeJson(query) + "\"}],\"execEndpoint\":\"/\"}";
logger.debug("request body: {}", body);
try (CurlResponse response = Curl.post(clipEndpoint + "/post").header("Content-Type", "application/json").body(body).execute()) {
final Map<String, Object> contentMap = response.getContent(PARSER);
if (((contentMap.get("data") instanceof final List dataList)
&& (!dataList.isEmpty() && dataList.get(0) instanceof final Map data))
&& (data.get("embedding") instanceof final List embeddingList)) {
logger.debug("embedding: {}", embeddingList);
final float[] embedding = new float[embeddingList.size()];
for (int i = 0; i < embedding.length; i++) {
embedding[i] = ((Number) embeddingList.get(i)).floatValue();
}
return embedding;
}
} catch (final IOException e) {
throw new CasAccessException("Clip server failed to generate an embedding.", e);
}
throw new CasAccessException("Clip server cannot generate an embedding");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package org.codelibs.fess.multimodal.crawler.extractor;

import static org.codelibs.fess.multimodal.MultiModalConstants.CAS_CLIENT;
import static org.codelibs.fess.multimodal.MultiModalConstants.X_FESS_EMBEDDING;

import java.io.InputStream;
import java.util.Map;

Expand All @@ -24,7 +27,6 @@
import org.apache.logging.log4j.Logger;
import org.codelibs.fess.crawler.entity.ExtractData;
import org.codelibs.fess.crawler.extractor.impl.TikaExtractor;
import org.codelibs.fess.multimodal.MultiModalConstants;
import org.codelibs.fess.multimodal.client.CasClient;
import org.codelibs.fess.multimodal.ingest.EmbeddingIngester;
import org.codelibs.fess.multimodal.util.EmbeddingUtil;
Expand All @@ -45,14 +47,14 @@ public int getWeight() {
public void init() {
super.init();

client = crawlerContainer.getComponent("casClient");
client = crawlerContainer.getComponent(CAS_CLIENT);
}

@Override
public ExtractData getText(final InputStream inputStream, final Map<String, String> params) {
return getText(inputStream, params, (data, in) -> {
try {
data.putValue(MultiModalConstants.X_FESS_EMBEDDING, EmbeddingUtil.encodeFloatArray(client.getImageEmbedding(in)));
data.putValue(X_FESS_EMBEDDING, EmbeddingUtil.encodeFloatArray(client.getImageEmbedding(in)));
} catch (final Exception e) {
logger.warn("Failed to convert an image to a vector.", e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright 2012-2024 CodeLibs Project and the Others.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language
* governing permissions and limitations under the License.
*/
package org.codelibs.fess.multimodal.index.query;

import java.io.IOException;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;

public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {

private static final String NAME = "knn";

private static final ParseField VECTOR_FIELD = new ParseField("vector");
private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField FILTER_FIELD = new ParseField("filter");
private static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
private static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance");
private static final ParseField MIN_SCORE_FIELD = new ParseField("min_score");

private static final int DEFAULT_K = 10;

protected String fieldName;

protected float[] vector;
protected int k;
protected QueryBuilder filter;
protected boolean ignoreUnmapped;
protected Float maxDistance;
protected Float minScore;

public KNNQueryBuilder(final StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.vector = in.readFloatArray();
this.k = in.readInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
this.ignoreUnmapped = in.readBoolean();
this.maxDistance = in.readOptionalFloat();
this.minScore = in.readOptionalFloat();
}

private KNNQueryBuilder() {
}

public static class Builder {
private String fieldName;
private float[] vector;
private int k = DEFAULT_K;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;
private Float maxDistance = null;
private Float minScore = null;

public Builder field(final String fieldName) {
this.fieldName = fieldName;
return this;
}

public Builder vector(final float[] vector) {
this.vector = vector;
return this;
}

public Builder k(final int k) {
this.k = k;
return this;
}

public Builder filter(final QueryBuilder filter) {
this.filter = filter;
return this;
}

public Builder ignoreUnmapped(final boolean ignoreUnmapped) {
this.ignoreUnmapped = ignoreUnmapped;
return this;
}

public Builder maxDistance(final Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

public Builder minScore(final Float minScore) {
this.minScore = minScore;
return this;
}

public KNNQueryBuilder build() {
final KNNQueryBuilder query = new KNNQueryBuilder();
query.fieldName = fieldName;
query.vector = vector;
query.k = k;
query.filter = filter;
query.ignoreUnmapped = ignoreUnmapped;
query.maxDistance = maxDistance;
query.minScore = minScore;
return query;
}
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
protected void doWriteTo(final StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeFloatArray(this.vector);
out.writeInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
out.writeBoolean(this.ignoreUnmapped);
out.writeOptionalFloat(this.maxDistance);
out.writeOptionalFloat(this.minScore);
}

@Override
protected void doXContent(final XContentBuilder xContentBuilder, final Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(VECTOR_FIELD.getPreferredName(), vector);
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
xContentBuilder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
if (maxDistance != null) {
xContentBuilder.field(MAX_DISTANCE_FIELD.getPreferredName(), maxDistance);
}
if (minScore != null) {
xContentBuilder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
}

@Override
protected Query doToQuery(final QueryShardContext context) throws IOException {
throw new UnsupportedOperationException("doToQuery is not supported.");
}

@Override
protected boolean doEquals(final KNNQueryBuilder obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
final EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(vector, obj.vector);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(filter, obj.filter);
equalsBuilder.append(ignoreUnmapped, obj.ignoreUnmapped);
equalsBuilder.append(maxDistance, obj.maxDistance);
equalsBuilder.append(minScore, obj.minScore);
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(vector).append(k).append(filter).append(ignoreUnmapped).append(maxDistance)
.append(minScore).toHashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,40 @@
*/
package org.codelibs.fess.multimodal.ingest;

import static org.codelibs.core.lang.StringUtil.EMPTY;
import static org.codelibs.fess.Constants.MAPPING_TYPE_ARRAY;
import static org.codelibs.fess.multimodal.MultiModalConstants.CONTENT_VECTOR_FIELD;
import static org.codelibs.fess.multimodal.MultiModalConstants.X_FESS_EMBEDDING;

import java.util.Map;

import javax.annotation.PostConstruct;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.codelibs.core.lang.StringUtil;
import org.codelibs.fess.Constants;
import org.codelibs.fess.ingest.Ingester;
import org.codelibs.fess.multimodal.MultiModalConstants;
import org.codelibs.fess.multimodal.util.EmbeddingUtil;
import org.codelibs.fess.util.ComponentUtil;

public class EmbeddingIngester extends Ingester {
private static final Logger logger = LogManager.getLogger(EmbeddingIngester.class);

protected String embeddingField;

@PostConstruct
public void init() {
embeddingField = System.getProperty("clip.index.embedding_field", "content_vector");

ComponentUtil.getFessConfig().addCrawlerMetadataNameMapping(MultiModalConstants.X_FESS_EMBEDDING, embeddingField,
Constants.MAPPING_TYPE_ARRAY, StringUtil.EMPTY);
ComponentUtil.getFessConfig().addCrawlerMetadataNameMapping(X_FESS_EMBEDDING, CONTENT_VECTOR_FIELD, MAPPING_TYPE_ARRAY, EMPTY);
}

@Override
protected Map<String, Object> process(final Map<String, Object> target) {
if (target.containsKey(embeddingField)) {
logger.debug("[{}] : {}", embeddingField, target);
if (target.get(embeddingField) instanceof final String[] encodedEmbeddings) {
if (target.containsKey(CONTENT_VECTOR_FIELD)) {
logger.debug("[{}] : {}", CONTENT_VECTOR_FIELD, target);
if (target.get(CONTENT_VECTOR_FIELD) instanceof final String[] encodedEmbeddings) {
final float[] embedding = EmbeddingUtil.decodeFloatArray(encodedEmbeddings[0]);
logger.debug("embedding:{}", embedding);
target.put(embeddingField, embedding);
target.put(CONTENT_VECTOR_FIELD, embedding);
} else {
logger.warn("{} is not an array.", embeddingField);
logger.warn("{} is not an array.", CONTENT_VECTOR_FIELD);
}
}
return target;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2012-2024 CodeLibs Project and the Others.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language
* governing permissions and limitations under the License.
*/
package org.codelibs.fess.multimodal.query;

import static org.codelibs.fess.Constants.DEFAULT_FIELD;
import static org.codelibs.fess.multimodal.MultiModalConstants.SEARCHER;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.PhraseQuery;
import org.codelibs.fess.entity.QueryContext;
import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher;
import org.codelibs.fess.multimodal.rank.fusion.MultiModalSearcher.SearchContext;
import org.codelibs.fess.mylasta.direction.FessConfig;
import org.codelibs.fess.query.PhraseQueryCommand;
import org.codelibs.fess.util.ComponentUtil;
import org.opensearch.index.query.QueryBuilder;

public class MultiModalPhraseQueryCommand extends PhraseQueryCommand {

private static final Logger logger = LogManager.getLogger(MultiModalPhraseQueryCommand.class);

@Override
protected QueryBuilder convertPhraseQuery(final FessConfig fessConfig, final QueryContext context, final PhraseQuery phraseQuery,
final float boost, final String field, final String[] texts) {
final SearchContext searchContext = getSearchContext();

if (!DEFAULT_FIELD.equals(field) || searchContext == null) {
return super.convertPhraseQuery(fessConfig, context, phraseQuery, boost, field, texts);
}

final String text = String.join(" ", texts);
final QueryBuilder queryBuilder =
new MultiModalQueryBuilder.Builder().query(text).minScore(searchContext.getParams().getMinScore()).build().toQueryBuilder();
context.addFieldLog(field, text);
context.addHighlightedQuery(text);
if (logger.isDebugEnabled()) {
logger.debug("KNNQueryBuilder: {}", queryBuilder);
}
return queryBuilder;
}

protected SearchContext getSearchContext() {
final MultiModalSearcher searcher = ComponentUtil.getComponent(SEARCHER);
return searcher.getContext();
}
}
Loading

0 comments on commit 849027b

Please sign in to comment.