diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 822bb67924..7f76d998d3 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -56,6 +56,11 @@ jobs: env: VERSION: ${{ matrix.version }} + # Run neural query integration tests separately as they use a significant amount of memory on their own + - run: "./build.sh integrate ${{ matrix.version }} neuralquery random:test_only_one --report" + name: Neural Query Integration Tests + working-directory: client + - name: Upload test report if: failure() uses: actions/upload-artifact@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index ef8130fc83..b3058ca56f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] ### Added +- Added support for the `neural` query type and `text_embedding` ingest processor type ([#636](https://github.com/opensearch-project/opensearch-net/pull/636)) - Added support for the `Cat.PitSegments` and `Cat.SegmentReplication` APIs ([#527](https://github.com/opensearch-project/opensearch-net/pull/527)) - Added support for `MinScore` on `ScriptScoreQuery` ([#624](https://github.com/opensearch-project/opensearch-net/pull/624)) - Added support for serializing the `DateOnly` and `TimeOnly` types ([#734](https://github.com/opensearch-project/opensearch-net/pull/734)) diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs index 644d1844c8..3e7cec92f8 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/EphemeralClusterConfiguration.cs @@ -75,7 +75,7 @@ public EphemeralClusterConfiguration(OpenSearchVersion version, ClusterFeatures /// This can be useful to fail early when subsequent operations are relying on installation /// succeeding. /// - public bool ValidatePluginsToInstall { get; } = true; + public bool ValidatePluginsToInstall { get; set; } = true; public bool EnableSsl => Features.HasFlag(ClusterFeatures.SSL); diff --git a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs index 3225476353..57a6982541 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Ephemeral/Tasks/InstallationTasks/InstallPlugins.cs @@ -61,30 +61,29 @@ public override void Run(IEphemeralCluster cluste .Where(p => !p.IsValid(v)) .Select(p => p.SubProductName).ToList(); if (invalidPlugins.Any()) - throw new OpenSearchCleanExitException( - $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); - } + { + throw new OpenSearchCleanExitException( + $"Can not install the following plugins for version {v}: {string.Join(", ", invalidPlugins)} "); + } + } foreach (var plugin in requiredPlugins) { - var includedByDefault = plugin.IsIncludedOutOfTheBox(v); - if (includedByDefault) + if (plugin.IsIncludedOutOfTheBox(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] shipped OOTB as of: {{{plugin.ShippedByDefaultAsOf}}}"); continue; } - var validForCurrentVersion = plugin.IsValid(v); - if (!validForCurrentVersion) + if (!plugin.IsValid(v)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] not valid for version: {{{v}}}"); continue; } - var alreadyInstalled = AlreadyInstalled(fs, plugin.SubProductName); - if (alreadyInstalled) + if (AlreadyInstalled(fs, plugin.SubProductName)) { cluster.Writer?.WriteDiagnostic( $"{{{nameof(InstallPlugins)}}} SKIP plugin [{plugin.SubProductName}] already installed"); @@ -92,7 +91,7 @@ public override void Run(IEphemeralCluster cluste } cluster.Writer?.WriteDiagnostic( - $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}: {{{plugin.IsValid(v)}}}"); + $"{{{nameof(InstallPlugins)}}} attempting install [{plugin.SubProductName}] as it's not OOTB: {{{plugin.ShippedByDefaultAsOf}}} and valid for {v}"); var homeConfigPath = Path.Combine(fs.OpenSearchHome, "config"); diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs index 9e84b0ee01..e7ea9bdf87 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/IntegrationTestDiscoverer.cs @@ -35,102 +35,109 @@ using Xunit; using Xunit.Abstractions; using Xunit.Sdk; -using Enumerable = System.Linq.Enumerable; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped, and a reason why. +/// +public abstract class SkipTestAttributeBase : Attribute { - /// - /// An Xunit test that should be skipped, and a reason why. - /// - public abstract class SkipTestAttributeBase : Attribute - { - /// - /// Whether the test should be skipped - /// - public abstract bool Skip { get; } + /// + /// Whether the test should be skipped + /// + public abstract bool Skip { get; } + + /// + /// The reason why the test should be skipped + /// + public abstract string Reason { get; } +} - /// - /// The reason why the test should be skipped - /// - public abstract string Reason { get; } - } +/// +/// An Xunit integration test +/// +[XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", + "OpenSearch.OpenSearch.Xunit")] +public class I : FactAttribute +{ +} - /// - /// An Xunit integration test - /// - [XunitTestCaseDiscoverer("OpenSearch.OpenSearch.Xunit.XunitPlumbing.IntegrationTestDiscoverer", - "OpenSearch.OpenSearch.Xunit")] - public class I : FactAttribute - { - } +/// +/// A test discoverer used to discover integration tests cases attached +/// to test methods that are attributed with attribute +/// +public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer +{ + public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) + { + } - /// - /// A test discoverer used to discover integration tests cases attached - /// to test methods that are attributed with attribute - /// - public class IntegrationTestDiscoverer : OpenSearchTestCaseDiscoverer - { - public IntegrationTestDiscoverer(IMessageSink diagnosticMessageSink) : base(diagnosticMessageSink) - { - } + /// + protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, + out string skipReason) + { + skipReason = null; + var runIntegrationTests = + discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); + if (!runIntegrationTests) return true; - /// - protected override bool SkipMethod(ITestFrameworkDiscoveryOptions discoveryOptions, ITestMethod testMethod, - out string skipReason) - { - skipReason = null; - var runIntegrationTests = - discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.RunIntegrationTests)); - if (!runIntegrationTests) return true; + var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); + if (cluster == null) + { + skipReason += + $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; + return true; + } - var cluster = TestAssemblyRunner.GetClusterForClass(testMethod.TestClass.Class); - if (cluster == null) - { - skipReason += - $"{testMethod.TestClass.Class.Name} does not define a cluster through IClusterFixture or {nameof(IntegrationTestClusterAttribute)}"; - return true; - } + var openSearchVersion = + discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.Version)); - var openSearchVersion = - discoveryOptions.GetValue(nameof(OpenSearchXunitRunOptions.Version)); + // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) + var skipVersionAttribute = GetAttributes(testMethod).FirstOrDefault(); + if (skipVersionAttribute != null) + { + var skipVersionRanges = + skipVersionAttribute.GetNamedArgument>(nameof(SkipVersionAttribute.Ranges)) ?? + new List(); + if (openSearchVersion == null && skipVersionRanges.Count > 0) + { + skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + + $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; + return true; + } - // Skip if the version we are testing against is attributed to be skipped do not run the test nameof(SkipVersionAttribute.Ranges) - var skipVersionAttribute = Enumerable.FirstOrDefault(GetAttributes(testMethod)); - if (skipVersionAttribute != null) - { - var skipVersionRanges = - skipVersionAttribute.GetNamedArgument>(nameof(SkipVersionAttribute.Ranges)) ?? - new List(); - if (openSearchVersion == null && skipVersionRanges.Count > 0) - { - skipReason = $"{nameof(SkipVersionAttribute)} has ranges defined for this test but " + - $"no {nameof(OpenSearchXunitRunOptions.Version)} has been provided to {nameof(OpenSearchXunitRunOptions)}"; - return true; - } + if (openSearchVersion != null) + { + var reason = skipVersionAttribute.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); + foreach (var range in skipVersionRanges) + { + // inrange takes prereleases into account + if (!openSearchVersion.InRange(range)) continue; + skipReason = + $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } + } + } - if (openSearchVersion != null) - { - var reason = skipVersionAttribute.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); - for (var index = 0; index < skipVersionRanges.Count; index++) - { - var range = skipVersionRanges[index]; - // inrange takes prereleases into account - if (!openSearchVersion.InRange(range)) continue; - skipReason = - $"{nameof(SkipVersionAttribute)} has range {range} that {openSearchVersion} satisfies"; - if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; - return true; - } - } - } + // Skip if a prerelease version and has SkipPrereleaseVersionsAttribute + var skipPrerelease = GetAttributes(testMethod).FirstOrDefault(); + if (openSearchVersion != null && openSearchVersion.IsPreRelease && skipPrerelease != null) + { + skipReason = $"{nameof(SkipPrereleaseVersionsAttribute)} has been applied to this test"; + var reason = skipPrerelease.GetNamedArgument(nameof(SkipVersionAttribute.Reason)); + if (!string.IsNullOrWhiteSpace(reason)) skipReason += $": {reason}"; + return true; + } - var skipTests = GetAttributes(testMethod) - .FirstOrDefault(a => a.GetNamedArgument(nameof(SkipTestAttributeBase.Skip))); + var skipTests = GetAttributes(testMethod) + .FirstOrDefault(a => a.GetNamedArgument(nameof(SkipTestAttributeBase.Skip))); - if (skipTests == null) return false; + if (skipTests == null) return false; - skipReason = skipTests.GetNamedArgument(nameof(SkipTestAttributeBase.Reason)); - return true; - } - } + skipReason = skipTests.GetNamedArgument(nameof(SkipTestAttributeBase.Reason)); + return true; + } } diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs new file mode 100644 index 0000000000..dee8646e6f --- /dev/null +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipPrereleaseVersionsAttribute.cs @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; + +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped for prerelease OpenSearch versions, and a reason why. +/// +public class SkipPrereleaseVersionsAttribute : Attribute +{ + public SkipPrereleaseVersionsAttribute(string reason) => Reason = reason; + + /// + /// The reason why the test should be skipped + /// + public string Reason { get; } +} diff --git a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs index cfeec7b8da..e885718fe2 100644 --- a/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs +++ b/abstractions/src/OpenSearch.OpenSearch.Xunit/XunitPlumbing/SkipVersionAttribute.cs @@ -31,35 +31,34 @@ using System.Linq; using SemanticVersioning; -namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing +namespace OpenSearch.OpenSearch.Xunit.XunitPlumbing; + +/// +/// A Xunit test that should be skipped for given OpenSearch versions, and a reason why. +/// +public class SkipVersionAttribute : Attribute { - /// - /// An Xunit test that should be skipped for given OpenSearch versions, and a reason why. - /// - public class SkipVersionAttribute : Attribute - { - // ReSharper disable once UnusedParameter.Local - // reason is used to allow the test its used on to self document why its been put in place - public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) - { - Reason = reason; - Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) - ? new List() - : skipVersionRangesSeparatedByComma.Split(',') - .Select(r => r.Trim()) - .Where(r => !string.IsNullOrWhiteSpace(r)) - .Select(r => new Range(r)) - .ToList(); - } + // ReSharper disable once UnusedParameter.Local + // reason is used to allow the test its used on to self document why its been put in place + public SkipVersionAttribute(string skipVersionRangesSeparatedByComma, string reason) + { + Reason = reason; + Ranges = string.IsNullOrEmpty(skipVersionRangesSeparatedByComma) + ? new List() + : skipVersionRangesSeparatedByComma.Split(',') + .Select(r => r.Trim()) + .Where(r => !string.IsNullOrWhiteSpace(r)) + .Select(r => new Range(r)) + .ToList(); + } - /// - /// The reason why the test should be skipped - /// - public string Reason { get; } + /// + /// The reason why the test should be skipped + /// + public string Reason { get; } - /// - /// The version ranges for which the test should be skipped - /// - public IList Ranges { get; } - } + /// + /// The version ranges for which the test should be skipped + /// + public IList Ranges { get; } } diff --git a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs index 536cebb3ee..f2ac62ab0e 100644 --- a/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs +++ b/abstractions/src/OpenSearch.Stack.ArtifactsApi/Products/OpenSearchPlugin.cs @@ -27,6 +27,7 @@ */ using System; +using Version = SemanticVersioning.Version; namespace OpenSearch.Stack.ArtifactsApi.Products { @@ -81,5 +82,9 @@ public OpenSearchPlugin(string plugin, Func isValid = n public static OpenSearchPlugin DeleteByQuery { get; } = new("delete-by-query", version => version < "1.0.0"); public static OpenSearchPlugin Knn { get; } = new("opensearch-knn"); - } + + public static OpenSearchPlugin MachineLearning { get; } = new("opensearch-ml", v => v.BaseVersion() >= new Version("1.3.0") && !v.IsPreRelease); + + public static OpenSearchPlugin NeuralSearch { get; } = new("opensearch-neural-search", v => v.BaseVersion() >= new Version("2.4.0") && !v.IsPreRelease); + } } diff --git a/samples/Samples/NeuralSearch/NeuralSearchSample.cs b/samples/Samples/NeuralSearch/NeuralSearchSample.cs index dc459d5418..aeb2c28a80 100644 --- a/samples/Samples/NeuralSearch/NeuralSearchSample.cs +++ b/samples/Samples/NeuralSearch/NeuralSearchSample.cs @@ -5,7 +5,6 @@ * compatible open source license. */ -using System.Diagnostics; using OpenSearch.Client; using OpenSearch.Net; @@ -46,7 +45,7 @@ protected override async Task Run(IOpenSearchClient client) .Add("plugins.ml_commons.only_run_on_ml_node", false) .Add("plugins.ml_commons.model_access_control_enabled", true) .Add("plugins.ml_commons.native_memory_threshold", 99))); - Debug.Assert(putSettingsResp.IsValid, putSettingsResp.DebugInformation); + Assert(putSettingsResp, r => r.IsValid); Console.WriteLine("Configured cluster to allow local execution of the ML model"); // Register an ML model group @@ -58,7 +57,7 @@ protected override async Task Run(IOpenSearchClient client) description = $"A model group for the opensearch-net {SampleName} sample", access_mode = "public" })); - Debug.Assert(registerModelGroupResp.Success && (string) registerModelGroupResp.Body.status == "CREATED", registerModelGroupResp.DebugInformation); + AssertCreatedStatus(registerModelGroupResp); Console.WriteLine($"Model group named {MlModelGroupName} {registerModelGroupResp.Body.status}: {registerModelGroupResp.Body.model_group_id}"); _modelGroupId = (string) registerModelGroupResp.Body.model_group_id; @@ -72,7 +71,7 @@ protected override async Task Run(IOpenSearchClient client) model_group_id = _modelGroupId, model_format = "TORCH_SCRIPT" })); - Debug.Assert(registerModelResp.Success && (string) registerModelResp.Body.status == "CREATED", registerModelResp.DebugInformation); + AssertCreatedStatus(registerModelResp); Console.WriteLine($"Model registration task {registerModelResp.Body.status}: {registerModelResp.Body.task_id}"); _modelRegistrationTaskId = (string) registerModelResp.Body.task_id; @@ -81,7 +80,7 @@ protected override async Task Run(IOpenSearchClient client) { var getTaskResp = await client.Http.GetAsync($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}"); Console.WriteLine($"Model registration: {getTaskResp.Body.state}"); - Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation); + AssertNotFailedState(getTaskResp); if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) { _modelId = getTaskResp.Body.model_id; @@ -93,7 +92,7 @@ protected override async Task Run(IOpenSearchClient client) // Deploy the ML model var deployModelResp = await client.Http.PostAsync($"/_plugins/_ml/models/{_modelId}/_deploy"); - Debug.Assert(deployModelResp.Success && (string) deployModelResp.Body.status == "CREATED", deployModelResp.DebugInformation); + AssertCreatedStatus(deployModelResp); Console.WriteLine($"Model deployment task {deployModelResp.Body.status}: {deployModelResp.Body.task_id}"); _modelDeployTaskId = (string) deployModelResp.Body.task_id; @@ -102,35 +101,21 @@ protected override async Task Run(IOpenSearchClient client) { var getTaskResp = await client.Http.GetAsync($"/_plugins/_ml/tasks/{_modelDeployTaskId}"); Console.WriteLine($"Model deployment: {getTaskResp.Body.state}"); - Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation); + AssertNotFailedState(getTaskResp); if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break; await Task.Delay(10000); } Console.WriteLine($"Model deployed: {_modelId}"); // Create the text_embedding ingest pipeline - // TODO: Client does not yet contain typings for the text_embedding processor - var putIngestPipelineResp = await client.Http.PutAsync( - $"/_ingest/pipeline/{IngestPipelineName}", - r => r.SerializableBody(new - { - description = $"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample", - processors = new[] - { - new - { - text_embedding = new - { - model_id = _modelId, - field_map = new - { - text = "passage_embedding" - } - } - } - } - })); - Debug.Assert(putIngestPipelineResp.IsValid, putIngestPipelineResp.DebugInformation); + var putIngestPipelineResp = await client.Ingest.PutPipelineAsync(IngestPipelineName, p => p + .Description($"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample") + .Processors(pp => pp + .TextEmbedding(te => te + .ModelId(_modelId) + .FieldMap(fm => fm + .Map(d => d.Text, d => d.PassageEmbedding))))); + AssertValid(putIngestPipelineResp); Console.WriteLine($"Put ingest pipeline {IngestPipelineName}: {putIngestPipelineResp.Acknowledged}"); _putIngestPipeline = true; @@ -152,7 +137,7 @@ protected override async Task Run(IOpenSearchClient client) .Engine("lucene") .SpaceType("l2") .Name("hnsw")))))); - Debug.Assert(createIndexResp.IsValid, createIndexResp.DebugInformation); + AssertValid(createIndexResp); Console.WriteLine($"Created index {IndexName}: {createIndexResp.Acknowledged}"); _createdIndex = true; @@ -169,31 +154,23 @@ protected override async Task Run(IOpenSearchClient client) .Index(IndexName) .IndexMany(documents) .Refresh(Refresh.WaitFor)); - Debug.Assert(bulkResp.IsValid, bulkResp.DebugInformation); + AssertValid(bulkResp); Console.WriteLine($"Indexed {documents.Length} documents"); // Perform the neural search - // TODO: Client does not yet contain typings for neural query type Console.WriteLine("Performing neural search for text 'wild west'"); - var searchResp = await client.Http.PostAsync>( - $"/{IndexName}/_search", - r => r.SerializableBody(new - { - _source = new { excludes = new[] { "passage_embedding" } }, - query = new - { - neural = new - { - passage_embedding = new - { - query_text = "wild west", - model_id = _modelId, - k = 5 - } - } - } - })); - Debug.Assert(searchResp.IsValid, searchResp.DebugInformation); + var searchResp = await client.SearchAsync(s => s + .Index(IndexName) + .Source(sf => sf + .Excludes(f => f + .Field(d => d.PassageEmbedding))) + .Query(q => q + .Neural(n => n + .Field(f => f.PassageEmbedding) + .QueryText("wild west") + .ModelId(_modelId) + .K(5)))); + AssertValid(searchResp); Console.WriteLine($"Found {searchResp.Hits.Count} documents"); foreach (var hit in searchResp.Hits) Console.WriteLine($"- Document id: {hit.Source.Id}, score: {hit.Score}, text: {hit.Source.Text}"); } @@ -205,7 +182,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the index var deleteIndexResp = await client.Indices.DeleteAsync(IndexName); - Debug.Assert(deleteIndexResp.IsValid, deleteIndexResp.DebugInformation); + AssertValid(deleteIndexResp); Console.WriteLine($"Deleted index: {deleteIndexResp.Acknowledged}"); } @@ -213,7 +190,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the ingest pipeline var deleteIngestPipelineResp = await client.Ingest.DeletePipelineAsync(IngestPipelineName); - Debug.Assert(deleteIngestPipelineResp.IsValid, deleteIngestPipelineResp.DebugInformation); + AssertValid(deleteIngestPipelineResp); Console.WriteLine($"Deleted ingest pipeline: {deleteIngestPipelineResp.Acknowledged}"); } @@ -221,7 +198,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model deployment task var deleteModelDeployTaskResp = await client.Http.DeleteAsync($"/_plugins/_ml/tasks/{_modelDeployTaskId}"); - Debug.Assert(deleteModelDeployTaskResp.Success && (string) deleteModelDeployTaskResp.Body.result == "deleted", deleteModelDeployTaskResp.DebugInformation); + AssertDeletedResult(deleteModelDeployTaskResp); Console.WriteLine($"Deleted model deployment task: {deleteModelDeployTaskResp.Body.result}"); } @@ -237,11 +214,11 @@ protected override async Task Cleanup(IOpenSearchClient client) break; } - Debug.Assert(((string?)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false, deleteModelResp.DebugInformation); + Assert(deleteModelResp, r => ((string?) r.Body.error?.reason)?.Contains("Try undeploy") ?? false); // Undeploy the ML model var undeployModelResp = await client.Http.PostAsync($"/_plugins/_ml/models/{_modelId}/_undeploy"); - Debug.Assert(undeployModelResp.Success, undeployModelResp.DebugInformation); + Assert(undeployModelResp, r => r.Success); Console.WriteLine("Undeployed model"); await Task.Delay(10000); } @@ -251,7 +228,7 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model registration task var deleteModelRegistrationTaskResp = await client.Http.DeleteAsync($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}"); - Debug.Assert(deleteModelRegistrationTaskResp.Success && (string) deleteModelRegistrationTaskResp.Body.result == "deleted", deleteModelRegistrationTaskResp.DebugInformation); + AssertDeletedResult(deleteModelRegistrationTaskResp); Console.WriteLine($"Deleted model registration task: {deleteModelRegistrationTaskResp.Body.result}"); } @@ -259,8 +236,17 @@ protected override async Task Cleanup(IOpenSearchClient client) { // Cleanup the model group var deleteModelGroupResp = await client.Http.DeleteAsync($"/_plugins/_ml/model_groups/{_modelGroupId}"); - Debug.Assert(deleteModelGroupResp.Success && (string) deleteModelGroupResp.Body.result == "deleted", deleteModelGroupResp.DebugInformation); + AssertDeletedResult(deleteModelGroupResp); Console.WriteLine($"Deleted model group: {deleteModelGroupResp.Body.result}"); } } + + private static void AssertCreatedStatus(DynamicResponse response) => + Assert(response, r => r.Success && (string)r.Body.status == "CREATED"); + + private static void AssertNotFailedState(DynamicResponse response) => + Assert(response, r => r.Success && (string) r.Body.state != "FAILED"); + + private static void AssertDeletedResult(DynamicResponse response) => + Assert(response, r => r.Success && (string) r.Body.result == "deleted"); } diff --git a/samples/Samples/Sample.cs b/samples/Samples/Sample.cs index e683b98935..4057486f11 100644 --- a/samples/Samples/Sample.cs +++ b/samples/Samples/Sample.cs @@ -8,6 +8,7 @@ using System.CommandLine; using System.CommandLine.Binding; using OpenSearch.Client; +using OpenSearch.Net; namespace Samples; @@ -58,4 +59,13 @@ public Command AsCommand(IValueDescriptor clientDescriptor) protected abstract Task Run(IOpenSearchClient client); protected virtual Task Cleanup(IOpenSearchClient client) => Task.CompletedTask; + + protected static void Assert(T response, Func condition) where T : IOpenSearchResponse + { + if (condition(response)) return; + + throw new Exception($"Assertion failed:\n{response.ApiCall?.DebugInformation}"); + } + + protected static void AssertValid(IResponse response) => Assert(response, r => r.IsValid); } diff --git a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs index 714f43b474..a6b83805b1 100644 --- a/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs +++ b/src/OpenSearch.Client/CommonAbstractions/Infer/Indices/Indices.cs @@ -46,12 +46,12 @@ internal Indices(ManyIndices indices) : base(indices) { } internal Indices(IEnumerable indices) : base(new ManyIndices(indices)) { } /// All indices. Represents _all - public static Indices All { get; } = new Indices(new AllIndicesMarker()); + public static Indices All { get; } = new(new AllIndicesMarker()); /// - public static Indices AllIndices { get; } = All; + public static Indices AllIndices => All; - private string DebugDisplay => Match( + private string DebugDisplay => Match( all => "_all", types => $"Count: {types.Indices.Count} [" + string.Join(",", types.Indices.Select((t, i) => $"({i + 1}: {t.DebugDisplay})")) + "]" ); @@ -62,11 +62,13 @@ string IUrlParameter.GetString(IConnectionConfigurationValues settings) => Match all => "_all", many => { - if (!(settings is IConnectionSettingsValues oscSettings)) - throw new Exception( - "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + if (settings is not IConnectionSettingsValues oscSettings) + { + throw new Exception( + "Tried to pass index names on querysting but it could not be resolved because no OpenSearch.Client settings are available"); + } - var infer = oscSettings.Inferrer; + var infer = oscSettings.Inferrer; var indices = many.Indices.Select(i => infer.IndexName(i)).Distinct(); return string.Join(",", indices); } diff --git a/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs b/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs index 7e260ff2c3..447d5042c8 100644 --- a/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs +++ b/src/OpenSearch.Client/Ingest/ProcessorFormatter.cs @@ -70,7 +70,8 @@ internal class ProcessorFormatter : IJsonFormatter { "uri_parts", 30 }, { "fingerprint", 31 }, { "community_id", 32 }, - { "network_direction", 33 } + { "network_direction", 33 }, + { "text_embedding", 34 } }; public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver formatterResolver) @@ -193,6 +194,9 @@ public IProcessor Deserialize(ref JsonReader reader, IJsonFormatterResolver form case 33: processor = Deserialize(ref reader, formatterResolver); break; + case 34: + processor = Deserialize(ref reader, formatterResolver); + break; } } else @@ -316,6 +320,9 @@ public void Serialize(ref JsonWriter writer, IProcessor value, IJsonFormatterRes case "network_direction": Serialize(ref writer, value, formatterResolver); break; + case "text_embedding": + Serialize(ref writer, value, formatterResolver); + break; default: var formatter = DynamicObjectResolver.ExcludeNullCamelCase.GetFormatter(); formatter.Serialize(ref writer, value, formatterResolver); diff --git a/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs new file mode 100644 index 0000000000..8384b34cb7 --- /dev/null +++ b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/InferenceProcessorBase.cs @@ -0,0 +1,119 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +[JsonFormatter(typeof(VerbatimDictionaryKeysFormatter))] +public interface IInferenceFieldMap : IIsADictionary { } + +[InterfaceDataContract] +public interface IInferenceProcessor : IProcessor +{ + /// + /// The ID of the model that will be used to generate the embeddings. + /// The model must be deployed in OpenSearch before it can be used in neural search. + /// + /// + /// For more information, + /// see Using custom models within OpenSearch + /// and Semantic search. + /// + [DataMember(Name = "model_id")] + string ModelId { get; set; } + + /// + /// Contains key-value pairs that specify the mapping of a text field to a vector field. + ///
    + ///
  • Key being the name of the field from which to generate embeddings.
  • + ///
  • Value being the name of the vector field in which to store the generated embeddings.
  • + ///
