diff --git a/win32/Lib/win32timezone.py b/win32/Lib/win32timezone.py index fbbe722915..81d95d3b11 100644 --- a/win32/Lib/win32timezone.py +++ b/win32/Lib/win32timezone.py @@ -1,5 +1,3 @@ -# -*- coding: UTF-8 -*- - """ win32timezone: Module for handling datetime.tzinfo time zones using the windows @@ -240,27 +238,50 @@ from __future__ import annotations import datetime +import functools import logging import operator import re import struct import winreg from itertools import count -from typing import Dict +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + Mapping, + TypeVar, + overload, +) import win32api +if TYPE_CHECKING: + from _operator import _SupportsComparison + + from _typeshed import SupportsKeysAndGetItem + from typing_extensions import Self + __author__ = "Jason R. Coombs " +_RangeMapKT = TypeVar("_RangeMapKT", bound="_SupportsComparison") + +_T = TypeVar("_T") +_VT = TypeVar("_VT") + log = logging.getLogger(__file__) # A couple of objects for working with objects as if they were native C-type # structures. class _SimpleStruct: - _fields_: list[tuple[str, type]] = [] # must be overridden by subclasses + _fields_: ClassVar[list[tuple[str, type]]] = [] # must be overridden by subclasses - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: for i, (name, typ) in enumerate(self._fields_): def_arg = None if i < len(args): @@ -280,10 +301,10 @@ def __init__(self, *args, **kw): def_val = typ(*def_arg) setattr(self, name, def_val) - def field_names(self): + def field_names(self) -> list[str]: return [f[0] for f in self._fields_] - def __eq__(self, other): + def __eq__(self, other) -> bool: if not hasattr(other, "_fields_"): return False if self._fields_ != other._fields_: @@ -293,7 +314,7 @@ def __eq__(self, other): return False return True - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @@ -338,11 +359,11 @@ class TimeZoneDefinition(DYNAMIC_TIME_ZONE_INFORMATION): additional bias applies (standard_bias and daylight_bias). """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """ >>> test_args = [1] * 44 - Try to construct a TimeZoneDefinition from + Try to construct a TimeZoneDefinition from: a) [DYNAMIC_]TIME_ZONE_INFORMATION args >>> TimeZoneDefinition(*test_args).bias @@ -405,7 +426,7 @@ def __init_from_bytes( daylight_disabled, ) - def __init_from_other(self, other): + def __init_from_other(self, other: TIME_ZONE_INFORMATION) -> None: if not isinstance(other, TIME_ZONE_INFORMATION): raise TypeError("Not a TIME_ZONE_INFORMATION") for name in other.field_names(): @@ -416,34 +437,44 @@ def __init_from_other(self, other): # size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) # ctypes.memmove(ctypes.addressof(self), other, size) - def __getattribute__(self, attr): + if TYPE_CHECKING: + # TIME_ZONE_INFORMATION fields as obtained by __getattribute__ + bias: datetime.timedelta + standard_name: str + standard_start: SYSTEMTIME + standard_bias: datetime.timedelta + daylight_name: str + daylight_start: SYSTEMTIME + daylight_bias: datetime.timedelta + + def __getattribute__(self, attr: str) -> Any: value = super().__getattribute__(attr) if "bias" in attr: value = datetime.timedelta(minutes=value) return value @classmethod - def current(class_): + def current(cls): "Windows Platform SDK GetTimeZoneInformation" code, tzi = win32api.GetTimeZoneInformation(True) - return code, class_(*tzi) + return code, cls(*tzi) - def set(self): + def set(self) -> None: tzi = tuple(getattr(self, n) for n, t in self._fields_) win32api.SetTimeZoneInformation(tzi) - def copy(self): + def copy(self) -> Self: # XXX - this is no longer a copy! return self.__class__(self) - def locate_daylight_start(self, year): + def locate_daylight_start(self, year) -> datetime.datetime: return self._locate_day(year, self.daylight_start) - def locate_standard_start(self, year): + def locate_standard_start(self, year) -> datetime.datetime: return self._locate_day(year, self.standard_start) @staticmethod - def _locate_day(year, cutoff): + def _locate_day(year, cutoff) -> datetime.datetime: """ Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION structure or call to GetTimeZoneInformation and interprets it based on the given @@ -545,7 +576,7 @@ def __init__( def __getinitargs__(self) -> tuple[TimeZoneDefinition, bool]: return (self.staticInfo, self.fixedStandardTime) - def _FindTimeZoneKey(self): + def _FindTimeZoneKey(self) -> _RegKeyDict: """Find the registry key for the time zone name (self.timeZoneName).""" # for multi-language compatability, match the time zone name in the # "Std" key of the time zone key. @@ -561,7 +592,7 @@ def _FindTimeZoneKey(self): raise ValueError(f"Timezone Name {timeZoneName!r} not found") return result - def _LoadInfoFromKey(self): + def _LoadInfoFromKey(self) -> None: """Loads the information from an opened time zone registry key into relevant fields of this TZI object""" key = self._FindTimeZoneKey() @@ -571,20 +602,20 @@ def _LoadInfoFromKey(self): self.staticInfo = TimeZoneDefinition(key["TZI"]) self._LoadDynamicInfoFromKey(key) - def _LoadFromTZI(self, tzi): + def _LoadFromTZI(self, tzi: TimeZoneDefinition): self.timeZoneName = tzi.standard_name self.displayName = "Unknown" self.standardName = tzi.standard_name self.daylightName = tzi.daylight_name self.staticInfo = tzi - def _LoadDynamicInfoFromKey(self, key): + def _LoadDynamicInfoFromKey(self, key) -> None: """ >>> tzi = TimeZoneInfo('Central Standard Time') Here's how the RangeMap is supposed to work: >>> m = RangeMap(zip([2006,2007], 'BC'), - ... sort_params = dict(reverse=True), + ... sort_params = {"reverse": True}, ... key_match_comparator=operator.ge) >>> m.get(2000, 'A') 'A' @@ -615,7 +646,7 @@ def _LoadDynamicInfoFromKey(self, key): """ try: info = key.subkey("Dynamic DST") - except OSError: + except FileNotFoundError: return del info["FirstEntry"] del info["LastEntry"] @@ -631,46 +662,57 @@ def _LoadDynamicInfoFromKey(self, key): key_match_comparator=operator.ge, ) - def __repr__(self): + def __repr__(self) -> str: result = f"{self.__class__.__name__}({self.timeZoneName!r}" if self.fixedStandardTime: result += ", True" result += ")" return result - def __str__(self): + def __str__(self) -> str: return self.displayName - def tzname(self, dt): + @overload # type: ignore[override] # Split definition into overrides + def tzname(self, dt: datetime.datetime) -> str: ... + @overload + def tzname(self, dt: None) -> None: ... + def tzname(self, dt: datetime.datetime | None) -> str | None: """ >>> MST = TimeZoneInfo('Mountain Standard Time') >>> MST.tzname(datetime.datetime(2003, 8, 2)) 'Mountain Daylight Time' >>> MST.tzname(datetime.datetime(2003, 11, 25)) 'Mountain Standard Time' + >>> MST.tzname(None) + """ + # https://docs.python.org/3/library/datetime.html#datetime.tzinfo.tzname + # > [...] returning `None` is appropriate if the class wishes to say + # > that `time` objects don’t participate in the `tzinfo` protocols. + if dt is None: + return None + dst = self.dst(dt) winInfo = self.getWinInfo(dt.year) - if self.dst(dt) == -winInfo.daylight_bias: - result = self.daylightName - elif self.dst(dt) == -winInfo.standard_bias: - result = self.standardName - else: - raise ValueError( - "Unexpected daylight bias", - dt, - self.dst(dt), - winInfo.daylight_bias, - winInfo.standard_bias, - ) - return result + if dst == -winInfo.daylight_bias: + return self.daylightName + elif dst == -winInfo.standard_bias: + return self.standardName + + raise ValueError( + "Unexpected daylight bias", + dt, + dst, + winInfo.daylight_bias, + winInfo.standard_bias, + ) - def getWinInfo(self, targetYear): + def getWinInfo(self, targetYear: int) -> TimeZoneDefinition: """ Return the most relevant "info" for this time zone in the target year. """ - if not hasattr(self, "dynamicInfo") or not self.dynamicInfo: + if not getattr(self, "dynamicInfo", {}): return self.staticInfo # Find the greatest year entry in self.dynamicInfo which is for # a year greater than or equal to our targetYear. If not found, @@ -685,20 +727,28 @@ def _getDaylightBias(self, dt): winInfo = self.getWinInfo(dt.year) return winInfo.bias + winInfo.daylight_bias - def utcoffset(self, dt): + @overload # type: ignore[override] # False-positive, our overload covers all base types + def utcoffset(self, dt: None) -> None: ... + @overload + def utcoffset(self, dt: datetime.datetime) -> datetime.timedelta: ... + def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta | None: "Calculates the utcoffset according to the datetime.tzinfo spec" if dt is None: - return + return None winInfo = self.getWinInfo(dt.year) return -winInfo.bias + self.dst(dt) - def dst(self, dt): + @overload # type: ignore[override] # False-positive, our overload covers all base types + def dst(self, dt: None) -> None: ... + @overload + def dst(self, dt: datetime.datetime) -> datetime.timedelta: ... + def dst(self, dt: datetime.datetime | None) -> datetime.timedelta | None: """ Calculate the daylight savings offset according to the datetime.tzinfo spec. """ if dt is None: - return + return None winInfo = self.getWinInfo(dt.year) if not self.fixedStandardTime and self._inDaylightSavings(dt): result = winInfo.daylight_bias @@ -735,25 +785,22 @@ def _inDaylightSavings(self, dt): return in_dst - def GetDSTStartTime(self, year): + def GetDSTStartTime(self, year: int) -> datetime.datetime: "Given a year, determines the time when daylight savings time starts" return self.getWinInfo(year).locate_daylight_start(year) - def GetDSTEndTime(self, year): + def GetDSTEndTime(self, year: int) -> datetime.datetime: "Given a year, determines the time when daylight savings ends." return self.getWinInfo(year).locate_standard_start(year) - def __le__(self, other): - return self.__dict__ < other.__dict__ - - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return self.__dict__ != other.__dict__ @classmethod - def local(class_): + def local(cls) -> Self: """Returns the local time zone as defined by the operating system in the registry. >>> localTZ = TimeZoneInfo.local() @@ -784,10 +831,12 @@ def local(class_): # not sufficient to represent the time zone in which # the current user is operating due # to dynamic time zones. - return class_(info, fix_standard_time) + return cls(info, fix_standard_time) + + _tzutc: ClassVar[Self | None] = None @classmethod - def utc(class_): + def utc(cls) -> Self: """Returns a time-zone representing UTC. Same as TimeZoneInfo('GMT Standard Time', True) but caches the result @@ -796,9 +845,9 @@ def utc(class_): >>> isinstance(TimeZoneInfo.utc(), TimeZoneInfo) True """ - if "_tzutc" not in class_.__dict__: - setattr(class_, "_tzutc", class_("GMT Standard Time", True)) - return class_._tzutc + if not cls._tzutc: + cls._tzutc = cls("GMT Standard Time", True) + return cls._tzutc # helper methods for accessing the timezone info from the registry @staticmethod @@ -833,7 +882,7 @@ def get_index_value(key_name): ) @staticmethod - def get_sorted_time_zone_names(): + def get_sorted_time_zone_names() -> list[str]: """ Return a list of time zone names that can be used to initialize TimeZoneInfo instances. @@ -842,11 +891,11 @@ def get_sorted_time_zone_names(): return [tz.standardName for tz in tzs] @staticmethod - def get_all_time_zones(): + def get_all_time_zones() -> list[TimeZoneInfo]: return [TimeZoneInfo(n) for n in TimeZoneInfo._get_time_zone_key_names()] @staticmethod - def get_sorted_time_zones(key=None): + def get_sorted_time_zones(key=None) -> list[TimeZoneInfo]: """ Return the time zones sorted by some key. key must be a function that takes a TimeZoneInfo object and returns @@ -860,17 +909,19 @@ def get_sorted_time_zones(key=None): return zones -class _RegKeyDict(Dict[str, int]): - def __init__(self, key: winreg.HKEYType): +class _RegKeyDict(Dict[str, str]): + def __init__(self, key: winreg._KeyType): dict.__init__(self) self.key = key self.__load_values() @classmethod - def open(cls, *args, **kargs): - return _RegKeyDict(winreg.OpenKeyEx(*args, **kargs)) + def open( + cls, key: winreg._KeyType, sub_key: str, reserved: int = 0, access: int = 131097 + ) -> _RegKeyDict: + return _RegKeyDict(winreg.OpenKeyEx(key, sub_key, reserved, access)) - def subkey(self, name): + def subkey(self, name: str) -> _RegKeyDict: if not name: raise ValueError("subkey name cannot be empty") return _RegKeyDict(winreg.OpenKeyEx(self.key, name)) @@ -891,7 +942,9 @@ def _enumerate_reg_keys(key): return _RegKeyDict._enumerate_reg(key, winreg.EnumKey) @staticmethod - def _enumerate_reg(key, func): + def _enumerate_reg( + key: _T, func: Callable[[_T, int], _VT] + ) -> Generator[_VT, None, None]: "Enumerates an open registry key as an iterable generator" try: for index in count(): @@ -928,7 +981,7 @@ def now() -> datetime.datetime: return datetime.datetime.now(TimeZoneInfo.local()) -def GetTZCapabilities(): +def GetTZCapabilities() -> dict[str, bool]: """ Run a few known tests to determine the capabilities of the time zone database on this machine. @@ -955,10 +1008,10 @@ def GetTZCapabilities(): class DLLHandleCache: - def __init__(self): - self.__cache = {} + def __init__(self) -> None: + self.__cache: dict[str, int] = {} - def __getitem__(self, filename): + def __getitem__(self, filename: str) -> int: key = filename.lower() return self.__cache.setdefault(key, win32api.LoadLibrary(key)) @@ -966,27 +1019,33 @@ def __getitem__(self, filename): DLLCache = DLLHandleCache() -def resolveMUITimeZone(spec): +def resolveMUITimeZone(spec: str) -> str | None: """Resolve a multilingual user interface resource for the time zone name spec should be of the format @path,-stringID[;comment] see http://msdn2.microsoft.com/en-us/library/ms725481.aspx for details + + >>> import sys + >>> result = resolveMUITimeZone('@tzres.dll,-110') + >>> expectedResultType = [type(None),str][sys.getwindowsversion() >= (6,)] + >>> type(result) is expectedResultType + True """ pattern = re.compile(r"@(?P.*),-(?P\d+)(?:;(?P.*))?") matcher = pattern.match(spec) assert matcher, "Could not parse MUI spec" + groupdict = matcher.groupdict() try: - handle = DLLCache[matcher.groupdict()["dllname"]] - result = win32api.LoadString(handle, int(matcher.groupdict()["index"])) + handle = DLLCache[groupdict["dllname"]] + result: str | None = win32api.LoadString(handle, int(groupdict["index"])) except win32api.error: result = None return result -# from jaraco.util.dictlib 5.3.1 -# TODO: Update to implementation in jaraco.collections -class RangeMap(dict): # type: ignore[type-arg] # Source code is untyped :/ TODO: Add generics! +# from jaraco.collections 5.1 +class RangeMap(Dict[_RangeMapKT, _VT]): """ A dictionary-like object that uses the keys as bounds for a range. Inclusion of the value for that range is determined by the @@ -995,25 +1054,30 @@ class RangeMap(dict): # type: ignore[type-arg] # Source code is untyped :/ TODO the sorted list of keys. One may supply keyword parameters to be passed to the sort function used - to sort keys (i.e. keys, reverse) as sort_params. + to sort keys (i.e. key, reverse) as sort_params. + + Create a map that maps 1-3 -> 'a', 4-6 -> 'b' - Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy >>> r[1], r[2], r[3], r[4], r[5], r[6] ('a', 'a', 'a', 'b', 'b', 'b') Even float values should work so long as the comparison operator supports it. + >>> r[4.5] 'b' - But you'll notice that the way rangemap is defined, it must be open-ended on one side. + Notice that the way rangemap is defined, it must be open-ended + on one side. + >>> r[0] 'a' >>> r[-1] 'a' One can close the open-end of the RangeMap by using undefined_value + >>> r = RangeMap({0: RangeMap.undefined_value, 3: 'a', 6: 'b'}) >>> r[0] Traceback (most recent call last): @@ -1021,33 +1085,71 @@ class RangeMap(dict): # type: ignore[type-arg] # Source code is untyped :/ TODO KeyError: 0 One can get the first or last elements in the range by using RangeMap.Item + >>> last_item = RangeMap.Item(-1) >>> r[last_item] 'b' .last_item is a shortcut for Item(-1) + >>> r[RangeMap.last_item] 'b' Sometimes it's useful to find the bounds for a RangeMap + >>> r.bounds() (0, 6) RangeMap supports .get(key, default) + >>> r.get(0, 'not found') 'not found' >>> r.get(7, 'not found') 'not found' + One often wishes to define the ranges by their left-most values, + which requires use of sort params and a key_match_comparator. + + >>> r = RangeMap({1: 'a', 4: 'b'}, + ... sort_params=dict(reverse=True), + ... key_match_comparator=operator.ge) + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') + + That wasn't nearly as easy as before, so an alternate constructor + is provided: + + >>> r = RangeMap.left({1: 'a', 4: 'b', 7: RangeMap.undefined_value}) + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') + """ - def __init__(self, source, sort_params={}, key_match_comparator=operator.le): + def __init__( + self, + source: ( + SupportsKeysAndGetItem[_RangeMapKT, _VT] | Iterable[tuple[_RangeMapKT, _VT]] + ), + sort_params: Mapping[str, Any] = {}, + key_match_comparator: Callable[[_RangeMapKT, _RangeMapKT], bool] = operator.le, + ) -> None: dict.__init__(self, source) self.sort_params = sort_params self.match = key_match_comparator - def __getitem__(self, item): + @classmethod + def left( + cls, + source: ( + SupportsKeysAndGetItem[_RangeMapKT, _VT] | Iterable[tuple[_RangeMapKT, _VT]] + ), + ) -> Self: + return cls( + source, sort_params={"reverse": True}, key_match_comparator=operator.ge + ) + + def __getitem__(self, item: _RangeMapKT) -> _VT: sorted_keys = sorted(self, **self.sort_params) if isinstance(item, RangeMap.Item): result = self.__getitem__(sorted_keys[item]) @@ -1058,7 +1160,11 @@ def __getitem__(self, item): raise KeyError(key) return result - def get(self, key, default=None): + @overload # type: ignore[override] # Signature simplified over dict and Mapping + def get(self, key: _RangeMapKT, default: _T) -> _VT | _T: ... + @overload + def get(self, key: _RangeMapKT, default: None = None) -> _VT | None: ... + def get(self, key: _RangeMapKT, default: _T | None = None) -> _VT | _T | None: """ Return the value for key if key is in the dictionary, else default. If default is not given, it defaults to None, so that this method @@ -1070,27 +1176,25 @@ def get(self, key, default=None): except KeyError: return default - def _find_first_match_(self, keys, item): - def is_match(k): - return self.match(item, k) - - matches = list(filter(is_match, keys)) - if matches: - return matches[0] - raise KeyError(item) + def _find_first_match_( + self, keys: Iterable[_RangeMapKT], item: _RangeMapKT + ) -> _RangeMapKT: + is_match = functools.partial(self.match, item) + matches = filter(is_match, keys) + try: + return next(matches) + except StopIteration: + raise KeyError(item) from None - def bounds(self): + def bounds(self) -> tuple[_RangeMapKT, _RangeMapKT]: sorted_keys = sorted(self, **self.sort_params) - return ( - sorted_keys[RangeMap.first_item], - sorted_keys[RangeMap.last_item], - ) + return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) # some special values for the RangeMap - undefined_value = type("RangeValueUndefined", (object,), {})() + undefined_value = type("RangeValueUndefined", (), {})() class Item(int): - pass + """RangeMap Item""" first_item = Item(0) last_item = Item(-1)