diff --git a/API/api_worker.py b/API/api_worker.py index 9a8ba6a6..162d47be 100644 --- a/API/api_worker.py +++ b/API/api_worker.py @@ -137,6 +137,32 @@ def zip_binding( return upload_file_path, inside_file_size +class BaseclassTask(celery.Task): + """Base class for celery tasks + + Args: + celery (_type_): _description_ + """ + + def on_failure(self, exc, task_id, args, kwargs, einfo): + """Logic when task fails + + Args: + exc (_type_): _description_ + task_id (_type_): _description_ + args (_type_): _description_ + kwargs (_type_): _description_ + einfo (_type_): _description_ + """ + # exc (Exception) - The exception raised by the task. + # args (Tuple) - Original arguments for the task that failed. + # kwargs (Dict) - Original keyword arguments for the task that failed. + print("{0!r} failed: {1!r}".format(task_id, exc)) + clean_dir = os.path.join(EXPORT_PATH, task_id) + if os.path.exists(clean_dir): + shutil.rmtree(clean_dir) + + @celery.task( bind=True, name="process_raw_data", @@ -165,7 +191,7 @@ def process_raw_data(self, params, user=None): params.use_st_within = True params.file_name = ( - format_file_name_str(params.file_name) if params.file_name else "Export" + format_file_name_str(params.file_name) if params.file_name else "RawExport" ) exportname = f"{params.file_name}_{params.output_type}{f'_uid_{str(self.request.id)}' if params.uuid else ''}" @@ -180,9 +206,9 @@ def process_raw_data(self, params, user=None): file_parts, ) - geom_area, geom_dump, working_dir = RawData(params).extract_current_data( - file_parts - ) + geom_area, geom_dump, working_dir = RawData( + params, str(self.request.id) + ).extract_current_data(file_parts) inside_file_size = 0 polygon_stats = None if "include_stats" in params.dict(): @@ -271,35 +297,11 @@ def process_raw_data(self, params, user=None): return final_response except Exception as ex: + if os.path.exists(os.path.join(EXPORT_PATH, str(self.request.id))): + shutil.rmtree(os.path.join(EXPORT_PATH, str(self.request.id))) raise ex -class BaseclassTask(celery.Task): - """Base class for celery tasks - - Args: - celery (_type_): _description_ - """ - - def on_failure(self, exc, task_id, args, kwargs, einfo): - """Logic when task fails - - Args: - exc (_type_): _description_ - task_id (_type_): _description_ - args (_type_): _description_ - kwargs (_type_): _description_ - einfo (_type_): _description_ - """ - # exc (Exception) - The exception raised by the task. - # args (Tuple) - Original arguments for the task that failed. - # kwargs (Dict) - Original keyword arguments for the task that failed. - print("{0!r} failed: {1!r}".format(task_id, exc)) - clean_dir = os.path.join(EXPORT_PATH, task_id) - if os.path.exists(clean_dir): - shutil.rmtree(clean_dir) - - @celery.task( bind=True, name="process_custom_request", diff --git a/src/app.py b/src/app.py index ece19b49..23c1f61e 100644 --- a/src/app.py +++ b/src/app.py @@ -449,16 +449,9 @@ class RawData: -Osm element type (Optional) """ - def __init__(self, parameters=None, dbdict=None): - if parameters: - # validation for the parameters if it is already validated with - # pydantic model or not , people coming from package they - # will not have api valdiation so to make sure they will be validated - # before accessing the class - # if isinstance(parameters, RawDataCurrentParams) is False: - # self.params = RawDataCurrentParams(**parameters) - # else: - self.params = parameters + def __init__(self, parameters, request_uid="raw-data-api", dbdict=None): + + self.params = parameters # only use connection pooling if it is configured in config file if use_connection_pooling: # if database credentials directly from class is not passed grab from pool @@ -471,6 +464,8 @@ def __init__(self, parameters=None, dbdict=None): self.d_b = Database(dict(dbdict)) self.con, self.cur = self.d_b.connect() + self.base_export_working_dir = os.path.join(export_path, request_uid) + @staticmethod def close_con(con): """Closes connection if exists""" @@ -728,7 +723,7 @@ def extract_current_data(self, exportname): ) = RawData.get_grid_id(self.params.geometry, self.cur) output_type = self.params.output_type # Check whether the export path exists or not - working_dir = os.path.join(export_path, exportname) + working_dir = os.path.join(self.base_export_working_dir, exportname) if not os.path.exists(working_dir): # Create a exports directory because it does not exist os.makedirs(working_dir) @@ -880,6 +875,16 @@ def get_osm_feature(self, osm_id): self.cur.close() return FeatureCollection(features=features) + def cleanup(self): + """ + Cleans up temporary resources. + """ + + if os.path.exists(self.base_export_working_dir): + shutil.rmtree(self.base_export_working_dir) + return True + return False + class S3FileTransfer: """Responsible for the file transfer to s3 from API maachine"""