|
|
|
#version 450
|
|
|
|
// clang-format off
|
|
|
|
|
|
|
|
// The number of threads spawned per-workgroup, these are substituted by the
|
|
|
|
// program pre-compilation
|
|
|
|
layout(
|
|
|
|
local_size_x = __lcsize_x__,
|
|
|
|
local_size_y = __lcsize_y__,
|
|
|
|
local_size_z = __lcsize_z__
|
|
|
|
) in;
|
|
|
|
|
|
|
|
// The buffers are provided via the tensors
|
|
|
|
layout(binding = 0) buffer tensorA { float matA[]; };
|
|
|
|
layout(binding = 1) buffer tensorB { float matB[]; };
|
|
|
|
layout(binding = 2) buffer tensorC { float matC[]; };
|
|
|
|
|
|
|
|
// specialization constant
|
|
|
|
layout(constant_id = 0) const float tensor_size_f = 0;
|
|
|
|
|
|
|
|
// each thread calculates just matC[id.y][id.x]
|
|
|
|
void main()
|
|
|
|
{
|
|
|
|
uint tensor_size_u = uint(tensor_size_f);
|
|
|
|
// thread ID in the workgroup and workgroup ID
|
|
|
|
uvec3 tid = gl_LocalInvocationID;
|
|
|
|
uvec3 gid = gl_WorkGroupID;
|
|
|
|
uvec3 id = gl_GlobalInvocationID;
|
|
|
|
|
|
|
|
// Cyx = sum(k, Ayk * Bkx)
|
|
|
|
float acc = 0;
|
|
|
|
uint y = id.y * tensor_size_u;
|
|
|
|
uint x = id.x;
|
|
|
|
for (uint k = 0; k < tensor_size_u; k++) {
|
|
|
|
acc += matA[y + k] * matB[x + k * tensor_size_u];
|
|
|
|
}
|
|
|
|
matC[y + id.x] = acc;
|
|
|
|
}
|