Skip to content

Commit

Permalink
graph: backend: kernels: sdp verbose enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
rongzha1 committed Dec 19, 2024
1 parent 7e450f8 commit 5a9ac0b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 41 deletions.
11 changes: 11 additions & 0 deletions src/graph/backend/dnnl/kernels/sdp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

#include "graph/backend/dnnl/dnnl_partition_impl.hpp"

#define VDISPATCH_GRAPH_SDP(msg, ...) \
VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__)

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -65,17 +68,25 @@ struct sdp_base_t : public kernel_base_t {
if (enable_ukernel) {
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to sdp_primitive_kernel_t");
}

if (ret != status::success && enable_decomp) {
kernel = std::make_shared<sdp_decomp_kernel_t<quantized, dt>>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to sdp_decomp_kernel_t");
}

if (ret != status::success) {
kernel = std::make_shared<larger_partition_kernel_t>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to larger_partition_kernel_t");
}
if (ret != status::success)
VDISPATCH_GRAPH_SDP("fail to dispatch to sdp kernel");
return ret;
}

Expand Down
33 changes: 20 additions & 13 deletions src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"

#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_decomp, (cond), status, msg, \
##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand All @@ -25,10 +29,11 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// The order of input logical tensors in inputs is not certain, we need
// to record the input offset in a certain order of ops.
auto op_status = record_input_offset(sg, inputs);
if (op_status != status::success) return false;
VCHECK_SDP_DECOMP(record_input_offset(sg, inputs) == status::success, false,
"Failed to record input offset");
dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims();
if (src1_user_dims.size() != 4) return false;
VCHECK_SDP_DECOMP(src1_user_dims.size() == 4, false,
"SDP input dims should be 4, but got %d", src1_user_dims.size());

// Initialize SDP input dimension according to the src of mm1
batch_size = src1_user_dims[0];
Expand All @@ -41,14 +46,15 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,

// Check batch size compatibility.
dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims();
if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) {
return false;
}
VCHECK_SDP_DECOMP(
batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0],
false, "Batch size mismatch");

// Check scale size
if (graph_inport[2] != -1) {
auto scale_sz = ltw(inputs[graph_inport[2]]).nelems();
if (scale_sz != 1) return false;
VCHECK_SDP_DECOMP(scale_sz == 1, false,
"Scale size in sdp should be 1, but got %d", scale_sz);
}

#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
Expand Down Expand Up @@ -451,9 +457,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between mm1 and scale");
// find scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -478,8 +484,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
mm2 = cur_op;
}
}
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;

VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph,
"Failed to find mm1 or mm2");
int src1_id = find_graph_inport(mm1->get_input_value(0));
graph_inport.emplace_back(src1_id);
int wei1_id = find_graph_inport(mm1->get_input_value(1));
Expand Down Expand Up @@ -534,7 +540,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
auto post_op = get_post_op(cur_op);
if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue;
auto ppost_op = get_post_op(post_op);
if (!ppost_op) return status::invalid_graph;
VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph,
"Failed to find post post op");

op_ptr reorder1;
op_ptr reorder2;
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
&& res->find_value_mem_map(
cfg_.v_zero_points_.get(), mem_storage[9]);

if (!ok) return status::runtime_error;
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
status::runtime_error,
"sdp_primitive_kernel get_prim_exec_args failed");

memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};
Expand Down
60 changes: 33 additions & 27 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

#include "common/compiler_workarounds.hpp"

#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_primitive, (cond), status, msg, \
##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
// Locate mm1 and all post ops(scale and mask) here.
// 1. locate mm1
if (mm1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented,
"Multiple mm1 found");
mm1 = cur_op;
// At least one of scale and mask exists
if (post_op->get_kind() == op_kind::dnnl_binary) {
Expand All @@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
}
}
} else {
if (mm2) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
"Multiple mm2 found");
mm2 = cur_op;
}
}

// Locate input/outputs: Q, K, V, dst, scale, mask
mm1_ = mm1;
mm2_ = mm2;
if (!mm1 || !mm2 || !final_op) return status::unimplemented;
VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented,
"Not all ops are found");

q_ = mm1->get_input_value(0);
k_ = mm1->get_input_value(1);
v_ = mm2->get_input_value(1);
Expand Down Expand Up @@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// At least 3 inputs: Q, K, V
if (inputs.size() < 3) return status::invalid_arguments;
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");

// step1(pattern check): Not support sdpa variants with select as mask
// We already have a pattern matcher to ensure that the sdpa patterns
Expand Down Expand Up @@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
mm1 = cur_op;
// Not support select between mm1 and scale(optional)
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between mm1 and scale(optional)");
// scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(

// Not support select after scale(optional) and mask(optional)
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select after scale(optional) and "
"mask(optional)");
} else {
mm2 = cur_op;
}
Expand All @@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
return -1;
};

if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
VCHECK_SDP_PRIMITIVE(
mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found");

// step3(dims check): only support 4-dims now.
int q_id = find_graph_inport(mm1->get_input_value(0));
int k_id = find_graph_inport(mm1->get_input_value(1));
int v_id = find_graph_inport(mm2->get_input_value(1));

bool ok = true;
ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1);
if (!ok) return status::unimplemented;
ok = ok && ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4;
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
status::unimplemented, "Q, K, V are not found");
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4,
status::unimplemented, "Q, K, V should be 4-dims");

// sdp_primitive only supports single scale value.
if (scale) {
const auto &s = scale->get_input_value(1)->get_logical_tensor();
if (ltw(s).nelems() != 1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented,
"Scale should be single value");
}

return ok ? status::success : status::unimplemented;
return status::success;
}

status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
Expand Down Expand Up @@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,

auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());

if (status != status::success) {
if (get_verbose(verbose_t::create_dispatch, component_t::graph)) {
verbose_printf(
"graph,create:dispatch,sdpa,could not create primitive, "
"falling back\n");
}
}

VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status,
"could not create primitive, falling back\n");
return status;
}

Expand Down

0 comments on commit 5a9ac0b

Please sign in to comment.