From 4dd418dd05a233659ba4f3cec4d7fc12f250575a Mon Sep 17 00:00:00 2001 From: Kim Pepper Date: Wed, 8 Jan 2025 17:26:15 +1100 Subject: [PATCH] Fix ML tests Signed-off-by: Kim Pepper --- ...gNamespaceTest.php => MlNamespaceTest.php} | 233 ++++++++---------- util/EndpointProxies/ml/getConnectorProxy.php | 2 + util/EndpointProxies/ml/predictProxy.php | 11 +- .../ml/updateModelGroupProxy.php | 2 + 4 files changed, 114 insertions(+), 134 deletions(-) rename tests/Namespaces/{MachineLearningNamespaceTest.php => MlNamespaceTest.php} (63%) diff --git a/tests/Namespaces/MachineLearningNamespaceTest.php b/tests/Namespaces/MlNamespaceTest.php similarity index 63% rename from tests/Namespaces/MachineLearningNamespaceTest.php rename to tests/Namespaces/MlNamespaceTest.php index 9001bfe9e..85cfbebac 100644 --- a/tests/Namespaces/MachineLearningNamespaceTest.php +++ b/tests/Namespaces/MlNamespaceTest.php @@ -1,29 +1,28 @@ method('getEndpoint') ->willReturn(new CreateConnector()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/connectors/_create', [], [ - 'foo' => 'bar', - ]); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/connectors/_create', [], [ + 'foo' => 'bar', + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->createConnector([ 'body' => [ @@ -67,16 +64,15 @@ public function testGetConnector(): void $endpointFactory->method('getEndpoint') ->willReturn(new GetConnector()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('GET', '/_plugins/_ml/connectors/foobar', [], null); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('GET', '/_plugins/_ml/connectors/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->getConnector([ - 'id' => 'foobar' + 'id' => 'foobar', + 'connector_id' => 'foobar' ]); } @@ -87,18 +83,16 @@ public function testGetConnectors(): void $endpointFactory->method('getEndpoint') ->willReturn(new GetConnectors()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') + $transport->method('sendRequest') ->with('POST', '/_plugins/_ml/connectors/_search', [], [ 'query' => [ 'match_all' => new \StdClass(), ], 'size' => 1000, - ]); - - $transport->method('resultOrFuture') - ->willReturn([]); + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->getConnectors([ 'body' => [ @@ -117,13 +111,11 @@ public function testDeleteConnector(): void $endpointFactory->method('getEndpoint') ->willReturn(new DeleteConnector()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('DELETE', '/_plugins/_ml/connectors/foobar', [], null); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('DELETE', '/_plugins/_ml/connectors/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->deleteConnector([ 'connector_id' => 'foobar' @@ -137,15 +129,13 @@ public function testRegisterModelGroup(): void $endpointFactory->method('getEndpoint') ->willReturn(new RegisterModelGroup()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/model_groups/_register', [], [ - 'foo' => 'bar', - ]); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/model_groups/_register', [], [ + 'foo' => 'bar', + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->registerModelGroup([ 'body' => [ @@ -161,18 +151,16 @@ public function testGetModelGroups(): void $endpointFactory->method('getEndpoint') ->willReturn(new GetModelGroups()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/model_groups/_search', [], [ - 'query' => [ - 'match_all' => new \StdClass(), - ], - 'size' => 1000, - ]); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/model_groups/_search', [], [ + 'query' => [ + 'match_all' => new \StdClass(), + ], + 'size' => 1000, + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->getModelGroups([ 'body' => [ @@ -191,21 +179,20 @@ public function testUpdateModelGroup(): void $endpointFactory->method('getEndpoint') ->willReturn(new UpdateModelGroup()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('PUT', '/_plugins/_ml/model_groups/foobar', [], [ - 'query' => [ - 'match_all' => new \StdClass(), - ], - 'size' => 1000, - ]); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('PUT', '/_plugins/_ml/model_groups/foobar', [], [ + 'query' => [ + 'match_all' => new \StdClass(), + ], + 'size' => 1000, + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->updateModelGroup([ 'id' => 'foobar', + 'model_group_id' => 'foobar', 'body' => [ 'query' => [ 'match_all' => new \StdClass(), @@ -222,13 +209,11 @@ public function testDeleteModelGroup(): void $endpointFactory->method('getEndpoint') ->willReturn(new DeleteModelGroup()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('DELETE', '/_plugins/_ml/model_groups/foobar', [], null); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('DELETE', '/_plugins/_ml/model_groups/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->deleteModelGroup([ 'id' => 'foobar' @@ -242,15 +227,13 @@ public function testRegisterModel(): void $endpointFactory->method('getEndpoint') ->willReturn(new RegisterModel()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/models/_register', [], [ - 'foo' => 'bar', - ]); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/models/_register', [], [ + 'foo' => 'bar', + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->registerModel([ 'body' => [ @@ -265,13 +248,11 @@ public function testGetModel(): void $endpointFactory->method('getEndpoint') ->willReturn(new GetModel()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('GET', '/_plugins/_ml/models/foobar', [], null); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('GET', '/_plugins/_ml/models/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->getModel([ 'id' => 'foobar', @@ -284,18 +265,16 @@ public function testSearchModels(): void $endpointFactory->method('getEndpoint') ->willReturn(new SearchModels()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') + $transport->method('sendRequest') ->with('GET', '/_plugins/_ml/models/_search', [], [ 'query' => [ 'match_all' => new \StdClass(), ], 'size' => 1000, - ]); - - $transport->method('resultOrFuture') - ->willReturn([]); + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->searchModels([ 'body' => [ @@ -313,13 +292,11 @@ public function testDeployModel(): void $endpointFactory->method('getEndpoint') ->willReturn(new DeployModel()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/models/foobar/_deploy', [], null); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/models/foobar/_deploy', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->deployModel([ 'model_id' => 'foobar', @@ -332,13 +309,11 @@ public function testUnDeployModel(): void $endpointFactory->method('getEndpoint') ->willReturn(new UndeployModel()); - $transport = $this->createMock(Transport::class); + $transport = $this->createMock(TransportInterface::class); - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/models/foobar/_undeploy', [], null); - - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/models/foobar/_undeploy', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->undeployModel([ 'model_id' => 'foobar', @@ -351,13 +326,11 @@ public function testDeleteModel(): void $endpointFactory->method('getEndpoint') ->willReturn(new DeleteModel()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('DELETE', '/_plugins/_ml/models/foobar', [], null); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('DELETE', '/_plugins/_ml/models/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->deleteModel([ 'id' => 'foobar', @@ -370,21 +343,21 @@ public function testPredict(): void $endpointFactory->method('getEndpoint') ->willReturn(new Predict()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('POST', '/_plugins/_ml/models/foobar/_predict', [], [ - 'foo' => 'bar', - ]); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('POST', '/_plugins/_ml/_predict/algo/model', [], [ + 'foo' => 'bar', + ]) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->predict([ - 'id' => 'foobar', - 'body' => [ - 'foo' => 'bar', - ] + 'id' => 'foobar', + 'body' => [ + 'foo' => 'bar', + ], + 'algorithm_name' => 'algo', + 'model_id' => 'model', ]); } @@ -394,13 +367,11 @@ public function testGetTask(): void $endpointFactory->method('getEndpoint') ->willReturn(new GetTask()); - $transport = $this->createMock(Transport::class); - - $transport->method('performRequest') - ->with('GET', '/_plugins/_ml/tasks/foobar', [], null); + $transport = $this->createMock(TransportInterface::class); - $transport->method('resultOrFuture') - ->willReturn([]); + $transport->method('sendRequest') + ->with('GET', '/_plugins/_ml/tasks/foobar', [], null) + ->willReturn([]); (new MlNamespace($transport, $endpointFactory))->getTask([ 'id' => 'foobar', diff --git a/util/EndpointProxies/ml/getConnectorProxy.php b/util/EndpointProxies/ml/getConnectorProxy.php index 739adfc50..a89d3034a 100644 --- a/util/EndpointProxies/ml/getConnectorProxy.php +++ b/util/EndpointProxies/ml/getConnectorProxy.php @@ -13,9 +13,11 @@ public function getConnector(array $params = []): array { $id = $this->extractArgument($params, 'id'); + $connector_id = $this->extractArgument($params, 'connector_id'); $endpoint = $this->endpointFactory->getEndpoint(\OpenSearch\Endpoints\Ml\GetConnector::class); $endpoint->setParams($params); $endpoint->setId($id); + $endpoint->setConnectorId($connector_id); return $this->performRequest($endpoint); } diff --git a/util/EndpointProxies/ml/predictProxy.php b/util/EndpointProxies/ml/predictProxy.php index cc0e8f7d0..2fedfe1f2 100644 --- a/util/EndpointProxies/ml/predictProxy.php +++ b/util/EndpointProxies/ml/predictProxy.php @@ -15,10 +15,15 @@ public function predict(array $params = []): array { $id = $this->extractArgument($params, 'id'); $body = $this->extractArgument($params, 'body'); + $algorithm_name = $this->extractArgument($params, 'algorithm_name'); + $model_id = $this->extractArgument($params, 'model_id'); + $endpoint = $this->endpointFactory->getEndpoint(\OpenSearch\Endpoints\Ml\Predict::class); - $endpoint->setParams($params); - $endpoint->setId($id); - $endpoint->setBody($body); + $endpoint->setParams($params) + ->setId($id) + ->setBody($body) + ->setAlgorithmName($algorithm_name) + ->setModelId($model_id); return $this->performRequest($endpoint); } diff --git a/util/EndpointProxies/ml/updateModelGroupProxy.php b/util/EndpointProxies/ml/updateModelGroupProxy.php index aad4c72ad..e3eb65770 100644 --- a/util/EndpointProxies/ml/updateModelGroupProxy.php +++ b/util/EndpointProxies/ml/updateModelGroupProxy.php @@ -14,11 +14,13 @@ public function updateModelGroup(array $params = []): array { $id = $this->extractArgument($params, 'id'); + $model_group_id = $this->extractArgument($params, 'model_group_id'); $body = $this->extractArgument($params, 'body'); $endpoint = $this->endpointFactory->getEndpoint(\OpenSearch\Endpoints\Ml\UpdateModelGroup::class); $endpoint->setParams($params); $endpoint->setBody($body); $endpoint->setId($id); + $endpoint->setModelGroupId($model_group_id); return $this->performRequest($endpoint); }