Skip to content

Commit

Permalink
add Range and new Sampler types
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Apr 3, 2023
1 parent 4aa9d9e commit 5bb5515
Showing 1 changed file with 80 additions and 1 deletion.
81 changes: 80 additions & 1 deletion src/gluonts/transform/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,93 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Tuple
from dataclasses import dataclass
from typing import Tuple, Optional, Union

import numpy as np
import pandas as pd
from pydantic import BaseModel

from gluonts.dataset.stat import ScaleHistogram


def clip(value, low, high):
"""
Clip ``value`` between ``low`` and ``high``, included.
"""
return max(low, min(high, value))


@dataclass
class Range:
start: Optional[Union[int, pd.Period]] = None
stop: Optional[Union[int, pd.Period]] = None
step: int = 1

def _start_as_int(self, start: pd.Period, length: int) -> int:
if self.start is None:
return 0
if isinstance(self.start, pd.Period):
return int((self.start - start) / start.freq)
if self.start < 0:
return length + self.start
return self.start

def _stop_as_int(self, start: pd.Period, length: int) -> int:
if self.stop is None:
return length
if isinstance(self.stop, pd.Period):
return int((self.stop - start) / start.freq)
if self.stop < 0:
return length + self.stop
return self.stop

def get(self, start: pd.Period, length: int) -> range:
return range(
clip(self._start_as_int(start, length), 0, length),
clip(self._stop_as_int(start, length), 0, length),
self.step,
)


@dataclass
class Sampler:
range_: Range

def sample(self, rge: range) -> list:
raise NotImplementedError()

def __call__(self, start: pd.Period, length: int) -> list:
return self.sample(self.range_.get(start, length))


@dataclass
class SampleAll(Sampler):
def sample(self, rge: range) -> list:
return list(rge)


@dataclass
class SampleOnAverage(Sampler):
average_num_samples: float = 1.0

def __post_init__(self):
self.average_length = 0
self.count = 0

def sample(self, rge: range) -> list:
if len(rge) == 0:
return []

self.average_length = (self.count * self.average_length + len(rge)) / (
self.count + 1
)
self.count += 1
p = self.average_num_samples / self.average_length
(indices,) = np.where(np.random.random_sample(len(rge)) < p)
return (min(rge) + indices).tolist()


class InstanceSampler(BaseModel):
"""
An InstanceSampler is called with the time series ``ts``, and returns a set
Expand Down

0 comments on commit 5bb5515

Please sign in to comment.