Skip to content

Commit

Permalink
Add type stub for cppcore (#358)
Browse files Browse the repository at this point in the history
* create tox lint environment with mypy

* Fix: type error in units module

* Fix: remove unused convert.py in tests

* add stub for cppcore

* add efel/*.pyi to MANIFEST.in

* merge get_cpp_feature and _get_cpp_feature

* type annotate  _get_cpp_data

* update CHANGELOG.md
  • Loading branch information
anilbey authored Jan 16, 2024
1 parent 328153e commit 160db25
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 69 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,18 @@ jobs:
pip install tox tox-gh-actions
- name: Run tox
run: tox -e docs

lint:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
pip install tox tox-gh-actions
- name: Run tox
run: tox -e lint
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [5.5.3] - 2024-01

- Add type stub for cppcore module to make Python recognise the C++ functions' arguments and return values.

## [5.5.0] - 2024-01

### C++ changes
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ include LICENSE.txt
include COPYING
include COPYING.lesser
include AUTHORS.txt
include efel/*.pyi
18 changes: 1 addition & 17 deletions efel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import efel.cppcore as cppcore

import efel.pyfeatures as pyfeatures
from efel.pyfeatures.pyfeatures import get_cpp_feature

"""
Disabling cppcore importerror override, it confuses users in case the error
Expand Down Expand Up @@ -434,23 +435,6 @@ def _get_feature_values_serial(trace_featurenames):
return featureDict


def get_cpp_feature(featureName, raise_warnings=None):
"""Return value of feature implemented in cpp"""
cppcoreFeatureValues = list()
exitCode = cppcore.getFeature(featureName, cppcoreFeatureValues)

if exitCode < 0:
if raise_warnings:
import warnings
warnings.warn(
"Error while calculating feature %s: %s" %
(featureName, cppcore.getgError()),
RuntimeWarning)
return None
else:
return numpy.array(cppcoreFeatureValues)


def getMeanFeatureValues(traces, featureNames, raise_warnings=True):
"""Convenience function that returns mean values from getFeatureValues()
Expand Down
12 changes: 12 additions & 0 deletions efel/cppcore.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def Initialize(depfilename: str, outfilename: str) -> int: ...
def getFeature(feature_name: str, values: list) -> int: ...
def getFeatureInt(feature_name: str, values: list[int]) -> int: ...
def getFeatureDouble(feature_name: str, values: list[float]) -> int: ...
def getMapIntData(data_name: str) -> list[int]: ...
def getMapDoubleData(data_name: str) -> list[float]: ...
def setFeatureInt(feature_name: str, values: list[int]) -> int: ...
def setFeatureDouble(feature_name: str, values: list[float]) -> float: ...
def setFeatureString(feature_name: str, value: str) -> int: ...
def featuretype(feature_name: str) -> str: ...
def getgError() -> str: ...
def getFeatureNames(feature_names: list[str]) -> None: ...
63 changes: 34 additions & 29 deletions efel/pyfeatures/pyfeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing_extensions import deprecated

import numpy
import efel.cppcore
from efel import cppcore
from numpy.fft import *


Expand Down Expand Up @@ -60,12 +60,12 @@

def voltage():
"""Get voltage trace"""
return _get_cpp_feature("voltage")
return get_cpp_feature("voltage")


def time():
"""Get time trace"""
return _get_cpp_feature("time")
return get_cpp_feature("time")


@deprecated("Use spike_count instead.")
Expand All @@ -75,7 +75,7 @@ def Spikecount() -> numpy.ndarray:

def spike_count() -> numpy.ndarray:
"""Get spike count."""
peak_indices = _get_cpp_feature("peak_indices")
peak_indices = get_cpp_feature("peak_indices")
if peak_indices is None:
return numpy.array([0])
return numpy.array([peak_indices.size])
Expand All @@ -90,7 +90,7 @@ def spike_count_stimint() -> numpy.ndarray:
"""Get spike count within stimulus interval."""
stim_start = _get_cpp_data("stim_start")
stim_end = _get_cpp_data("stim_end")
peak_times = _get_cpp_feature("peak_time")
peak_times = get_cpp_feature("peak_time")
if peak_times is None:
return numpy.array([0])

Expand All @@ -105,7 +105,7 @@ def trace_check() -> numpy.ndarray | None:
"""
stim_start = _get_cpp_data("stim_start")
stim_end = _get_cpp_data("stim_end")
peak_times = _get_cpp_feature("peak_time")
peak_times = get_cpp_feature("peak_time")
if peak_times is None: # If no spikes, then no problem
return numpy.array([0])
# Check if there are no spikes or if all spikes are within the stimulus interval
Expand All @@ -117,7 +117,7 @@ def trace_check() -> numpy.ndarray | None:

def burst_number() -> numpy.ndarray:
"""The number of bursts."""
burst_mean_freq = _get_cpp_feature("burst_mean_freq")
burst_mean_freq = get_cpp_feature("burst_mean_freq")
if burst_mean_freq is None:
return numpy.array([0])
return numpy.array([burst_mean_freq.size])
Expand All @@ -132,7 +132,7 @@ def strict_burst_number() -> numpy.ndarray:
The burst detection can be fine-tuned by changing the setting
strict_burst_factor. Default value is 2.0."""
burst_mean_freq = _get_cpp_feature("strict_burst_mean_freq")
burst_mean_freq = get_cpp_feature("strict_burst_mean_freq")
if burst_mean_freq is None:
return numpy.array([0])
return numpy.array([burst_mean_freq.size])
Expand All @@ -144,11 +144,11 @@ def impedance():
dt = _get_cpp_data("interp_step")
Z_max_freq = _get_cpp_data("impedance_max_freq")
voltage_trace = voltage()
holding_voltage = _get_cpp_feature("voltage_base")
holding_voltage = get_cpp_feature("voltage_base")
normalized_voltage = voltage_trace - holding_voltage
current_trace = current()
if current_trace is not None:
holding_current = _get_cpp_feature("current_base")
holding_current = get_cpp_feature("current_base")
normalized_current = current_trace - holding_current
n_spikes = spike_count()
if n_spikes < 1: # if there is no spikes in ZAP
Expand All @@ -174,12 +174,12 @@ def impedance():

def current():
"""Get current trace"""
return _get_cpp_feature("current")
return get_cpp_feature("current")


def ISIs():
"""Get all ISIs."""
peak_times = _get_cpp_feature("peak_time")
peak_times = get_cpp_feature("peak_time")
if peak_times is None:
return None
else:
Expand All @@ -191,7 +191,7 @@ def initburst_sahp_vb():

# Required cpp features
initburst_sahp_value = initburst_sahp()
voltage_base = _get_cpp_feature("voltage_base")
voltage_base = get_cpp_feature("voltage_base")

if initburst_sahp_value is None or voltage_base is None or \
len(initburst_sahp_value) != 1 or len(voltage_base) != 1:
Expand All @@ -205,7 +205,7 @@ def initburst_sahp_ssse():

# Required cpp features
initburst_sahp_value = initburst_sahp()
ssse = _get_cpp_feature("steady_state_voltage_stimend")
ssse = get_cpp_feature("steady_state_voltage_stimend")

if initburst_sahp_value is None or ssse is None or \
len(initburst_sahp_value) != 1 or len(ssse) != 1:
Expand All @@ -218,10 +218,10 @@ def initburst_sahp():
"""SlowAHP voltage after initial burst"""

# Required cpp features
voltage = _get_cpp_feature("voltage")
time = _get_cpp_feature("time")
voltage = get_cpp_feature("voltage")
time = get_cpp_feature("time")
time = time[:len(voltage)]
peak_times = _get_cpp_feature("peak_time")
peak_times = get_cpp_feature("peak_time")

# Required python features
all_isis = ISIs()
Expand Down Expand Up @@ -310,9 +310,9 @@ def depol_block():
stim_end = _get_cpp_data("stim_end")

# Required cpp features
voltage = _get_cpp_feature("voltage")
time = _get_cpp_feature("time")
AP_begin_voltage = _get_cpp_feature("AP_begin_voltage")
voltage = get_cpp_feature("voltage")
time = get_cpp_feature("time")
AP_begin_voltage = get_cpp_feature("AP_begin_voltage")
stim_start_idx = numpy.flatnonzero(time >= stim_start)[0]
stim_end_idx = numpy.flatnonzero(time >= stim_end)[0]

Expand Down Expand Up @@ -380,8 +380,8 @@ def depol_block_bool():
def spikes_per_burst():
"""Calculate the number of spikes per burst"""

burst_begin_indices = _get_cpp_feature("burst_begin_indices")
burst_end_indices = _get_cpp_feature("burst_end_indices")
burst_begin_indices = get_cpp_feature("burst_begin_indices")
burst_end_indices = get_cpp_feature("burst_end_indices")

if burst_begin_indices is None:
return None
Expand Down Expand Up @@ -424,18 +424,23 @@ def spikes_in_burst1_burstlast_diff():
])


def _get_cpp_feature(feature_name):
"""Get cpp feature"""
def get_cpp_feature(featureName, raise_warnings=None):
"""Return value of feature implemented in cpp"""
cppcoreFeatureValues = list()
exitCode = efel.cppcore.getFeature(feature_name, cppcoreFeatureValues)
exitCode = cppcore.getFeature(featureName, cppcoreFeatureValues)

if exitCode < 0:
if raise_warnings:
import warnings
warnings.warn(
"Error while calculating feature %s: %s" %
(featureName, cppcore.getgError()),
RuntimeWarning)
return None
else:
return numpy.array(cppcoreFeatureValues)


def _get_cpp_data(data_name):
"""Get cpp data value"""

return efel.cppcore.getMapDoubleData(data_name)[0]
def _get_cpp_data(data_name: str) -> float:
"""Get cpp data value."""
return cppcore.getMapDoubleData(data_name)[0]
4 changes: 4 additions & 0 deletions efel/units/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@


_units_raw = pkgutil.get_data(__name__, "units.json")

if _units_raw is None:
raise ValueError("Failed to load units.json")

_units = json.loads(_units_raw)


Expand Down
16 changes: 15 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Unit tests for units module."""


import importlib
import pytest
from efel.units import get_unit
from unittest.mock import patch


def test_get_unit():
Expand All @@ -11,3 +13,15 @@ def test_get_unit():
assert get_unit("AP1_amp") != "wrong unit"
assert get_unit("AP1_amp") == "mV"
assert get_unit("ohmic_input_resistance") == "MΩ"


@patch('efel.units.pkgutil.get_data')
def test_get_data_failure(mock_get_data):
"""Test for handling failure in loading units.json."""
mock_get_data.return_value = None

with pytest.raises(ValueError) as excinfo:
# Dynamically reload the module to simulate the import with mock
importlib.reload(importlib.import_module('efel.units'))

assert str(excinfo.value) == "Failed to load units.json"
19 changes: 0 additions & 19 deletions tests/testdata/allfeatures/convert.py

This file was deleted.

14 changes: 11 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
[tox]
envlist = docs,py3-{test}
envlist = docs,lint,py3-{test}
minversion = 4
[gh-actions]
python =
3.8: py3
3.9: py3
3.10: py3
3.11: py3
3.12: py3


[testenv]
envdir = {toxworkdir}/py3-test
deps =
pytest>=7.3.1
scipy>=1.10.1
pycodestyle>=2.11.0
pytest-xdist>=3.3.1
extras =
neo
usedevelop=True
commands =
pycodestyle --ignore=E402,W503,W504 --exclude=_version.py --max-line-length=88 efel tests
pytest -sx -n auto tests


Expand Down Expand Up @@ -78,3 +77,12 @@ commands =
make html SPHINXOPTS=-W
# make sure the feature names and units are up-to-date
pytest test_feature_units_in_docs.py

[testenv:lint]
envdir = {toxworkdir}/lint
deps =
pycodestyle>=2.11.0
mypy>=1.8.0
commands =
pycodestyle --ignore=E402,W503,W504 --exclude=_version.py --max-line-length=88 efel tests
mypy efel tests --ignore-missing-imports

0 comments on commit 160db25

Please sign in to comment.