Skip to content

Commit

Permalink
Update typing to higher standard (facebookresearch#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrapin authored Feb 13, 2019
1 parent b8ebcb2 commit eee8eb1
Show file tree
Hide file tree
Showing 27 changed files with 57 additions and 56 deletions.
2 changes: 1 addition & 1 deletion nevergrad/benchmark/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self) -> None:
self._order = 0
self._time = 0.

def submit(self, function: Callable, *args: Any, **kwargs: Any) -> MockedSteadyJob:
def submit(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> MockedSteadyJob:
if self.priority_queue: # new job may come before the current "next" job
self.priority_queue[0].job._done = False
value = function(*args, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions nevergrad/benchmark/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import argparse
import itertools
from pathlib import Path
from typing import Iterator, List, Optional, Any
from typing import Iterator, List, Optional, Any, Dict
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
Expand All @@ -29,7 +29,7 @@ def _make_style_generator() -> Iterator[str]:
return (l + m + c for l, m, c in zip(lines, markers, colors))


class NameStyle(dict):
class NameStyle(Dict[str, Any]):
"""Provides a style for each name, and keeps to it
"""

Expand All @@ -39,7 +39,7 @@ def __init__(self) -> None:

def __getitem__(self, name: str) -> Any:
if name not in self:
self[name] = next(self._gen)
super().__setitem__(name, next(self._gen))
return super().__getitem__(name)


Expand Down Expand Up @@ -154,7 +154,7 @@ def create_plots(df: pd.DataFrame, output_folder: PathLike, max_combsize: int =


def make_xpresults_plot(df: pd.DataFrame, title: str, output_filepath: Optional[PathLike] = None,
name_style: Optional[dict] = None) -> None:
name_style: Optional[Dict[str, Any]] = None) -> None:
"""Creates a xp result plot out of the given dataframe: regret with respect to budget for
each optimizer after averaging on all experiments (it is good practice to use a df
which is filtered out for one set of input parameters)
Expand Down Expand Up @@ -300,12 +300,12 @@ def winrates_from_selection(df: tools.Selector, categories: List[str], num_rows:
best_names = [(f"{name} ({100 * val:2.1f}\%)").replace("Search", "") for name, val in zip(mean_win.index[: num_rows], mean_win)]
return pd.DataFrame(index=best_names, columns=sorted_names, data=data)

def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
"""Shortcut to the figure savefig method
"""
self._fig.savefig(*args, **kwargs)

def __del__(self):
def __del__(self) -> None:
plt.close(self._fig)


Expand Down
2 changes: 1 addition & 1 deletion nevergrad/benchmark/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ExecutorTest(TestCase):
simple=(add, list(range(10))),
delayed=(Function(), [5, 6, 7, 8, 9, 4, 3, 2, 1, 0])
)
def test_mocked_steady_executor(self, func: Callable, expected: List[int]) -> None:
def test_mocked_steady_executor(self, func: Callable[..., Any], expected: List[int]) -> None:
executor = execution.MockedSteadyExecutor()
jobs: List[execution.MockedSteadyJob] = []
for k in range(10):
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/benchmark/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_name_style() -> None:
np.testing.assert_equal(nstyle["blublu"], "-ob")


def test_split_long_title():
def test_split_long_title() -> None:
title = "abcd,efgh"
np.testing.assert_equal(plotting.split_long_title(title), title)
title = ",".join(["a" * 25, "b" * 25, "c" * 25, "d" * 15])
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/benchmark/xpbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CallCounter(execution.PostponedObject):
the callable to wrap
"""

def __init__(self, func: Callable) -> None:
def __init__(self, func: Callable[..., Any]) -> None:
self.func = func
self.num_calls = 0

Expand Down
8 changes: 4 additions & 4 deletions nevergrad/common/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import functools


class Registry(dict):
class Registry(dict): # type: ignore
"""Registers function or classes as a dict.
"""

def __init__(self) -> None:
super().__init__()
self._information: Dict[str, dict] = {}
self._information: Dict[str, Dict[Any, Any]] = {}

def register(self, obj: Any, info: Optional[Dict[Any, Any]] = None) -> Any:
"""Decorator method for registering functions/classes
Expand All @@ -36,12 +36,12 @@ def unregister(self, name: str) -> None:
if name in self:
del self[name]

def register_with_info(self, **info: Any) -> Callable:
def register_with_info(self, **info: Any) -> Callable[..., Any]:
"""Decorator for registering a function and information about it
"""
return functools.partial(self.register, info=info)

def get_info(self, name: str) -> dict:
def get_info(self, name: str) -> Dict[Any, Any]:
if name not in self:
raise ValueError(f'"{name}" is not registered.')
return self._information.setdefault(name, {})
3 changes: 2 additions & 1 deletion nevergrad/common/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Iterable
from unittest import TestCase
import genty
import numpy as np
Expand All @@ -19,7 +20,7 @@ class UtilsTests(TestCase):
additional=((1, 4, 3, 2), [" - additional element(s): {4}."]),
both=((1, 2, 4), [" - additional element(s): {4}.", " - missing element(s): {3}."]),
)
def test_assert_set_equal(self, estimate: testing.Iterable, message: str) -> None:
def test_assert_set_equal(self, estimate: Iterable[int], message: str) -> None:
reference = {1, 2, 3}
try:
testing.assert_set_equal(estimate, reference)
Expand Down
6 changes: 3 additions & 3 deletions nevergrad/common/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Iterable, List, Any
from typing import Iterable, List, Any, Tuple
from unittest import TestCase
import genty
import numpy as np
Expand All @@ -21,7 +21,7 @@ class ToolsTests(TestCase):
two=([1, 2], [(1, 2)]),
three=([1, 2, 3], [(1, 2), (2, 3)]),
)
def test_pairwise(self, iterator: Iterable, expected: List) -> None:
def test_pairwise(self, iterator: Iterable[Any], expected: List[Tuple[Any, ...]]) -> None:
output = list(tools.pairwise(iterator))
testing.printed_assert_equal(output, expected)

Expand All @@ -32,7 +32,7 @@ def test_pairwise(self, iterator: Iterable, expected: List) -> None:
values=({"c1": ["i3-c1", "i2-c1"]}, ["i2", "i3"]),
conditions=({"c1": ["i3-c1", "i2-c1"], "c2": "i3-c2"}, ["i3"]),
)
def test_selector(self, criteria: Any, expected: List) -> None:
def test_selector(self, criteria: Any, expected: List[str]) -> None:
df = tools.Selector(index=["i1", "i2", "i3"], columns=["c1", "c2"])
for i, c in itertools.product(df.index, df.columns):
df.loc[i, c] = f"{i}-{c}"
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/common/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np


def assert_set_equal(estimate: Iterable, reference: Iterable, err_msg: str = "") -> None:
def assert_set_equal(estimate: Iterable[Any], reference: Iterable[Any], err_msg: str = "") -> None:
"""Asserts that both sets are equals, with comprehensive error message.
This function should only be used in tests.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions nevergrad/common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import testing


def pairwise(iterable: Iterable) -> Iterator[Tuple[Any, Any]]:
def pairwise(iterable: Iterable[Any]) -> Iterator[Tuple[Any, Any]]:
"""Returns an iterator over sliding pairs of the input iterator
s -> (s0,s1), (s1,s2), (s2, s3), ...
Expand All @@ -29,7 +29,7 @@ def pairwise(iterable: Iterable) -> Iterator[Tuple[Any, Any]]:
return zip(a, b)


def grouper(iterable: Iterable, n: int, fillvalue: Optional[Any] = None) -> Iterator[List]:
def grouper(iterable: Iterable[Any], n: int, fillvalue: Optional[Any] = None) -> Iterator[List[Any]]:
"""Collect data into fixed-length chunks or blocks
Copied from itertools recipe documentation
Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
Expand All @@ -38,7 +38,7 @@ def grouper(iterable: Iterable, n: int, fillvalue: Optional[Any] = None) -> Iter
return itertools.zip_longest(*args, fillvalue=fillvalue)


def roundrobin(*iterables: Iterable) -> Iterator[Any]:
def roundrobin(*iterables: Iterable[Any]) -> Iterator[Any]:
"""roundrobin('ABC', 'D', 'EF') --> A D E B F C
"""
# Recipe credited to George Sakkis
Expand All @@ -58,7 +58,7 @@ class Selector(pd.DataFrame): # type: ignore
"""Pandas dataframe class with a simplified selection function
"""

def select(self, **kwargs: Union[str, Sequence[str], Callable]) -> 'Selector':
def select(self, **kwargs: Union[str, Sequence[str], Callable[[Any], bool]]) -> 'Selector':
"""Select rows based on a value, a sequence of values or a discriminating function
Parameters
Expand All @@ -83,7 +83,7 @@ def select(self, **kwargs: Union[str, Sequence[str], Callable]) -> 'Selector':
df = df.loc[selected, :]
return Selector(df)

def select_and_drop(self, **kwargs: Union[str, Sequence[str], Callable]) -> 'Selector':
def select_and_drop(self, **kwargs: Union[str, Sequence[str], Callable[[Any], bool]]) -> 'Selector':
"""Same as select, but drops the columns used for selection
"""
df = self.select(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/common/typetools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def result(self) -> Any:
class ExecutorLike(Protocol):
# pylint: disable=pointless-statement, unused-argument

def submit(self, function: Callable, *args: Any, **kwargs: Any) -> JobLike:
def submit(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> JobLike:
...
2 changes: 1 addition & 1 deletion nevergrad/functions/functionlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def oracle_call(self, x: np.ndarray) -> float:
def duplicate(self) -> "ArtificialFunction":
"""Create an equivalent instance, initialized with the same settings
"""
return self.__class__(**self._parameters) # type: ignore
return self.__class__(**self._parameters)

def get_postponing_delay(self, arguments: Tuple[Tuple[Any, ...], Dict[str, Any]], value: float) -> float:
"""Delay before returning results in steady state mode benchmarks (fake execution time)
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/functions/photonics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...instrumentation.utils import CommandFunction


def tanh_crop(x, min_val, max_val):
def tanh_crop(x: ArrayLike, min_val: float, max_val: float) -> np.ndarray:
return .5 * (max_val + min_val) + .5 * (max_val - min_val) * np.tanh(x)


Expand Down
4 changes: 2 additions & 2 deletions nevergrad/functions/test_corefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from unittest import TestCase
from typing import Callable
from typing import Callable, Any
import numpy as np
import genty
from . import corefuncs
Expand All @@ -14,7 +14,7 @@
class CoreFuncsTests(TestCase):

@genty.genty_dataset(**{name: (name, func) for name, func in corefuncs.registry.items()}) # type: ignore
def test_core_function(self, name: str, func: Callable) -> None:
def test_core_function(self, name: str, func: Callable[..., Any]) -> None:
x = np.random.normal(0, 1, 100)
outputs = []
for _ in range(2):
Expand Down
10 changes: 5 additions & 5 deletions nevergrad/instrumentation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def with_name(self, name: str) -> "Instrumentation":
self._name = name
return self

def _set_args_kwargs(self, args: Tuple[Any, ...], kwargs: Dict) -> None:
def _set_args_kwargs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
self.names, arguments = self._make_argument_names_and_list(args, kwargs)
self.instruments: List[utils.Variable] = [variables._Constant.convert_non_instrument(a) for a in arguments]
num_instru = len(set(id(i) for i in self.instruments))
Expand All @@ -58,15 +58,15 @@ def kwargs(self) -> Dict[str, utils.Variable]:
return {name: arg for name, arg in zip(self.names, self.instruments) if name is not None}

@staticmethod
def _make_argument_names_and_list(args: Tuple[Any, ...], kwargs: Dict) -> Tuple[Tuple[Optional[str], ...], Tuple[Any, ...]]:
def _make_argument_names_and_list(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Optional[str], ...], Tuple[Any, ...]]:
"""Converts *args and **kwargs to a tuple of names (with None for positional),
and the corresponding tuple of values.
Eg:
_make_argument_names_and_list(3, z="blublu", machin="truc")
>>> (None, "machin", "z"), (3, "truc", "blublu")
"""
names: Tuple[Optional[str], ...] = tuple([None] * len(args) + sorted(kwargs))
names: Tuple[Optional[str], ...] = tuple([None] * len(args) + sorted(kwargs)) # type: ignore
arguments: Tuple[Any, ...] = args + tuple(kwargs[x] for x in names if x is not None)
return names, arguments

Expand Down Expand Up @@ -94,7 +94,7 @@ def arguments_to_data(self, *args: Any, **kwargs: Any) -> np.ndarray:
data = list(itertools.chain.from_iterable([instrument.process_arg(arg) for instrument, arg in zip(self.instruments, arguments)]))
return np.array(data)

def instrument(self, function: Callable) -> "InstrumentedFunction":
def instrument(self, function: Callable[..., Any]) -> "InstrumentedFunction":
return InstrumentedFunction(function, *self.args, **self.kwargs)

def __format__(self, format_spec: str) -> str:
Expand Down Expand Up @@ -146,7 +146,7 @@ class InstrumentedFunction(base.BaseFunction):
"""

def __init__(self, function: Callable, *args: Any, **kwargs: Any) -> None:
def __init__(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
assert callable(function)
self.instrumentation = Instrumentation(*args, **kwargs)
super().__init__(dimension=self.instrumentation.dimension)
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/instrumentation/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def softmax_discretization(x: ArrayLike, arity: int = 2, deterministic: bool = F
warnings.warn("Encountered NaN values for discretization")
data[np.isnan(data)] = -np.inf
if deterministic:
output: list = np.argmax(data, axis=1).tolist()
output: List[float] = np.argmax(data, axis=1).tolist()
return output
return [np.random.choice(arity, p=softmax_probas(d)) for d in data]

Expand Down
4 changes: 2 additions & 2 deletions nevergrad/instrumentation/instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __eq__(self, other: Any) -> bool:
return False

@classmethod
def sub(cls, text: str, extension: str, replacers: Dict) ->str:
def sub(cls, text: str, extension: str, replacers: Dict[str, Any]) -> str:
found: Set[str] = set()
kwargs = {x: _convert_to_string(y, extension) for x, y in replacers.items()}

def _replacer(regex: Match) -> str:
def _replacer(regex: Match[str]) -> str:
name = regex.group("name")
if name in found:
raise RuntimeError(f'Trying to remplace a second time placeholder "{name}"')
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/instrumentation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class UtilsTests(TestCase):
@genty.genty_dataset( # type: ignore
empty=([], [], [])
)
def test_split_data(self, tokens: List, data: List, expected: List) -> None:
def test_split_data(self, tokens: List[utils.Variable], data: List[float], expected: List[List[float]]) -> None:
output = utils.split_data(data, tokens)
testing.printed_assert_equal(output, expected)

Expand Down
2 changes: 1 addition & 1 deletion nevergrad/instrumentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def process_instruments(instruments: Iterable[Variable], data: List[float],
return tuple([instrument.process(d, deterministic=deterministic) for instrument, d in zip(instruments, splitted_data)])


class TemporaryDirectoryCopy(tempfile.TemporaryDirectory):
class TemporaryDirectoryCopy(tempfile.TemporaryDirectory): # type: ignore
"""Creates a full copy of a directory inside a temporary directory
This class can be used as TemporaryDirectory but:
- the created copy path is available through the copyname attribute
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/instrumentation/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, mean: float, std: float, shape: Optional[List[int]] = None) -
self.shape = shape

@classmethod
def from_regex(cls, regex: Match) -> utils.Variable:
def from_regex(cls, regex: Match[str]) -> utils.Variable:
return cls(float(regex.group("mean")), float(regex.group("std")))

@property
Expand Down
2 changes: 1 addition & 1 deletion nevergrad/optimization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, dimension: int, budget: Optional[int] = None, num_workers: in
# instance state
self._num_suggestions = 0
self._num_evaluations = 0
self._callbacks: Dict[str, List[Callable]] = {}
self._callbacks: Dict[str, List[Any]] = {}

@property
def num_suggestions(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions nevergrad/optimization/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Any
from typing import Optional, Any, Dict
import numpy as np
from ..common.typetools import ArrayLike

Expand Down Expand Up @@ -65,7 +65,7 @@ def crossover(parent: ArrayLike, donor: ArrayLike) -> ArrayLike:
return discrete_mutation(mix)


def get_roulette(archive: dict, num: Optional[int] = None) -> Any:
def get_roulette(archive: Dict[Any, Any], num: Optional[int] = None) -> Any:
"""Apply a roulette tournament selection.
"""
if num is None:
Expand Down
Loading

0 comments on commit eee8eb1

Please sign in to comment.