#version 450 // clang-format off // The number of threads spawned per-workgroup, these are substituted by the // program pre-compilation layout( local_size_x = __lcsize_x__, local_size_y = __lcsize_y__, local_size_z = __lcsize_z__ ) in; // The buffers are provided via the tensors layout(binding = 0) buffer tensorA { float matA[]; }; layout(binding = 1) buffer tensorB { float matB[]; }; layout(binding = 2) buffer tensorC { float matC[]; }; // specialization constant layout(constant_id = 0) const float tensor_size_f = 0; // each thread calculates just matC[id.y][id.x] void main() { uint tensor_size_u = uint(tensor_size_f); // thread ID in the workgroup and workgroup ID uvec3 tid = gl_LocalInvocationID; uvec3 gid = gl_WorkGroupID; uvec3 id = gl_GlobalInvocationID; // Cyx = sum(k, Ayk * Bkx) float acc = 0; uint y = id.y * tensor_size_u; uint x = id.x; for (uint k = 0; k < tensor_size_u; k++) { acc += matA[y + k] * matB[x + k * tensor_size_u]; } matC[y + id.x] = acc; }