-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-packa.cpp
126 lines (110 loc) · 3.42 KB
/
test-packa.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*
* Copyright (C) 2022 Xiao Song.
* All Rights Reserved.
* Content of this file is not for commertial use.
*/
#include <cstdio>
#include <cstdlib>
#include <immintrin.h>
#define n_r 8
#define k_b 24 // NOTE: this is different number compared with actual k_b
#define m_r 31 // NOTE: this is different number compared with actual m_r
#define m_b 31 // NOTE: this is different number compared with actual m_b
#define TO_STRING_HELPER(X) #X
#define TO_STRING(X) TO_STRING_HELPER(X)
// Define loop unrolling depending on the compiler
#if defined(__ICC) || defined(__ICL)
#define UNROLL_LOOP(n) _Pragma(TO_STRING(unroll (n)))
#elif defined(__clang__)
#define UNROLL_LOOP(n) _Pragma(TO_STRING(unroll (n)))
#elif defined(__GNUC__) && !defined(__clang__)
#define UNROLL_LOOP(n) _Pragma(TO_STRING(GCC unroll (16)))
#elif defined(_MSC_BUILD)
#pragma message ("Microsoft Visual C++ (MSVC) detected: Loop unrolling not supported!")
#define UNROLL_LOOP(n)
#else
#warning "Unknown compiler: Loop unrolling not supported!"
#define UNROLL_LOOP(n)
#endif
void pack_a( double* src_a, double* pak_a, int lda )
{
for ( int m_r_i = 0; m_r_i < (m_b / m_r); ++m_r_i )
{
double* src_a_row_m_r_i = src_a + m_r_i * m_r * lda;
double* pak_a_row_m_r_i = pak_a + m_r_i * m_r * k_b;
UNROLL_LOOP( 4 )
for ( int row_i = 0; row_i < m_r; ++row_i )
{
double* src_a_row_i = src_a_row_m_r_i + row_i * lda;
double* pak_a_row_i = pak_a_row_m_r_i + row_i;
UNROLL_LOOP( 8 * 4 )
for ( int col_i = 0; col_i < k_b; ++col_i )
{
*(pak_a_row_i + col_i * m_r) = *(src_a_row_i + col_i);
}
}
}
}
int main( int argc, char** argv )
{
if ( argc != 4 )
{
printf("Invalid argv\n");
return -1;
}
int m = m_b * atoi(argv[1]);
int k = k_b * atoi(argv[2]);
int n = n_r * atoi(argv[3]);
int lda = k;
int ldb = n;
int ldc = n;
double* src_a = (double*)_mm_malloc( m * k * sizeof( double ), 64 );
double* pak_a = (double*)_mm_malloc( m_b * k_b * sizeof( double ), 64 );
for ( int i = 0; i < m * k; ++i )
{
*( src_a + i ) = i;
}
// Print source matrix
printf("\nSRC A\n");
for ( int row_i = 0; row_i < m; ++row_i )
{
if ( row_i % m_r == 0 )
{
printf("\n");
}
for ( int col_i = 0; col_i < k; ++col_i )
{
if ( col_i % k_b == 0 )
{
printf(" ");
}
printf("%4.0f ", *(src_a + row_i * lda + col_i ));
}
printf("\n");
}
for ( int k_b_i = 0; k_b_i < k / k_b; k_b_i++)
{
for ( int m_b_i = 0; m_b_i < m / m_b; m_b_i++ )
{
// Pack \tilde a
pack_a( src_a + m_b_i * m_b * lda + k_b_i * k_b, pak_a, lda );
// Print hat A
// hat A is stored in column major order
printf("\nHAT A\n");
for ( int col_i = 0; col_i < (k_b * m_b / m_r); ++col_i )
{
if ( col_i % k_b == 0 )
{
printf("\n");
}
for ( int row_i = 0; row_i < m_r; ++row_i )
{
printf("%4.0f ", *(pak_a + col_i * m_r + row_i ));
}
printf("\n");
}
}
}
_mm_free( src_a );
_mm_free( pak_a );
}