Skip to content

Commit

Permalink
end module and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokisakix committed Apr 8, 2024
1 parent 2525432 commit 49c812a
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 6 deletions.
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Terox is a tiny Python package that provides some features:
- [x] Support automatic differentiation.
- [ ] Provides convenient tensor calculation.
- [ ] Control the parameters and the model.
- [x] Control the parameters and the model.
- [ ] Provides common computing functions for deep learning.
- [ ] Provides common deep learning components.
- [ ] Provides deep learning model optimizer.
Expand Down Expand Up @@ -42,4 +42,32 @@ Make sure that everything is installed by running python and then checking. If y

```Python
import terox
print(terox.__version__) # expect output: "Terox v0.1 by Tokisakix."
print(terox.__version__) # expect output: "Terox v0.1 by Tokisakix."
```

## Test

You can test the correctness of the project by running `pytest` in the root directory of the project:

```Shell
python -m pytest
```

Pytest tests all modules by default, but you can also run the following commands to do some testing:

```Shell
python -m pytest -m <test-name>
```

Where `<test-name>` can select the following test module name:

```Shell
# autodiff test
test_function
test_scalar
test_scalar_opts
test_scalar_overload

# module test
test_module
```
29 changes: 28 additions & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Terox 是一个很精简的 Python 包,它提供了一些特性:
- [x] 支持自动微分。
- [ ] 提供方便的张量计算。
- [ ] 便捷控制参数和模型。
- [x] 便捷控制参数和模型。
- [ ] 提供深度学习常用的计算函数。
- [ ] 提供常用的深度学习组件。
- [ ] 提供深度学习模型优化器。
Expand Down Expand Up @@ -43,4 +43,31 @@ python -m pip install -Ue .
```Python
import terox
print(terox.__version__) # 期望输出: "Terox v0.1 by Tokisakix."
```

## 测试

你可以在项目根目录下运行 `pytest` 来测试此项目的正确性:

```Shell
python -m pytest
```

默认情况下,Pytest 会测试所有的模块,你也可以运行下列命令来进行部分测试:

```Shell
python -m pytest -m <test-name>
```

其中 `<test-name>` 可以选择如下测试模块名:

```Shell
# autodiff test
test_function
test_scalar
test_scalar_opts
test_scalar_overload

# module test
test_module
```
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ markers =
test_function
test_scalar
test_scalar_opts
test_scalar_overload
test_scalar_overload

test_module
4 changes: 2 additions & 2 deletions terox/autodiff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .function import add, sub, mul, div, neg, max, min, eq, lt, gt, abs, exp, log, relu
from .variable import VarFunction, VarHistory, Variable
from .variable import Variable
from .scalar import Scalar
from .scalar_opts import ScalarOptsBackend, Add, Sub, Mul, Div, Max, Min, Eq, Lt, Gt, Abs, Exp, Log, Relu, Sigmoid, Tanh
from .scalar_opts import ScalarOptsBackend
2 changes: 2 additions & 0 deletions terox/module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .module import Module
from .parameter import Parameter
37 changes: 37 additions & 0 deletions terox/module/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, List, Tuple

from .parameter import Parameter

module_count:int = 0

class Module():

_id: int

def __init__(self) -> None:
global module_count
self._id = module_count
module_count += 1
return

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)

def getParmeterDict(self) -> dict:
def getParmeterList(module:Module) -> List[Tuple[str, Parameter]]:
parmeters = []
for key in module.__dict__:
value = module.__dict__[key]
if isinstance(value, Module):
subparmeters = []
for name, parmeter in getParmeterList(value):
subparmeters.append((f"{key}.{name}", parmeter))
parmeters += subparmeters
if isinstance(value, Parameter):
parmeters.append((key, value))
return parmeters
parmeters = dict(getParmeterList(self))
return parmeters

def forward(self, *args: Any, **kwds: Any) -> Any:
raise NotImplementedError
18 changes: 18 additions & 0 deletions terox/module/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..autodiff.variable import Variable

parameter_count:int = 0

class Parameter():

_id: int
_value: Variable

def __init__(self, _value:Variable) -> None:
global parameter_count
self._id = parameter_count
parameter_count += 1
self._value = _value
return

def value(self) -> Variable:
return self._value
31 changes: 31 additions & 0 deletions test/module/test_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from terox.autodiff.scalar import Scalar
from terox.module.module import Module
from terox.module.parameter import Parameter

class M1(Module):
def __init__(self, p1:Parameter, p2:Parameter) -> None:
super().__init__()
self.p1 = p1
self.p2 = p2
return

class M2(Module):
def __init__(self, m1:M1, p3:Parameter) -> None:
super().__init__()
self.m1 = m1
self.p3 = p3
return

p1 = Parameter(Scalar(1.0))
p2 = Parameter(Scalar(2.0))
p3 = Parameter(Scalar(3.0))
m1 = M1(p1, p2)
m2 = M2(m1, p3)

@pytest.mark.test_module
def test_module_dict() -> None:
assert m1.getParmeterDict() == {"p1":p1, "p2":p2}
assert m2.getParmeterDict() == {"m1.p1":p1, "m1.p2":p2, "p3":p3}
return

0 comments on commit 49c812a

Please sign in to comment.