+ ///
+ [DataMember(Name = "field_map")] + IInferenceFieldMap FieldMap { get; set; } +} + +public class InferenceFieldMap : IsADictionaryBase, IInferenceFieldMap +{ + public InferenceFieldMap() { } + public InferenceFieldMap(IDictionary container) : base(container) { } + + public void Add(Field source, Field target) => BackingDictionary.Add(source, target); +} + +/// +public abstract class InferenceProcessorBase : ProcessorBase, IInferenceProcessor +{ + /// + public string ModelId { get; set; } + /// + public IInferenceFieldMap FieldMap { get; set; } +} + +public class InferenceFieldMapDescriptor + : IsADictionaryDescriptorBase, InferenceFieldMap, Field, Field> + where TDocument : class +{ + public InferenceFieldMapDescriptor() : base(new InferenceFieldMap()) { } + + public InferenceFieldMapDescriptor Map(Field source, Field target) => + Assign(source, target); + + public InferenceFieldMapDescriptor Map( + Expression> source, + Field target + ) => + Assign(source, target); + + public InferenceFieldMapDescriptor Map( + Field source, + Expression> target + ) => + Assign(source, target); + + public InferenceFieldMapDescriptor Map( + Expression> source, + Expression> target + ) => + Assign(source, target); +} + +/// +public abstract class InferenceProcessorDescriptorBase + : ProcessorDescriptorBase, IInferenceProcessor + where T : class + where TInferenceProcessorDescriptor : InferenceProcessorDescriptorBase, TInferenceProcessorInterface + where TInferenceProcessorInterface : class, IInferenceProcessor +{ + string IInferenceProcessor.ModelId { get; set; } + IInferenceFieldMap IInferenceProcessor.FieldMap { get; set; } + + /// + public TInferenceProcessorDescriptor ModelId(string modelId) => Assign(modelId, (a, v) => a.ModelId = v); + + /// + public TInferenceProcessorDescriptor FieldMap(IDictionary fieldMap) => + Assign(fieldMap, (a, v) => a.FieldMap = v != null ? new InferenceFieldMap(v) : null); + + /// + public TInferenceProcessorDescriptor FieldMap(IInferenceFieldMap fieldMap) => + Assign(fieldMap, (a, v) => a.FieldMap = v); + + /// + public TInferenceProcessorDescriptor FieldMap(Func, IPromise> selector) => + Assign(selector, (a, v) => a.FieldMap = v?.Invoke(new InferenceFieldMapDescriptor())?.Value); + + /// + public TInferenceProcessorDescriptor FieldMap(Func, IPromise> selector) + where TDocument : class => + Assign(selector, (a, v) => a.FieldMap = v?.Invoke(new InferenceFieldMapDescriptor())?.Value); +} diff --git a/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs new file mode 100644 index 0000000000..5df6fc70ff --- /dev/null +++ b/src/OpenSearch.Client/Ingest/Processors/Plugins/NeuralSearch/TextEmbeddingProcessor.cs @@ -0,0 +1,32 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +/// +/// The text_embedding processor is used to generate vector embeddings from text fields for semantic search. +/// +[InterfaceDataContract] +public interface ITextEmbeddingProcessor : IInferenceProcessor +{ +} + +/// +public class TextEmbeddingProcessor : InferenceProcessorBase, ITextEmbeddingProcessor +{ + protected override string Name => "text_embedding"; +} + +/// +public class TextEmbeddingProcessorDescriptor + : InferenceProcessorDescriptorBase, ITextEmbeddingProcessor>, ITextEmbeddingProcessor + where TDocument : class +{ + protected override string Name => "text_embedding"; +} diff --git a/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs b/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs index 7fab35d570..1fd0a29b23 100644 --- a/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs +++ b/src/OpenSearch.Client/Ingest/ProcessorsDescriptor.cs @@ -205,5 +205,9 @@ public ProcessorsDescriptor NetworkCommunityId(Func public ProcessorsDescriptor NetworkDirection(Func, INetworkDirectionProcessor> selector) where T : class => Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new NetworkDirectionProcessorDescriptor()))); - } + + /// + public ProcessorsDescriptor TextEmbedding(Func, ITextEmbeddingProcessor> selector) where T : class => + Assign(selector, (a, v) => a.AddIfNotNull(v?.Invoke(new TextEmbeddingProcessorDescriptor()))); + } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs index 3468f49ee4..e71268f7a7 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/IQueryContainer.cs @@ -203,6 +203,9 @@ public interface IQueryContainer [DataMember(Name = "knn")] IKnnQuery Knn { get; set; } + [DataMember(Name = "neural")] + INeuralQuery Neural { get; set; } + void Accept(IQueryVisitor visitor); } } diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs index a7b9c79fdb..45cfb19d3a 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainer-Assignments.cs @@ -61,6 +61,7 @@ public partial class QueryContainer : IQueryContainer, IDescriptor private IMoreLikeThisQuery _moreLikeThis; private IMultiMatchQuery _multiMatch; private INestedQuery _nested; + private INeuralQuery _neural; private IParentIdQuery _parentId; private IPercolateQuery _percolate; private IPrefixQuery _prefix; @@ -254,6 +255,12 @@ INestedQuery IQueryContainer.Nested set => _nested = Set(value); } + INeuralQuery IQueryContainer.Neural + { + get => _neural; + set => _neural = Set(value); + } + IParentIdQuery IQueryContainer.ParentId { get => _parentId; diff --git a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs index 419e41d869..3ccb8529bb 100644 --- a/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs +++ b/src/OpenSearch.Client/QueryDsl/Abstractions/Container/QueryContainerDescriptor.cs @@ -32,480 +32,485 @@ namespace OpenSearch.Client { - [DataContract] - public class QueryContainerDescriptor : QueryContainer where T : class - { - private QueryContainer WrapInContainer( - Func create, - Action assign - ) - where TQuery : class, TQueryInterface, IQuery, new() - where TQueryInterface : class, IQuery - { - // Invoke the create delegate before assigning container; the create delegate - // may mutate the current QueryContainerDescriptor instance such that it - // contains a query. See https://github.com/elastic/elasticsearch-net/issues/2875 - var query = create.InvokeOrDefault(new TQuery()); - - var container = ContainedQuery == null - ? this - : new QueryContainerDescriptor(); - - IQueryContainer c = container; - c.IsVerbatim = query.IsVerbatim; - c.IsStrict = query.IsStrict; - assign(query, container); - container.ContainedQuery = query; - - //if query is writable (not conditionless or verbatim): return a container that holds the query - if (query.IsWritable) - return container; - - //query is conditionless but marked as strict, throw exception - if (query.IsStrict) - throw new ArgumentException("Query is conditionless but strict is turned on"); - - //query is conditionless return an empty container that can later be rewritten - return null; - } - - /// - /// A query defined using a raw json string. - /// The query must be enclosed within '{' and '}' - /// - /// The query dsl json - public QueryContainer Raw(string rawJson) => - WrapInContainer((RawQueryDescriptor descriptor) => descriptor.Raw(rawJson), (query, container) => container.RawQuery = query); - - /// - /// A query that uses a query parser in order to parse its content. - /// - public QueryContainer QueryString(Func, IQueryStringQuery> selector) => - WrapInContainer(selector, (query, container) => container.QueryString = query); - - /// - /// A query that uses the SimpleQueryParser to parse its context. - /// Unlike the regular query_string query, the simple_query_string query will - /// never throw an exception, and discards invalid parts of the query. - /// - public QueryContainer SimpleQueryString(Func, ISimpleQueryStringQuery> selector) => - WrapInContainer(selector, (query, container) => container.SimpleQueryString = query); - - /// - /// A query that match on any (configurable) of the provided terms. - /// This is a simpler syntax query for using a bool query with several term queries in the should clauses. - /// - public QueryContainer Terms(Func, ITermsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Terms = query); - - /// - /// A fuzzy based query that uses similarity based on Levenshtein (edit distance) algorithm. - /// Warning: this query is not very scalable with its default prefix length of 0. In this case, - /// every term will be enumerated and cause an edit score calculation or max_expansions is not set. - /// - public QueryContainer Fuzzy(Func, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - public QueryContainer FuzzyNumeric(Func, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - public QueryContainer FuzzyDate(Func, IFuzzyQuery> selector) => - WrapInContainer(selector, (query, container) => container.Fuzzy = query); - - /// - /// The default match query is of type boolean. It means that the text provided is analyzed and the analysis - /// process constructs a boolean query from the provided text. - /// - public QueryContainer Match(Func, IMatchQuery> selector) => - WrapInContainer(selector, (query, container) => container.Match = query); - - /// - /// The match_phrase query analyzes the match and creates a phrase query out of the analyzed text. - /// - public QueryContainer MatchPhrase(Func, IMatchPhraseQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchPhrase = query); - - /// - public QueryContainer MatchBoolPrefix(Func, IMatchBoolPrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchBoolPrefix = query); - - /// - /// The match_phrase_prefix is the same as match_phrase, expect it allows for prefix matches on the last term - /// in the text - /// - public QueryContainer MatchPhrasePrefix(Func, IMatchPhrasePrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.MatchPhrasePrefix = query); - - /// - /// The multi_match query builds further on top of the match query by allowing multiple fields to be specified. - /// The idea here is to allow to more easily build a concise match type query over multiple fields instead of using a - /// relatively more expressive query by using multiple match queries within a bool query. - /// - public QueryContainer MultiMatch(Func, IMultiMatchQuery> selector) => - WrapInContainer(selector, (query, container) => container.MultiMatch = query); - - /// - /// Nested query allows to query nested objects / docs (see nested mapping). The query is executed against the - /// nested objects / docs as if they were indexed as separate docs (they are, internally) and resulting in the - /// root parent doc (or parent nested mapping). - /// - public QueryContainer Nested(Func, INestedQuery> selector) => - WrapInContainer(selector, (query, container) => container.Nested = query); - - /// - /// A thin wrapper allowing fined grained control what should happen if a query is conditionless - /// if you need to fallback to something other than a match_all query - /// - public QueryContainer Conditionless(Func, IConditionlessQuery> selector) - { - var query = selector(new ConditionlessQueryDescriptor()); - return query?.Query ?? query?.Fallback; - } - - /// - /// Matches documents with fields that have terms within a certain numeric range. - /// - public QueryContainer Range(Func, INumericRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - public QueryContainer LongRange(Func, ILongRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// - /// Matches documents with fields that have terms within a certain date range. - /// - public QueryContainer DateRange(Func, IDateRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// - /// Matches documents with fields that have terms within a certain term range. - /// - public QueryContainer TermRange(Func, ITermRangeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Range = query); - - /// - /// More like this query find documents that are like the provided text by running it against one or more fields. - /// - public QueryContainer MoreLikeThis(Func, IMoreLikeThisQuery> selector) => - WrapInContainer(selector, (query, container) => container.MoreLikeThis = query); - - /// - /// A geo_shape query that finds documents - /// that have a geometry that matches for the given spatial relation and input shape - /// - public QueryContainer GeoShape(Func, IGeoShapeQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoShape = query); - - /// - /// Finds documents with shapes that either intersect, are within, or do not intersect a specified shape. - /// - public QueryContainer Shape(Func, IShapeQuery> selector) => - WrapInContainer(selector, (query, container) => container.Shape = query); - - /// - /// Matches documents with a geo_point type field that falls within a polygon of points - /// - public QueryContainer GeoPolygon(Func, IGeoPolygonQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoPolygon = query); - - /// - /// Matches documents with a geo_point type field to include only those - /// that exist within a specific distance from a given geo_point - /// - public QueryContainer GeoDistance(Func, IGeoDistanceQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoDistance = query); - - /// - /// Matches documents with a geo_point type field to include only those that exist within a bounding box - /// - public QueryContainer GeoBoundingBox(Func, IGeoBoundingBoxQuery> selector) => - WrapInContainer(selector, (query, container) => container.GeoBoundingBox = query); - - /// - /// The has_child query works the same as the has_child filter, by automatically wrapping the filter with a - /// constant_score. - /// - /// Type of the child - public QueryContainer HasChild(Func, IHasChildQuery> selector) where TChild : class => - WrapInContainer(selector, (query, container) => container.HasChild = query); - - /// - /// The has_parent query works the same as the has_parent filter, by automatically wrapping the filter with a - /// constant_score. - /// - /// Type of the parent - public QueryContainer HasParent(Func, IHasParentQuery> selector) where TParent : class => - WrapInContainer(selector, (query, container) => container.HasParent = query); - - public QueryContainer Knn(Func, IKnnQuery> selector) => - WrapInContainer(selector, (query, container) => container.Knn = query); - - /// - /// A query that generates the union of documents produced by its subqueries, and that scores each document - /// with the maximum score for that document as produced by any subquery, plus a tie breaking increment for - /// any additional matching subqueries. - /// - public QueryContainer DisMax(Func, IDisMaxQuery> selector) => - WrapInContainer(selector, (query, container) => container.DisMax = query); - - /// - public QueryContainer DistanceFeature(Func, IDistanceFeatureQuery> selector) => - WrapInContainer(selector, (query, container) => container.DistanceFeature = query); - - /// - /// A query that wraps a filter or another query and simply returns a constant score equal to the query boost - /// for every document in the filter. Maps to Lucene ConstantScoreQuery. - /// - public QueryContainer ConstantScore(Func, IConstantScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.ConstantScore = query); - - /// - /// A query that matches documents matching boolean combinations of other queries. The bool query maps to - /// Lucene BooleanQuery. - /// It is built using one or more boolean clauses, each clause with a typed occurrence - /// - public QueryContainer Bool(Func, IBoolQuery> selector) => - WrapInContainer(selector, (query, container) => container.Bool = query); - - /// - /// A query that can be used to effectively demote results that match a given query. - /// Unlike the "must_not" clause in bool query, this still selects documents that contain - /// undesirable terms, but reduces their overall score. - /// - public QueryContainer Boosting(Func, IBoostingQuery> selector) => - WrapInContainer(selector, (query, container) => container.Boosting = query); - - /// - /// A query that matches all documents. Maps to Lucene MatchAllDocsQuery. - /// - public QueryContainer MatchAll(Func selector = null) => - WrapInContainer(selector, (query, container) => container.MatchAll = query ?? new MatchAllQuery()); - - /// - /// A query that matches no documents. This is the inverse of the match_all query. - /// - public QueryContainer MatchNone(Func selector = null) => - WrapInContainer(selector, (query, container) => container.MatchNone = query ?? new MatchNoneQuery()); - - /// - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// - public QueryContainer Term(Expression> field, object value, double? boost = null, string name = null) => - Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); - - /// - /// Helper method to easily filter on join relations - /// - public QueryContainer HasRelationName(Expression> field, RelationName value) => - Term(t => t.Field(field).Value(value)); - - /// Helper method to easily filter on join relations - public QueryContainer HasRelationName(Expression> field) => - Term(t => t.Field(field).Value(Infer.Relation())); - - /// - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// - public QueryContainer Term(Field field, object value, double? boost = null, string name = null) => - Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); - - /// - /// Matches documents that have fields that contain a term (not analyzed). - /// The term query maps to Lucene TermQuery. - /// - public QueryContainer Term(Func, ITermQuery> selector) => - WrapInContainer(selector, (query, container) => container.Term = query); - - /// - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate - /// over many terms. In order to prevent extremely slow wildcard queries, a wildcard term should - /// not start with one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// - public QueryContainer Wildcard(Expression> field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, - string name = null - ) => - Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); - - /// - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. - /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with - /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// - public QueryContainer Wildcard(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => - Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); - - /// - /// Matches documents that have fields matching a wildcard expression (not analyzed). - /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, - /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. - /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with - /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. - /// - public QueryContainer Wildcard(Func, IWildcardQuery> selector) => - WrapInContainer(selector, (query, container) => container.Wildcard = query); - - /// - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// - public QueryContainer Prefix(Expression> field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, - string name = null - ) => - Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); - - /// - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// - public QueryContainer Prefix(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => - Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); - - /// - /// Matches documents that have fields containing terms with a specified prefix (not analyzed). - /// The prefix query maps to Lucene PrefixQuery. - /// - public QueryContainer Prefix(Func, IPrefixQuery> selector) => - WrapInContainer(selector, (query, container) => container.Prefix = query); - - /// - /// Matches documents that only have the provided ids. - /// Note, this filter does not require the _id field to be indexed since - /// it works using the _uid field. - /// - public QueryContainer Ids(Func selector) => - WrapInContainer(selector, (query, container) => container.Ids = query); - - /// - /// Allows fine-grained control over the order and proximity of matching terms. - /// Matching rules are constructed from a small set of definitions, - /// and the rules are then applied to terms from a particular field. - /// The definitions produce sequences of minimal intervals that span terms in a body of text. - /// These intervals can be further combined and filtered by parent sources. - /// - public QueryContainer Intervals(Func, IIntervalsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Intervals = query); - - /// - public QueryContainer RankFeature(Func, IRankFeatureQuery> selector) => - WrapInContainer(selector, (query, container) => container.RankFeature = query); - - /// - /// Matches spans containing a term. The span term query maps to Lucene SpanTermQuery. - /// - public QueryContainer SpanTerm(Func, ISpanTermQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanTerm = query); - - /// - /// Matches spans near the beginning of a field. The span first query maps to Lucene SpanFirstQuery. - /// - public QueryContainer SpanFirst(Func, ISpanFirstQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanFirst = query); - - /// - /// Matches spans which are near one another. One can specify slop, the maximum number of - /// intervening unmatched positions, as well as whether matches are required to be in-order. - /// The span near query maps to Lucene SpanNearQuery. - /// - public QueryContainer SpanNear(Func, ISpanNearQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanNear = query); - - /// - /// Matches the union of its span clauses. - /// The span or query maps to Lucene SpanOrQuery. - /// - public QueryContainer SpanOr(Func, ISpanOrQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanOr = query); - - /// - /// Removes matches which overlap with another span query. - /// The span not query maps to Lucene SpanNotQuery. - /// - public QueryContainer SpanNot(Func, ISpanNotQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanNot = query); - - /// - /// Wrap a multi term query (one of fuzzy, prefix, term range or regexp query) - /// as a span query so it can be nested. - /// - public QueryContainer SpanMultiTerm(Func, ISpanMultiTermQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanMultiTerm = query); - - /// - /// Returns matches which enclose another span query. - /// The span containing query maps to Lucene SpanContainingQuery - /// - public QueryContainer SpanContaining(Func, ISpanContainingQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanContaining = query); - - /// - /// Returns Matches which are enclosed inside another span query. - /// The span within query maps to Lucene SpanWithinQuery - /// - public QueryContainer SpanWithin(Func, ISpanWithinQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanWithin = query); - - /// - /// Wraps span queries to allow them to participate in composite single-field Span queries by 'lying' about their search field. - /// That is, the masked span query will function as normal, but the field points back to the set field of the query. - /// This can be used to support queries like SpanNearQuery or SpanOrQuery across different fields, - /// which is not ordinarily permitted. - /// - public QueryContainer SpanFieldMasking(Func, ISpanFieldMaskingQuery> selector) => - WrapInContainer(selector, (query, container) => container.SpanFieldMasking = query); - - /// - /// Allows you to use regular expression term queries. - /// "term queries" means that OpenSearch will apply the regexp to the terms produced - /// by the tokenizer for that field, and not to the original text of the field. - /// - public QueryContainer Regexp(Func, IRegexpQuery> selector) => - WrapInContainer(selector, (query, container) => container.Regexp = query); - - /// - /// The function_score query allows you to modify the score of documents that are retrieved by a query. - /// This can be useful if, for example, a score function is computationally expensive and it is - /// sufficient to compute the score on a filtered set of documents. - /// - /// - public QueryContainer FunctionScore(Func, IFunctionScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.FunctionScore = query); - - public QueryContainer Script(Func, IScriptQuery> selector) => - WrapInContainer(selector, (query, container) => container.Script = query); - - public QueryContainer ScriptScore(Func, IScriptScoreQuery> selector) => - WrapInContainer(selector, (query, container) => container.ScriptScore = query); - - public QueryContainer Exists(Func, IExistsQuery> selector) => - WrapInContainer(selector, (query, container) => container.Exists = query); - - /// - /// Used to match queries stored in an index. - /// The percolate query itself contains the document that will be used as query - /// to match with the stored queries. - /// - public QueryContainer Percolate(Func, IPercolateQuery> selector) => - WrapInContainer(selector, (query, container) => container.Percolate = query); - - /// - /// Used to find child documents which belong to a particular parent. - /// - public QueryContainer ParentId(Func, IParentIdQuery> selector) => - WrapInContainer(selector, (query, container) => container.ParentId = query); - - /// - /// Returns any documents that match with at least one or more of the provided terms. - /// The terms are not analyzed and thus must match exactly. The number of terms that must match varies - /// per document and is either controlled by a minimum should match field or - /// computed per document in a minimum should match script. - /// - public QueryContainer TermsSet(Func, ITermsSetQuery> selector) => - WrapInContainer(selector, (query, container) => container.TermsSet = query); - } + [DataContract] + public class QueryContainerDescriptor : QueryContainer where T : class + { + private QueryContainer WrapInContainer( + Func create, + Action assign + ) + where TQuery : class, TQueryInterface, IQuery, new() + where TQueryInterface : class, IQuery + { + // Invoke the create delegate before assigning container; the create delegate + // may mutate the current QueryContainerDescriptor instance such that it + // contains a query. See https://github.com/elastic/elasticsearch-net/issues/2875 + var query = create.InvokeOrDefault(new TQuery()); + + var container = ContainedQuery == null + ? this + : new QueryContainerDescriptor(); + + IQueryContainer c = container; + c.IsVerbatim = query.IsVerbatim; + c.IsStrict = query.IsStrict; + assign(query, container); + container.ContainedQuery = query; + + //if query is writable (not conditionless or verbatim): return a container that holds the query + if (query.IsWritable) + return container; + + //query is conditionless but marked as strict, throw exception + if (query.IsStrict) + throw new ArgumentException("Query is conditionless but strict is turned on"); + + //query is conditionless return an empty container that can later be rewritten + return null; + } + + /// + /// A query defined using a raw json string. + /// The query must be enclosed within '{' and '}' + /// + /// The query dsl json + public QueryContainer Raw(string rawJson) => + WrapInContainer((RawQueryDescriptor descriptor) => descriptor.Raw(rawJson), (query, container) => container.RawQuery = query); + + /// + /// A query that uses a query parser in order to parse its content. + /// + public QueryContainer QueryString(Func, IQueryStringQuery> selector) => + WrapInContainer(selector, (query, container) => container.QueryString = query); + + /// + /// A query that uses the SimpleQueryParser to parse its context. + /// Unlike the regular query_string query, the simple_query_string query will + /// never throw an exception, and discards invalid parts of the query. + /// + public QueryContainer SimpleQueryString(Func, ISimpleQueryStringQuery> selector) => + WrapInContainer(selector, (query, container) => container.SimpleQueryString = query); + + /// + /// A query that match on any (configurable) of the provided terms. + /// This is a simpler syntax query for using a bool query with several term queries in the should clauses. + /// + public QueryContainer Terms(Func, ITermsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Terms = query); + + /// + /// A fuzzy based query that uses similarity based on Levenshtein (edit distance) algorithm. + /// Warning: this query is not very scalable with its default prefix length of 0. In this case, + /// every term will be enumerated and cause an edit score calculation or max_expansions is not set. + /// + public QueryContainer Fuzzy(Func, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + public QueryContainer FuzzyNumeric(Func, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + public QueryContainer FuzzyDate(Func, IFuzzyQuery> selector) => + WrapInContainer(selector, (query, container) => container.Fuzzy = query); + + /// + /// The default match query is of type boolean. It means that the text provided is analyzed and the analysis + /// process constructs a boolean query from the provided text. + /// + public QueryContainer Match(Func, IMatchQuery> selector) => + WrapInContainer(selector, (query, container) => container.Match = query); + + /// + /// The match_phrase query analyzes the match and creates a phrase query out of the analyzed text. + /// + public QueryContainer MatchPhrase(Func, IMatchPhraseQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchPhrase = query); + + /// + public QueryContainer MatchBoolPrefix(Func, IMatchBoolPrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchBoolPrefix = query); + + /// + /// The match_phrase_prefix is the same as match_phrase, expect it allows for prefix matches on the last term + /// in the text + /// + public QueryContainer MatchPhrasePrefix(Func, IMatchPhrasePrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.MatchPhrasePrefix = query); + + /// + /// The multi_match query builds further on top of the match query by allowing multiple fields to be specified. + /// The idea here is to allow to more easily build a concise match type query over multiple fields instead of using a + /// relatively more expressive query by using multiple match queries within a bool query. + /// + public QueryContainer MultiMatch(Func, IMultiMatchQuery> selector) => + WrapInContainer(selector, (query, container) => container.MultiMatch = query); + + /// + /// Nested query allows to query nested objects / docs (see nested mapping). The query is executed against the + /// nested objects / docs as if they were indexed as separate docs (they are, internally) and resulting in the + /// root parent doc (or parent nested mapping). + /// + public QueryContainer Nested(Func, INestedQuery> selector) => + WrapInContainer(selector, (query, container) => container.Nested = query); + + /// + /// A thin wrapper allowing fined grained control what should happen if a query is conditionless + /// if you need to fallback to something other than a match_all query + /// + public QueryContainer Conditionless(Func, IConditionlessQuery> selector) + { + var query = selector(new ConditionlessQueryDescriptor()); + return query?.Query ?? query?.Fallback; + } + + /// + /// Matches documents with fields that have terms within a certain numeric range. + /// + public QueryContainer Range(Func, INumericRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + public QueryContainer LongRange(Func, ILongRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// + /// Matches documents with fields that have terms within a certain date range. + /// + public QueryContainer DateRange(Func, IDateRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// + /// Matches documents with fields that have terms within a certain term range. + /// + public QueryContainer TermRange(Func, ITermRangeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Range = query); + + /// + /// More like this query find documents that are like the provided text by running it against one or more fields. + /// + public QueryContainer MoreLikeThis(Func, IMoreLikeThisQuery> selector) => + WrapInContainer(selector, (query, container) => container.MoreLikeThis = query); + + /// + /// A geo_shape query that finds documents + /// that have a geometry that matches for the given spatial relation and input shape + /// + public QueryContainer GeoShape(Func, IGeoShapeQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoShape = query); + + /// + /// Finds documents with shapes that either intersect, are within, or do not intersect a specified shape. + /// + public QueryContainer Shape(Func, IShapeQuery> selector) => + WrapInContainer(selector, (query, container) => container.Shape = query); + + /// + /// Matches documents with a geo_point type field that falls within a polygon of points + /// + public QueryContainer GeoPolygon(Func, IGeoPolygonQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoPolygon = query); + + /// + /// Matches documents with a geo_point type field to include only those + /// that exist within a specific distance from a given geo_point + /// + public QueryContainer GeoDistance(Func, IGeoDistanceQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoDistance = query); + + /// + /// Matches documents with a geo_point type field to include only those that exist within a bounding box + /// + public QueryContainer GeoBoundingBox(Func, IGeoBoundingBoxQuery> selector) => + WrapInContainer(selector, (query, container) => container.GeoBoundingBox = query); + + /// + /// The has_child query works the same as the has_child filter, by automatically wrapping the filter with a + /// constant_score. + /// + /// Type of the child + public QueryContainer HasChild(Func, IHasChildQuery> selector) where TChild : class => + WrapInContainer(selector, (query, container) => container.HasChild = query); + + /// + /// The has_parent query works the same as the has_parent filter, by automatically wrapping the filter with a + /// constant_score. + /// + /// Type of the parent + public QueryContainer HasParent(Func, IHasParentQuery> selector) where TParent : class => + WrapInContainer(selector, (query, container) => container.HasParent = query); + + public QueryContainer Knn(Func, IKnnQuery> selector) => + WrapInContainer(selector, (query, container) => container.Knn = query); + + /// + /// A query that generates the union of documents produced by its subqueries, and that scores each document + /// with the maximum score for that document as produced by any subquery, plus a tie breaking increment for + /// any additional matching subqueries. + /// + public QueryContainer DisMax(Func, IDisMaxQuery> selector) => + WrapInContainer(selector, (query, container) => container.DisMax = query); + + /// + public QueryContainer DistanceFeature(Func, IDistanceFeatureQuery> selector) => + WrapInContainer(selector, (query, container) => container.DistanceFeature = query); + + /// + /// A query that wraps a filter or another query and simply returns a constant score equal to the query boost + /// for every document in the filter. Maps to Lucene ConstantScoreQuery. + /// + public QueryContainer ConstantScore(Func, IConstantScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.ConstantScore = query); + + /// + /// A query that matches documents matching boolean combinations of other queries. The bool query maps to + /// Lucene BooleanQuery. + /// It is built using one or more boolean clauses, each clause with a typed occurrence + /// + public QueryContainer Bool(Func, IBoolQuery> selector) => + WrapInContainer(selector, (query, container) => container.Bool = query); + + /// + /// A query that can be used to effectively demote results that match a given query. + /// Unlike the "must_not" clause in bool query, this still selects documents that contain + /// undesirable terms, but reduces their overall score. + /// + public QueryContainer Boosting(Func, IBoostingQuery> selector) => + WrapInContainer(selector, (query, container) => container.Boosting = query); + + /// + /// A query that matches all documents. Maps to Lucene MatchAllDocsQuery. + /// + public QueryContainer MatchAll(Func selector = null) => + WrapInContainer(selector, (query, container) => container.MatchAll = query ?? new MatchAllQuery()); + + /// + /// A query that matches no documents. This is the inverse of the match_all query. + /// + public QueryContainer MatchNone(Func selector = null) => + WrapInContainer(selector, (query, container) => container.MatchNone = query ?? new MatchNoneQuery()); + + /// + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// + public QueryContainer Term(Expression> field, object value, double? boost = null, string name = null) => + Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); + + /// + /// Helper method to easily filter on join relations + /// + public QueryContainer HasRelationName(Expression> field, RelationName value) => + Term(t => t.Field(field).Value(value)); + + /// Helper method to easily filter on join relations + public QueryContainer HasRelationName(Expression> field) => + Term(t => t.Field(field).Value(Infer.Relation())); + + /// + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// + public QueryContainer Term(Field field, object value, double? boost = null, string name = null) => + Term(t => t.Field(field).Value(value).Boost(boost).Name(name)); + + /// + /// Matches documents that have fields that contain a term (not analyzed). + /// The term query maps to Lucene TermQuery. + /// + public QueryContainer Term(Func, ITermQuery> selector) => + WrapInContainer(selector, (query, container) => container.Term = query); + + /// + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate + /// over many terms. In order to prevent extremely slow wildcard queries, a wildcard term should + /// not start with one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// + public QueryContainer Wildcard(Expression> field, string value, double? boost = null, + MultiTermQueryRewrite rewrite = null, + string name = null + ) => + Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); + + /// + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. + /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with + /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// + public QueryContainer Wildcard(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => + Wildcard(t => t.Field(field).Value(value).Rewrite(rewrite).Boost(boost).Name(name)); + + /// + /// Matches documents that have fields matching a wildcard expression (not analyzed). + /// Supported wildcards are *, which matches any character sequence (including the empty one), and ?, + /// which matches any single character. Note this query can be slow, as it needs to iterate over many terms. + /// In order to prevent extremely slow wildcard queries, a wildcard term should not start with + /// one of the wildcards * or ?. The wildcard query maps to Lucene WildcardQuery. + /// + public QueryContainer Wildcard(Func, IWildcardQuery> selector) => + WrapInContainer(selector, (query, container) => container.Wildcard = query); + + /// + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// + public QueryContainer Prefix(Expression> field, string value, double? boost = null, + MultiTermQueryRewrite rewrite = null, + string name = null + ) => + Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); + + /// + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// + public QueryContainer Prefix(Field field, string value, double? boost = null, MultiTermQueryRewrite rewrite = null, string name = null) => + Prefix(t => t.Field(field).Value(value).Boost(boost).Rewrite(rewrite).Name(name)); + + /// + /// Matches documents that have fields containing terms with a specified prefix (not analyzed). + /// The prefix query maps to Lucene PrefixQuery. + /// + public QueryContainer Prefix(Func, IPrefixQuery> selector) => + WrapInContainer(selector, (query, container) => container.Prefix = query); + + /// + /// Matches documents that only have the provided ids. + /// Note, this filter does not require the _id field to be indexed since + /// it works using the _uid field. + /// + public QueryContainer Ids(Func selector) => + WrapInContainer(selector, (query, container) => container.Ids = query); + + /// + /// Allows fine-grained control over the order and proximity of matching terms. + /// Matching rules are constructed from a small set of definitions, + /// and the rules are then applied to terms from a particular field. + /// The definitions produce sequences of minimal intervals that span terms in a body of text. + /// These intervals can be further combined and filtered by parent sources. + /// + public QueryContainer Intervals(Func, IIntervalsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Intervals = query); + + /// + public QueryContainer RankFeature(Func, IRankFeatureQuery> selector) => + WrapInContainer(selector, (query, container) => container.RankFeature = query); + + /// + /// Matches spans containing a term. The span term query maps to Lucene SpanTermQuery. + /// + public QueryContainer SpanTerm(Func, ISpanTermQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanTerm = query); + + /// + /// Matches spans near the beginning of a field. The span first query maps to Lucene SpanFirstQuery. + /// + public QueryContainer SpanFirst(Func, ISpanFirstQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanFirst = query); + + /// + /// Matches spans which are near one another. One can specify slop, the maximum number of + /// intervening unmatched positions, as well as whether matches are required to be in-order. + /// The span near query maps to Lucene SpanNearQuery. + /// + public QueryContainer SpanNear(Func, ISpanNearQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanNear = query); + + /// + /// Matches the union of its span clauses. + /// The span or query maps to Lucene SpanOrQuery. + /// + public QueryContainer SpanOr(Func, ISpanOrQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanOr = query); + + /// + /// Removes matches which overlap with another span query. + /// The span not query maps to Lucene SpanNotQuery. + /// + public QueryContainer SpanNot(Func, ISpanNotQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanNot = query); + + /// + /// Wrap a multi term query (one of fuzzy, prefix, term range or regexp query) + /// as a span query so it can be nested. + /// + public QueryContainer SpanMultiTerm(Func, ISpanMultiTermQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanMultiTerm = query); + + /// + /// Returns matches which enclose another span query. + /// The span containing query maps to Lucene SpanContainingQuery + /// + public QueryContainer SpanContaining(Func, ISpanContainingQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanContaining = query); + + /// + /// Returns Matches which are enclosed inside another span query. + /// The span within query maps to Lucene SpanWithinQuery + /// + public QueryContainer SpanWithin(Func, ISpanWithinQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanWithin = query); + + /// + /// Wraps span queries to allow them to participate in composite single-field Span queries by 'lying' about their search field. + /// That is, the masked span query will function as normal, but the field points back to the set field of the query. + /// This can be used to support queries like SpanNearQuery or SpanOrQuery across different fields, + /// which is not ordinarily permitted. + /// + public QueryContainer SpanFieldMasking(Func, ISpanFieldMaskingQuery> selector) => + WrapInContainer(selector, (query, container) => container.SpanFieldMasking = query); + + /// + /// Allows you to use regular expression term queries. + /// "term queries" means that OpenSearch will apply the regexp to the terms produced + /// by the tokenizer for that field, and not to the original text of the field. + /// + public QueryContainer Regexp(Func, IRegexpQuery> selector) => + WrapInContainer(selector, (query, container) => container.Regexp = query); + + /// + /// The function_score query allows you to modify the score of documents that are retrieved by a query. + /// This can be useful if, for example, a score function is computationally expensive and it is + /// sufficient to compute the score on a filtered set of documents. + /// + /// + public QueryContainer FunctionScore(Func, IFunctionScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.FunctionScore = query); + + public QueryContainer Script(Func, IScriptQuery> selector) => + WrapInContainer(selector, (query, container) => container.Script = query); + + public QueryContainer ScriptScore(Func, IScriptScoreQuery> selector) => + WrapInContainer(selector, (query, container) => container.ScriptScore = query); + + public QueryContainer Exists(Func, IExistsQuery> selector) => + WrapInContainer(selector, (query, container) => container.Exists = query); + + /// + /// Used to match queries stored in an index. + /// The percolate query itself contains the document that will be used as query + /// to match with the stored queries. + /// + public QueryContainer Percolate(Func, IPercolateQuery> selector) => + WrapInContainer(selector, (query, container) => container.Percolate = query); + + /// + /// Used to find child documents which belong to a particular parent. + /// + public QueryContainer ParentId(Func, IParentIdQuery> selector) => + WrapInContainer(selector, (query, container) => container.ParentId = query); + + /// + /// Returns any documents that match with at least one or more of the provided terms. + /// The terms are not analyzed and thus must match exactly. The number of terms that must match varies + /// per document and is either controlled by a minimum should match field or + /// computed per document in a minimum should match script. + /// + public QueryContainer TermsSet(Func, ITermsSetQuery> selector) => + WrapInContainer(selector, (query, container) => container.TermsSet = query); + + public QueryContainer Neural(Func, INeuralQuery> selector) => + WrapInContainer(selector, (query, container) => container.Neural = query); + } } diff --git a/src/OpenSearch.Client/QueryDsl/Query.cs b/src/OpenSearch.Client/QueryDsl/Query.cs index 84796d0636..67c023b5ae 100644 --- a/src/OpenSearch.Client/QueryDsl/Query.cs +++ b/src/OpenSearch.Client/QueryDsl/Query.cs @@ -123,6 +123,9 @@ public static QueryContainer MultiMatch(Func, IMult public static QueryContainer Nested(Func, INestedQuery> selector) => new QueryContainerDescriptor().Nested(selector); + public static QueryContainer Neural(Func, INeuralQuery> selector) => + new QueryContainerDescriptor().Neural(selector); + public static QueryContainer ParentId(Func, IParentIdQuery> selector) => new QueryContainerDescriptor().ParentId(selector); diff --git a/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs b/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs new file mode 100644 index 0000000000..f97080d694 --- /dev/null +++ b/src/OpenSearch.Client/QueryDsl/Specialized/Neural/NeuralQuery.cs @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System.Runtime.Serialization; +using OpenSearch.Net.Utf8Json; + +namespace OpenSearch.Client; + +/// +/// A neural query. +/// +[InterfaceDataContract] +[JsonFormatter(typeof(FieldNameQueryFormatter))] +public interface INeuralQuery : IFieldNameQuery +{ + /// + /// The query text from which to produce queries. + /// + [DataMember(Name = "query_text")] + string QueryText { get; set; } + + /// + /// The number of results the k-NN search returns. + /// + [DataMember(Name = "k")] + int? K { get; set; } + + /// + /// The ID of the model that will be used in the embedding interface. + /// The model must be indexed in OpenSearch before it can be used in Neural Search. + /// + [DataMember(Name = "model_id")] + string ModelId { get; set; } +} + +[DataContract] +public class NeuralQuery : FieldNameQueryBase, INeuralQuery +{ + /// + public string QueryText { get; set; } + /// + public int? K { get; set; } + /// + public string ModelId { get; set; } + + protected override bool Conditionless => IsConditionless(this); + + internal override void InternalWrapInContainer(IQueryContainer container) => container.Neural = this; + + internal static bool IsConditionless(INeuralQuery q) => string.IsNullOrEmpty(q.QueryText) || q.K == null || q.K == 0 || string.IsNullOrEmpty(q.ModelId) || q.Field.IsConditionless(); +} + +public class NeuralQueryDescriptor + : FieldNameQueryDescriptorBase, INeuralQuery, T>, + INeuralQuery + where T : class +{ + protected override bool Conditionless => NeuralQuery.IsConditionless(this); + string INeuralQuery.QueryText { get; set; } + int? INeuralQuery.K { get; set; } + string INeuralQuery.ModelId { get; set; } + + /// + public NeuralQueryDescriptor QueryText(string queryText) => Assign(queryText, (a, v) => a.QueryText = v); + + /// + public NeuralQueryDescriptor K(int? k) => Assign(k, (a, v) => a.K = v); + + /// + public NeuralQueryDescriptor ModelId(string modelId) => Assign(modelId, (a, v) => a.ModelId = v); +} diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs index 2608c09ac9..82a9700a61 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/DslPrettyPrintVisitor.cs @@ -179,6 +179,8 @@ private void WriteShape(IGeoShape shape, IFieldLookup indexedField, Field field, public virtual void Visit(INestedQuery query) => Write("nested"); + public virtual void Visit(INeuralQuery query) => Write("neural", query.Field); + public virtual void Visit(IPrefixQuery query) => Write("prefix"); public virtual void Visit(IQueryStringQuery query) => Write("query_string"); diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs index 4440578ab7..58bbb302e7 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryVisitor.cs @@ -100,6 +100,8 @@ public interface IQueryVisitor void Visit(INestedQuery query); + void Visit(INeuralQuery query); + void Visit(IPrefixQuery query); void Visit(IQueryStringQuery query); @@ -247,6 +249,8 @@ public virtual void Visit(IMultiMatchQuery query) { } public virtual void Visit(INestedQuery query) { } + public virtual void Visit(INeuralQuery query) { } + public virtual void Visit(IPrefixQuery query) { } public virtual void Visit(IQueryStringQuery query) { } diff --git a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs index 2ff147331b..5a8697dd89 100644 --- a/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs +++ b/src/OpenSearch.Client/QueryDsl/Visitor/QueryWalker.cs @@ -83,6 +83,7 @@ public void Walk(IQueryContainer qd, IQueryVisitor visitor) VisitQuery(qd.Percolate, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.ParentId, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.TermsSet, visitor, (v, d) => v.Visit(d)); + VisitQuery(qd.Neural, visitor, (v, d) => v.Visit(d)); VisitQuery(qd.Bool, visitor, (v, d) => { diff --git a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs index 1f8457c82f..320aab24ea 100644 --- a/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs +++ b/tests/Tests.Core/ManagedOpenSearch/Clusters/WritableCluster.cs @@ -47,10 +47,12 @@ private static ClientTestClusterConfiguration CreateConfiguration() => AnalysisIcu, AnalysisKuromoji, AnalysisNori, AnalysisPhonetic, IngestAttachment, IngestGeoIp, Knn, + MachineLearning, MapperMurmur3, Security) { - MaxConcurrency = 4 + MaxConcurrency = 4, + ValidatePluginsToInstall = false }; protected override void SeedNode() diff --git a/tests/Tests/Ingest/ProcessorAssertions.cs b/tests/Tests/Ingest/ProcessorAssertions.cs index 74f020a865..bf615bf766 100644 --- a/tests/Tests/Ingest/ProcessorAssertions.cs +++ b/tests/Tests/Ingest/ProcessorAssertions.cs @@ -30,6 +30,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using JetBrains.Annotations; using OpenSearch.OpenSearch.Xunit.XunitPlumbing; using OpenSearch.Client; using Tests.Core.Client; @@ -62,11 +63,21 @@ public abstract class ProcessorAssertion : IProcessorAssertion public static class ProcessorAssertions { public static IEnumerable All => - from t in typeof(ProcessorAssertions).GetNestedTypes() - where typeof(IProcessorAssertion).IsAssignableFrom(t) && t.IsClass - let a = t.GetCustomAttributes(typeof(SkipVersionAttribute)).FirstOrDefault() as SkipVersionAttribute - where a == null || !a.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion)) - select (IProcessorAssertion)Activator.CreateInstance(t); + typeof(ProcessorAssertions).GetNestedTypes() + .Where(t => + { + if (!t.IsClass || !typeof(IProcessorAssertion).IsAssignableFrom(t)) return false; + + var skipVersion = t.GetCustomAttributes().FirstOrDefault(); + if (skipVersion != null && skipVersion.Ranges.Any(r => r.IsSatisfied(TestClient.Configuration.OpenSearchVersion))) + return false; + + var skipPrereleases = t.GetCustomAttributes().FirstOrDefault(); + if (skipPrereleases != null && TestClient.Configuration.OpenSearchVersion.IsPreRelease) return false; + + return true; + }) + .Select(t => (IProcessorAssertion)Activator.CreateInstance(t)); public static IProcessor[] Initializers => All.Select(a => a.Initializer).ToArray(); @@ -592,5 +603,45 @@ public class Pipeline : ProcessorAssertion public override string Key => "pipeline"; } + + [SkipVersion("<2.4.0", "neural search plugin was released with v2.4.0")] + [SkipPrereleaseVersions("Prerelease versions of OpenSearch do not include the ML & Neural Search plugins")] + public class TextEmbedding : ProcessorAssertion + { + private class NeuralSearchDoc + { + [PropertyName("text")] + public string Text { get; set; } + + [PropertyName("passage_embedding")] + public float[] PassageEmbedding { get; set; } + } + + public override ProcFunc Fluent => d => d + .TextEmbedding(te => te + .ModelId("someModel-abcdef") + .FieldMap(f => f + .Map(doc => doc.Text, doc => doc.PassageEmbedding))); + + public override IProcessor Initializer => new TextEmbeddingProcessor + { + ModelId = "someModel-abcdef", + FieldMap = new InferenceFieldMap + { + {new Field((NeuralSearchDoc d) => d.Text), new Field((NeuralSearchDoc d) => d.PassageEmbedding)} + } + }; + + public override object Json => new + { + model_id = "someModel-abcdef", + field_map = new + { + text = "passage_embedding" + } + }; + + public override string Key => "text_embedding"; + } } } diff --git a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs index 3d064a9d4e..53c06a61ed 100644 --- a/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs +++ b/tests/Tests/Ingest/PutPipeline/PutPipelineApiTests.cs @@ -55,7 +55,7 @@ public PutPipelineApiTests(WritableCluster cluster, EndpointUsage usage) : base( processors = ProcessorAssertions.AllAsJson }; -protected override int ExpectStatusCode => 200; + protected override int ExpectStatusCode => 200; protected override Func Fluent => d => d .Description("My test pipeline") diff --git a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs index eaeaff5800..b2cd56a7d6 100644 --- a/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs +++ b/tests/Tests/QueryDsl/QueryDslUsageTestsBase.cs @@ -34,6 +34,7 @@ using FluentAssertions; using OpenSearch.Client; using Newtonsoft.Json; +using OpenSearch.OpenSearch.Ephemeral; using Tests.Core.Client; using Tests.Core.Extensions; using Tests.Core.ManagedOpenSearch.Clusters; @@ -41,105 +42,111 @@ using Tests.Framework.EndpointTests; using Tests.Framework.EndpointTests.TestState; -namespace Tests.QueryDsl +namespace Tests.QueryDsl; + +public abstract class QueryDslUsageTestsBase + : ApiTestBase, ISearchRequest, SearchDescriptor, SearchRequest> + where TCluster : IEphemeralCluster, IOpenSearchClientTestCluster, new() + where TDocument : class { - public abstract class QueryDslUsageTestsBase - : ApiTestBase, ISearchRequest, SearchDescriptor, SearchRequest> - { - protected readonly QueryContainer ConditionlessQuery = new QueryContainer(new TermQuery()); + protected QueryDslUsageTestsBase(TCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected abstract IndexName IndexName { get; } + protected abstract string ExpectedIndexString { get; } - protected readonly QueryContainer VerbatimQuery = new QueryContainer(new TermQuery { IsVerbatim = true }); + protected virtual ConditionlessWhen ConditionlessWhen => null; - protected byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); + protected override object ExpectJson => new { query = QueryJson }; - protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + protected override Func, ISearchRequest> Fluent => s => s + .Index(IndexName) + .Query(QueryFluent); - protected virtual ConditionlessWhen ConditionlessWhen => null; + protected override HttpMethod HttpMethod => HttpMethod.POST; - protected override object ExpectJson => new { query = QueryJson }; + protected override SearchRequest Initializer => + new(IndexName) + { + Query = QueryInitializer + }; - protected override Func, ISearchRequest> Fluent => s => s - .Query(q => QueryFluent(q)); + protected virtual NotConditionlessWhen NotConditionlessWhen => null; - protected override HttpMethod HttpMethod => HttpMethod.POST; + protected abstract QueryContainer QueryInitializer { get; } - protected override SearchRequest Initializer => - new SearchRequest - { - Query = QueryInitializer - }; + protected abstract object QueryJson { get; } + protected override string UrlPath => $"/{ExpectedIndexString}/_search"; - protected virtual bool KnownParseException => false; + protected override LazyResponses ClientUsage() => Calls( + (client, f) => client.Search(f), + (client, f) => client.SearchAsync(f), + (client, r) => client.Search(r), + (client, r) => client.SearchAsync(r) + ); - protected virtual NotConditionlessWhen NotConditionlessWhen => null; + protected abstract QueryContainer QueryFluent(QueryContainerDescriptor q); - protected abstract QueryContainer QueryInitializer { get; } + [U] public void FluentIsNotConditionless() => + AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor())); - protected abstract object QueryJson { get; } - protected override string UrlPath => "/project/_search"; + [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); - protected override LazyResponses ClientUsage() => Calls( - (client, f) => client.Search(f), - (client, f) => client.SearchAsync(f), - (client, r) => client.Search(r), - (client, r) => client.SearchAsync(r) - ); + private void AssertIsNotConditionless(IQueryContainer c) + { + if (!c.IsVerbatim) + c.IsConditionless.Should().BeFalse(); + } - protected abstract QueryContainer QueryFluent(QueryContainerDescriptor q); + [U] public void SeenByVisitor() + { + var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); + var query = QueryFluent(new QueryContainerDescriptor()); + query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); + query.Accept(visitor); + var pretty = visitor.PrettyPrint; + pretty.Should().NotBeNullOrWhiteSpace(); + } - [U] public void FluentIsNotConditionless() => - AssertIsNotConditionless(QueryFluent(new QueryContainerDescriptor())); + [U] public void ConditionlessWhenExpectedToBe() + { + if (ConditionlessWhen == null) return; - [U] public void InitializerIsNotConditionless() => AssertIsNotConditionless(QueryInitializer); + foreach (var when in ConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor())); + when(QueryInitializer); + } - private void AssertIsNotConditionless(IQueryContainer c) - { - if (!c.IsVerbatim) - c.IsConditionless.Should().BeFalse(); - } + ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); + } - [U] public void SeenByVisitor() - { - var visitor = new DslPrettyPrintVisitor(TestClient.DefaultInMemoryClient.ConnectionSettings); - var query = QueryFluent(new QueryContainerDescriptor()); - query.Should().NotBeNull("query evaluated to null which implies it may be conditionless"); - query.Accept(visitor); - var pretty = visitor.PrettyPrint; - pretty.Should().NotBeNullOrWhiteSpace(); - } + [U] public void NotConditionlessWhenExpectedToBe() + { + if (NotConditionlessWhen == null) return; - [U] public void ConditionlessWhenExpectedToBe() - { - if (ConditionlessWhen == null) return; + foreach (var when in NotConditionlessWhen) + { + when(QueryFluent(new QueryContainerDescriptor())); + when(QueryInitializer); + } + } - foreach (var when in ConditionlessWhen) - { - when(QueryFluent(new QueryContainerDescriptor())); - //this.JsonEquals(query, new { }); - when(QueryInitializer); - //this.JsonEquals(query, new { }); - } + [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(AssertQueryResponseValid); - ((IQueryContainer)QueryInitializer).IsConditionless.Should().BeFalse(); - } + protected virtual void AssertQueryResponseValid(ISearchResponse response) => response.ShouldBeValid(); +} + +public abstract class QueryDslUsageTestsBase + : QueryDslUsageTestsBase +{ + protected static byte[] ShortFormQuery => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new { description = "project description" })); - [U] public void NotConditionlessWhenExpectedToBe() - { - if (NotConditionlessWhen == null) return; + protected static readonly QueryContainer ConditionlessQuery = new(new TermQuery()); - foreach (var when in NotConditionlessWhen) - { - var query = QueryFluent(new QueryContainerDescriptor()); - when(query); + protected static readonly QueryContainer VerbatimQuery = new(new TermQuery { IsVerbatim = true }); - query = QueryInitializer; - when(query); - } - } + protected QueryDslUsageTestsBase(ReadOnlyCluster cluster, EndpointUsage usage) : base(cluster, usage) { } - [I] protected async Task AssertQueryResponse() => await AssertOnAllResponses(r => - { - r.ShouldBeValid(); - }); - } + protected override IndexName IndexName => typeof(Project); + protected override string ExpectedIndexString => "project"; } diff --git a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs index 8150004a20..98d20ac6e4 100644 --- a/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs +++ b/tests/Tests/QueryDsl/Specialized/Knn/KnnQueryUsageTests.cs @@ -6,7 +6,6 @@ */ using System; -using System.Linq; using System.Threading.Tasks; using FluentAssertions; using OpenSearch.Client; diff --git a/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs new file mode 100644 index 0000000000..28745bc74c --- /dev/null +++ b/tests/Tests/QueryDsl/Specialized/Neural/NeuralQueryUsageTests.cs @@ -0,0 +1,295 @@ +/* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +*/ + +using System; +using System.Linq; +using System.Threading; +using FluentAssertions; +using OpenSearch.Client; +using OpenSearch.Net; +using OpenSearch.OpenSearch.Xunit.XunitPlumbing; +using OpenSearch.Stack.ArtifactsApi.Products; +using Tests.Core.Extensions; +using Tests.Core.ManagedOpenSearch.Clusters; +using Tests.Framework.EndpointTests.TestState; +using Version = SemanticVersioning.Version; + +namespace Tests.QueryDsl.Specialized.Neural; + +public class NeuralQueryCluster : ClientTestClusterBase +{ + public NeuralQueryCluster() : base(CreateConfiguration()) { } + + private static ClientTestClusterConfiguration CreateConfiguration() + { + var config = new ClientTestClusterConfiguration( + OpenSearchPlugin.Knn, + OpenSearchPlugin.MachineLearning, + OpenSearchPlugin.NeuralSearch, + OpenSearchPlugin.Security) + { + MaxConcurrency = 4, + ValidatePluginsToInstall = false, + }; + + config.DefaultNodeSettings.Add("plugins.ml_commons.only_run_on_ml_node", "false"); + config.DefaultNodeSettings.Add("plugins.ml_commons.native_memory_threshold", "99"); + config.DefaultNodeSettings.Add("plugins.ml_commons.model_access_control_enabled", "true", ">=2.8.0"); + + return config; + } +} + +public class NeuralSearchDoc +{ + [PropertyName("id")] public string Id { get; set; } + [PropertyName("text")] public string Text { get; set; } + [PropertyName("passage_embedding")] public float[] PassageEmbedding { get; set; } +} + +[SkipVersion("<2.6.0", "Avoid the various early permutations of the ML APIs")] +public class NeuralQueryUsageTests + : QueryDslUsageTestsBase +{ + private static readonly string TestName = nameof(NeuralQueryUsageTests).ToLowerInvariant(); + + private string _modelGroupId; + private string _modelId = "default-for-unit-tests"; + + public NeuralQueryUsageTests(NeuralQueryCluster cluster, EndpointUsage usage) : base(cluster, usage) { } + + protected override IndexName IndexName => TestName; + protected override string ExpectedIndexString => TestName; + + protected override QueryContainer QueryInitializer => new NeuralQuery + { + Field = Infer.Field(d => d.PassageEmbedding), + QueryText = "wild west", + K = 5, + ModelId = _modelId + }; + + protected override object QueryJson => new + { + neural = new + { + passage_embedding = new + { + query_text = "wild west", + k = 5, + model_id = _modelId + } + } + }; + + protected override QueryContainer QueryFluent(QueryContainerDescriptor q) => q + .Neural(n => n + .Field(f => f.PassageEmbedding) + .QueryText("wild west") + .K(5) + .ModelId(_modelId)); + + protected override ConditionlessWhen ConditionlessWhen => new ConditionlessWhen(a => a.Neural) + { + q => + { + q.Field = null; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = null; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = ""; + q.K = 5; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = null; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 0; + q.ModelId = "aFcV879"; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = null; + }, + q => + { + q.Field = "passage_embedding"; + q.QueryText = "wild west"; + q.K = 5; + q.ModelId = ""; + } + }; + + protected override void IntegrationSetup(IOpenSearchClient client, CallUniqueValues values) + { + var baseVersion = Cluster.ClusterConfiguration.Version.BaseVersion(); + var renamedToRegisterDeploy = baseVersion >= new Version("2.7.0"); + var hasModelAccessControl = baseVersion >= new Version("2.8.0"); + + if (hasModelAccessControl) + { + var registerModelGroupResp = client.Http.Post( + "/_plugins/_ml/model_groups/_register", + r => r.SerializableBody(new + { + name = TestName, + access_mode = "public", + model_access_mode = "public" + })); + registerModelGroupResp.ShouldBeCreated(); + _modelGroupId = (string)registerModelGroupResp.Body.model_group_id; + } + + var registerModelResp = client.Http.Post( + $"/_plugins/_ml/models/{(renamedToRegisterDeploy ? "_register" : "_upload")}", + r => r.SerializableBody(new + { + name = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b", + version = "1.0.1", + model_group_id = _modelGroupId, + model_format = "TORCH_SCRIPT" + })); + registerModelResp.ShouldBeCreated(); + var modelRegistrationTaskId = (string) registerModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get($"/_plugins/_ml/tasks/{modelRegistrationTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) + { + _modelId = getTaskResp.Body.model_id; + break; + } + Thread.Sleep(5000); + } + + var deployModelResp = client.Http.Post($"/_plugins/_ml/models/{_modelId}/{(renamedToRegisterDeploy ? "_deploy" : "_load")}"); + deployModelResp.ShouldBeCreated(); + var modelDeployTaskId = (string) deployModelResp.Body.task_id; + + while (true) + { + var getTaskResp = client.Http.Get($"/_plugins/_ml/tasks/{modelDeployTaskId}"); + getTaskResp.ShouldNotBeFailed(); + if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break; + Thread.Sleep(5000); + } + + var putIngestPipelineResp = client.Ingest.PutPipeline(TestName, p => p + .Processors(pp => pp + .TextEmbedding(te => te + .ModelId(_modelId) + .FieldMap(fm => fm + .Map(d => d.Text, d => d.PassageEmbedding))))); + putIngestPipelineResp.ShouldBeValid(); + + var createIndexResp = client.Indices.Create( + IndexName, + i => i + .Settings(s => s + .Setting("index.knn", true) + .DefaultPipeline(TestName)) + .Map(m => m + .Properties(p => p + .Text(t => t.Name(d => d.Id)) + .Text(t => t.Name(d => d.Text)) + .KnnVector(k => k + .Name(d => d.PassageEmbedding) + .Dimension(768) + .Method(km => km + .Engine("lucene") + .SpaceType("l2") + .Name("hnsw")))))); + createIndexResp.ShouldBeValid(); + + var documents = new NeuralSearchDoc[] + { + new() { Id = "4319130149.jpg", Text = "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena ." }, + new() { Id = "1775029934.jpg", Text = "A wild animal races across an uncut field with a minimal amount of trees ." }, + new() { Id = "2664027527.jpg", Text = "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco ." }, + new() { Id = "4427058951.jpg", Text = "A man who is riding a wild horse in the rodeo is very near to falling off ." }, + new() { Id = "2691147709.jpg", Text = "A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse ." } + }; + var bulkResp = client.Bulk(b => b + .Index(IndexName) + .IndexMany(documents) + .Refresh(Refresh.WaitFor)); + bulkResp.ShouldBeValid(); + } + + protected override void AssertQueryResponseValid(ISearchResponse response) + { + base.AssertQueryResponseValid(response); + + response.Hits.Should().HaveCount(5); + var hit = response.Hits.First(); + + hit.Id.Should().Be("4427058951.jpg"); + hit.Score.Should().BeApproximately(0.01585195, 0.00000001); + hit.Source.Text.Should().Be("A man who is riding a wild horse in the rodeo is very near to falling off ."); + hit.Source.PassageEmbedding.Should().HaveCount(768); + } + + protected override void IntegrationTeardown(IOpenSearchClient client, CallUniqueValues values) + { + client.Indices.Delete(IndexName); + client.Ingest.DeletePipeline(TestName); + + if (_modelId != "default-for-unit-tests") + { + while (true) + { + var deleteModelResp = client.Http.Delete($"/_plugins/_ml/models/{_modelId}"); + if (deleteModelResp.Success || !(((string)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false)) break; + + client.Http.Post($"/_plugins/_ml/models/{_modelId}/_undeploy"); + Thread.Sleep(5000); + } + } + + if (_modelGroupId != null) + { + client.Http.Delete($"/_plugins/_ml/model_groups/{_modelGroupId}"); + } + } +} + +internal static class Helpers +{ + public static void ShouldBeCreated(this DynamicResponse r) + { + if (!r.Success || r.Body.status != "CREATED") throw new Exception("Expected to be created, was: " + r.DebugInformation); + } + + public static void ShouldNotBeFailed(this DynamicResponse r) + { + if (!r.Success || r.Body.state == "FAILED") throw new Exception("Expected to not be failed, was: " + r.DebugInformation); + } +}