Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dkling-reply committed Dec 7, 2023
1 parent b4b5a20 commit 13af693
Show file tree
Hide file tree
Showing 46 changed files with 783 additions and 358 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ Black reformats the code, isort orders the imports and flake8 checks for remaini
Example usage:

```bash
isort -rc -sl .
autoflake --remove-all-unused-imports -i -r --exclude alembic .
isort -rc -m 3 .
isort --force-single-line-imports .
autoflake --remove-all-unused-imports -i -r --exclude ./alembic .
# Note: '3' means 3-vert-hanging multiline imports
isort --multi-line 3 .
```
17 changes: 14 additions & 3 deletions charging_stations_pipelines/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
"""This module contains the base class for all models."""
from typing import Any, Optional

from sqlalchemy import MetaData
from sqlalchemy.ext.declarative import declarative_base
Expand All @@ -7,6 +8,11 @@


class BaseWithSafeSetProperty:
"""This class is a base class for all models that prevents accidental usage of non-existing class attributes."""

def __init__(self):
self.id: Optional[int] = None

def __setattr__(self, name: str, value: Any) -> None:
"""This method sets the value of the specified attribute and prevents accidental usage of non-existing class
attributes.
Expand All @@ -18,11 +24,16 @@ def __setattr__(self, name: str, value: Any) -> None:
:return: None.
"""
if not (name.startswith("_") or hasattr(self, name)):
raise AttributeError(f"Cannot set non-existing attribute '{name}' on class '{self.__class__.__name__}'.")
raise AttributeError(
f"Cannot set non-existing attribute '{name}' on class '{self.__class__.__name__}'."
)
super().__setattr__(name, value)

def __repr__(self):
return f"<{self.__class__.__name__} with id: {self.id}>"


Base = declarative_base(cls=BaseWithSafeSetProperty, metadata=MetaData(schema=settings.db_schema))
Base = declarative_base(
cls=BaseWithSafeSetProperty, metadata=MetaData(schema=settings.db_schema)
)
"""The base class for all models."""
4 changes: 2 additions & 2 deletions charging_stations_pipelines/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
class Pipeline:
"""Base class for data processing pipelines."""

def __init__(self, config: configparser, session: Session, online=False):
def __init__(self, config: configparser, session: Optional[Session], online=False):
self.config = config
self.session = session
self.online = online

self.data: Optional[Union[pd.DataFrame, JSON]] = None

def _retrieve_data(self):
def retrieve_data(self):
"""Retrieves the data from the data source."""
raise NotImplementedError

Expand Down
7 changes: 5 additions & 2 deletions charging_stations_pipelines/pipelines/at/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""The E-Control data source provides information about charging stations in Austria."""

"""The AT package contains the pipelines for the Austrian data source."""
from typing import Final

DATA_SOURCE_KEY: Final[str] = 'AT_ECONTROL'
"""The data source key for the e-control data source."""

SCOPE_COUNTRIES: Final[list[str]] = ['AT']
"""The list of country codes covered by the e-control data source."""
88 changes: 69 additions & 19 deletions charging_stations_pipelines/pipelines/at/econtrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- 'Address': Contains address information for the charging station.
- 'Charging': Provides the charging details for a particular charging station.
"""

import collections
import configparser
import logging
import os
Expand All @@ -15,10 +15,16 @@
from tqdm import tqdm

from charging_stations_pipelines.pipelines import Pipeline
from charging_stations_pipelines.pipelines.at import DATA_SOURCE_KEY
from charging_stations_pipelines.pipelines.at import DATA_SOURCE_KEY, SCOPE_COUNTRIES
from charging_stations_pipelines.pipelines.at.econtrol_crawler import get_data
from charging_stations_pipelines.pipelines.at.econtrol_mapper import map_address, map_charging, map_station
from charging_stations_pipelines.pipelines.station_table_updater import StationTableUpdater
from charging_stations_pipelines.pipelines.at.econtrol_mapper import (
map_address,
map_charging,
map_station,
)
from charging_stations_pipelines.pipelines.station_table_updater import (
StationTableUpdater,
)

