Skip to content

Commit

Permalink
benchdnn: graph: use plain copy and paste if reorder fails
Browse files Browse the repository at this point in the history
  • Loading branch information
wzt1997 committed Dec 17, 2024
1 parent f4fcbb6 commit 141de68
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
54 changes: 37 additions & 17 deletions tests/benchdnn/dnnl_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,46 @@ int execute_reorder(const dnn_mem_t &src, dnn_mem_t &dst,
}
#endif

if (!r_pd_) {
DNN_SAFE(dnnl_reorder_primitive_desc_create(&r_pd_, src.md_,
src.engine(), dst.md_, dst.engine(), attr),
WARN);
while (!r_pd_) {
// Fallback to GPU reorder.
auto status = dnnl_reorder_primitive_desc_create(
&r_pd_, src.md_, src.engine(), dst.md_, dst.engine(), attr);
if (status != dnnl_success) break;

auto r_pd = make_benchdnn_dnnl_wrapper(r_pd_);
const auto &scratchpad_md = query_md(r_pd, DNNL_ARG_SCRATCHPAD);
const auto &scratchpad_engine
= dst.engine_kind() == dnnl_gpu ? dst.engine() : src.engine();
dnn_mem_t scratchpad(scratchpad_md, scratchpad_engine);

DNN_SAFE(dnnl_primitive_create(&prim_, r_pd), CRIT);
auto prim = make_benchdnn_dnnl_wrapper(prim_);

args_t args;
args.set(DNNL_ARG_FROM, *r_src);
args.set(DNNL_ARG_TO, *r_dst);
args.set(DNNL_ARG_SCRATCHPAD, scratchpad);

return execute_and_wait(prim, args);
}
auto r_pd = make_benchdnn_dnnl_wrapper(r_pd_);
const auto &scratchpad_md = query_md(r_pd, DNNL_ARG_SCRATCHPAD);
const auto &scratchpad_engine
= dst.engine_kind() == dnnl_gpu ? dst.engine() : src.engine();
dnn_mem_t scratchpad(scratchpad_md, scratchpad_engine);

DNN_SAFE(dnnl_primitive_create(&prim_, r_pd), CRIT);
auto prim = make_benchdnn_dnnl_wrapper(prim_);

args_t args;
args.set(DNNL_ARG_FROM, *r_src);
args.set(DNNL_ARG_TO, *r_dst);
args.set(DNNL_ARG_SCRATCHPAD, scratchpad);
if (dnnl_memory_desc_equal(src.md_, dst.md_)) {
// If fail to create reorder pd, use plain data copy for identical mds.
BENCHDNN_PRINT(2, "%s\n", "[REORDER] Fallback to plain copy.");
const int64_t chunk_size = 64;
const int64_t n_chunks = div_up(src.nelems(), chunk_size);
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
int64_t idx_start = idx_chunk * chunk_size;
int64_t idx_end = MIN2(idx_start + chunk_size, src.nelems());
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
float e = src.get_elem(idx);
dst.set_elem(idx, e);
}
});
return OK;
}

return execute_and_wait(prim, args);
return FAIL;
}

// `swap_dt` changes `this` data type which may be needed for
Expand Down
17 changes: 1 addition & 16 deletions tests/benchdnn/graph/input_displacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,7 @@ int partition_data_displacer_t::displace_input_data(
}
dnnl_memory_desc_destroy(mem_replace.md_);
dnnl_memory_desc_clone(&mem_replace.md_, md);

// As int4->int4 reorder is not supported, use naive data copy and paste.
if (md->data_type == dnnl_s4 || md->data_type == dnnl_u4) {
const int64_t chunk_size = 64;
const int64_t n_chunks = div_up(mem.nelems(), chunk_size);
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
int64_t idx_start = idx_chunk * chunk_size;
int64_t idx_end = MIN2(idx_start + chunk_size, mem.nelems());

for (int64_t idx = idx_start; idx < idx_end; ++idx) {
float e = mem_replace.get_elem(idx);
mem.set_elem(idx, e);
}
});
} else
SAFE(mem.reorder(mem_replace), WARN);
SAFE(mem.reorder(mem_replace), WARN);

if (is_reshaped_dims) dnnl_memory_desc_destroy(md);
return OK;
Expand Down

0 comments on commit 141de68

Please sign in to comment.