Java中32位值的8x8矩阵的SIMD换位



我在C++中找到了以下代码,用于快速转换32位值的8x8矩阵:https://stackoverflow.com/a/51887176/1915854

inline void Transpose8x8Shuff(unsigned long *in)
{
__m256 *inI = reinterpret_cast<__m256 *>(in);
__m256 rI[8];
rI[0] = _mm256_unpacklo_ps(inI[0], inI[1]); 
rI[1] = _mm256_unpackhi_ps(inI[0], inI[1]); 
rI[2] = _mm256_unpacklo_ps(inI[2], inI[3]); 
rI[3] = _mm256_unpackhi_ps(inI[2], inI[3]); 
rI[4] = _mm256_unpacklo_ps(inI[4], inI[5]); 
rI[5] = _mm256_unpackhi_ps(inI[4], inI[5]); 
rI[6] = _mm256_unpacklo_ps(inI[6], inI[7]); 
rI[7] = _mm256_unpackhi_ps(inI[6], inI[7]); 
__m256 rrF[8];
__m256 *rF = reinterpret_cast<__m256 *>(rI);
rrF[0] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(1,0,1,0));
rrF[1] = _mm256_shuffle_ps(rF[0], rF[2], _MM_SHUFFLE(3,2,3,2));
rrF[2] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(1,0,1,0)); 
rrF[3] = _mm256_shuffle_ps(rF[1], rF[3], _MM_SHUFFLE(3,2,3,2));
rrF[4] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(1,0,1,0));
rrF[5] = _mm256_shuffle_ps(rF[4], rF[6], _MM_SHUFFLE(3,2,3,2));
rrF[6] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(1,0,1,0));
rrF[7] = _mm256_shuffle_ps(rF[5], rF[7], _MM_SHUFFLE(3,2,3,2));
rF = reinterpret_cast<__m256 *>(in);
rF[0] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x20);
rF[1] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x20);
rF[2] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x20);
rF[3] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x20);
rF[4] = _mm256_permute2f128_ps(rrF[0], rrF[4], 0x31);
rF[5] = _mm256_permute2f128_ps(rrF[1], rrF[5], 0x31);
rF[6] = _mm256_permute2f128_ps(rrF[2], rrF[6], 0x31);
rF[7] = _mm256_permute2f128_ps(rrF[3], rrF[7], 0x31);
}

然而,将其转换为Java向量API(https://download.java.net/java/early_access/panama/docs/api/jdk.incubator.vector/jdk/incubator/vector/IntVector.html(并不简单,因为Java向量API并不直接映射到CPU指令/C++内部函数。

你能分享一下Java中以下内部函数/宏的等价物是什么吗?

  1. _mm256_unpacklo_ps()
  2. _mm256_unpackhi_ps()
  3. _mm256_shuffle_ps()
  4. _MM_SHUFFLE()
  5. _mm256_permute2f128_ps()

我可以使用最新的JDK 19。

更新:根据@Soots的建议,我已经实现了以下内容,它通过了测试,但速度非常慢:

public class SimdOps {
public static final VectorSpecies<Integer> SPECIES_INT = IntVector.SPECIES_256;
public static final VectorSpecies<Long> SPECIES_LONG = LongVector.SPECIES_256;
public static final VectorShuffle<Integer> vsUnpackLo = VectorShuffle.fromValues(SPECIES_INT, 0, -8, 1, -7, 4, -4,
5, -3);
public static final VectorShuffle<Integer> vsUnpackHi = VectorShuffle.fromValues(SPECIES_INT, 2, -6, 3, -5, 6, -2,
7, -1);
public static final VectorShuffle<Integer> vsShuffle1010 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, -8, -7, 4,
5, -4, -3);
public static final VectorShuffle<Integer> vsShuffle3232 = VectorShuffle.fromValues(SPECIES_INT, 2, 3, -6, -5, 6, 7,
-2, -1);
public static final VectorShuffle<Integer> vsPermute0x20 = VectorShuffle.fromValues(SPECIES_INT, 0, 1, 2, 3, -8, -7,
-6, -5);
public static final VectorShuffle<Integer> vsPermute0x31 = VectorShuffle.fromValues(SPECIES_INT, 4, 5, 6, 7, -4, -3,
-2, -1);
// Transpose 8x8 matrix of 32-bit integers, stored in 256-bit SIMD vectors
public static final void transpose8x8(IntVector[] inpM) {
assert inpM.length == Constants.INTS_PER_SIMD;
// https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
// https://stackoverflow.com/questions/73977998/simd-transposition-of-8x8-matrix-of-32-bit-values-in-java
final IntVector rI0 = inpM[0].rearrange(vsUnpackLo, inpM[1]);
final IntVector rI1 = inpM[0].rearrange(vsUnpackHi, inpM[1]);
final IntVector rI2 = inpM[2].rearrange(vsUnpackLo, inpM[3]);
final IntVector rI3 = inpM[2].rearrange(vsUnpackHi, inpM[3]);
final IntVector rI4 = inpM[4].rearrange(vsUnpackLo, inpM[5]);
final IntVector rI5 = inpM[4].rearrange(vsUnpackHi, inpM[5]);
final IntVector rI6 = inpM[6].rearrange(vsUnpackLo, inpM[7]);
final IntVector rI7 = inpM[6].rearrange(vsUnpackHi, inpM[7]);

final IntVector rrF0 = rI0.rearrange(vsShuffle1010, rI2);
final IntVector rrF1 = rI0.rearrange(vsShuffle3232, rI2);
final IntVector rrF2 = rI1.rearrange(vsShuffle1010, rI3);
final IntVector rrF3 = rI1.rearrange(vsShuffle3232, rI3);
final IntVector rrF4 = rI4.rearrange(vsShuffle1010, rI6);
final IntVector rrF5 = rI4.rearrange(vsShuffle3232, rI6);
final IntVector rrF6 = rI5.rearrange(vsShuffle1010, rI7);
final IntVector rrF7 = rI5.rearrange(vsShuffle3232, rI7);
inpM[0] = rrF0.rearrange(vsPermute0x20, rrF4);
inpM[1] = rrF1.rearrange(vsPermute0x20, rrF5);
inpM[2] = rrF2.rearrange(vsPermute0x20, rrF6);
inpM[3] = rrF3.rearrange(vsPermute0x20, rrF7);
inpM[4] = rrF0.rearrange(vsPermute0x31, rrF4);
inpM[5] = rrF1.rearrange(vsPermute0x31, rrF5);
inpM[6] = rrF2.rearrange(vsPermute0x31, rrF6);
inpM[7] = rrF3.rearrange(vsPermute0x31, rrF7);
}
};

