Skip to content

Commit

Permalink
Add MCRTensor (#174)
Browse files Browse the repository at this point in the history
* Add MCRTensor

* Add MCRTensor header to doc files

* Simplify the implementation of bundle
  • Loading branch information
milad2073 authored Oct 8, 2024
1 parent 132ee17 commit 54b6465
Show file tree
Hide file tree
Showing 14 changed files with 570 additions and 110 deletions.
1 change: 1 addition & 0 deletions docs/torchhd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ VSA Models
HRRTensor
FHRRTensor
BSBCTensor
MCRTensor
VTBTensor


Expand Down
2 changes: 2 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.mcr import MCRTensor

from torchhd.functional import (
ensure_vsa_tensor,
Expand Down Expand Up @@ -90,6 +91,7 @@
"FHRRTensor",
"BSBCTensor",
"VTBTensor",
"MCRTensor",
"functional",
"embeddings",
"structures",
Expand Down
7 changes: 5 additions & 2 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.mcr import MCRTensor
from torchhd.types import VSAOptions


Expand Down Expand Up @@ -90,6 +91,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
return BSBCTensor
elif vsa == "VTB":
return VTBTensor
elif vsa == "MCR":
return MCRTensor

raise ValueError(f"Provided VSA model is not supported, specified: {vsa}")

Expand Down Expand Up @@ -358,7 +361,7 @@ def level(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC":
if vsa == "BSBC" or vsa == "MCR":
hv.block_size = span_hv.block_size

for i in range(num_vectors):
Expand Down Expand Up @@ -585,7 +588,7 @@ def circular(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC":
if vsa == "BSBC" or vsa == "MCR":
hv.block_size = span_hv.block_size

mutation_history = deque()
Expand Down
Loading

0 comments on commit 54b6465

Please sign in to comment.