@@ -65,13 +65,13 @@ unsafe fn avx_interleave(a: __m256, b: __m256) -> (__m256, __m256) {
6565}
6666
6767#[ inline]
68- #[ target_feature( enable = "avx2" ) ]
69- unsafe fn sse_unpacklo_ps ( a : __m128i ) -> ( __m128 , __m128 ) {
70- let v2 = _mm_unpacklo_epi32 ( a , _mm_setzero_si128 ( ) ) ; // a0 a2 b0 b2
71-
72- let a = _mm_unpacklo_epi32 ( v2 , _mm_setzero_si128 ( ) ) ; // a0 a1 a2 a3
73- let b = _mm_unpackhi_epi32 ( v2 , _mm_setzero_si128 ( ) ) ; // b0 b1 ab b3
74- ( _mm_castsi128_ps ( a ) , _mm_castsi128_ps ( b ) )
68+ #[ target_feature( enable = "avx2" , enable = "fma" ) ]
69+ unsafe fn complex_mul_fma ( a : __m128 , b : __m128 ) -> __m128 {
70+ let temp1 = _mm_shuffle_ps :: < 0xA0 > ( b , b ) ;
71+ let temp2 = _mm_shuffle_ps :: < 0xF5 > ( b , b ) ;
72+ let mul2 = _mm_mul_ps ( a , temp2 ) ;
73+ let mul2 = _mm_shuffle_ps :: < 0xB1 > ( mul2 , mul2 ) ;
74+ _mm_fmaddsub_ps ( a , temp1 , mul2 )
7575}
7676
7777#[ target_feature( enable = "avx2" , enable = "fma" ) ]
@@ -89,15 +89,15 @@ unsafe fn mul_spectrum_in_place_f32_impl(
8989 let other = & other[ ..complex_size] ;
9090
9191 for ( dst, kernel) in value1. chunks_exact_mut ( 16 ) . zip ( other. chunks_exact ( 16 ) ) {
92- let vd0 = _mm256_loadu_ps ( dst. as_ptr ( ) as * const f32 ) ;
93- let vd1 = _mm256_loadu_ps ( dst. get_unchecked ( 4 ..) . as_ptr ( ) as * const f32 ) ;
94- let vd2 = _mm256_loadu_ps ( dst. get_unchecked ( 8 ..) . as_ptr ( ) as * const f32 ) ;
95- let vd3 = _mm256_loadu_ps ( dst. get_unchecked ( 12 ..) . as_ptr ( ) as * const f32 ) ;
92+ let vd0 = _mm256_loadu_ps ( dst. as_ptr ( ) . cast ( ) ) ;
93+ let vd1 = _mm256_loadu_ps ( dst. get_unchecked ( 4 ..) . as_ptr ( ) . cast ( ) ) ;
94+ let vd2 = _mm256_loadu_ps ( dst. get_unchecked ( 8 ..) . as_ptr ( ) . cast ( ) ) ;
95+ let vd3 = _mm256_loadu_ps ( dst. get_unchecked ( 12 ..) . as_ptr ( ) . cast ( ) ) ;
9696
97- let vk0 = _mm256_loadu_ps ( kernel. as_ptr ( ) as * const f32 ) ;
98- let vk1 = _mm256_loadu_ps ( kernel. get_unchecked ( 4 ..) . as_ptr ( ) as * const f32 ) ;
99- let vk2 = _mm256_loadu_ps ( kernel. get_unchecked ( 8 ..) . as_ptr ( ) as * const f32 ) ;
100- let vk3 = _mm256_loadu_ps ( kernel. get_unchecked ( 12 ..) . as_ptr ( ) as * const f32 ) ;
97+ let vk0 = _mm256_loadu_ps ( kernel. as_ptr ( ) . cast ( ) ) ;
98+ let vk1 = _mm256_loadu_ps ( kernel. get_unchecked ( 4 ..) . as_ptr ( ) . cast ( ) ) ;
99+ let vk2 = _mm256_loadu_ps ( kernel. get_unchecked ( 8 ..) . as_ptr ( ) . cast ( ) ) ;
100+ let vk3 = _mm256_loadu_ps ( kernel. get_unchecked ( 12 ..) . as_ptr ( ) . cast ( ) ) ;
101101
102102 let ( ar0, ai0) = avx_deinterleave ( vd0, vd1) ;
103103 let ( ar1, ai1) = avx_deinterleave ( vd2, vd3) ;
@@ -123,18 +123,18 @@ unsafe fn mul_spectrum_in_place_f32_impl(
123123 let ( d0, d1) = avx_interleave ( prod_r0, prod_i0) ;
124124 let ( d2, d3) = avx_interleave ( prod_r1, prod_i1) ;
125125
126- _mm256_storeu_ps ( dst. as_mut_ptr ( ) as * mut f32 , d0) ;
127- _mm256_storeu_ps ( dst. get_unchecked_mut ( 4 ..) . as_mut_ptr ( ) as * mut f32 , d1) ;
128- _mm256_storeu_ps ( dst. get_unchecked_mut ( 8 ..) . as_mut_ptr ( ) as * mut f32 , d2) ;
129- _mm256_storeu_ps ( dst. get_unchecked_mut ( 12 ..) . as_mut_ptr ( ) as * mut f32 , d3) ;
126+ _mm256_storeu_ps ( dst. as_mut_ptr ( ) . cast ( ) , d0) ;
127+ _mm256_storeu_ps ( dst. get_unchecked_mut ( 4 ..) . as_mut_ptr ( ) . cast ( ) , d1) ;
128+ _mm256_storeu_ps ( dst. get_unchecked_mut ( 8 ..) . as_mut_ptr ( ) . cast ( ) , d2) ;
129+ _mm256_storeu_ps ( dst. get_unchecked_mut ( 12 ..) . as_mut_ptr ( ) . cast ( ) , d3) ;
130130 }
131131
132132 let dst_rem = value1. chunks_exact_mut ( 16 ) . into_remainder ( ) ;
133133 let src_rem = other. chunks_exact ( 16 ) . remainder ( ) ;
134134
135135 for ( dst, kernel) in dst_rem. chunks_exact_mut ( 4 ) . zip ( src_rem. chunks_exact ( 4 ) ) {
136- let a0 = _mm256_loadu_ps ( dst. as_ptr ( ) as * const f32 ) ;
137- let b0 = _mm256_loadu_ps ( kernel. as_ptr ( ) as * const f32 ) ;
136+ let a0 = _mm256_loadu_ps ( dst. as_ptr ( ) . cast ( ) ) ;
137+ let b0 = _mm256_loadu_ps ( kernel. as_ptr ( ) . cast ( ) ) ;
138138
139139 let ( ar0, ai0) = avx_deinterleave ( a0, _mm256_setzero_ps ( ) ) ;
140140 let ( br0, bi0) = avx_deinterleave ( b0, _mm256_setzero_ps ( ) ) ;
@@ -149,7 +149,7 @@ unsafe fn mul_spectrum_in_place_f32_impl(
149149
150150 let ( d0, _) = avx_interleave ( prod_r0, prod_i0) ;
151151
152- _mm256_storeu_ps ( dst. as_mut_ptr ( ) as * mut f32 , d0) ;
152+ _mm256_storeu_ps ( dst. as_mut_ptr ( ) . cast ( ) , d0) ;
153153 }
154154
155155 let dst_rem = dst_rem. chunks_exact_mut ( 4 ) . into_remainder ( ) ;
@@ -159,18 +159,7 @@ unsafe fn mul_spectrum_in_place_f32_impl(
159159 let v0 = _mm_loadu_si64 ( dst as * const Complex < f32 > as * const _ ) ;
160160 let v1 = _mm_loadu_si64 ( kernel as * const Complex < f32 > as * const _ ) ;
161161
162- let ( ar0, ai0) = sse_unpacklo_ps ( v0) ;
163- let ( br0, bi0) = sse_unpacklo_ps ( v1) ;
164-
165- let mut prod_r0 = _mm_mul_ps ( ar0, br0) ;
166- let mut prod_i0 = _mm_mul_ps ( ar0, bi0) ;
167- prod_r0 = _mm_fnmadd_ps ( ai0, bi0, prod_r0) ;
168- prod_i0 = _mm_fmadd_ps ( ai0, br0, prod_i0) ;
169-
170- prod_r0 = _mm_mul_ps ( prod_r0, _mm256_castps256_ps128 ( v_norm_factor) ) ;
171- prod_i0 = _mm_mul_ps ( prod_i0, _mm256_castps256_ps128 ( v_norm_factor) ) ;
172-
173- let lo = _mm_unpacklo_ps ( prod_r0, prod_i0) ;
162+ let lo = complex_mul_fma ( _mm_castsi128_ps ( v0) , _mm_castsi128_ps ( v1) ) ;
174163
175164 _mm_storeu_si64 ( dst as * mut Complex < f32 > as * mut _ , _mm_castps_si128 ( lo) ) ;
176165 }
0 commit comments