You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.0 KiB
70 lines
2.0 KiB
#include <iostream>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include <kompute/Kompute.hpp>
|
|
|
|
|
|
const std::string shader_code(R"(#version 450
|
|
// The execution structure
|
|
layout (local_size_x = 1) in;
|
|
|
|
// The buffers are provided via the tensors
|
|
layout(binding = 0) buffer bufA { float a[]; };
|
|
layout(binding = 1) buffer bufB { float b[]; };
|
|
layout(binding = 2) buffer bufOut { float o[]; };
|
|
|
|
void main() {
|
|
uint index = gl_GlobalInvocationID.x;
|
|
o[index] = a[index] * b[index];
|
|
})");
|
|
|
|
|
|
static std::vector<uint32_t> compileSource(const std::string& source)
|
|
{
|
|
std::ofstream fileOut("tmp_kp_shader.comp");
|
|
fileOut << source;
|
|
fileOut.close();
|
|
if (system(std::string("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv").c_str()))
|
|
throw std::runtime_error("Error running glslangValidator command");
|
|
std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary);
|
|
std::vector<char> buffer;
|
|
buffer.insert(buffer.begin(), std::istreambuf_iterator<char>(fileStream), {});
|
|
return {(uint32_t*)buffer.data(), (uint32_t*)(buffer.data() + buffer.size())};
|
|
}
|
|
|
|
|
|
int main()
|
|
{
|
|
kp::Manager mgr;
|
|
|
|
std::shared_ptr<kp::TensorT<float>> tensorInA = mgr.tensor({ 2.0, 4.0, 6.0 });
|
|
std::shared_ptr<kp::TensorT<float>> tensorInB = mgr.tensor({ 0.0, 1.0, 2.0 });
|
|
std::shared_ptr<kp::TensorT<float>> tensorOut = mgr.tensor({ 0.0, 0.0, 0.0 });
|
|
|
|
const std::vector<std::shared_ptr<kp::Tensor>> params = {
|
|
tensorInA, tensorInB, tensorOut
|
|
};
|
|
|
|
const std::vector<uint32_t> shader = compileSource(shader_code);
|
|
std::shared_ptr<kp::Algorithm> algo = mgr.algorithm(params, shader);
|
|
|
|
mgr.sequence()
|
|
->record<kp::OpTensorSyncDevice>(params)
|
|
->record<kp::OpAlgoDispatch>(algo)
|
|
->record<kp::OpTensorSyncLocal>(params)
|
|
->eval();
|
|
|
|
// prints "Output { 0 4 12 }"
|
|
std::cout << "Output: { ";
|
|
for (const float& elem : tensorOut->vector()) {
|
|
std::cout << elem << " ";
|
|
}
|
|
std::cout << "}" << std::endl;
|
|
|
|
if (tensorOut->vector() != std::vector<float>{ 0, 4, 12 }) {
|
|
throw std::runtime_error("Result does not match");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|