Skip to content

Commit

Permalink
match axes.legend signature behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
theOehrly committed Jul 24, 2024
1 parent c1c763c commit c93beef
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions fastf1/plotting/_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import warnings
from typing import (
Any,
Dict,
Expand All @@ -10,6 +11,7 @@
)

import matplotlib.axes
import matplotlib.legend

from fastf1.core import Session
from fastf1.internals.fuzzy import fuzzy_matcher
Expand Down Expand Up @@ -729,6 +731,10 @@ def add_sorted_driver_legend(
``ax.legend()`` method. It can only be used when driver names or driver
abbreviations are used as labels for the legend.
This function supports the same ``*args`` and ``**kwargs`` as
Matplotlib's ``ax.legend()``, including the ``handles`` and ``labels``
arguments. Check the Matplotlib documentation for more information.
There is no particular need to use this function except to make the
legend more visually pleasing.
Expand All @@ -746,7 +752,16 @@ def add_sorted_driver_legend(
"""
dtm = _get_driver_team_mapping(session)
handles, labels = ax.get_legend_handles_labels()

try:
handles, labels, kwargs \
= matplotlib.legend._parse_legend_args([ax], *args, **kwargs)
except AttributeError:
warnings.warn("Failed to parse optional legend arguments correctly.",
UserWarning)
kwargs.pop('handles', None)
kwargs.pop('labels', None)
handles, labels = ax.get_legend_handles_labels()

teams_list = list(dtm.teams_by_normalized.values())
driver_list = list(dtm.drivers_by_normalized.values())
Expand All @@ -771,14 +786,11 @@ def add_sorted_driver_legend(

handles_new = list()
labels_new = list()
seen_labels = set()
for elem in ref:
if elem[3] not in seen_labels:
handles_new.append(elem[2])
labels_new.append(elem[3])
seen_labels.add(elem[3])
handles_new.append(elem[2])
labels_new.append(elem[3])

return ax.legend(handles_new, labels_new, *args, **kwargs)
return ax.legend(handles_new, labels_new, **kwargs)


def set_default_colormap(colormap: str):
Expand Down

0 comments on commit c93beef

Please sign in to comment.