Skip to content

Commit

Permalink
Merge pull request #135 from ORNL/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
renan-souza authored Sep 19, 2024
2 parents d5446b5 + a949285 commit 1a26aa0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 10 additions & 1 deletion flowcept/flowcept_api/db_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ def query(
f"collections for task and workflow."
)

def save_torch_model(self, model, custom_metadata: dict = None) -> str:
def save_torch_model(
self,
model,
task_id=None,
workflow_id=None,
custom_metadata: dict = None,
) -> str:
"""
Save the PyTorch model's state_dict to a MongoDB collection as binary data.
Expand All @@ -153,6 +159,7 @@ def save_torch_model(self, model, custom_metadata: dict = None) -> str:
Returns:
str: The object ID of the saved model in the database.
"""
import torch
import io
Expand All @@ -169,6 +176,8 @@ def save_torch_model(self, model, custom_metadata: dict = None) -> str:
obj_id = self.save_object(
object=binary_data,
type="ml_model",
task_id=task_id,
workflow_id=workflow_id,
custom_metadata=cm,
)

Expand Down
2 changes: 1 addition & 1 deletion flowcept/instrumentation/decorators/responsible_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def wrapper(*args, **kwargs):
"torch"
].get("save_models", False):
obj_id = DBAPI().save_torch_model(
model, ret["responsible_ai_metadata"]
model, custom_metadata=ret["responsible_ai_metadata"]
)
ret["object_id"] = obj_id
return ret
Expand Down

0 comments on commit 1a26aa0

Please sign in to comment.