diff --git a/src/nnet.c b/src/nnet.c index 9fcb808..e2735ae 100644 --- a/src/nnet.c +++ b/src/nnet.c @@ -41,8 +41,11 @@ #define SOFTMAX_HACK -#ifdef __AVX2__ +#ifdef __AVX__ #include + + +#ifdef __AVX2__ static __m256 exp8_approx(__m256 X) { const __m256 K0 = _mm256_set1_ps(0.99992522f); @@ -65,7 +68,44 @@ static __m256 exp8_approx(__m256 X) Y = _mm256_castsi256_ps(_mm256_and_si256(mask, _mm256_add_epi32(I, _mm256_castps_si256(Y)))); return Y; } - +#else +#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c) +#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c) +static __m128 exp4_approx(__m128 X) +{ + const __m128 K0 = _mm_set1_ps(0.99992522f); + const __m128 K1 = _mm_set1_ps(0.69583354f); + const __m128 K2 = _mm_set1_ps(0.22606716f); + const __m128 K3 = _mm_set1_ps(0.078024523f); + const __m128 log2_E = _mm_set1_ps(1.44269504); + const __m128 max_in = _mm_set1_ps(50.f); + const __m128 min_in = _mm_set1_ps(-50.f); + const __m128i mask = _mm_set1_epi32(0x7fffffff); + __m128 XF, Y; + __m128i I; + X = _mm_mul_ps(X, log2_E); + X = _mm_max_ps(min_in, _mm_min_ps(max_in, X)); + XF = _mm_floor_ps(X); + I = _mm_cvtps_epi32(XF); + X = _mm_sub_ps(X, XF); + Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0); + I = _mm_slli_epi32(I, 23); + Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y)))); + return Y; +} +static __m256 exp8_approx(__m256 X) +{ + __m256 Y; + __m128 Xhi, Xlo, Yhi, Ylo; + Xhi = _mm256_extractf128_ps(X, 1); + Xlo = _mm256_extractf128_ps(X, 0); + Yhi = exp4_approx(Xhi); + Ylo = exp4_approx(Xlo); + Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1); + Y = _mm256_insertf128_ps(Y, Ylo, 0); + return Y; +} +#endif static float celt_exp(float x) {