Skip to content

Commit

Permalink
Replace Flask with FastAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Nov 20, 2024
1 parent 3b603d4 commit 128f7e7
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 62 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ types = [
everest = [
"progressbar2",
"ruamel.yaml",
"flask",
"fastapi",
"decorator",
"resdata",
"colorama",
Expand Down Expand Up @@ -235,6 +235,11 @@ allowed-confusables = ["–"]
[tool.ruff.lint.pylint]
max-args = 20

[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = [
"fastapi.Depends",
]

[tool.pyright]
include = ["src"]
exclude = ["tests"]
Expand Down
134 changes: 73 additions & 61 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,27 @@
import traceback
from base64 import b64encode
from datetime import datetime, timedelta
from functools import partial, wraps
from functools import partial

# from flask import Flask, Response, jsonify, request
import uvicorn
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from dns import resolver, reversename
from flask import Flask, Response, jsonify, request
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
JSONResponse,
PlainTextResponse,
Response,
)
from fastapi.security import (
HTTPBasic,
HTTPBasicCredentials,
)
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
Expand All @@ -37,7 +49,7 @@
from everest.util import configure_logger, makedirs_if_needed, version_info


def get_machine_name():
def _get_machine_name() -> str:
"""Returns a name that can be used to identify this machine in a network
A fully qualified domain name is returned if available. Otherwise returns
Expand Down Expand Up @@ -88,71 +100,70 @@ def _opt_monitor(shared_data=None):
return "stop_optimization"


def _everserver_thread(shared_data, server_config):
app = Flask(__name__)

def check_user(password):
return password == server_config["authentication"]
def _everserver_thread(shared_data, server_config) -> None:
app = FastAPI()
security = HTTPBasic()

def requires_authenticated(f):
@wraps(f)
def decorated(*args, **kwargs):
auth = request.authorization
if not auth or not check_user(auth.password):
return "unauthorized", 401
return f(*args, **kwargs)

return decorated

def log(f):
@wraps(f)
def decorated(*args, **kwargs):
url = request.path
method = request.method
ip = request.environ.get("HTTP_X_REAL_IP", request.remote_addr)
logging.getLogger("everserver").info(
"{} entered from {} with HTTP {}".format(url, ip, method)
def _check_user(credentials: HTTPBasicCredentials) -> None:
if credentials.password != server_config["authentication"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)
return f(*args, **kwargs)

return decorated

@app.route("/")
@requires_authenticated
@log
def get_home():
return "Everest is running"
def _log(request: Request) -> None:
logging.getLogger("everserver").info(
f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}"
)

@app.route("/" + STOP_ENDPOINT, methods=["POST"])
@requires_authenticated
@log
def stop():
@app.get("/")
def get_status(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> PlainTextResponse:
_log(request)
_check_user(credentials)
return PlainTextResponse("Everest is running")

@app.post("/" + STOP_ENDPOINT)
def stop(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> Response:
_log(request)
_check_user(credentials)
shared_data[STOP_ENDPOINT] = True
return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200)

@app.route("/" + SIM_PROGRESS_ENDPOINT)
@requires_authenticated
@log
def get_sim_progress():
return jsonify(shared_data[SIM_PROGRESS_ENDPOINT])

@app.route("/" + OPT_PROGRESS_ENDPOINT)
@requires_authenticated
@log
def get_opt_progress():
@app.get("/" + SIM_PROGRESS_ENDPOINT)
def get_sim_progress(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
_log(request)
_check_user(credentials)
progress = shared_data[SIM_PROGRESS_ENDPOINT]
return JSONResponse(jsonable_encoder(progress))

@app.get("/" + OPT_PROGRESS_ENDPOINT)
def get_opt_progress(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
_log(request)
_check_user(credentials)
progress = get_opt_status(server_config["optimization_output_dir"])
return jsonify(progress)

ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.load_cert_chain(
server_config["cert_path"],
server_config["key_path"],
server_config["key_passwd"],
return JSONResponse(jsonable_encoder(progress))

uvicorn.run(
app,
host="0.0.0.0",
port=server_config["port"],
ssl_keyfile=server_config["key_path"],
ssl_certfile=server_config["cert_path"],
ssl_version=ssl.PROTOCOL_SSLv23,
ssl_keyfile_password=server_config["key_passwd"],
)
app.run(host="0.0.0.0", port=server_config["port"], ssl_context=ctx)


def _find_open_port(host, lower, upper):
def _find_open_port(host, lower, upper) -> int:
for port in range(lower, upper):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -168,7 +179,7 @@ def _find_open_port(host, lower, upper):
raise Exception(msg)


def _write_hostfile(host_file_path, host, port, cert, auth):
def _write_hostfile(host_file_path, host, port, cert, auth) -> None:
if not os.path.exists(os.path.dirname(host_file_path)):
os.makedirs(os.path.dirname(host_file_path))
data = {
Expand All @@ -187,7 +198,7 @@ def _configure_loggers(
detached_node_dir: str,
everest_logs_dir: str,
logging_level: int,
):
) -> None:
configure_logger(
name="res",
file_path=os.path.join(detached_node_dir, "simulations.log"),
Expand Down Expand Up @@ -249,7 +260,7 @@ def main():
cert_path, key_path, key_pw = _generate_certificate(
ServerConfig.get_certificate_dir(config.output_dir)
)
host = get_machine_name()
host = _get_machine_name()
port = _find_open_port(host, lower=5000, upper=5800)
_write_hostfile(host_file, host, port, cert_path, authentication)

Expand Down Expand Up @@ -344,6 +355,7 @@ def main():
message=traceback.format_exc(),
)
return

update_everserver_status(status_path, ServerStatus.completed, message=message)


Expand Down Expand Up @@ -404,7 +416,7 @@ def _generate_certificate(cert_folder: str):
)

# Generate the certificate and sign it with the private key
cert_name = get_machine_name()
cert_name = _get_machine_name()
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"),
Expand Down

0 comments on commit 128f7e7

Please sign in to comment.