Skip to content

Commit

Permalink
benchdnn: graph: use plain copy and paste if reorder fails
Browse files Browse the repository at this point in the history
  • Loading branch information
wzt1997 committed Dec 16, 2024
1 parent c852fdc commit 807e034
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
29 changes: 26 additions & 3 deletions tests/benchdnn/dnnl_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,34 @@ int execute_reorder(const dnn_mem_t &src, dnn_mem_t &dst,
}
#endif

dnnl_status_t status;
if (!r_pd_) {
DNN_SAFE(dnnl_reorder_primitive_desc_create(&r_pd_, src.md_,
src.engine(), dst.md_, dst.engine(), attr),
WARN);
status = dnnl_reorder_primitive_desc_create(
&r_pd_, src.md_, src.engine(), dst.md_, dst.engine(), attr);
if (status != dnnl_success) {
if (dnnl_memory_desc_equal(src.md_, dst.md_)) {
// If fail to create reorder pd, use plain data copy.
const int64_t chunk_size = 64;
const int64_t n_chunks = div_up(src.nelems(), chunk_size);
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
int64_t idx_start = idx_chunk * chunk_size;
int64_t idx_end
= MIN2(idx_start + chunk_size, src.nelems());
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
float e = src.get_elem(idx);
dst.set_elem(idx, e);
}
});
return OK;
} else {
BENCHDNN_PRINT(0,
"Error: Function '%s' at (%s:%d) returned '%s'\n",
__FUNCTION__, __FILE__, __LINE__, status2str(status));
return FAIL;
}
}
}

auto r_pd = make_benchdnn_dnnl_wrapper(r_pd_);
const auto &scratchpad_md = query_md(r_pd, DNNL_ARG_SCRATCHPAD);
const auto &scratchpad_engine
Expand Down
17 changes: 1 addition & 16 deletions tests/benchdnn/graph/input_displacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,7 @@ int partition_data_displacer_t::displace_input_data(
}
dnnl_memory_desc_destroy(mem_replace.md_);
dnnl_memory_desc_clone(&mem_replace.md_, md);

// As int4->int4 reorder is not supported, use naive data copy and paste.
if (md->data_type == dnnl_s4 || md->data_type == dnnl_u4) {
const int64_t chunk_size = 64;
const int64_t n_chunks = div_up(mem.nelems(), chunk_size);
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
int64_t idx_start = idx_chunk * chunk_size;
int64_t idx_end = MIN2(idx_start + chunk_size, mem.nelems());

for (int64_t idx = idx_start; idx < idx_end; ++idx) {
float e = mem_replace.get_elem(idx);
mem.set_elem(idx, e);
}
});
} else
SAFE(mem.reorder(mem_replace), WARN);
SAFE(mem.reorder(mem_replace), WARN);

if (is_reshaped_dims) dnnl_memory_desc_destroy(md);
return OK;
Expand Down

0 comments on commit 807e034

Please sign in to comment.