Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nrTransformPrecode() and nrTransformDeprecode() #75

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py3gpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from .nrDLSCHInfo import nrDLSCHInfo
from .nrPDSCHIndices import nrPDSCHIndices
from .nrPDSCHMCSTables import nrPDSCHMCSTables
from .nrTransformPrecode import nrTransformPrecode
from .nrTransformDeprecode import nrTransformDeprecode

from .configs.nrCarrierConfig import nrCarrierConfig
from .configs.nrNumerologyConfig import nrNumerologyConfig
Expand Down
6 changes: 6 additions & 0 deletions py3gpp/nrTransformDeprecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

def nrTransformDeprecode(modSym, mrb):
mrb = int(mrb)
assert modSym.shape[0] % (mrb * 12) == 0, "input number of rows must be an integer multiple of mrb * 12"
return (np.fft.ifft(modSym.reshape(int(modSym.shape[0] / (mrb * 12)), mrb * 12)) * np.sqrt(mrb * 12)).ravel()
6 changes: 6 additions & 0 deletions py3gpp/nrTransformPrecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

def nrTransformPrecode(modSym, mrb):
mrb = int(mrb)
assert modSym.shape[0] % (mrb * 12) == 0, "input number of rows must be an integer multiple of mrb * 12"
return (np.fft.fft(modSym.reshape(int(modSym.shape[0] / (mrb * 12)), mrb * 12)) * 1/np.sqrt(mrb * 12)).ravel()
6 changes: 6 additions & 0 deletions tests/test_data/transformPrecode.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions tests/test_nrTransformPrecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import numpy as np
import pytest

from py3gpp.nrTransformPrecode import nrTransformPrecode
from py3gpp.nrSymbolModulate import nrSymbolModulate

sys.path.append("test_data")

from test_data.transformPrecode import cw, desired_result_2, desired_result_40

def test_run_nr_transform_precode_2():
modSym = nrSymbolModulate(cw, 'QPSK')
result_2 = nrTransformPrecode(modSym, 2)
assert np.array_equal(np.round(result_2, 8), np.round(desired_result_2, 8))

def test_run_nr_transform_precode_40():
modSym = nrSymbolModulate(cw, 'QPSK')
result_40 = nrTransformPrecode(modSym, 40)
assert np.array_equal(np.round(result_40, 8), np.round(desired_result_40, 8))

if __name__ == '__main__':
test_run_nr_transform_precode_2()
test_run_nr_transform_precode_40()
29 changes: 29 additions & 0 deletions tests/test_nrTransformPrecode_nrTransformDeprecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys
import numpy as np
import pytest

from py3gpp.nrTransformPrecode import nrTransformPrecode
from py3gpp.nrTransformDeprecode import nrTransformDeprecode
from py3gpp.nrSymbolModulate import nrSymbolModulate

sys.path.append("test_data")

from test_data.transformPrecode import cw, desired_result_2, desired_result_40

def test_run_nrTransformPrecode_nrTransformDeprecode_2():
modSym = nrSymbolModulate(cw, 'QPSK')
result = nrTransformPrecode(modSym, 2)
assert np.array_equal(np.round(result, 8), np.round(desired_result_2, 8))
x = nrTransformDeprecode(result, 2)
assert np.array_equal(np.round(x, 8), np.round(modSym, 8))

def test_run_nrTransformPrecode_nrTransformDeprecode_40():
modSym = nrSymbolModulate(cw, 'QPSK')
result = nrTransformPrecode(modSym, 40)
assert np.array_equal(np.round(result, 8), np.round(desired_result_40, 8))
x = nrTransformDeprecode(result, 40)
assert np.array_equal(np.round(x, 8), np.round(modSym, 8))

if __name__ == '__main__':
test_run_nrTransformPrecode_nrTransformDeprecode_2()
test_run_nrTransformPrecode_nrTransformDeprecode_40()
Loading