Skip to content

Commit

Permalink
Added ML train and predict API, fixed test for predict and train API,…
Browse files Browse the repository at this point in the history
… defined ML status as an enum, updated ML task state enum, updated CHANGELOG.md.

Signed-off-by: Nathalie Jonathan <[email protected]>
  • Loading branch information
nathaliellenaa committed Dec 26, 2024
1 parent ce7a47e commit 3d7f92e
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Added support for combining output variables ([#737](https://github.com/opensearch-project/opensearch-api-specification/pull/737))
- Added 404 response to `/_search/scroll` ([#749](https://github.com/opensearch-project/opensearch-api-specification/pull/749))
- Added `node_failures` to `DELETE /_search/scroll` and `DELETE /_search/scroll/{scroll_id}` ([#749](https://github.com/opensearch-project/opensearch-api-specification/pull/749))
- Added `POST /_plugins/_ml/_train/{algorithm_name}` and `POST /_plugins/_ml/_predict/{algorithm_name}/{model_id}` ([#x](https://github.com/opensearch-project/opensearch-api-specification/pull/x))
- Added `POST /_plugins/_ml/_train/{algorithm_name}`, `_predict/{algorithm_name}/{model_id}`, and `_train_predict/{algorithm_name}` ([#755](https://github.com/opensearch-project/opensearch-api-specification/pull/755))

### Removed
- Removed unsupported `_common.mapping:SourceField`'s `mode` field and associated `_common.mapping:SourceFieldMode` enum ([#652](https://github.com/opensearch-project/opensearch-api-specification/pull/652))
Expand Down
40 changes: 40 additions & 0 deletions spec/namespaces/ml.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ paths:
responses:
'200':
$ref: '#/components/responses/ml.train@200'
/_plugins/_ml/_train_predict/{algorithm_name}:
post:
operationId: ml.train_predict.0
x-operation-group: ml.train_predict
description: Trains a model and predicts against the same training dataset.
parameters:
- $ref: '#/components/parameters/ml.train_predict::path.algorithm_name'
requestBody:
$ref: '#/components/requestBodies/ml.train_predict'
responses:
'200':
$ref: '#/components/responses/ml.train_predict@200'
/_plugins/_ml/connectors/_create:
post:
operationId: ml.create_connector.0
Expand Down Expand Up @@ -275,9 +287,26 @@ components:
$ref: '../schemas/ml._common.yaml#/components/schemas/InputQuery'
input_index:
type: array
description: The input index.
items:
type: string
ml.train_predict:
content:
application/json:
schema:
type: object
properties:
parameters:
$ref: '../schemas/ml._common.yaml#/components/schemas/TrainParameters'
input_query:
$ref: '../schemas/ml._common.yaml#/components/schemas/InputQuery'
input_index:
type: array
description: The input index.
items:
type: string
input_data:
$ref: '../schemas/ml._common.yaml#/components/schemas/PredictionResult'
ml.create_connector:
content:
application/json:
Expand Down Expand Up @@ -416,6 +445,11 @@ components:
application/json:
schema:
$ref: '../schemas/ml._common.yaml#/components/schemas/TrainResponse'
ml.train_predict@200:
content:
application/json:
schema:
$ref: '../schemas/ml._common.yaml#/components/schemas/TrainPredictResponse'
ml.create_connector@200:
content:
application/json:
Expand Down Expand Up @@ -497,6 +531,12 @@ components:
required: true
schema:
type: string
ml.train_predict::path.algorithm_name:
name: algorithm_name
in: path
required: true
schema:
type: string
ml.delete_connector::path.connector_id:
name: connector_id
in: path
Expand Down
79 changes: 73 additions & 6 deletions spec/schemas/ml._common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ components:
framework_type:
type: string
description: The framework type.
Status:
type: string
description: The status.
enum:
- CANCELLED
- COMPLETED
- COMPLETED_WITH_ERROR
- CREATED
- FAILED
- RUNNING
PredictResponse:
type: object
properties:
Expand All @@ -196,8 +206,7 @@ components:
items:
$ref: '#/components/schemas/InferenceResults'
status:
type: string
description: The status.
$ref: '#/components/schemas/Status'
prediction_result:
$ref: '#/components/schemas/PredictionResult'
InferenceResults:
Expand Down Expand Up @@ -298,8 +307,7 @@ components:
- INTEGER
- STRING
value:
type: integer
format: int64
type: number
description: The value.
InputQuery:
type: object
Expand All @@ -313,6 +321,55 @@ components:
type: integer
format: int64
description: The size of the query.
query:
$ref: '#/components/schemas/Query'
Query:
type: object
description: The query.
properties:
bool:
$ref: '#/components/schemas/BoolQuery'
BoolQuery:
type: object
description: The boolean query.

Check failure on line 334 in spec/schemas/ml._common.yaml

View workflow job for this annotation

GitHub Actions / check

[vale] reported by reviewdog 🐶 [Vale.Terms] Use 'Boolean' instead of 'boolean'. Raw Output: {"message": "[Vale.Terms] Use 'Boolean' instead of 'boolean'.", "location": {"path": "spec/schemas/ml._common.yaml", "range": {"start": {"line": 334, "column": 24}}}, "severity": "ERROR"}

Check failure on line 334 in spec/schemas/ml._common.yaml

View workflow job for this annotation

GitHub Actions / check

[vale] reported by reviewdog 🐶 [OpenSearch.Spelling] Error: boolean. If you are referencing a setting, variable, format, function, or repository, surround it with tic marks. Raw Output: {"message": "[OpenSearch.Spelling] Error: boolean. If you are referencing a setting, variable, format, function, or repository, surround it with tic marks.", "location": {"path": "spec/schemas/ml._common.yaml", "range": {"start": {"line": 334, "column": 24}}}, "severity": "ERROR"}
properties:
filter:
type: array
description: Filter query that appears in matching documents.
items:
$ref: '#/components/schemas/Filter'
Filter:
type: object
description: The filter element.
properties:
range:
$ref: '#/components/schemas/Range'
Range:
type: object
description: The filter range.
properties:
k1:
$ref: '#/components/schemas/Key'
k2:
$ref: '#/components/schemas/Key'
k3:
$ref: '#/components/schemas/Key'
Key:
type: object
description: The key.
properties:
gte:
type: number
description: Greater than or equal to.
lte:
type: number
description: Less than or equal to.
gt:
type: number
description: Greater than.
lt:
type: number
description: Less than.
TrainParameters:
type: object
properties:
Expand All @@ -337,8 +394,16 @@ components:
model_id:
$ref: '_common.yaml#/components/schemas/Name'
status:
type: string
description: The status.
$ref: '#/components/schemas/Status'
required:
- status
TrainPredictResponse:
type: object
properties:
status:
$ref: '#/components/schemas/Status'
prediction_result:
$ref: '#/components/schemas/PredictionResult'
required:
- status
ModelGroupRegistration:
Expand Down Expand Up @@ -390,9 +455,11 @@ components:
type: string
enum:
- CANCELLED
- CANCELLING
- COMPLETED
- COMPLETED_WITH_ERROR
- CREATED
- EXPIRED
- FAILED
- RUNNING
task_type:
Expand Down
17 changes: 17 additions & 0 deletions tests/plugins/ml/train_predict/predict.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ distributions:
- amazon-managed
- amazon-serverless
prologues:
- path: /_bulk
method: POST
request:
content_type: application/x-ndjson
payload:
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 5.1, sepal_width_in_cm: 3.5, petal_length_in_cm: 1.4, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 4.9, sepal_width_in_cm: 3.1, petal_length_in_cm: 1.4, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 4.7, sepal_width_in_cm: 3.2, petal_length_in_cm: 1.3, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- path: /iris_data/_refresh
method: POST
- path: _plugins/_ml/_train/{algorithm_name}
id: train_model
method: POST
Expand All @@ -32,6 +46,9 @@ epilogues:
model_id: ${train_model.model_id}
method: DELETE
status: [200, 404]
- path: /iris_data
method: DELETE
status: [200, 404]
chapters:
- synopsis: Predict trained model.
id: predict_trained_model
Expand Down
18 changes: 18 additions & 0 deletions tests/plugins/ml/train_predict/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,30 @@ distributions:
excluded:
- amazon-managed
- amazon-serverless
prologues:
- path: /_bulk
method: POST
request:
content_type: application/x-ndjson
payload:
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 5.1, sepal_width_in_cm: 3.5, petal_length_in_cm: 1.4, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 4.9, sepal_width_in_cm: 3.1, petal_length_in_cm: 1.4, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- {sepal_length_in_cm: 4.7, sepal_width_in_cm: 3.2, petal_length_in_cm: 1.3, petal_width_in_cm: 0.2, species: setosa}
- {index: {_index: iris_data}}
- path: /iris_data/_refresh
method: POST
epilogues:
- path: /_plugins/_ml/models/{model_id}
parameters:
model_id: ${train_model.model_id}
method: DELETE
status: [200, 404]
- path: /iris_data
method: DELETE
status: [200, 404]
chapters:
- synopsis: Train model synchronously.
id: train_model
Expand Down
92 changes: 92 additions & 0 deletions tests/plugins/ml/train_predict/train_and_predict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
$schema: ../../../../json_schemas/test_story.schema.yaml

description: Test training a model, then immediately predict against the same training dataset.
distributions:
excluded:
- amazon-managed
- amazon-serverless
prologues:
- path: /_bulk
method: POST
request:
content_type: application/x-ndjson
payload:
- {index: {_index: test_data}}
- {k1: 5.1, k2: 3.5, k3: 1.4}
- {index: {_index: test_data}}
- {k1: 4.9, k2: 3.1, k3: 1.4}
- {index: {_index: test_data}}
- {k1: 4.7, k2: 3.2, k3: 1.3}
- {index: {_index: test_data}}
- path: /test_data/_refresh
method: POST
epilogues:
- path: /test_data
method: DELETE
status: [200, 404]
chapters:
- synopsis: Train and predict with indexed data.
id: train_predict_model
path: /_plugins/_ml/_train_predict/{algorithm_name}
method: POST
parameters:
algorithm_name: KMEANS
request:
payload:
parameters:
centroids: 3
iterations: 10
distance_type: COSINE
input_query:
query:
bool:
filter:
- range:
k1:
gte: 3.9
lte: 4.8
size: 10
input_index:
- test_data
response:
status: 200
payload:
status: COMPLETED
output:
prediction: payload.prediction_result
- synopsis: Train and predict with data directly.
id: train_predict_model
path: /_plugins/_ml/_train_predict/{algorithm_name}
method: POST
parameters:
algorithm_name: KMEANS
request:
payload:
parameters:
centroids: 3
iterations: 10
distance_type: COSINE
input_data:
column_metas:
- name: k1
column_type: DOUBLE
- name: k2
column_type: DOUBLE
rows:
- values:
- column_type: DOUBLE
value: 1.01
- column_type: DOUBLE
value: 2.01
- values:
- column_type: DOUBLE
value: 3.01
- column_type: DOUBLE
value: 4.01
response:
status: 200
payload:
status: COMPLETED
output:
prediction: payload.prediction_result

0 comments on commit 3d7f92e

Please sign in to comment.