-
Notifications
You must be signed in to change notification settings - Fork 1
/
tensor.hpp
79 lines (65 loc) · 2.59 KB
/
tensor.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/**
* @file tensor.hpp
* @author Wuqiong Zhao ([email protected]), et al.
* @brief Tensor (3D Array) for FLAMES
* @version 0.1.0
* @date 2023-07-15
*
* @copyright Copyright (c) 2023 Wuqiong Zhao
*
*/
#ifndef _FLAMES_TENSOR_HPP_
#define _FLAMES_TENSOR_HPP_
#ifndef _FLAMES_CORE_HPP_
# include "core.hpp"
#endif
#ifndef FLAMES_TENSOR_PARTITION_COMPLETE
# ifdef FLAMES_MAT_PARTITION_COMPLETE
# define FLAMES_TENSOR_PARTITION_COMPLETE
# endif
#endif
namespace flames {
template <typename T, size_t n_rows, size_t n_cols, size_t n_slices, MatType type>
class Tensor {
public:
using element_type = T;
using value_type = T;
using View = MatView<T, n_rows, n_cols, type>;
Tensor() {
#ifdef FLAMES_TENSOR_PARTITION_COMPLETE
FLAMES_PRAGMA(ARRAY_PARTITION variable = _data type = complete)
#else
FLAMES_PRAGMA(ARRAY_PARTITION variable = _data type = block factor = FLAMES_MAT_PARTITION_FACTOR)
#endif
}
inline static constexpr size_t matSize() noexcept {
return type == MatType::NORMAL ? n_rows * n_cols
: type == MatType::DIAGONAL ? n_rows
: type == MatType::SCALAR ? 1
: type == MatType::SUPPER ? (n_rows - 1) * n_rows / 2
: type == MatType::SLOWER ? (n_rows - 1) * n_rows / 2
: type == MatType::ASYM ? (n_rows - 1) * n_rows / 2
: (1 + n_rows) * n_rows / 2;
}
inline static constexpr size_t size() noexcept { return n_slices * size(); }
inline View slice(size_t index) const {
assert(index < n_slices && "Index should be within in range for MatView::slice(index).");
return const_cast<T*>(_data + index * matSize());
}
inline View slice(size_t index) {
assert(index < n_slices && "Index should be within in range for MatView::slice(index).");
return const_cast<T*>(_data + index * matSize());
}
inline View operator[](size_t index) const { return slice(index); }
inline View operator[](size_t index) { return slice(index); }
private:
T _data[type == MatType::NORMAL ? n_slices * n_rows * n_cols
: type == MatType::DIAGONAL ? n_slices * n_rows
: type == MatType::SCALAR ? n_slices * 1
: type == MatType::SUPPER ? n_slices * (n_rows - 1) * n_rows / 2
: type == MatType::SLOWER ? n_slices * (n_rows - 1) * n_rows / 2
: type == MatType::ASYM ? n_slices * (n_rows - 1) * n_rows / 2
: n_slices * (1 + n_rows) * n_rows / 2];
};
} // namespace flames
#endif