Skip to content

Commit

Permalink
feat: Replace find() implementation with Python (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeeDongGeon1996 authored Apr 14, 2024
1 parent 16665c8 commit d219bce
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
1 change: 0 additions & 1 deletion stock_indicators/_cslib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from Skender.Stock.Indicators import QuoteUtility as CsQuoteUtility
from Skender.Stock.Indicators import ResultBase as CsResultBase
from Skender.Stock.Indicators import Pruning as CsPruning
from Skender.Stock.Indicators import Seeking as CsSeeking

# Enums
from Skender.Stock.Indicators import BetaType as CsBetaType
Expand Down
16 changes: 6 additions & 10 deletions stock_indicators/indicators/common/results.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime as PyDateTime
from typing import Callable, Iterable, List, Type, TypeVar
from typing import Callable, Iterable, List, Optional, Type, TypeVar

from stock_indicators._cslib import CsResultBase, CsPruning, CsSeeking
from stock_indicators._cslib import CsResultBase, CsPruning
from stock_indicators._cstypes import DateTime as CsDateTime
from stock_indicators._cstypes import List as CsList
from stock_indicators._cstypes import to_pydatetime
Expand Down Expand Up @@ -77,18 +77,14 @@ def __add__(self, other: "IndicatorResults"):
def __mul__(self, value: int):
return self.__class__(list(self._csdata).__mul__(value), self._wrapper_class)

@_verify_data
def find(self, lookup_date: PyDateTime) -> _T:
"""Find indicator values on a specific date."""
def find(self, lookup_date: PyDateTime) -> Optional[_T]:
"""Find indicator values on a specific date. It returns `None` if no result found."""
if not isinstance(lookup_date, PyDateTime):
raise TypeError(
"lookup_date must be an instance of datetime.datetime."
)

result = CsSeeking.Find[CsResultBase](
CsList(self._get_csdata_type(), self._csdata), CsDateTime(lookup_date)
)
return self._wrapper_class(result)

return next((r for r in self if r.date == lookup_date), None)

@_verify_data
def remove_warmup_periods(self, remove_periods: int):
Expand Down
13 changes: 10 additions & 3 deletions tests/common/test_indicator_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,16 @@ def test_done_and_reload(self, quotes):
def test_find(self, quotes):
results = indicators.get_sma(quotes, 20)

# r[18]
r = results.find(datetime(2017, 1, 30))
assert r.sma is None
# r[19]
r = results.find(datetime(2017, 1, 31))
assert 214.5250 == round(float(r.sma), 4)

def test_not_found(self, quotes):
results = indicators.get_sma(quotes, 20)

# returns None
r = results.find(datetime(1996, 10, 12))
assert r is None

def test_remove_with_period(self, quotes):
results = indicators.get_sma(quotes, 20)
Expand Down

0 comments on commit d219bce

Please sign in to comment.