diff --git a/app/femr/business/services/system/EncounterService.java b/app/femr/business/services/system/EncounterService.java index 7bbeeddaa..73dead914 100644 --- a/app/femr/business/services/system/EncounterService.java +++ b/app/femr/business/services/system/EncounterService.java @@ -18,12 +18,15 @@ */ package femr.business.services.system; +import femr.business.services.core.ISessionService; +import femr.common.dtos.CurrentUser; import io.ebean.ExpressionList; import io.ebean.Query; import com.google.inject.Inject; import com.google.inject.name.Named; import femr.business.helpers.QueryProvider; import femr.business.services.core.IEncounterService; +import femr.business.services.core.ISessionService; import femr.common.IItemModelMapper; import femr.common.dtos.ServiceResponse; import femr.common.models.*; @@ -46,6 +49,7 @@ public class EncounterService implements IEncounterService { private final IRepository chiefComplaintRepository; private final IPatientRepository patientRepository; private final IEncounterRepository patientEncounterRepository; + private final ISessionService sessionService; private final IRepository patientEncounterTabFieldRepository; private final IRepository tabFieldRepository; private final IUserRepository userRepository; @@ -56,6 +60,7 @@ public class EncounterService implements IEncounterService { public EncounterService(IRepository chiefComplaintRepository, IPatientRepository patientRepository, IEncounterRepository patientEncounterRepository, + ISessionService sessionService, IRepository patientEncounterTabFieldRepository, IRepository tabFieldRepository, IUserRepository userRepository, @@ -65,6 +70,7 @@ public EncounterService(IRepository chiefComplaintRepository, this.chiefComplaintRepository = chiefComplaintRepository; this.patientRepository = patientRepository; this.patientEncounterRepository = patientEncounterRepository; + this.sessionService = sessionService; this.patientEncounterTabFieldRepository = patientEncounterTabFieldRepository; this.tabFieldRepository = tabFieldRepository; this.userRepository = userRepository; @@ -77,7 +83,7 @@ public EncounterService(IRepository chiefComplaintRepository, */ @Override public ServiceResponse createPatientEncounter(int patientId, int userId, Integer tripId, String ageClassification, List chiefComplaints) { - + System.out.println("Create Patient Encounter"); ServiceResponse response = new ServiceResponse<>(); try { @@ -98,8 +104,9 @@ public ServiceResponse createPatientEncounter(int patientI if (patientAgeClassification != null) patientAgeClassificationId = patientAgeClassification.getId(); - IPatientEncounter newPatientEncounter = patientEncounterRepository.createPatientEncounter(patientId, dateUtils.getCurrentDateTime(), nurseUser.getId(), patientAgeClassificationId, tripId); - + CurrentUser currentUserSession = sessionService.retrieveCurrentUserSession(); + String languageCode = currentUserSession.getLanguageCode(); + IPatientEncounter newPatientEncounter = patientEncounterRepository.createPatientEncounter(patientId, dateUtils.getCurrentDateTime(), nurseUser.getId(), patientAgeClassificationId, tripId, languageCode); List chiefComplaintBeans = new ArrayList<>(); Integer chiefComplaintSortOrder = 0; for (String cc : chiefComplaints) { diff --git a/app/femr/data/daos/core/IEncounterRepository.java b/app/femr/data/daos/core/IEncounterRepository.java index e966d3c12..fa6dad7ac 100644 --- a/app/femr/data/daos/core/IEncounterRepository.java +++ b/app/femr/data/daos/core/IEncounterRepository.java @@ -17,7 +17,7 @@ public interface IEncounterRepository { * @param tripId id of the trip, may be null * @return the new patient encounter or null if an error happens */ - IPatientEncounter createPatientEncounter(int patientID, DateTime date, int userId, Integer patientAgeClassificationId, Integer tripId); + IPatientEncounter createPatientEncounter(int patientID, DateTime date, int userId, Integer patientAgeClassificationId, Integer tripId, String languageCode); /** * Deletes a patient's encounter. this is a soft delete diff --git a/app/femr/data/daos/system/EncounterRepository.java b/app/femr/data/daos/system/EncounterRepository.java index cd0a1dd83..a4a7ddc41 100644 --- a/app/femr/data/daos/system/EncounterRepository.java +++ b/app/femr/data/daos/system/EncounterRepository.java @@ -41,7 +41,7 @@ public EncounterRepository(Provider missionTripProvider, * {@inheritDoc} */ @Override - public IPatientEncounter createPatientEncounter(int patientID, DateTime date, int userId, Integer patientAgeClassificationId, Integer tripId){ + public IPatientEncounter createPatientEncounter(int patientID, DateTime date, int userId, Integer patientAgeClassificationId, Integer tripId, String languageCode){ IPatientEncounter patientEncounter = patientEncounterProvider.get(); @@ -57,7 +57,7 @@ public IPatientEncounter createPatientEncounter(int patientID, DateTime date, in patientEncounter.setPatientAgeClassification(Ebean.getReference(patientAgeClassificationProvider.get().getClass(), patientAgeClassificationId)); if (tripId != null) patientEncounter.setMissionTrip(Ebean.getReference(missionTripProvider.get().getClass(), tripId)); - + patientEncounter.setLanguageCode(languageCode); Ebean.save(patientEncounter); }catch (Exception ex){ diff --git a/app/femr/ui/controllers/BackEndControllerHelper.java b/app/femr/ui/controllers/BackEndControllerHelper.java index 2214444a1..dc3db7dde 100644 --- a/app/femr/ui/controllers/BackEndControllerHelper.java +++ b/app/femr/ui/controllers/BackEndControllerHelper.java @@ -1,17 +1,8 @@ package femr.ui.controllers; -import femr.util.translation.TranslationServer; -import femr.util.translation.TranslationJson; import java.io.*; -import java.net.MalformedURLException; import java.util.ArrayList; -import java.net.HttpURLConnection; -import java.net.URL; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.json.JSONArray; -import org.json.JSONObject; public class BackEndControllerHelper { @@ -52,25 +43,4 @@ public static ArrayList executeSpeedTestScript(String absPath) { return speedInfo; } - - public static String translate(String arg, String from, String to) { - String output = ""; - try { - output = TranslationServer.makeServerRequest(arg, from, to); - - //parse translation from JSON - ObjectMapper mapper = new ObjectMapper(); - TranslationJson api = mapper.readValue(output, TranslationJson.class); - output = api.translatedText; - - } catch(MalformedURLException e){ - System.out.println("Malformed URL Exception"); - System.out.println(e.getMessage()); - } catch(IOException e){ - System.out.println("IOException for parsing JSON"); - System.out.println(e.getMessage()); - } - return output; - } - } diff --git a/app/femr/ui/controllers/MedicalController.java b/app/femr/ui/controllers/MedicalController.java index ba3f82e05..d95200c7c 100644 --- a/app/femr/ui/controllers/MedicalController.java +++ b/app/femr/ui/controllers/MedicalController.java @@ -1,5 +1,6 @@ package femr.ui.controllers; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Inject; import controllers.AssetsFinder; import femr.business.services.core.*; @@ -19,6 +20,8 @@ import femr.util.DataStructure.Mapping.TabFieldMultiMap; import femr.util.DataStructure.Mapping.VitalMultiMap; import femr.util.stringhelpers.StringUtils; +import femr.util.translation.TranslationJson; +import femr.util.translation.TranslationServer; import femr.util.translation.TranslationResponseMap; import play.data.Form; import play.data.FormFactory; @@ -271,36 +274,49 @@ public Result editGet(int patientId) { // } public Result translateGet() { - String text = request().getQueryString("text"); - //Harrison Shu - CurrentUser currentUserSession = sessionService.retrieveCurrentUserSession(); - String toLanguage = currentUserSession.getLanguageCode(); + String jsonText = request().getQueryString("text"); + int patientId = Integer.parseInt(request().getQueryString("patientId")); // retrieve current patient encounter encounter - int patientId = Integer.parseInt(request().getQueryString("patientId")); ServiceResponse currentEncounterByPatientId = searchService.retrieveRecentPatientEncounterItemByPatientId(patientId); if (currentEncounterByPatientId.hasErrors()) { throw new RuntimeException(); } - PatientEncounterItem patientEncounter = currentEncounterByPatientId.getResponseObject(); - String fromLanguage = patientEncounter.getLanguageCode(); + //Harrison Shu + String toLanguage = sessionService.retrieveCurrentUserSession().getLanguageCode(); + String fromLanguage = currentEncounterByPatientId.getResponseObject().getLanguageCode(); + // Harrison Shu: Handles the creation of the response map and figures out whether or not to translate - TranslationResponseMap responseMapObject = new TranslationResponseMap(fromLanguage, toLanguage, text); + TranslationResponseMap responseMapObject = new TranslationResponseMap(fromLanguage, toLanguage, jsonText); return ok(responseMapObject.getResponseJson()); } - public String translate(String text, String fromLanguage, String toLanguage) { + +// Calls Python Script to translate + public static String translate(String jsonText, String fromLanguage, String toLanguage) { String data = ""; try { - data = BackEndControllerHelper.translate(text, fromLanguage, toLanguage); + data = TranslationServer.makeServerRequest(jsonText, fromLanguage, toLanguage); + data = parseJsonResponse(data); } catch (Exception e) { e.printStackTrace(); } return data; } + public static String parseJsonResponse(String jsonResponse){ + try{ + ObjectMapper mapper = new ObjectMapper(); + TranslationJson api = mapper.readValue(jsonResponse, TranslationJson.class); + jsonResponse = api.translatedText; + return jsonResponse; + } catch(Exception e){ + System.out.println(e.getMessage()); + throw new RuntimeException(); + } + } /** * Get the populated partial view that represents 1 row of new prescription fields diff --git a/app/femr/util/translation/TranslationResponseMap.java b/app/femr/util/translation/TranslationResponseMap.java index d67ada956..b8fd4fc8f 100644 --- a/app/femr/util/translation/TranslationResponseMap.java +++ b/app/femr/util/translation/TranslationResponseMap.java @@ -1,5 +1,5 @@ package femr.util.translation; -import femr.ui.controllers.BackEndControllerHelper; +import femr.ui.controllers.MedicalController; import java.util.*; import play.libs.Json; @@ -37,7 +37,7 @@ private void populateTranslation() { } else { String data = ""; try { - data = BackEndControllerHelper.translate(text, fromLanguage, toLanguage); + data = MedicalController.translate(text, fromLanguage, toLanguage); } catch (Exception e) { e.printStackTrace(); } diff --git a/app/femr/util/translation/TranslationServer.java b/app/femr/util/translation/TranslationServer.java index fc573ee51..3bb04972c 100644 --- a/app/femr/util/translation/TranslationServer.java +++ b/app/femr/util/translation/TranslationServer.java @@ -4,69 +4,49 @@ import javax.inject.Singleton; import java.io.*; import java.net.HttpURLConnection; -import java.net.MalformedURLException; import java.net.URL; import java.net.URLConnection; -import java.nio.file.Files; -import java.nio.file.Paths; import java.util.Scanner; -import java.net.URLEncoder; import java.nio.charset.StandardCharsets; @Singleton public class TranslationServer { @Inject public TranslationServer(){ - this.start(); + start(timeout); } private static int portNumber = -1; + private static final String timeout = "600"; - //Takes a string, a from code and a to code, and returns the translatedtext - public static String makeServerRequest(String text, String from, String to) throws MalformedURLException { - if(serverNotRunning()){ - start(); - while(serverNotRunning()); //block + public static int getPortFromLog(String logPath, Boolean wait){ + int portNumber = 0; + try{ + File log = new File(logPath); + Scanner s = new Scanner(log); + if(wait){ + while(log.length() == 0); + } + if (s.hasNext()) { + portNumber = Integer.parseInt(s.nextLine().split(": ")[1]); + } + s.close(); } - //Build GET request argument, replacing spaces and newlines - String response = ""; - try { - // Harrison Shu - // Encode the URL String parameter before creating URL to allow arabic and hebrew to be in the URL - String encodedText = URLEncoder.encode(text, StandardCharsets.UTF_8.toString()); - - //Make GET request - URL url = new URL("http://localhost:" + portNumber +"/?text=" + - encodedText + "&from=" + from + "&to=" + to); - HttpURLConnection con = (HttpURLConnection) url.openConnection(); - con.setRequestMethod("GET"); - con.connect(); - - //read response from server - BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream())); - response = in.readLine(); - in.close(); - - con.disconnect(); - } catch(IOException e){ - return makeServerRequest(text, from, to); + catch(FileNotFoundException e){ + return portNumber; } - return response; + catch(Exception e){ + System.out.println("Problem retrieving port number from log"); + throw new RuntimeException(e); + } + return portNumber; } - public static boolean serverNotRunning(){ + public static boolean serverNotRunning(String logPath){ //initial value of portNumber if(portNumber == -1){ - File log = new File("translator/server.log"); - try { - Scanner s = new Scanner(log); - if(s.hasNext()){ - portNumber = Integer.parseInt(s.nextLine().split(": ")[1]); - } - else{ - return true; - } - } catch (FileNotFoundException e) { - throw new RuntimeException(e); + portNumber = getPortFromLog(logPath, false); + if(portNumber == 0){ + return true; } } try{ @@ -79,33 +59,63 @@ public static boolean serverNotRunning(){ return true; } } - public static void start(){ + public static String makeServerRequest(String jsonString, String from, String to) { + String logPath = "translator/server.log"; + if(serverNotRunning(logPath)){ + start(timeout); + while(serverNotRunning(logPath)); //block + } + + String response; + try { + byte[] json = jsonString.getBytes(StandardCharsets.UTF_8); + int length = json.length; + + //Make POST request + URL url = new URL("http://localhost:" + portNumber + "/?from=" + from + "&to=" + to); + HttpURLConnection con = (HttpURLConnection) url.openConnection(); + con.setRequestMethod("POST"); + con.setDoOutput(true); + con.setFixedLengthStreamingMode(length); + con.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); + con.connect(); + + try(OutputStream os = con.getOutputStream()) { + os.write(json); + } + + //read response from server + BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream())); + response = in.readLine(); + in.close(); + + con.disconnect(); + } catch(IOException e){ + System.out.println(e.getMessage()); + return "Translation Unavailable"; + } + return response; + } + public static void start(String timeout) { System.out.println("Starting translation server..."); - if(serverNotRunning()){ - File log = new File("translator/server.log"); - String absPath = "translator/server.py"; + String logPath = "translator/server.log"; + String absPath = "translator/server.py"; + if(serverNotRunning(logPath)){ + File log = new File(logPath); try { - ProcessBuilder pb = new ProcessBuilder("python", absPath); + log.createNewFile(); + ProcessBuilder pb = new ProcessBuilder("python", absPath, timeout); pb.redirectOutput(log); pb.redirectErrorStream(true); pb.start(); - } catch (IOException e) { System.out.println("An I/O error has occurred."); System.out.println(e.getMessage()); } - try { - Scanner s = new Scanner(log); - //Wait for server.log to be written to (port number) - while(log.length() == 0); - portNumber = Integer.parseInt(s.nextLine().split(": ")[1]); - s.close(); - } catch (FileNotFoundException e) { - System.out.println("A FileNotFound error has occurred."); - System.out.println(e.getMessage()); - } + portNumber = getPortFromLog(logPath, true); + } System.out.println("Translation server running!"); } diff --git a/public/js/medical/medical.js b/public/js/medical/medical.js index 33dbcbe6a..0568f5073 100644 --- a/public/js/medical/medical.js +++ b/public/js/medical/medical.js @@ -339,55 +339,60 @@ $(document).ready(function () { textToTranslate = textToTranslate + " @ " + $(jsonObj[i].id).val().replace("@","at"); jsonObj[i].text = $(jsonObj[i].id).val(); } - console.log("text:", textToTranslate); + console.log(jsonObj); // get translation $.ajax({ type: 'get', url: '/translate', - data: {text : textToTranslate, patientId: patientId}, + data: {text : JSON.stringify(jsonObj), patientId: patientId}, success: function(response){ + console.log("response:", response); console.log("translation:", response.translation); - var listTranslated = response.translation.split("@"); - - if (response.translation.split(":")[0] === "SameToSame") { - // same to same (like en to en) + if(response.translation === "SameToSame"){ $("#toggleBtn").remove(); - } else if (response.translation.split(".")[0] === "Translation Unavailable") { - $("#loading").remove(); - $("#toggleBtn").text("Unavailable"); - console.error(response.translation); - } else if (listTranslated.length !== jsonObj.length || response.toLanguageIsRtl) { - console.log("backup translation required (", listTranslated.length, "out of 16 tabs recovered)"); - for (let i = 0; i < jsonObj.length; i++) { - $.ajax({ - type: 'get', - url: '/translate', - data: {text: jsonObj[i].text, patientId: patientId}, - success: function (response) { - if (i === jsonObj.length - 1) { - // end buffering on last field - $("#loading").remove(); - $("#toggleBtn").text("Show Original"); - } - populateField(response.translation, jsonObj, response.fromLanguageIsRtl, response.toLanguageIsRtl, i); - }, - failure: function (result) { - console.error('Failed to fetch backup translation'); - } - }); + } + else{ + var listTranslated = JSON.parse(response.translation); + console.log(listTranslated); + if(listTranslated[0]["text"] === "Translation Unavailable"){ + //option 1 - end buffering + $("#loading").remove(); + $("#toggleBtn").text("Unavailable"); } - } else { - // end buffering if no backup - $("#loading").remove(); - $("#toggleBtn").text("Show Original"); - - // for each field populate them - for (let i = 0; i < jsonObj.length; i++) { - var textOut = listTranslated[i]; - populateField(textOut, jsonObj, response.fromLanguageIsRtl, response.toLanguageIsRtl, i); + //ELSE IF NOT CORRECT ATM + else if (listTranslated.length !== jsonObj.length) { + console.log("backup translation required out of 16 tabs ", listTranslated.length, " recovered"); + for (let i = 0; i < jsonObj.length; i++) { + $.ajax({ + type: 'get', + url: '/translate', + data: {text: jsonObj[i].text, patientId: patientId}, + success: function (response) { + if (i === jsonObj.length - 1) { + // end buffering on last field + $("#loading").remove(); + $("#toggleBtn").text("Show Original"); + } + populateField(response.translation, jsonObj, response.fromLanguageIsRtl, response.toLanguageIsRtl, i); + }, + failure: function (result) { + console.error('Failed to fetch backup translation'); + } + }); + } + } + else{ + $("#loading").remove(); + $("#toggleBtn").text("Show Original"); + + // for each field populate them + for (let i = 0; i < jsonObj.length; i++) { + var textOut = listTranslated[i]["text"]; + populateField(textOut, jsonObj, response.fromLanguageIsRtl, response.toLanguageIsRtl, i); + } } } }, diff --git a/test/unit/app/femr/business/services/TranslationServiceTest.java b/test/unit/app/femr/business/services/TranslationServiceTest.java new file mode 100644 index 000000000..b9cac5a87 --- /dev/null +++ b/test/unit/app/femr/business/services/TranslationServiceTest.java @@ -0,0 +1,115 @@ +package unit.app.femr.business.services; + +import controllers.AssetsFinder; +import femr.business.services.core.*; +import femr.ui.controllers.MedicalController; +import femr.util.translation.TranslationServer; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import play.data.FormFactory; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; + +import static org.mockito.Mockito.mock; + +public class TranslationServiceTest { + AssetsFinder assetsFinder; + FormFactory formFactory; + ITabService tabService; + IEncounterService encounterService; + IMedicationService medicationService; + IPhotoService photoService; + ISessionService sessionService; + ISearchService searchService; + IVitalService vitalService; + MedicalController medicalController; + + @Before + public void setUp(){ + assetsFinder = mock(AssetsFinder.class); + formFactory = mock(FormFactory.class); + tabService = mock(ITabService.class); + encounterService = mock(IEncounterService.class); + sessionService = mock(ISessionService.class); + searchService = mock(ISearchService.class); + medicationService = mock(IMedicationService.class); + photoService = mock(IPhotoService.class); + vitalService = mock(IVitalService.class); + medicalController = new MedicalController( + assetsFinder, + formFactory, + tabService, + encounterService, + medicationService, + photoService, + sessionService, + searchService, + vitalService); + } + + //MedicalController Tests + @Test + public void parseJsonResponseTest() { + String data = "{\"translate_data\" : \"Hello, World\"}"; + Assert.assertEquals("Hello, World", medicalController.parseJsonResponse(data)); + } + + // TranslationServer Tests + @Test + public void getPortFromLogNoFileTest(){ + String logPath = "test/server.log"; + Assert.assertEquals(0, TranslationServer.getPortFromLog(logPath, false)); + } + + @Test + public void getPortFromLogEmptyFileTest() throws IOException { + String logPath = "test/server.log"; + File log = new File(logPath); + log.createNewFile(); + Assert.assertEquals(0, TranslationServer.getPortFromLog(logPath, false)); + log.delete(); + } + + @Test + public void getPortFromLogTest() throws IOException { + String logPath = "test/server.log"; + File log = new File(logPath); + log.createNewFile(); + FileWriter writer = new FileWriter(logPath); + writer.write("Serving at port: 8000"); + writer.close(); + Assert.assertEquals(8000, TranslationServer.getPortFromLog(logPath, false)); + log.delete(); + } + + @Test + public void serverNotRunningTrue() throws IOException { + String logPath = "test/server.log"; + File log = new File(logPath); + log.createNewFile(); + FileWriter writer = new FileWriter(logPath); + writer.write("Serving at port: 8000"); + writer.close(); + Assert.assertTrue(TranslationServer.serverNotRunning("test/server.log")); + log.delete(); + } + + @Test + public void serverNotRunningFalse() { + TranslationServer.start("10"); + Assert.assertFalse(TranslationServer.serverNotRunning("translator/server.log")); + } + @Test + public void makeServerRequestTest() { + TranslationServer.start("10"); + String jsonString = "[{\"id\":\"#complaintInfo\", \"text\":\"Hello, World\"}, " + + "{\"id\":\"#onset\", \"text\":\"Hello, World\"}]"; + String outputString = + "{\"translate_data\": \"[{\\\"id\\\": \\\"#complaintInfo\\\", \\\"text\\\": \\\"Hola, Mundo\\\"}, " + + "{\\\"id\\\": \\\"#onset\\\", \\\"text\\\": \\\"Hola, Mundo\\\"}]\"}"; + Assert.assertEquals(outputString, TranslationServer.makeServerRequest(jsonString, "en", "es")); + } +} diff --git a/translator/server.py b/translator/server.py index 584efbb1a..be659fc1f 100644 --- a/translator/server.py +++ b/translator/server.py @@ -1,107 +1,135 @@ -import os -import sys -import json -import argostranslate.package -import argostranslate.translate -from functools import cached_property -from http.server import BaseHTTPRequestHandler -from urllib.parse import parse_qsl, urlparse -from http.server import HTTPServer -from pathlib import Path -from transformers import MarianMTModel, MarianTokenizer -from typing import Sequence -from libargos import install_packages -import socket -import time - -PORTS = [8000, 5000, 8001, 8002, 8003, 8004, 8005, 8006, 8007, 8008] -TIMEOUT = 3600 -PATH = os.getcwd() - - - -class MarianModel: - def __init__(self, source_lang: str, dest_lang: str) -> None: - path = f"{PATH}/translator/marian_models/opus-mt-{source_lang}-{dest_lang}" - self.model = MarianMTModel.from_pretrained(path, local_files_only = True) - self.tokenizer = MarianTokenizer.from_pretrained(path, local_files_only = True) - - def translate(self, texts: Sequence[str]) -> Sequence[str]: - tokens = self.tokenizer(list(texts), return_tensors="pt", padding=True) - translate_tokens = self.model.generate(**tokens) - return [self.tokenizer.decode(t, skip_special_tokens=True) for t in translate_tokens] - -class WebRequestHandler(BaseHTTPRequestHandler): - @cached_property - def url(self): - return urlparse(self.path) - - @cached_property - def query_data(self): - return dict(parse_qsl(self.url.query)) - - @cached_property - def translate_data(self): - text = self.query_data['text'] - from_code = self.query_data['from'] - to_code = self.query_data['to'] - - # Use Argos if Language Package Exists - if Path(f"{PATH}/translator/argos_models/translate-{from_code}_{to_code}.argosmodel").exists(): - translatedText = argostranslate.translate.translate(text, from_code, to_code) - return translatedText - # Use Marian if Language Package Exists in Marian but not Argos - elif Path(f"{PATH}/translator/marian_models/opus-mt-{from_code}-{to_code}").exists(): - marian = MarianModel(from_code, to_code) - translatedText = marian.translate([text]) - return translatedText[0] - # Use Argos "English in the Middle" if not in Argos and Marian by Default - elif (Path(f"{PATH}/translator/argos_models/translate-{from_code}_en.argosmodel").exists() and \ - Path(f"{PATH}/translator/argos_models/translate-{to_code}_en.argosmodel").exists()) or \ - (Path(f"{PATH}/translator/argos_models/translate-en_{from_code}.argosmodel").exists() and \ - Path(f"{PATH}/translator/argos_models/translate-en_{to_code}.argosmodel").exists()): - translatedText = argostranslate.translate.translate(text, from_code, to_code) - return translatedText - # If a package doesn't exist - else: - return "Translation Unavailable:" + from_code + to_code - - def do_GET(self): - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - self.wfile.write(self.get_response().encode("utf-8")) - - def get_response(self): - return json.dumps( - { - "translate_data" : self.translate_data if self.query_data else "", - }, - ensure_ascii=False - ) - - -def port_open(port): - #connect_ex returns 0 if it connects to a socket meaning port is closed - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) != 0 - -def start_server(port): - try: - server = HTTPServer(("127.0.0.1", port), WebRequestHandler) - server.timeout = TIMEOUT - server.handle_timeout = lambda: (_ for _ in ()).throw(TimeoutError()) - print(f"Serving at port: {port}", file=sys.stderr) - print(f"Server started at {time.strftime('%I:%M')} with timeout: {TIMEOUT} seconds", file=sys.stderr) - while(True): server.handle_request() - except TimeoutError: - print("Translation server timed out") - sys.exit() - -if __name__ == "__main__": - install_packages() - for port in PORTS: - if(port_open(port)): - start_server(port) - - +import os +import sys +import json +import argostranslate.package +import argostranslate.translate +from functools import cached_property +from http.server import BaseHTTPRequestHandler +from urllib.parse import parse_qsl, urlparse +from http.server import HTTPServer +from pathlib import Path +from transformers import MarianMTModel, MarianTokenizer +from typing import Sequence +from libargos import install_packages +import socket +import time + +PORTS = [8000, 5000, 8001, 8002, 8003, 8004, 8005, 8006, 8007, 8008] +TIMEOUT = int(sys.argv[1]) +PATH = os.getcwd() + +class MarianModel: + def __init__(self, source_lang: str, dest_lang: str) -> None: + path = f"{PATH}/translator/marian_models/opus-mt-{source_lang}-{dest_lang}" + self.model = MarianMTModel.from_pretrained(path, local_files_only = True) + self.tokenizer = MarianTokenizer.from_pretrained(path, local_files_only = True) + + def translate(self, texts: Sequence[str]) -> Sequence[str]: + tokens = self.tokenizer(list(texts), return_tensors="pt", padding=True) + translate_tokens = self.model.generate(**tokens) + return [self.tokenizer.decode(t, skip_special_tokens=True) for t in translate_tokens] + +class WebRequestHandler(BaseHTTPRequestHandler): + @cached_property + def url(self): + return urlparse(self.path) + + @cached_property + def query_data(self): + return dict(parse_qsl(self.url.query)) + + @cached_property + def translate_data(self): + #parse request + length = int(self.headers['Content-Length']) + tab_list = json.loads(self.rfile.read(length).decode('utf-8')) + from_code = self.query_data['from'] + to_code = self.query_data['to'] + + #for each tab in tab_list, build delimiter string to translate + num_tabs = len(tab_list) + translate_string = "" + for i in range(num_tabs): + if(i == 0): + translate_string = tab_list[i]["text"] + else: + if(tab_list[i]["text"] != ""): + translate_string = translate_string + " @ " + tab_list[i]["text"].replace("@", "at") + + #translate string and split into list on delimiter + translated_text = self.translate(translate_string, from_code, to_code) + translated_list = list(translated_text.split(" @ ")) + + #for each translation, place text in copy of tab_list to return + out = tab_list + for i in range(len(out)): + if(i < len(translated_list)): + out[i]["text"] = translated_list[i] + else: + out[i]["text"] = "" + return json.dumps(out) + + def translate(self, text, from_code, to_code): + # Use Argos if Language Package Exists + if Path(f"{PATH}/translator/argos_models/translate-{from_code}_{to_code}.argosmodel").exists(): + translatedText = argostranslate.translate.translate(text, from_code, to_code) + return translatedText + # Use Marian if Language Package Exists in Marian but not Argos + elif Path(f"{PATH}/translator/marian_models/opus-mt-{from_code}-{to_code}").exists(): + marian = MarianModel(from_code, to_code) + translatedText = marian.translate([text]) + return translatedText[0] + # Use Argos "English in the Middle" if not in Argos and Marian by Default + elif (Path(f"{PATH}/translator/argos_models/translate-{from_code}_en.argosmodel").exists() and \ + Path(f"{PATH}/translator/argos_models/translate-{to_code}_en.argosmodel").exists()) or \ + (Path(f"{PATH}/translator/argos_models/translate-en_{from_code}.argosmodel").exists() and \ + Path(f"{PATH}/translator/argos_models/translate-en_{to_code}.argosmodel").exists()): + translatedText = argostranslate.translate.translate(text, from_code, to_code) + return translatedText + # If a package doesn't exist + else: + return "Translation Unavailable" + + def get_response(self): + return json.dumps( + { + "translate_data" : self.translate_data if self.query_data else "", + }, + ensure_ascii=False + ) + + def do_GET(self): + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(self.get_response().encode("utf-8")) + + def do_POST(self): + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(self.get_response().encode("utf-8")) + +def port_open(port): + #connect_ex returns 0 if it connects to a socket meaning port is closed + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port)) != 0 + +def start_server(port): + try: + server = HTTPServer(("127.0.0.1", port), WebRequestHandler) + server.timeout = TIMEOUT + server.handle_timeout = lambda: (_ for _ in ()).throw(TimeoutError()) + print(f"Serving at port: {port}", file=sys.stderr) + print(f"Server started at {time.strftime('%I:%M')} with timeout: {TIMEOUT} seconds", file=sys.stderr) + while(True): server.handle_request() + except TimeoutError: + print("Translation server timed out") + sys.exit() + +if __name__ == "__main__": + install_packages() + for port in PORTS: + if(port_open(port)): + start_server(port) + \ No newline at end of file