Skip to content

Commit

Permalink
Remove race-condition issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkidu93 committed Oct 31, 2023
1 parent fdc7ce0 commit c944705
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 66 deletions.
5 changes: 5 additions & 0 deletions samples/ServalApp/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def __repr__(self):
return self.__str__()


class Param(Base):
__tablename__ = "params"
param_name = Column("param_name", String, primary_key=True)


def create_db_if_not_exists():
engine = create_engine("sqlite:///builds.db")
metadata.create_all(bind=engine)
155 changes: 89 additions & 66 deletions samples/ServalApp/serval_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from time import sleep

import streamlit as st
from db import Build, State, create_db_if_not_exists
from db import Build, State, Param, create_db_if_not_exists
from serval_auth_module import ServalBearerAuth
from serval_client_module import (
PretranslateCorpusConfig,
Expand Down Expand Up @@ -33,46 +33,60 @@ def send_emails():
session = Session()
try:

def started(build: Build, email_server: ServalAppEmailServer, data=None):
def started(build: Build, data=None):
logger.info(f"Started:\n{build}")
email_server.send_build_started_email(build.email, str(build))
session.delete(build)
session.add(
Build(
build_id=build.build_id,
engine_id=build.engine_id,
email=build.email,
state=State.Active,
corpus_id=build.corpus_id,
with ServalAppEmailServer(
os.environ.get("SERVAL_APP_EMAIL_PASSWORD")
) as email_server:
email_server.send_build_started_email(build.email, str(build))
session.delete(build)
session.add(
Build(
build_id=build.build_id,
engine_id=build.engine_id,
email=build.email,
state=State.Active,
corpus_id=build.corpus_id,
)
)
)
session.commit()
logger.info("Email sent and build updated.")

def faulted(build: Build, email_server: ServalAppEmailServer, data=None):
def faulted(build: Build, data=None):
logger.warn(f"Faulted:\n{build}")
email_server.send_build_faulted_email(build.email, str(build), error=data)
session.delete(build)
with ServalAppEmailServer(
os.environ.get("SERVAL_APP_EMAIL_PASSWORD")
) as email_server:
email_server.send_build_faulted_email(
build.email, str(build), error=data
)
session.delete(build)
session.commit()
logger.info("Email sent and build deleted.")

def completed(build: Build, email_server: ServalAppEmailServer, data=None):
def completed(build: Build, data=None):
logger.info(f"Completed:\{build}")
pretranslations = client.translation_engines_get_all_pretranslations(
build.engine_id, build.corpus_id
)
email_server.send_build_completed_email(
build.email,
"\n".join(
[
f"{'|'.join(pretranslation.refs)}\t{pretranslation.translation}"
for pretranslation in pretranslations
]
),
str(build),
)
session.delete(build)
with ServalAppEmailServer(
os.environ.get("SERVAL_APP_EMAIL_PASSWORD")
) as email_server:
email_server.send_build_completed_email(
build.email,
"\n".join(
[
f"{'|'.join(pretranslation.refs)}\t{pretranslation.translation}"
for pretranslation in pretranslations
]
),
str(build),
)
session.delete(build)
session.commit()
logger.info("Email sent and build deleted.")

def default_update(build: Build, email_server: ServalAppEmailServer, data=None):
def default_update(build: Build, data=None):
logger.info(f"Updated:\n{build}")

serval_auth = ServalBearerAuth()
Expand All @@ -85,80 +99,89 @@ def default_update(build: Build, email_server: ServalAppEmailServer, data=None):
"Canceled": faulted,
}

def get_update(build: Build, email_server: ServalAppEmailServer):
def get_update(build: Build):
build_update = client.translation_engines_get_build(
id=build.engine_id, build_id=build.build_id
)
if build.state == State.Pending and build_update.state == "Active":
started(build, email_server)
started(build)
else:
responses.get(build_update.state, default_update)(
build, email_server, build_update.message
build, build_update.message
)
session.commit()

def send_updates(email_server: ServalAppEmailServer):
def send_updates():
logger.info("Checking for updates...")
with session.no_autoflush:
builds = session.query(Build).all()
for build in builds:
try:
get_update(build, email_server)
get_update(build)
except Exception as e:
logger.error(
f"Failed to update {build} because of exception {e}"
)
raise e

with ServalAppEmailServer(
os.environ.get("SERVAL_APP_EMAIL_PASSWORD")
) as email_server:
while True:
send_updates(email_server)
sleep(int(os.environ.get("SERVAL_APP_UPDATE_FREQ_SEC", 300)))
while True:
send_updates()
sleep(int(os.environ.get("SERVAL_APP_UPDATE_FREQ_SEC", 300)))

except Exception as e:
logger.exception(e)
st.session_state["background_process_has_started"] = False
session.delete(
session.query(Param)
.where(Param.param_name == "background_process_has_started")
.first()
)
session.commit()


if not st.session_state.get("background_process_has_started", False):
engine = create_engine("sqlite:///builds.db")
Session = sessionmaker(bind=engine)
session = Session()

background_process_has_started = (
session.query(Param)
.where(Param.param_name == "background_process_has_started")
.first()
)
if not background_process_has_started:
cron_thread = Thread(target=send_emails)
add_script_run_ctx(cron_thread)
cron_thread.start()
st.session_state["background_process_has_started"] = True
session.add(Param(param_name="background_process_has_started"))
session.commit()

serval_auth = None
if not st.session_state.get("authorized", False):
with st.form(key="Authorization Form"):
st.session_state["client_id"] = st.text_input(label="Client ID")
st.session_state["client_secret"] = st.text_input(
label="Client Secret", type="password"
)
logger.info("HERE_")
if st.form_submit_button("Authorize"):
st.session_state["authorized"] = True
st.rerun()
try:
st.session_state["serval_auth"] = ServalBearerAuth(
client_id=st.session_state["client_id"]
if st.session_state["client_id"] != ""
else "<invalid>",
client_secret=st.session_state["client_secret"]
if st.session_state["client_secret"] != ""
else "<invalid>",
)
st.session_state["authorized"] = True
st.rerun()
except ValueError:
st.session_state["authorized"] = False
st.session_state["authorization_failure"] = True
if st.session_state.get("authorization_failure", False):
st.error("Invalid credentials. Please check your credentials.")
else:
try:
serval_auth = ServalBearerAuth(
client_id=st.session_state["client_id"]
if st.session_state["client_id"] != ""
else "<invalid>",
client_secret=st.session_state["client_secret"]
if st.session_state["client_secret"] != ""
else "<invalid>",
)
except ValueError:
st.session_state["authorized"] = False
st.session_state["authorization_failure"] = True
st.rerun()
client = RemoteCaller(
url_prefix=os.environ.get("SERVAL_HOST_URL"), auth=serval_auth
url_prefix=os.environ.get("SERVAL_HOST_URL"),
auth=st.session_state["serval_auth"],
)
engine = create_engine("sqlite:///builds.db")
Session = sessionmaker(bind=engine)
session = Session()

def submit():
engine = json.loads(
Expand Down Expand Up @@ -278,6 +301,7 @@ def already_active_build_for(email: str, client: str):
st.subheader("Neural Machine Translation")

tried_to_submit = st.session_state.get("tried_to_submit", False)
submitted = False
with st.form(key="NmtTranslationForm"):
st.session_state["build_name"] = st.text_input(
label="Build Name", placeholder="MyBuild (Optional)"
Expand Down Expand Up @@ -339,8 +363,7 @@ def already_active_build_for(email: str, client: str):
st.session_state[
"error"
] = "There is already an a pending or active build associated with this email address and client id. \
Please wait for the previous build to finish."
st.rerun()
Please wait for the previous build to finish."
elif (
st.session_state["source_language"] != ""
and st.session_state["target_language"] != ""
Expand All @@ -353,14 +376,14 @@ def already_active_build_for(email: str, client: str):
st.toast(
"Translations are on their way! You'll receive an email when your translation job has begun."
)
st.session_state["error"] = None
sleep(4)
st.rerun()
else:
st.session_state["tried_to_submit"] = True
st.session_state[
"error"
] = "Some required fields were left blank. Please fill in all fields above"
st.rerun()
st.markdown(
f"<sub>\* Use IETF tags if possible. See [here](https://en.wikipedia.org/wiki/IETF_language_tag) \
for more information on IETF tags. For more details, see [the Serval API documentation]({os.environ.get('SERVAL_HOST_URL')}/swagger/index.html#/Translation%20Engines/TranslationEngines_Create).</sub>",
Expand Down
1 change: 1 addition & 0 deletions samples/ServalApp/serval_email_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __enter__(self):
return self

def __exit__(self, *args):
self.server.quit()
self.server.close()

def send_build_completed_email(
Expand Down

0 comments on commit c944705

Please sign in to comment.