Skip to content

Commit

Permalink
add train calendar and train detail by id methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lzgirlcat committed Oct 13, 2024
1 parent cc48c58 commit bc32c82
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
51 changes: 47 additions & 4 deletions koleo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .api import KoleoAPI
from .storage import DEFAULT_CONFIG_PATH, Storage
from .types import ExtendedBaseStationInfo, TrainDetailResponse, TrainOnStationInfo
from .types import ExtendedBaseStationInfo, TrainDetailResponse, TrainOnStationInfo, TrainCalendar
from .utils import RemainderString, arr_dep_to_dt, convert_platform_number, name_to_slug, parse_datetime


Expand Down Expand Up @@ -101,7 +101,7 @@ def find_station(self, query: str | None):
f"[bold blue][link=https://koleo.pl/dworzec-pkp/{st["name_slug"]}]{st["name"]}[/bold blue] ID: {st["id"]}[/link]"
)

def train_info(self, brand: str, name: str, date: datetime):
def get_train_calendars(self, brand: str, name: str) -> list[TrainCalendar]:
brand = brand.upper().strip()
name_parts = name.split(" ")
if len(name_parts) == 1 and name_parts[0].isnumeric():
Expand All @@ -126,21 +126,47 @@ def train_info(self, brand: str, name: str, date: datetime):
except self.client.errors.KoleoNotFound:
self.print(f'[bold red]Train not found: nr={number}, name="{train_name}"[/bold red]')
exit(2)
train_id = train_calendars["train_calendars"][0]["date_train_map"][date.strftime("%Y-%m-%d")]
return train_calendars["train_calendars"]

def train_calendar(self, brand: str, name: str):
train_calendars = self.get_train_calendars(brand, name)
brands = self.storage.get_cache("brands") or self.storage.set_cache("brands", self.client.get_brands())
for calendar in train_calendars:
brand = next(iter(i for i in brands if i["id"] == calendar["trainBrand"]), {}).get("logo_text", "")
parts = [f"[red]{brand}[/red] [bold blue]{calendar['train_nr']}{" "+ v if (v:=calendar.get("train_name")) else ""}[/bold blue]:"]
for k, v in calendar["date_train_map"].items():
parts.append(f" [bold green]{k}[/bold green]: [purple]{v}[/purple]")
self.print("\n".join(parts))

def train_info(self, brand: str, name: str, date: datetime):
train_calendars = self.get_train_calendars(brand, name)
if not (train_id:=train_calendars[0]["date_train_map"].get(date.strftime("%Y-%m-%d"))):
self.print(f"[bold red]This train doesn't run on the selected date: {date.strftime("%Y-%m-%d")}[/bold red]")
exit(2)
self.train_detail(train_id)

def train_detail(self, train_id: int):
train_details = self.client.get_train(train_id)
brands = self.storage.get_cache("brands") or self.storage.set_cache("brands", self.client.get_brands())
brand = next(iter(i for i in brands if i["id"] == train_details["train"]["brand_id"]), {}).get("logo_text", "")

parts = [f"[red]{brand}[/red] [bold blue]{train_details["train"]["train_full_name"]}[/bold blue]"]
parts.append(f" {train_details["train"]["run_desc"]}")

route_start = arr_dep_to_dt(train_details["stops"][0]["departure"])
route_end = arr_dep_to_dt(train_details["stops"][-1]["arrival"])

if route_end.hour < route_start.hour or (
route_end.hour == route_start.hour and route_end.minute < route_end.minute
):
route_end += timedelta(days=1)

travel_time = route_end - route_start
speed = train_details["stops"][-1]["distance"] / 1000 / travel_time.seconds * 3600
parts.append(
f"[white] {travel_time.seconds//3600}h{(travel_time.seconds % 3600)/60:.0f}m {speed:^4.1f}km/h [/white]"
)

vehicle_types: dict[str, str] = {
stop["station_display_name"]: stop["vehicle_type"]
for stop in train_details["stops"]
Expand Down Expand Up @@ -333,7 +359,7 @@ def main():
train_route = subparsers.add_parser(
"trainroute",
aliases=["r", "tr", "t", "poc", "pociąg"],
help="Allows you to show the train's route",
help="Allows you to check the train's route",
)
train_route.add_argument("brand", help="The brand name", type=str)
train_route.add_argument("name", help="The train name", nargs="+", action=RemainderString)
Expand All @@ -346,6 +372,23 @@ def main():
)
train_route.set_defaults(func=cli.train_info, pass_=["brand", "name", "date"])

train_calendar = subparsers.add_parser(
"traincalendar",
aliases=["kursowanie", "tc", "k"],
help="Allows you to check what days the train runs on",
)
train_calendar.add_argument("brand", help="The brand name", type=str)
train_calendar.add_argument("name", help="The train name", nargs="+", action=RemainderString)
train_calendar.set_defaults(func=cli.train_calendar, pass_=["brand", "name"])

train_detail = subparsers.add_parser(
"traindetail",
aliases=["td", "tid", "id", "idpoc"],
help="Allows you to show the train's route given it's koleo ID",
)
train_detail.add_argument("train_id", help="The koleo ID", type=int)
train_detail.set_defaults(func=cli.train_detail, pass_=["train_id"])

stations = subparsers.add_parser(
"stations", aliases=["s", "find", "f", "stacje", "ls", "q"], help="Allows you to find stations by their name"
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_requirements_file(path):

setuptools.setup(
name="koleo-cli",
version="0.2.137.8",
version="0.2.137.9",
description="Koleo CLI",
long_description=long_description(),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit bc32c82

Please sign in to comment.