Skip to content

Commit

Permalink
adding to linearmap operation
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyic committed Mar 20, 2024
1 parent 91f8a00 commit eef4987
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/aggforce/map/jaxlinearmap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Extends LinearMaps for Jax operations."""

from typing import overload, TypeVar
from jax import Array
import jax.numpy as jnp
Expand Down Expand Up @@ -26,12 +27,10 @@ def jax_standard_matrix(self) -> Array:
return jnp.asarray(self.standard_matrix)

@overload
def __call__(self, points: NDArray) -> NDArray:
...
def __call__(self, points: NDArray) -> NDArray: ...

@overload
def __call__(self, points: Array) -> Array:
...
def __call__(self, points: Array) -> Array: ...

def __call__(self, points: ArrT) -> ArrT:
r"""Apply map to a particular form of 3-dim array.
Expand Down Expand Up @@ -65,12 +64,10 @@ def __call__(self, points: ArrT) -> ArrT:
return transformed

@overload
def flat_call(self, flattened: NDArray) -> NDArray:
...
def flat_call(self, flattened: NDArray) -> NDArray: ...

@overload
def flat_call(self, flattened: Array) -> Array:
...
def flat_call(self, flattened: Array) -> Array: ...

def flat_call(self, flattened: ArrT) -> ArrT:
"""Apply map to pre-flattened array.
Expand Down Expand Up @@ -129,3 +126,7 @@ def __add__(self, lm: "LinearMap", /) -> "JLinearMap":
def from_linearmap(cls, lm: LinearMap, /) -> "JLinearMap":
"""Create JLinearMap from LinearMap."""
return JLinearMap(mapping=lm.standard_matrix)

def to_linearmap(self) -> LinearMap:
"""Create normal LinearMap from the current object."""
return LinearMap(mapping=self.standard_matrix)

0 comments on commit eef4987

Please sign in to comment.