Skip to content

Commit

Permalink
Standardize data_location param
Browse files Browse the repository at this point in the history
- Update debug_progress_bar from data_source
- Apply CLI corrections
- Fix error throwing and handling
- Write results directly in file
  • Loading branch information
panos-span committed Jun 9, 2024
1 parent 41bac53 commit 03b0beb
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 88 deletions.
45 changes: 14 additions & 31 deletions src/alexandria3k/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,18 @@ def download(args):
"""Download data using the specified data source."""
args.validate_args(args)
data_source_instance = get_data_source_instance(args)
if not hasattr(data_source_instance, "download"):
raise Alexandria3kError(
f"The data source {args.data_name} does not support downloading"
)
data_source_instance.download(
args.database, args.sql_query, args.data_source, *args.extra_args
datbase=args.database,
sql_query=args.sql_query,
data_location=args.data_location,
*args.extra_args,
)
perf.log(f"Data downloaded and saved to {args.data_source}")
perf.log(f"Data downloaded and saved to {args.data_location}")


def validate_args(args):
"""Validate that both database and sql_query are either both provided or both omitted."""
if (args.database and not args.sql_query) or (
args.sql_query and not args.database
):
if bool(args.database) != bool(args.sql_query):
raise argparse.ArgumentTypeError(
"Both --database and --sql-query must be provided together or not at all."
)
Expand Down Expand Up @@ -146,7 +143,7 @@ def add_subcommand_download(subparsers):
help="SQL query to retrieve the data for downloading",
)
parser.add_argument(
"data_source",
"data_location",
type=str,
nargs=1,
help="File path to save the downloaded data",
Expand All @@ -165,16 +162,6 @@ def populate(args):
"""Populate the specified database from the specified data source."""

data_source_instance = get_data_source_instance(args)
if hasattr(data_source_instance, "download"):
# Check if the ouput_path attribute is not None
if (
data_source_instance.data_source is None
and args.data_source is None
):
raise Alexandria3kError(
"Data Source is not set. Please ensure the download"
"method has been called and data_source is set."
)

if args.row_selection_file:
with open(args.row_selection_file, encoding="utf-8") as file:
Expand All @@ -189,11 +176,11 @@ def populate(args):
populate_args = [args.database, args.columns, args.row_selection]

# Check if the method accepts the input_file_path parameter
if "data_source" in parameters and args.data_source:
populate_args.append(args.data_source)
elif "data_source" not in parameters and args.data_source:
if "data_location" in parameters and args.data_location:
populate_args.append(args.data_location)
elif "data_location" not in parameters and args.data_location:
raise Alexandria3kError(
f"Method {populate_method} does not accept the data_source parameter."
f"Method {populate_method} does not accept the data_location parameter."
f"for the data source {args.data_name}"
)

Expand All @@ -216,14 +203,10 @@ def add_subcommand_populate(subparsers):
help="Name of the data source to use",
)
parser.add_argument(
"data_location", nargs="?", help="Path or URL of the source's data"
)
# Add optional argument input_file_path
parser.add_argument(
"data_source",
type=str,
"data_location",
nargs="?",
help="File path to the downloaded data",
type=str,
help="Path or URL of the source's data",
)
parser.add_argument(
"-a",
Expand Down
13 changes: 7 additions & 6 deletions src/alexandria3k/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ def __init__(self, table, get_file_cache):
super().__init__(table)
self.get_file_cache = get_file_cache

def debug_progress_bar(self):
def debug_progress_bar(self, current_progress=None, total_length=None):
"""Print a progress bar"""
total_length = len(self.table.data_source)
current_progress = self.file_index + 1
if current_progress is None:
total_length = len(self.table.data_source)
current_progress = self.file_index + 1

percent = current_progress / total_length * 100
progress_marker = int(
Expand Down Expand Up @@ -605,15 +606,15 @@ def get_query_column_names(self):
"""Return the column names associated with an executing query"""
return [description[0] for description in self.cursor.description]

def download(self, database, data_source, sql_query=None):
def download(self, database, data_location, sql_query=None):
"""
Download the data source to the specified SQLite database.
:param database: The path specifying the SQLite database to populate.
:type database: str
:param data_source: The data source to use for populating the database.
:type data_source: object
:param data_location : The data location to use for populating the database.
:type data_location: object
:param sql_query: An SQL `SELECT` query specifying the required data,
defaults to `None`.
Expand Down
77 changes: 27 additions & 50 deletions src/alexandria3k/data_sources/issn_subject_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from alexandria3k.db_schema import ColumnMeta, TableMeta
from alexandria3k import perf, debug
from alexandria3k.common import ensure_table_exists, get_string_resource
from alexandria3k.data_source import PROGRESS_BAR_LENGTH
from alexandria3k.common import Alexandria3kError
from alexandria3k.data_source import FilesCursor
from alexandria3k.common import warn

issn_subject_codes_table = TableMeta(
"issn_subject_codes",
Expand Down Expand Up @@ -90,10 +92,6 @@ def get_config_path(config_path):
if config_path:
return config_path
config_path = os.path.expanduser("~/.config/pybliometrics.cfg")
if not os.path.exists(config_path):
raise FileNotFoundError(
f"Configuration file not found at {config_path}"
)
debug.log("config-file", f"Using config file at {config_path}")
return config_path

Expand All @@ -114,76 +112,55 @@ def execute_sql_query(self, cursor, script):
issns = [row[0] for row in cursor.fetchall()]
return issns

def debug_progress_bar(self, current_progress, total_length):
"""Print a progress bar"""
percent = current_progress / total_length * 100
progress_marker = int(
PROGRESS_BAR_LENGTH * current_progress / total_length
)
progress_bar = "#" * progress_marker + "-" * (
PROGRESS_BAR_LENGTH - progress_marker
)
debug.log(
"progress_bar",
f"\r[{progress_bar}] {percent:.2f}% | "
f"Downloaded {current_progress} out of {total_length} ISSNs",
end="",
)

def fetch_subject_codes(self, writer, issns):
"""Fetch the subject codes for the specified ISSNs."""
total_issns = len(issns)
for index, issn in enumerate(issns):
self.debug_progress_bar(index + 1, total_issns)
FilesCursor.debug_progress_bar(
self, current_progress=index + 1, total_length=total_issns
)
query = {"issn": issn}
try:
serial_search = SerialSearch(query=query, view="STANDARD")
results = list(serial_search.results)

subject_area_codes_set = set()

for result in results:
if (
"subject_area_codes" in result
and result["subject_area_codes"]
):
subject_area_codes = result[
"subject_area_codes"
].split(";")
subject_area_codes_set.update(subject_area_codes)

for code in subject_area_codes_set:
writer.writerow([issn, code])

except (KeyError, ValueError) as e:
print(f"Error processing ISSN {issn}: {e}")

def download(self, database, data_source, sql_query=None):
for code in (result.get("subject_area_codes")).split(";"):
writer.writerow([issn, code])

except (KeyError, ValueError):
warn(f"Error processing ISSN {issn}")

def download(self, database, data_location, sql_query=None):
"""
Create a CSV file with ISSNs and their corresponding ASJC subject codes from API calls.
:param database: The path specifying the SQLite database to use in order to fetch the ISSNs.
:type database: str
:param data_source: The path specifying the CSV file to use for the population.
:type data_source: str
:param data_location: The path specifying the CSV file to use for the population.
:type data_location: str
:param sql_query: The SQL query to use in order to fetch the ISSNs, defaults to `None`.
The default query fetches all unique ISSNs from the database. The query should
return a single column with the name `issn`.
:type sql_query: str, optional
"""
data_source = os.path.join(data_source, "issn_subject_codes.csv")

try:
pybliometrics.scopus.init()
except RuntimeError as e:
raise e
raise Alexandria3kError(
"Error in downloading data with pybliometrics"
) from e
connection = sqlite3.connect(database)
ensure_table_exists(connection, "works")
cursor = connection.cursor()
script = sql_query or get_string_resource("sql/get-issns.sql")
issns = self.execute_sql_query(cursor, script)

with open(data_source, mode="w", newline="", encoding="utf-8") as file:
with open(
data_location, mode="w", newline="", encoding="utf-8"
) as file:
writer = csv.writer(file)
writer.writerow(["issn", "subject_code"])
self.fetch_subject_codes(writer, issns)
Expand All @@ -192,7 +169,7 @@ def download(self, database, data_source, sql_query=None):
perf.log("create_csv_from_api")

def populate(
self, database_path, columns=None, condition=None, data_source=None
self, database_path, columns=None, condition=None, data_location=None
):
"""
Populate the SQLite database with data from the CSV file.
Expand All @@ -207,13 +184,13 @@ def populate(
population, defaults to `None`.
:type condition: str, optional
:param data_source: The path specifying the CSV file to use for the population,
:param data_location: The path specifying the CSV file to use for the population,
defaults to `None`.
:type data_source: str, optional
:type data_location: str, optional
"""
# Update the data source to use the CSV file
new_data_source = VTSource(
issn_subject_codes_table, data_source, self.sample
issn_subject_codes_table, data_location, self.sample
)
self.data_source = new_data_source

Expand Down
2 changes: 1 addition & 1 deletion tests/data_sources/test_issn_subject_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def mock_serial_search_results(self, query, view):

with patch.object(SerialSearch, '__init__', mock_serial_search_init), \
patch.object(SerialSearch, 'results', new_callable=MagicMock, return_value=mock_serial_search_results):
cls.issn_subject_codes.populate(DATABASE_PATH, data_source=INPUT_FILE_PATH)
cls.issn_subject_codes.populate(DATABASE_PATH, data_location=INPUT_FILE_PATH)

@classmethod
def tearDownClass(cls):
Expand Down

0 comments on commit 03b0beb

Please sign in to comment.