Skip to content

Commit

Permalink
ES|QL Add initial support for semantic_text field type (elastic#113920)
Browse files Browse the repository at this point in the history
* Add initial support for semantic_text field type

* Update docs/changelog/113920.yaml

* More tests and fixes

* Use mock inference service

* Fix tests

* Spotless

* Fix mixed-cluster and multi-clusters tests

* sort

* Attempt another fix for bwc tests

* Spotless

* Fix merge

* Attempt another fix

* Don't load the inference-service-test plugin for mixed versions/clusters

* Add more tests, address review comments

* trivial

* revert

* post-merge fix block loader

* post-merge fix compile

* add mixed version testing

* whitespace

* fix MultiClusterSpecIT

* add more fields to mapping

* Revert  mixed version testing

* whitespace

---------

Co-authored-by: ChrisHegarty <[email protected]>
Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 183ad88 commit fd43adc
Show file tree
Hide file tree
Showing 26 changed files with 490 additions and 35 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/113920.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 113920
summary: Add initial support for `semantic_text` field type
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
public class EsqlCorePlugin extends Plugin implements ExtensiblePlugin {
public static final FeatureFlag DATE_NANOS_FEATURE_FLAG = new FeatureFlag("esql_date_nanos");

public static final FeatureFlag SEMANTIC_TEXT_FEATURE_FLAG = new FeatureFlag("esql_semantic_text");
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,14 @@ public enum DataType {
* inside alongside time-series aggregations. These fields are not parsable from the
* mapping and should be hidden from users.
*/
PARTIAL_AGG(builder().esType("partial_agg").unknownSize());
PARTIAL_AGG(builder().esType("partial_agg").unknownSize()),
/**
* String fields that are split into chunks, where each chunk has attached embeddings
* used for semantic search. Generally ESQL only sees {@code semantic_text} fields when
* loaded from the index and ESQL will load these fields as strings without their attached
* chunks or embeddings.
*/
SEMANTIC_TEXT(builder().esType("semantic_text").unknownSize());

/**
* Types that are actively being built. These types are not returned
Expand All @@ -203,7 +210,8 @@ public enum DataType {
* check that sending them to a function produces a sane error message.
*/
public static final Map<DataType, FeatureFlag> UNDER_CONSTRUCTION = Map.ofEntries(
Map.entry(DATE_NANOS, EsqlCorePlugin.DATE_NANOS_FEATURE_FLAG)
Map.entry(DATE_NANOS, EsqlCorePlugin.DATE_NANOS_FEATURE_FLAG),
Map.entry(SEMANTIC_TEXT, EsqlCorePlugin.SEMANTIC_TEXT_FEATURE_FLAG)
);

private final String typeName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ protected boolean supportsAsync() {
protected boolean enableRoundingDoubleValuesOnAsserting() {
return true;
}

@Override
protected boolean supportsInferenceTestService() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,9 @@ static boolean hasIndexMetadata(String query) {
protected boolean enableRoundingDoubleValuesOnAsserting() {
return true;
}

@Override
protected boolean supportsInferenceTestService() {
return false;
}
}
1 change: 1 addition & 0 deletions x-pack/plugin/esql/qa/server/multi-node/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies {

clusterPlugins project(':plugins:mapper-size')
clusterPlugins project(':plugins:mapper-murmur3')
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
}

GradleUtils.extendSourceSet(project, "javaRestTest", "yamlRestTest")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

public class EsqlSpecIT extends EsqlSpecTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> {});
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));

@Override
protected String getTestRestCluster() {
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/esql/qa/server/single-node/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies {

clusterPlugins project(':plugins:mapper-size')
clusterPlugins project(':plugins:mapper-murmur3')
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
}

restResources {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class EsqlSpecIT extends EsqlSpecTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster();
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));

@Override
protected String getTestRestCluster() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@
import static org.elasticsearch.xpack.esql.CsvTestUtils.ExpectedResults;
import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.availableDatasetsForEs;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.clusterHasInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.createInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.deleteInferenceEndpoint;
import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.loadDataSetIntoEs;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.classpathResources;

Expand Down Expand Up @@ -129,7 +132,11 @@ protected EsqlSpecTestCase(

@Before
public void setup() throws IOException {
if (indexExists(CSV_DATASET_MAP.keySet().iterator().next()) == false) {
if (supportsInferenceTestService() && clusterHasInferenceEndpoint(client()) == false) {
createInferenceEndpoint(client());
}

if (indexExists(availableDatasetsForEs(client()).iterator().next().indexName()) == false) {
loadDataSetIntoEs(client());
}
}
Expand All @@ -148,6 +155,8 @@ public static void wipeTestData() throws IOException {
throw e;
}
}

deleteInferenceEndpoint(client());
}

public boolean logResults() {
Expand All @@ -164,6 +173,9 @@ public final void test() throws Throwable {
}

protected void shouldSkipTest(String testName) throws IOException {
if (testCase.requiredCapabilities.contains("semantic_text_type")) {
assumeTrue("Inference test service needs to be supported for semantic_text", supportsInferenceTestService());
}
checkCapabilities(adminClient(), testFeatureService, testName, testCase);
assumeTrue("Test " + testName + " is not enabled", isEnabled(testName, instructions, Version.CURRENT));
}
Expand Down Expand Up @@ -207,6 +219,10 @@ protected static void checkCapabilities(RestClient client, TestFeatureService te
}
}

protected boolean supportsInferenceTestService() {
return true;
}

protected final void doTest() throws Throwable {
RequestObjectBuilder builder = new RequestObjectBuilder(randomFrom(XContentType.values()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ private static void assertMetadata(
|| expectedType == UNSIGNED_LONG)) {
continue;
}
if (blockType == Type.KEYWORD && (expectedType == Type.IP || expectedType == Type.VERSION || expectedType == Type.TEXT)) {
if (blockType == Type.KEYWORD
&& (expectedType == Type.IP
|| expectedType == Type.VERSION
|| expectedType == Type.TEXT
|| expectedType == Type.SEMANTIC_TEXT)) {
// Type.asType translates all bytes references into keywords
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ public enum Type {
SCALED_FLOAT(s -> s == null ? null : scaledFloat(s, "100"), Double.class),
KEYWORD(Object::toString, BytesRef.class),
TEXT(Object::toString, BytesRef.class),
SEMANTIC_TEXT(Object::toString, BytesRef.class),
IP(
StringUtils::parseIP,
(l, r) -> l instanceof String maybeIP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.logging.log4j.core.config.plugins.util.PluginManager;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.common.Strings;
Expand All @@ -36,9 +37,11 @@
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.esql.CsvTestUtils.COMMA_ESCAPING_REGEX;
Expand Down Expand Up @@ -81,6 +84,7 @@ public class CsvTestsDataLoader {
private static final TestsDataset K8S = new TestsDataset("k8s", "k8s-mappings.json", "k8s.csv").withSetting("k8s-settings.json");
private static final TestsDataset ADDRESSES = new TestsDataset("addresses");
private static final TestsDataset BOOKS = new TestsDataset("books");
private static final TestsDataset SEMANTIC_TEXT = new TestsDataset("semantic_text").withInferenceEndpoint(true);

public static final Map<String, TestsDataset> CSV_DATASET_MAP = Map.ofEntries(
Map.entry(EMPLOYEES.indexName, EMPLOYEES),
Expand Down Expand Up @@ -112,7 +116,8 @@ public class CsvTestsDataLoader {
Map.entry(K8S.indexName, K8S),
Map.entry(DISTANCES.indexName, DISTANCES),
Map.entry(ADDRESSES.indexName, ADDRESSES),
Map.entry(BOOKS.indexName, BOOKS)
Map.entry(BOOKS.indexName, BOOKS),
Map.entry(SEMANTIC_TEXT.indexName, SEMANTIC_TEXT)
);

private static final EnrichConfig LANGUAGES_ENRICH = new EnrichConfig("languages_policy", "enrich-policy-languages.json");
Expand Down Expand Up @@ -219,8 +224,13 @@ public static void main(String[] args) throws IOException {
}
}

private static void loadDataSetIntoEs(RestClient client, IndexCreator indexCreator) throws IOException {
loadDataSetIntoEs(client, LogManager.getLogger(CsvTestsDataLoader.class), indexCreator);
public static Set<TestsDataset> availableDatasetsForEs(RestClient client) throws IOException {
boolean inferenceEnabled = clusterHasInferenceEndpoint(client);

return CSV_DATASET_MAP.values()
.stream()
.filter(d -> d.requiresInferenceEndpoint == false || inferenceEnabled)
.collect(Collectors.toCollection(HashSet::new));
}

public static void loadDataSetIntoEs(RestClient client) throws IOException {
Expand All @@ -229,22 +239,61 @@ public static void loadDataSetIntoEs(RestClient client) throws IOException {
});
}

public static void loadDataSetIntoEs(RestClient client, Logger logger) throws IOException {
loadDataSetIntoEs(client, logger, (restClient, indexName, indexMapping, indexSettings) -> {
ESRestTestCase.createIndex(restClient, indexName, indexSettings, indexMapping, null);
});
}
private static void loadDataSetIntoEs(RestClient client, IndexCreator indexCreator) throws IOException {
Logger logger = LogManager.getLogger(CsvTestsDataLoader.class);

private static void loadDataSetIntoEs(RestClient client, Logger logger, IndexCreator indexCreator) throws IOException {
for (var dataset : CSV_DATASET_MAP.values()) {
Set<String> loadedDatasets = new HashSet<>();
for (var dataset : availableDatasetsForEs(client)) {
load(client, dataset, logger, indexCreator);
loadedDatasets.add(dataset.indexName);
}
forceMerge(client, CSV_DATASET_MAP.keySet(), logger);
forceMerge(client, loadedDatasets, logger);
for (var policy : ENRICH_POLICIES) {
loadEnrichPolicy(client, policy.policyName, policy.policyFileName, logger);
}
}

/** The semantic_text mapping type require an inference endpoint that needs to be setup before creating the index. */
public static void createInferenceEndpoint(RestClient client) throws IOException {
Request request = new Request("PUT", "_inference/sparse_embedding/test_sparse_inference");
request.setJsonEntity("""
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
""");
client.performRequest(request);
}

public static void deleteInferenceEndpoint(RestClient client) throws IOException {
try {
client.performRequest(new Request("DELETE", "_inference/test_sparse_inference"));
} catch (ResponseException e) {
// 404 here means the endpoint was not created
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
throw e;
}
}
}

public static boolean clusterHasInferenceEndpoint(RestClient client) throws IOException {
Request request = new Request("GET", "_inference/sparse_embedding/test_sparse_inference");
try {
client.performRequest(request);
} catch (ResponseException e) {
if (e.getResponse().getStatusLine().getStatusCode() == 404) {
return false;
}
throw e;
}
return true;
}

private static void loadEnrichPolicy(RestClient client, String policyName, String policyFileName, Logger logger) throws IOException {
URL policyMapping = CsvTestsDataLoader.class.getResource("/" + policyFileName);
if (policyMapping == null) {
Expand Down Expand Up @@ -511,34 +560,79 @@ public record TestsDataset(
String dataFileName,
String settingFileName,
boolean allowSubFields,
Map<String, String> typeMapping
Map<String, String> typeMapping,
boolean requiresInferenceEndpoint
) {
public TestsDataset(String indexName, String mappingFileName, String dataFileName) {
this(indexName, mappingFileName, dataFileName, null, true, null);
this(indexName, mappingFileName, dataFileName, null, true, null, false);
}

public TestsDataset(String indexName) {
this(indexName, "mapping-" + indexName + ".json", indexName + ".csv", null, true, null);
this(indexName, "mapping-" + indexName + ".json", indexName + ".csv", null, true, null, false);
}

public TestsDataset withIndex(String indexName) {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping);
return new TestsDataset(
indexName,
mappingFileName,
dataFileName,
settingFileName,
allowSubFields,
typeMapping,
requiresInferenceEndpoint
);
}

public TestsDataset withData(String dataFileName) {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping);
return new TestsDataset(
indexName,
mappingFileName,
dataFileName,
settingFileName,
allowSubFields,
typeMapping,
requiresInferenceEndpoint
);
}

public TestsDataset withSetting(String settingFileName) {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping);
return new TestsDataset(
indexName,
mappingFileName,
dataFileName,
settingFileName,
allowSubFields,
typeMapping,
requiresInferenceEndpoint
);
}

public TestsDataset noSubfields() {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, false, typeMapping);
return new TestsDataset(
indexName,
mappingFileName,
dataFileName,
settingFileName,
false,
typeMapping,
requiresInferenceEndpoint
);
}

public TestsDataset withTypeMapping(Map<String, String> typeMapping) {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping);
return new TestsDataset(
indexName,
mappingFileName,
dataFileName,
settingFileName,
allowSubFields,
typeMapping,
requiresInferenceEndpoint
);
}

public TestsDataset withInferenceEndpoint(boolean needsInference) {
return new TestsDataset(indexName, mappingFileName, dataFileName, settingFileName, allowSubFields, typeMapping, needsInference);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ public static Literal randomLiteral(DataType type) {
case KEYWORD -> new BytesRef(randomAlphaOfLength(5));
case IP -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean())));
case TIME_DURATION -> Duration.ofMillis(randomLongBetween(-604800000L, 604800000L)); // plus/minus 7 days
case TEXT -> new BytesRef(randomAlphaOfLength(50));
case TEXT, SEMANTIC_TEXT -> new BytesRef(randomAlphaOfLength(50));
case VERSION -> randomVersion().toBytesRef();
case GEO_POINT -> GEO.asWkb(GeometryTestUtils.randomPoint());
case CARTESIAN_POINT -> CARTESIAN.asWkb(ShapeTestUtils.randomPoint());
Expand Down
Loading

0 comments on commit fd43adc

Please sign in to comment.