#include "example_utils.hpp"
#include <CL/sycl.hpp>
#include <cassert>
#include <iostream>
#include <numeric>
class kernel_tag;
memory::dims tz_dims = {2, 3, 4, 5};
const size_t N = std::accumulate(tz_dims.begin(), tz_dims.end(), (size_t)1,
std::multiplies<size_t>());
auto sycl_buf = mem.get_sycl_buffer<float>();
queue q = strm.get_sycl_queue();
q.submit([&](handler &cgh) {
auto a = sycl_buf.get_access<access::mode::write>(cgh);
cgh.parallel_for<kernel_tag>(range<1>(N), [=](id<1> i) {
int idx = (int)i[0];
a[idx] = (idx % 2) ? -idx : idx;
});
});
relu.execute(strm, {{DNNL_ARG_SRC, mem}, {DNNL_ARG_DST, mem}});
strm.wait();
auto host_acc = sycl_buf.get_access<access::mode::read>();
for (size_t i = 0; i < N; i++) {
float exp_value = (i % 2) ? 0.0f : i;
if (host_acc[i] != (float)exp_value)
throw std::string(
"Unexpected output, find a negative value after the ReLU "
"execution");
}
}
int main(int argc, char **argv) {
try {
sycl_interop_tutorial(engine_kind);
std::cerr <<
"DNNL error: " << e.
what() << std::endl
<<
"Error status: " << dnnl_status2str(e.
status) << std::endl;
return 1;
} catch (std::string &e) {
std::cerr << "Error in the example: " << e << std::endl;
return 2;
}
std::cout << "Example passes" << std::endl;
return 0;
}