Skip to content

Commit

Permalink
Merge pull request #35 from ORNL/dev
Browse files Browse the repository at this point in the history
Main < Dev
  • Loading branch information
renan-souza authored Feb 7, 2023
2 parents b36fa1b + 97bbe91 commit 65ded87
Show file tree
Hide file tree
Showing 23 changed files with 7,508 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ mlflow.db
tests/plugins/mnist_train
tests/plugins/tensorboard_events/
tensorboard_events/
notebooks/*
**/*.DS_Store*
**/*.log*
notebooks/.ipynb_checkpoints
1 change: 1 addition & 0 deletions extra_requirements/api-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
requests
12 changes: 8 additions & 4 deletions extra_requirements/dev-requirements-mac.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
tensorboard==2.10.1
tensorflow-macos==2.10.0
pandas==1.5.1
tbparse==0.0.7
pytest==6.2.4
flake8==5.0.4
black==23.1.0
numpy==1.23.4
tensorboard==2.11.0
tensorflow-macos==2.11.0
bokeh==2.4.2
jupyterlab
2 changes: 2 additions & 0 deletions extra_requirements/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ black==23.1.0
numpy==1.23.4
tensorboard==2.11.0
tensorflow==2.11.0
bokeh==2.4.2
jupyterlab==3.6.1
23 changes: 21 additions & 2 deletions flowcept/commons/doc_db/document_db_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@ def __init__(self):
db = client[MONGO_DB]
self._collection = db[MONGO_COLLECTION]

def find(self, filter_: Dict) -> List[Dict]:
def find(
self,
filter: dict,
projection=None,
limit=None,
sort=None,
remove_json_unserializables=True,
) -> List[Dict]:
if limit is None:
limit = 0

if remove_json_unserializables:
projection = {"_id": 0, "timestamp": 0}

try:
lst = list()
for doc in self._collection.find(filter_):
for doc in self._collection.find(
filter=filter, projection=projection, limit=limit, sort=sort
):
lst.append(doc)
return lst
except Exception as e:
Expand Down Expand Up @@ -69,12 +84,16 @@ def insert_and_update_many(
return False

def delete_ids(self, ids_list: List[ObjectId]):
if type(ids_list) != list:
ids_list = [ids_list]
try:
self._collection.delete_many({"_id": {"$in": ids_list}})
except Exception as e:
print("Error when deleting documents.", e)

def delete_keys(self, key_name, keys_list: List[ObjectId]):
if type(keys_list) != list:
keys_list = [keys_list]
try:
self._collection.delete_many({key_name: {"$in": keys_list}})
except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion flowcept/commons/flowcept_data_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, AnyStr, Any, Union
from typing import Dict, AnyStr, Any, Union, List


class Status(str, Enum): # inheriting from str here for JSON serialization
Expand Down Expand Up @@ -40,6 +40,8 @@ class TaskMessage:
private_ip: AnyStr = None
sys_name: AnyStr = None
address: AnyStr = None
dependencies: List = None
dependents: List = None

# def __init__(self,
# task_id: AnyStr = None, # Any way to identify a task
Expand Down
6 changes: 6 additions & 0 deletions flowcept/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@
PUBLIC_IP = os.getenv("PUBLIC_IP", external_ip)

PRIVATE_IP = os.getenv("PRIVATE_IP", socket.gethostbyname(socket.getfqdn()))


#### Web Server

WEBSERVER_HOST = os.getenv("WEBSERVER_HOST", "0.0.0.0")
WEBSERVER_PORT = int(os.getenv("WEBSERVER_PORT", "5000"))
67 changes: 67 additions & 0 deletions flowcept/flowcept_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from threading import Thread
from time import sleep

from flowcept.commons.doc_db.document_inserter import DocumentInserter

import json
from typing import List, Dict
import requests

from flowcept.configs import WEBSERVER_PORT, WEBSERVER_HOST
from flowcept.flowcept_webserver.app import BASE_ROUTE
from flowcept.flowcept_webserver.resources.query_rsrc import TaskQuery


class FlowceptConsumerAPI(object):
def __init__(self):
self._consumer_thread: Thread = None

def start(self):
self._consumer_thread = Thread(target=DocumentInserter().main)
self._consumer_thread.start()
print("Flowcept Consumer starting...")
sleep(2)
print("Ok, we're consuming messages!")

# def close(self):
# self._consumer_thread.join()


class TaskQueryAPI(object):
def __init__(
self,
host: str = WEBSERVER_HOST,
port: int = WEBSERVER_PORT,
auth=None,
):
self._host = host
self._port = port
self._url = (
f"http://{self._host}:{self._port}{BASE_ROUTE}{TaskQuery.ROUTE}"
)

def query(
self,
filter: dict,
projection: dict = None,
limit: int = 0,
sort: dict = None,
remove_json_unserializables=True,
) -> List[Dict]:
request_data = {"filter": json.dumps(filter)}
if projection:
request_data["projection"] = json.dumps(projection)
if limit:
request_data["limit"] = limit
if sort:
request_data["sort"] = json.dumps(sort)
if remove_json_unserializables:
request_data[
"remove_json_unserializables"
] = remove_json_unserializables

r = requests.post(self._url, json=request_data)
if 200 <= r.status_code < 300:
return r.json()
else:
raise Exception(r.text)
1 change: 1 addition & 0 deletions flowcept/flowcept_api/task_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion flowcept/flowcept_consumer/consumer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def curate_task_msg(task_msg_dict: dict):
field_val_dict[f"arg{i}"] = arg
i += 1
else: # Scalar value
field_val_dict["arg1"] = field_val
field_val_dict["arg0"] = field_val
task_msg_dict[field] = field_val_dict


Expand Down
10 changes: 7 additions & 3 deletions flowcept/flowcept_webserver/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from flask_restful import Api
from flask import Flask, request, jsonify

from flowcept.configs import WEBSERVER_HOST, WEBSERVER_PORT
from flowcept.flowcept_webserver.resources.query_rsrc import TaskQuery
from flowcept.flowcept_webserver.resources.task_messages_rsrc import (
TaskMessages,
)


QUERY_ROUTE = "/query"
BASE_ROUTE = "/api"
app = Flask(__name__)
api = Api(app)
api.add_resource(TaskMessages, f"{QUERY_ROUTE}/{TaskMessages.ROUTE}")

api.add_resource(TaskMessages, f"{BASE_ROUTE}/{TaskMessages.ROUTE}")
api.add_resource(TaskQuery, f"{BASE_ROUTE}/{TaskQuery.ROUTE}")


@app.route("/")
Expand All @@ -18,4 +22,4 @@ def liveness():


if __name__ == "__main__":
app.run()
app.run(host=WEBSERVER_HOST, port=WEBSERVER_PORT)
32 changes: 32 additions & 0 deletions flowcept/flowcept_webserver/resources/query_rsrc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from flask_restful import Resource, reqparse

from flowcept.commons.doc_db.document_db_dao import DocumentDBDao


class TaskQuery(Resource):
ROUTE = "/task_query"

def post(self):
parser = reqparse.RequestParser()
req_args = ["filter", "projection", "sort", "limit"]
for arg in req_args:
parser.add_argument(arg, type=str, required=False, help="")
args = parser.parse_args()

doc_args = {}
for arg in args:
if args[arg] is None:
continue
try:
doc_args[arg] = json.loads(args[arg])
except Exception as e:
return (f"Could not parse {arg} argument: {e}"), 400

dao = DocumentDBDao()
docs = dao.find(**doc_args)

if docs is not None and len(docs):
return docs, 201
else:
return (f"Could not find matching docs"), 404
19 changes: 10 additions & 9 deletions flowcept/flowcept_webserver/resources/task_messages_rsrc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flask import Flask
from flask_restful import Resource, Api, reqparse
from flask import jsonify, request
from flask_restful import Resource

from flowcept.commons.doc_db.document_db_dao import DocumentDBDao

Expand All @@ -8,14 +8,15 @@ class TaskMessages(Resource):
ROUTE = "/task_messages"

def get(self):
parser = reqparse.RequestParser()
parser.add_argument("task_id", type=str, required=False) # add args
args = parser.parse_args()

args = request.args
task_id = args.get("task_id", None)
filter = {}
if "task_id" in args["task_id"]:
filter = {"task_id": args["task_id"]}
if task_id is not None:
filter = {"task_id": task_id}

dao = DocumentDBDao()
docs = dao.find(filter)
return {docs}, 200
if len(docs):
return jsonify(docs), 201
else:
return "No tasks found.", 404
15 changes: 14 additions & 1 deletion flowcept/flowceptor/plugins/dask/dask_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@ def _get_arg(arg_name):
arg_val = _get_arg("kwargs")
if arg_val is not None:
picked_kwargs = pickle.loads(arg_val)
if "workflow_id" in picked_kwargs:
task_msg.workflow_id = picked_kwargs.pop("workflow_id")
if len(picked_kwargs):
task_msg.used.update(picked_kwargs)


def get_task_deps(task_state, task_msg: TaskMessage):
if len(task_state.dependencies):
task_msg.dependencies = [t.key for t in task_state.dependencies]
if len(task_state.dependents):
task_msg.dependents = [t.key for t in task_state.dependents]


class DaskSchedulerInterceptor(BaseInterceptor):
def __init__(self, scheduler, plugin_key="dask"):
self._scheduler = scheduler
Expand Down Expand Up @@ -59,12 +68,16 @@ def callback(self, task_id, start, finish, *args, **kwargs):
task_msg = TaskMessage()
task_msg.task_id = task_id
task_msg.custom_metadata = {
"scheduler": self._scheduler.address_safe
"scheduler": self._scheduler.address_safe,
"scheduler_id": self._scheduler.id,
"scheduler_pid": self._scheduler.proc.pid,
}
task_msg.status = Status.SUBMITTED
if self.settings.scheduler_create_timestamps:
task_msg.utc_timestamp = get_utc_now()

get_task_deps(ts, task_msg)

if hasattr(ts, "group_key"):
task_msg.activity_id = ts.group_key

Expand Down
1 change: 1 addition & 0 deletions flowcept/flowceptor/plugins/zambeze/zambeze_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def prepare_task_msg(self, zambeze_msg: Dict) -> TaskMessage:
task_msg.experiment_id = zambeze_msg.get("campaign_id")
task_msg.task_id = zambeze_msg.get("activity_id")
task_msg.activity_id = zambeze_msg.get("name")
task_msg.dependencies = zambeze_msg.get("depends_on")
task_msg.custom_metadata = {"command": zambeze_msg.get("command")}
task_msg.status = get_status_from_str(
zambeze_msg.get("activity_status")
Expand Down
Loading

0 comments on commit 65ded87

Please sign in to comment.