diff --git a/fastf1/core.py b/fastf1/core.py index eab4846c3..5699144c1 100644 --- a/fastf1/core.py +++ b/fastf1/core.py @@ -44,7 +44,7 @@ from functools import cached_property import warnings import typing -from typing import Optional, List, Literal, Iterable, Union, Tuple, Any +from typing import Optional, List, Literal, Iterable, Union, Tuple, Any, Type import numpy as np import pandas as pd @@ -52,6 +52,7 @@ import fastf1 from fastf1 import _api as api from fastf1 import ergast +from fastf1.internals.pandas_base import BaseDataFrame, BaseSeries from fastf1.livetiming.data import LiveTimingData from fastf1.mvapi import get_circuit_info, CircuitInfo from fastf1.logger import get_logger, soft_exceptions @@ -212,10 +213,7 @@ def __init__(self, @property def _constructor(self): - def _new(*args, **kwargs): - return Telemetry(*args, **kwargs).__finalize__(self) - - return _new + return Telemetry @property def base_class_view(self): @@ -1689,7 +1687,8 @@ def _fix_missing_laps_retired_on_track(self): }) # add generated laps at the end and fix sorting at the end - self._laps = pd.concat([self._laps, new_last]) + self._laps = (pd.concat([self._laps, new_last]) + .__finalize__(self._laps)) any_new = True if any_new: @@ -2394,7 +2393,7 @@ def _calculate_t0_date(self, *tel_data_sets: dict): self._t0_date = date_offset.round('ms') -class Laps(pd.DataFrame): +class Laps(BaseDataFrame): """Object for accessing lap (timing) data of multiple laps. Args: @@ -2548,8 +2547,7 @@ class Laps(pd.DataFrame): } _metadata = ['session'] - _internal_names = pd.DataFrame._internal_names \ - + ['base_class_view', 'telemetry'] + _internal_names = BaseDataFrame._internal_names + ['telemetry'] _internal_names_set = set(_internal_names) QUICKLAP_THRESHOLD = 1.07 @@ -2588,30 +2586,8 @@ def __init__(self, self.session = session @property - def _constructor(self): - def _new(*args, **kwargs): - return Laps(*args, **kwargs).__finalize__(self) - - return _new - - @property - def _constructor_sliced(self): - def _new(*args, **kwargs): - name = kwargs.get('name') - if name and (name in self.columns): - # vertical slice - return pd.Series(*args, **kwargs).__finalize__(self) - - # horizontal slice - return Lap(*args, **kwargs).__finalize__(self) - - return _new - - @property - def base_class_view(self): - """For a nicer debugging experience; can now view as - dataframe in various IDEs""" - return pd.DataFrame(self) + def _constructor_sliced_horizontal(self) -> Type["Lap"]: + return Lap @cached_property def telemetry(self) -> Telemetry: @@ -3247,7 +3223,7 @@ def iterlaps(self, require: Optional[Iterable] = None) \ yield index, lap -class Lap(pd.Series): +class Lap(BaseSeries): """ Object for accessing lap (timing) data of a single lap. @@ -3256,19 +3232,9 @@ class Lap(pd.Series): telemetry data. """ _metadata = ['session'] - _internal_names = pd.Series._internal_names + ['telemetry'] + _internal_names = BaseSeries._internal_names + ['telemetry'] _internal_names_set = set(_internal_names) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def _constructor(self): - def _new(*args, **kwargs): - return Lap(*args, **kwargs).__finalize__(self) - - return _new - @cached_property def telemetry(self) -> Telemetry: """Telemetry data for this lap @@ -3429,7 +3395,7 @@ def get_weather_data(self) -> pd.Series: return pd.Series(index=self.session.weather_data.columns) -class SessionResults(pd.DataFrame): +class SessionResults(BaseDataFrame): """This class provides driver and result information for all drivers that participated in a session. @@ -3565,9 +3531,6 @@ class SessionResults(pd.DataFrame): 'Points': 'float64' } - _internal_names = pd.DataFrame._internal_names + ['base_class_view'] - _internal_names_set = set(_internal_names) - def __init__(self, *args, force_default_cols: bool = False, **kwargs): if force_default_cols: kwargs['columns'] = list(self._COL_TYPES.keys()) @@ -3586,37 +3549,12 @@ def __init__(self, *args, force_default_cols: bool = False, **kwargs): self[col] = self[col].astype(_type) - def __repr__(self): - return self.base_class_view.__repr__() - - @property - def _constructor(self): - def _new(*args, **kwargs): - return SessionResults(*args, **kwargs).__finalize__(self) - - return _new - @property - def _constructor_sliced(self): - def _new(*args, **kwargs): - name = kwargs.get('name') - if name and (name in self.columns): - # vertical slice - return pd.Series(*args, **kwargs).__finalize__(self) - - # horizontal slice - return DriverResult(*args, **kwargs).__finalize__(self) + def _constructor_sliced_horizontal(self) -> Type["DriverResult"]: + return DriverResult - return _new - @property - def base_class_view(self): - """For a nicer debugging experience; can view DataFrame through - this property in various IDEs""" - return pd.DataFrame(self) - - -class DriverResult(pd.Series): +class DriverResult(BaseSeries): """This class provides driver and result information for a single driver. This class subclasses a :class:`pandas.Series` and the usual methods @@ -3635,19 +3573,9 @@ class DriverResult(pd.Series): .. versionadded:: 2.2 """ - _internal_names = pd.DataFrame._internal_names + ['dnf'] + _internal_names = BaseSeries._internal_names + ['dnf'] _internal_names_set = set(_internal_names) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def _constructor(self): - def _new(*args, **kwargs): - return DriverResult(*args, **kwargs).__finalize__(self) - - return _new - @property def dnf(self) -> bool: """True if driver did not finish""" diff --git a/fastf1/ergast/interface.py b/fastf1/ergast/interface.py index aa0a17286..1cd6d1432 100644 --- a/fastf1/ergast/interface.py +++ b/fastf1/ergast/interface.py @@ -1,15 +1,13 @@ import copy import json -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Type, Union from fastf1.req import Cache import fastf1.ergast.structure as API +from fastf1.internals.pandas_base import BaseDataFrame, BaseSeries from fastf1.version import __version__ -import pandas as pd - - BASE_URL = 'https://ergast.com/api/f1' HEADERS = {'User-Agent': f'FastF1/{__version__}'} @@ -101,7 +99,7 @@ def get_prev_result_page(self) -> Union['ErgastSimpleResponse', ) -class ErgastResultFrame(pd.DataFrame): +class ErgastResultFrame(BaseDataFrame): """ Wraps a Pandas ``DataFrame``. Additionally, this class can be initialized from Ergast response data with automatic flattening and type @@ -117,7 +115,7 @@ class ErgastResultFrame(pd.DataFrame): auto_cast: Determines if values are automatically cast to the most appropriate data type from their original string representation """ - _internal_names = pd.DataFrame._internal_names + ['base_class_view'] + _internal_names = BaseDataFrame._internal_names + ['base_class_view'] _internal_names_set = set(_internal_names) def __init__(self, data=None, *, @@ -164,48 +162,17 @@ def _flatten_element(cls, nested: dict, category: dict, cast: bool): return nested, flat @property - def _constructor(self): - def _new(*args, **kwargs): - return ErgastResultFrame(*args, **kwargs).__finalize__(self) - - return _new - - @property - def _constructor_sliced(self): - def _new(*args, **kwargs): - name = kwargs.get('name') - if name and (name in self.columns): - # vertical slice - return pd.Series(*args, **kwargs).__finalize__(self) - - # horizontal slice - return ErgastResultSeries(*args, **kwargs).__finalize__(self) - - return _new - - @property - def base_class_view(self): - """For a nicer debugging experience; can view DataFrame through - this property in various IDEs""" - return pd.DataFrame(self) + def _constructor_sliced_horizontal(self): + return ErgastResultSeries -class ErgastResultSeries(pd.Series): +class ErgastResultSeries(BaseSeries): """ Wraps a Pandas ``Series``. Currently, no extra functionality is implemented. """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def _constructor(self): - def _new(*args, **kwargs): - return ErgastResultSeries(*args, **kwargs).__finalize__(self) - - return _new + pass class ErgastRawResponse(ErgastResponseMixin, list): @@ -274,6 +241,13 @@ class ErgastSimpleResponse(ErgastResponseMixin, ErgastResultFrame): + ErgastResponseMixin._internal_names _internal_names_set = set(_internal_names) + @property + def _constructor(self) -> Type["ErgastResultFrame"]: + # drop from ErgastSimpleResponse to ErgastResultFrame, removing the + # ErgastResponseMixin because a slice of the data is no longer a full + # response and pagination, ... is therefore not supported anymore + return ErgastResultFrame + class ErgastMultiResponse(ErgastResponseMixin): """ diff --git a/fastf1/events.py b/fastf1/events.py index 473bd633a..0cec722d6 100644 --- a/fastf1/events.py +++ b/fastf1/events.py @@ -163,7 +163,7 @@ import datetime import json import warnings -from typing import Literal, Union, Optional +from typing import Callable, Literal, Union, Optional import dateutil.parser @@ -180,6 +180,7 @@ import fastf1._api from fastf1.core import Session import fastf1.ergast +from fastf1.internals.pandas_base import BaseSeries, BaseDataFrame from fastf1.logger import get_logger, soft_exceptions from fastf1.req import Cache from fastf1.utils import recursive_dict_get, to_datetime, to_timedelta @@ -735,7 +736,7 @@ def _get_schedule_from_ergast(year) -> "EventSchedule": return schedule -class EventSchedule(pd.DataFrame): +class EventSchedule(BaseDataFrame): """This class implements a per-season event schedule. For detailed information about the information that is available for each @@ -784,9 +785,6 @@ class EventSchedule(pd.DataFrame): _metadata = ['year'] - _internal_names = pd.DataFrame._internal_names + ['base_class_view'] - _internal_names_set = set(_internal_names) - def __init__(self, *args, year: int = 0, force_default_cols: bool = False, **kwargs): if force_default_cols: @@ -807,28 +805,9 @@ def __init__(self, *args, year: int = 0, self[col] = _type() self[col] = self[col].astype(_type) - def __repr__(self): - return self.base_class_view.__repr__() - - @property - def _constructor(self): - def _new(*args, **kwargs): - return EventSchedule(*args, **kwargs).__finalize__(self) - - return _new - @property - def _constructor_sliced(self): - def _new(*args, **kwargs): - return Event(*args, **kwargs).__finalize__(self) - - return _new - - @property - def base_class_view(self): - """For a nicer debugging experience; can view DataFrame through - this property in various IDEs""" - return pd.DataFrame(self) + def _constructor_sliced_horizontal(self) -> Callable[..., "Event"]: + return Event def is_testing(self): """Return `True` or `False`, depending on whether each event is a @@ -934,7 +913,7 @@ def get_event_by_name( return self._fuzzy_event_search(name) -class Event(pd.Series): +class Event(BaseSeries): """This class represents a single event (race weekend or testing event). Each event consists of one or multiple sessions, depending on the type @@ -955,13 +934,6 @@ def __init__(self, *args, year: int = None, **kwargs): super().__init__(*args, **kwargs) self.year = year - @property - def _constructor(self): - def _new(*args, **kwargs): - return Event(*args, **kwargs).__finalize__(self) - - return _new - def is_testing(self) -> bool: """Return `True` or `False`, depending on whether this event is a testing event.""" diff --git a/fastf1/internals/pandas_base.py b/fastf1/internals/pandas_base.py new file mode 100644 index 000000000..c1a16899d --- /dev/null +++ b/fastf1/internals/pandas_base.py @@ -0,0 +1,105 @@ +"""Base classes for objects that inherit form Pandas Series or DataFrame.""" +from typing import final, Optional, Type + +import pandas as pd + + +class BaseDataFrame(pd.DataFrame): + """Base class for objects that inherit from Pandas DataFrame. + + A same-dimensional slice of an object that inherits from this class will + be of equivalent type (instead of being a Pandas DataFrame). + + A one-dimensional slice of an object that inherits from this class can + be of different type, depending on whether the DataFrame-like object was + sliced vertically or horizontally. For this, the additional properties + ``_constructor_sliced_horizontal`` and ``_constructor_sliced_vertical`` are + introduced to extend the functionality that is provided by Pandas' + ``_constructor_sliced`` property. Both properties are set to + ``pandas.Series`` by default and only need to be overwritten when + necessary. + """ + _internal_names = pd.DataFrame._internal_names + ['base_class_view'] + _internal_names_set = set(_internal_names) + + def __repr__(self) -> str: + return self.base_class_view.__repr__() + + @property + def _constructor(self) -> Type["BaseDataFrame"]: + # by default, use the customized class as a constructor, i.e. all + # classes that inherit from this base class will always use themselves + # as a constructor + return self.__class__ + + @final + @property + def _constructor_sliced(self) -> Type[pd.Series]: + # dynamically create a subclass of _BaseSeriesConstructor that + # has a reference to this self (i.e. the object from which the slice + # is created) as a class property + # type(...) returns a new subclass of a Series + return type('_DynamicBaseSeriesConstructor', # noqa: return type + (_BaseSeriesConstructor,), + {'__meta_created_from': self}) + + @property + def _constructor_sliced_horizontal(self) -> Type[pd.Series]: + return pd.Series + + @property + def _constructor_sliced_vertical(self) -> Type[pd.Series]: + return pd.Series + + @property + def base_class_view(self) -> pd.DataFrame: + """For a nicer debugging experience; can view DataFrame through + this property in various IDEs""" + return pd.DataFrame(self) + + +class _BaseSeriesConstructor(pd.Series): + """ + Base class for an intermediary and dynamically defined constructor + class that implements horizontal and vertical slicing of Pandas DataFrames + with different result objects types. + + This class is never seen by the user. It is never fully instantiated + because it always returns an instance of a class that does not derive + from this class in its __new__ method. + """ + + __meta_created_from: Optional[BaseDataFrame] + + def __new__(cls, data=None, index=None, *args, **kwargs) -> pd.Series: + parent = getattr(cls, '__meta_created_from') + + if index is None: + # no index is explicitly given, try to get an index from the + # data itself (for example, if `data` is a BlockManager) + index = getattr(data, 'index', None) + + if (parent is None) or (index is None): + # do "conventional" slicing and return a pd.Series + constructor = pd.Series + + elif parent.index is index: + # our index matches the parent index, therefore, the data is + # a column of the parent DataFrame + constructor = parent._constructor_sliced_vertical + else: + # the data is a row of the parent DataFrame + constructor = parent._constructor_sliced_horizontal + + return constructor(data=data, index=index, *args, **kwargs) + + +class BaseSeries(pd.Series): + """Base class for objects that inherit from Pandas Series. + + A same-dimensional slice of an object that inherits from this class will + be of equivalent type (instead of being a Pandas Series). + """ + @property + def _constructor(self) -> Type[pd.Series]: + return self.__class__ diff --git a/fastf1/tests/test_events.py b/fastf1/tests/test_events.py index 010d79121..5969184f8 100644 --- a/fastf1/tests/test_events.py +++ b/fastf1/tests/test_events.py @@ -238,3 +238,23 @@ def test_event_get_nonexistent_session_date(): event = fastf1.get_event(2020, 13) with pytest.raises(ValueError, match="does not exist"): event.get_session_date('FP2') + + +def test_events_constructors(): + frame = fastf1.events.EventSchedule({'RoundNumber': [1, 2, 3], + 'Country': ['a', 'b', 'c']}) + + # test slicing to frame + assert isinstance(frame.iloc[1:], fastf1.events.EventSchedule) + + # test horizontal slicing + assert isinstance(frame.iloc[0], fastf1.events.Event) + assert isinstance(frame.iloc[0], pd.Series) + + # test vertical slicing + assert not isinstance(frame.loc[:, 'Country'], fastf1.events.Event) + assert isinstance(frame.loc[:, 'Country'], pd.Series) + + # test base class view + assert isinstance(frame.base_class_view, pd.DataFrame) + assert not isinstance(frame.base_class_view, fastf1.events.EventSchedule) diff --git a/fastf1/tests/test_internals.py b/fastf1/tests/test_internals.py index 184944fcc..70aab9231 100644 --- a/fastf1/tests/test_internals.py +++ b/fastf1/tests/test_internals.py @@ -1,3 +1,4 @@ +from fastf1.internals.pandas_base import BaseDataFrame, BaseSeries from fastf1.internals.pandas_extensions import _unsafe_create_df_fast import numpy as np @@ -18,3 +19,109 @@ def test_fast_df_creation(): ) pd.testing.assert_frame_equal(df_safe, df_fast) + + +def test_base_frame_slicing(): + class TestSeriesVertical(pd.Series): + pass + + class TestSeriesHorizontal(BaseSeries): + pass + + class TestDataFrame(BaseDataFrame): + @property + def _constructor_sliced_vertical(self): + return TestSeriesVertical + + @property + def _constructor_sliced_horizontal(self): + return TestSeriesHorizontal + + df = TestDataFrame({'A': [10, 11, 12], 'B': [20, 21, 22]}) + assert isinstance(df, TestDataFrame) + assert isinstance(df, pd.DataFrame) + + df_sliced = df.iloc[0:2] + assert isinstance(df_sliced, TestDataFrame) + assert isinstance(df_sliced, pd.DataFrame) + assert (df_sliced + == pd.DataFrame({'A': [10, 11], 'B': [20, 21]}) + ).all().all() + + vert_ser = df.loc[:, 'A'] + assert isinstance(vert_ser, TestSeriesVertical) + assert isinstance(vert_ser, pd.Series) + assert (vert_ser == pd.Series([10, 11, 12])).all() + + hor_ser = df.iloc[0] + assert isinstance(hor_ser, TestSeriesHorizontal) + assert isinstance(hor_ser, pd.Series) + assert (hor_ser == pd.Series({'A': 10, 'B': 20})).all() + + # iterrows initializes row series from ndarray not blockmanager + for _, row in df.iterrows(): + assert isinstance(row, TestSeriesHorizontal) + assert isinstance(row, pd.Series) + + +def test_base_series_slicing(): + class TestSeries(BaseSeries): + pass + + series = TestSeries([0, 1, 2, 3]) + ser_sliced = series.iloc[0:2] + assert (ser_sliced == pd.Series([0, 1])).all() + assert isinstance(ser_sliced, BaseSeries) + assert isinstance(ser_sliced, pd.Series) + + +def test_base_frame_metadata_propagation(): + class TestSeriesHorizontal(BaseSeries): + _metadata = ['some_value'] + + class TestSeriesVertical(BaseSeries): + pass + + class TestDataFrame(BaseDataFrame): + _metadata = ['some_value'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.some_value = None + + @property + def _constructor_sliced_horizontal(self): + return TestSeriesHorizontal + + @property + def _constructor_sliced_vertical(self): + return TestSeriesVertical + + df = TestDataFrame({'A': [10, 11, 12], 'B': [20, 21, 22]}) + df.some_value = 100 + + # propagation to dataframe slice + df_sliced = df.iloc[0:2] + assert df_sliced.some_value == 100 + + # no propagation to a series object that does not define this metadata + vert_slice = df.loc[:, 'A'] + assert not hasattr(vert_slice, 'some_value') + + # propagation to a series object that does define the same metadata + hor_slice = df.iloc[0] + assert hor_slice.some_value == 100 + + # iterrows initializes row series from ndarray not blockmanager + for _, row in df.iterrows(): + assert row.some_value == 100 + + +def test_base_series_metadata_propagation(): + class TestSeries(BaseSeries): + _metadata = ['some_value'] + + series = TestSeries([0, 1, 2, 3]) + series.some_value = 100 + ser_sliced = series.iloc[0:2] + assert ser_sliced.some_value == 100