而瓶颈是jdk.incubator.vector.Int256Vector.rearrange(VectorShuffle, Vector)。它至少比标量代码慢10倍。有什么想法吗?

免责声明:我从未用Java写过类似的东西。

根据文件,重新安排似乎是唯一的办法
唯一的问题是如何将C内部函数转换为VectorShuffle<Float>的整数。

以下是要查找的C++代码:

void printShuffle( __m256 v, const char* name )
{
__m256i iv = _mm256_cvtps_epi32( v );
std::array<int, 8> a;
_mm256_storeu_si256( ( __m256i* )a.data(), iv );
printf( "%s: %i, %i, %i, %i, %i, %i, %i, %in", name,
a[ 0 ], a[ 1 ], a[ 2 ], a[ 3 ], a[ 4 ], a[ 5 ], a[ 6 ], a[ 7 ] );
}
#define TEST( expr ) printShuffle( expr, #expr )
void printJavaRearranges()
{
const __m256 a = _mm256_setr_ps( 0, 1, 2, 3, 4, 5, 6, 7 );
const __m256 b = _mm256_sub_ps( a, _mm256_set1_ps( 8 ) );
TEST( _mm256_unpacklo_ps( a, b ) );
TEST( _mm256_unpackhi_ps( a, b ) );
TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ) );
TEST( _mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ) );
TEST( _mm256_permute2f128_ps( a, b, 0x20 ) );
TEST( _mm256_permute2f128_ps( a, b, 0x31 ) );
}

输出:

_mm256_unpacklo_ps( a, b ): 0, -8, 1, -7, 4, -4, 5, -3
_mm256_unpackhi_ps( a, b ): 2, -6, 3, -5, 6, -2, 7, -1
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(1,0,1,0) ): 0, 1, -8, -7, 4, 5, -4, -3
_mm256_shuffle_ps( a, b, _MM_SHUFFLE(3,2,3,2) ): 2, 3, -6, -5, 6, 7, -2, -1
_mm256_permute2f128_ps( a, b, 0x20 ): 0, 1, 2, 3, -8, -7, -6, -5
_mm256_permute2f128_ps( a, b, 0x31 ): 4, 5, 6, 7, -4, -3, -2, -1

_mm256_permute2f128_ps指令可以选择性地将通道清零,Java的矢量API可能无法做到这一点。幸运的是,源代码中的直接值不会将任何部分归零。

如果幸运的话,运行时可能会将这些值映射到相应的AVX指令中(当JIT事先知道这些值并且从未更改时(。

最新更新