logger = logging.getLogger(__name__)

Expand All @@ -41,11 +47,15 @@ def __init__(self, config: configparser, session: Session, online: bool = False)
super().__init__(config, session, online)

relative_dir = os.path.join("../../..", "data")
self.data_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), relative_dir)
self.data_dir = os.path.join(
pathlib.Path(__file__).parent.resolve(), relative_dir
)

def _retrieve_data(self):
pathlib.Path(self.data_dir).mkdir(parents=True, exist_ok=True)
tmp_data_path = os.path.join(self.data_dir, self.config[DATA_SOURCE_KEY]["filename"])
tmp_data_path = os.path.join(
self.data_dir, self.config[DATA_SOURCE_KEY]["filename"]
)
if self.online:
logger.info("Retrieving Online Data")
get_data(tmp_data_path)
Expand All @@ -54,33 +64,73 @@ def _retrieve_data(self):
self.data = pd.read_json(tmp_data_path, lines=True) # pd.DataFrame

def run(self):
""" Runs the pipeline for a data source.
"""Runs the pipeline for a data source.
Retrieves data, processes it, and updates the Station table (as well as Address and Charging tables).
:return: None
"""
logger.info(f"Running {DATA_SOURCE_KEY} Pipeline...")
self._retrieve_data()
station_updater = StationTableUpdater(session=self.session, logger=logger)
count_imported_stations, count_empty_stations, count_invalid_stations = 0, 0, 0
for _, datapoint in tqdm(iterable=self.data.iterrows(), total=self.data.shape[0]): # type: _, pd.Series

stats = collections.defaultdict(int)
datapoint: pd.Series
for _, datapoint in tqdm(
iterable=self.data.iterrows(), total=self.data.shape[0]
):
try:
station = map_station(datapoint)
if not station or not station.source_id or not station.point:
count_empty_stations += 1
# Filter out stations with country codes that are not in the scope of the pipeline
if station.country_code not in SCOPE_COUNTRIES:
stats['count_country_mismatch_stations'] += 1
logger.debug(
f"Skipping {DATA_SOURCE_KEY} entry due to invalid country code in Station:"
f" {station.country_code}.\n"
f"Row:\n----\n{datapoint}\n----\n"
)
continue

# Address mapping
station.address = map_address(datapoint, None)
# Filter out stations which have an invalid address
if (
station.address
and station.address.country
and station.address.country not in SCOPE_COUNTRIES
):
stats['count_country_mismatch_stations'] += 1
logger.debug(
f"Skipping {DATA_SOURCE_KEY} entry due to invalid country code in Address: "
f"{station.address.country}.\n"
f"Row:\n----\n{datapoint}\n----\n"
)
continue

# Filter out stations which have a mismatching country code between Station and Address
if station.country_code != station.address.country:
stats['count_country_mismatch_stations'] += 1
logger.debug(
f"Skipping {DATA_SOURCE_KEY} entry due to "
f"mismatching country codes between Station and Address: "
f"{station.country_code} != {station.address.country}.\n"
f"Row:\n----\n{datapoint}\n----\n"
)
continue

station.charging = map_charging(datapoint, None)

count_imported_stations += 1
stats['count_valid_stations'] += 1
station_updater.update_station(station, DATA_SOURCE_KEY)
except Exception as e:
count_invalid_stations += 1
stats['count_parse_error'] += 1
logger.debug(
f"{DATA_SOURCE_KEY} entry could not be mapped! Error:\n{e}\nRow:\n----\n{datapoint}\n----\n")
continue
station_updater.update_station(station, DATA_SOURCE_KEY)
logger.info(f"Finished {DATA_SOURCE_KEY} Pipeline, "
f"new stations imported: {count_imported_stations}, empty stations: {count_empty_stations}, "
f"stations which could not be parsed: {count_invalid_stations}.")
f"{DATA_SOURCE_KEY} entry could not be parsed, error:\n{e}\n"
f"Row:\n----\n{datapoint}\n----\n"
)
logger.info(
f"Finished {DATA_SOURCE_KEY} Pipeline:\n"
f"1. New stations imported: {stats['count_valid_stations']}\n"
f"2. Not parseable: {stats['count_parse_error']}\n"
f"3. Wrong country code stations: {stats['count_country_mismatch_stations']}."
)
station_updater.log_update_station_counts()
10 changes: 5 additions & 5 deletions charging_stations_pipelines/pipelines/at/econtrol_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import requests

from charging_stations_pipelines.pipelines.at import DATA_SOURCE_KEY
from charging_stations_pipelines.pipelines import at

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,9 +37,9 @@ def _get_paginated_stations(url: str, headers: dict[str, str] = None) -> Generat

if total_count <= page_size:
# No pagination needed
return
yield

num_pages = total_count // page_size + total_count % page_size
num_pages = total_count // page_size + (1 if total_count % page_size else 0)
for page_num in range(2, num_pages + 1):
idx_start = page_size * (page_num - 1)
idx_end = min(page_size * page_num - 1, total_count - 1)
Expand All @@ -66,7 +66,7 @@ def get_data(tmp_data_path):
headers = {'Authorization': f"Basic {os.getenv('ECONTROL_AT_AUTH')}", 'User-Agent': 'Mozilla/5.0'}
logger.debug(f'Using HTTP headers:\n{headers}')

logger.info(f"Downloading {DATA_SOURCE_KEY} data from {url}...")
logger.info(f"Downloading {at.DATA_SOURCE_KEY} data from {url}...")
with open(tmp_data_path, 'w') as f:
for page in _get_paginated_stations(url, headers):
logger.debug(f"Getting data: {page['fromIndex']}..{page['endIndex']}")
Expand All @@ -75,5 +75,5 @@ def get_data(tmp_data_path):
for station in page['stations']:
json.dump(station, f, ensure_ascii=False)
f.write('\n')
logger.info(f"Downloaded {DATA_SOURCE_KEY} data to: {tmp_data_path}")
logger.info(f"Downloaded {at.DATA_SOURCE_KEY} data to: {tmp_data_path}")
logger.info(f"Downloaded file size: {os.path.getsize(tmp_data_path)} bytes")
9 changes: 4 additions & 5 deletions charging_stations_pipelines/pipelines/at/econtrol_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from charging_stations_pipelines.models.address import Address
from charging_stations_pipelines.models.charging import Charging
from charging_stations_pipelines.models.station import Station
from charging_stations_pipelines.pipelines.at import DATA_SOURCE_KEY
from charging_stations_pipelines.pipelines import at
from charging_stations_pipelines.shared import check_coordinates, lst_expand, lst_flatten, \
str_strip_whitespace, str_to_float, try_remove_dupes

Expand All @@ -23,7 +23,7 @@
def _aggregate_attribute(points: pd.Series, attr: str) -> list[list[T]]:
attr_list_agg: Final[list[list[T]]] = []
for p in points:
attr_vals: list = p.get(attr, [])
attr_vals: list[str] = p.get(attr, [])
attr_list_agg.append(attr_vals)

return attr_list_agg
Expand Down Expand Up @@ -75,6 +75,7 @@ def map_station(row: pd.Series) -> Station:
station = Station()

station.country_code = country_id # should be always 'AT'

# Using combination of evseCountryId, evseStationId and evseOperatorId as source_id,
# since evseStationId alone is not unique enough
station.source_id = f'{country_id}*{operator_id}*{station_id}'
Expand All @@ -85,7 +86,7 @@ def map_station(row: pd.Series) -> Station:
station.operator = str_strip_whitespace(row.get('contactName')) or None
station.payment = None
station.authentication = _extract_auth_modes(row.get("points")) or None
station.data_source = DATA_SOURCE_KEY
station.data_source = at.DATA_SOURCE_KEY
station.point = _extract_location(row.get("location"))
station.raw_data = row.to_json() # Note: is stored in DB as native type 'JSON'
station.date_created = None # Note: is not available in the data source
Expand All @@ -108,8 +109,6 @@ def map_address(row: pd.Series, station_id: Optional[int]) -> Address:
address.town = str_strip_whitespace(row.get("city")) or None
address.postcode = str_strip_whitespace(row.get("postCode")) or None
address.district = None # Note: is not available in the data source
address.district = None # Note: is not available in the data source
address.state = None # Note: is not available in the data source
address.state = None # Note: is not available in the data source
address.country = str_strip_whitespace(row.get("evseCountryId")) or None

Expand Down
3 changes: 2 additions & 1 deletion charging_stations_pipelines/pipelines/de/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from typing import Final

DATA_SOURCE_KEY: Final[str] = 'BNA'
DATA_SOURCE_KEY: Final[str] = "BNA"
"""The data source key for the BNA (Bundesnetzagentur) data source."""
33 changes: 24 additions & 9 deletions charging_stations_pipelines/pipelines/de/bna.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import os
import pathlib

import pandas as pd
from sqlalchemy.orm import Session
from tqdm import tqdm

import charging_stations_pipelines.pipelines.de as de
from charging_stations_pipelines.pipelines import Pipeline
from charging_stations_pipelines.pipelines.de import DATA_SOURCE_KEY
from charging_stations_pipelines.pipelines.de.bna_crawler import get_bna_data
from charging_stations_pipelines.pipelines.de.bna_mapper import map_address_bna, map_charging_bna, map_station_bna
from charging_stations_pipelines.pipelines.station_table_updater import StationTableUpdater
from charging_stations_pipelines.pipelines.de.bna_mapper import (
map_address_bna,
map_charging_bna,
map_station_bna,
)
from charging_stations_pipelines.pipelines.station_table_updater import (
StationTableUpdater,
)
from charging_stations_pipelines.shared import load_excel_file

logger = logging.getLogger(__name__)
Expand All @@ -23,29 +30,37 @@ def __init__(self, config: configparser, session: Session, online: bool = False)
super().__init__(config, session, online)

relative_dir = os.path.join("../../..", "data")
self.data_dir = os.path.join(pathlib.Path(__file__).parent.resolve(), relative_dir)
self.data_dir = os.path.join(
pathlib.Path(__file__).parent.resolve(), relative_dir
)

def _retrieve_data(self):
def retrieve_data(self):
pathlib.Path(self.data_dir).mkdir(parents=True, exist_ok=True)
tmp_data_path = os.path.join(self.data_dir, self.config[DATA_SOURCE_KEY]["filename"])
tmp_data_path = os.path.join(
self.data_dir, self.config[de.DATA_SOURCE_KEY]["filename"]
)
if self.online:
logger.info("Retrieving Online Data")
get_bna_data(tmp_data_path)
self.data = load_excel_file(tmp_data_path)

def run(self):
logger.info("Running DE GOV Pipeline...")
self._retrieve_data()
self.retrieve_data()
station_updater = StationTableUpdater(session=self.session, logger=logger)

row: pd.Series
for _, row in tqdm(iterable=self.data.iterrows(), total=self.data.shape[0]):
try:
mapped_address = map_address_bna(row, None)
mapped_charging = map_charging_bna(row, None)

mapped_station = map_station_bna(row)

mapped_station.address = mapped_address
mapped_station.charging = mapped_charging
except Exception as e:
logger.error(f"{DATA_SOURCE_KEY} entry could not be mapped! Error: {e}")
logger.error(f"{de.DATA_SOURCE_KEY} entry could not be mapped! Error: {e}")
continue
station_updater.update_station(mapped_station, DATA_SOURCE_KEY)
station_updater.update_station(mapped_station, de.DATA_SOURCE_KEY)
station_updater.log_update_station_counts()
Loading

0 comments on commit 13af693

Please sign in to comment.