Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: kernels: sdp verbose enhancement #2273

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/graph/backend/dnnl/kernels/sdp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

#include "graph/backend/dnnl/dnnl_partition_impl.hpp"

#define VDISPATCH_GRAPH_SDP(msg, ...) \
VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__)

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -65,17 +68,25 @@ struct sdp_base_t : public kernel_base_t {
if (enable_ukernel) {
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to sdp_primitive_kernel");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VDISPATCH_GRAPH_SDP("dispatch to sdp_primitive_kernel");
VDISPATCH_GRAPH_SDP("dispatch to sdp_primitive_kernel_t");

Just in case you intention is to show the class name.

}

if (ret != status::success && enable_decomp) {
kernel = std::make_shared<sdp_decomp_kernel_t<quantized, dt>>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to sdp_decomp_kernel");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, sdp_decomp_kernel_t

}

if (ret != status::success) {
kernel = std::make_shared<larger_partition_kernel_t>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
if (ret == status::success)
VDISPATCH_GRAPH_SDP("dispatch to larger_partition_kernel");
}
if (ret != status::success)
VDISPATCH_GRAPH_SDP("fail to dispatch to sdp kernel");
return ret;
}

Expand Down
5 changes: 3 additions & 2 deletions src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));

// Check if it's supported by decomposition kernel
if (!sdp_cfg_.initial_check(subgraph_, inputs))
return status::unimplemented;
VCONDCHECK(graph, create, check, sdp_decomp,
sdp_cfg_.initial_check(subgraph_, inputs), status::unimplemented,
"sdp_decomp_kernel_t: initial check failed");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this as we already do checks in the initial_check function.


subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
return this->memory_planner_.get_memory_info(val);
Expand Down
27 changes: 16 additions & 11 deletions src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"

#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_decomp, (cond), status, msg, \
##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand All @@ -25,8 +29,8 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// The order of input logical tensors in inputs is not certain, we need
// to record the input offset in a certain order of ops.
auto op_status = record_input_offset(sg, inputs);
if (op_status != status::success) return false;
VCHECK_SDP_DECOMP(record_input_offset(sg, inputs) == status::success, false,
"Failed to record input offset");
dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims();
if (src1_user_dims.size() != 4) return false;

Expand All @@ -41,9 +45,9 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,

// Check batch size compatibility.
dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims();
if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) {
return false;
}
VCHECK_SDP_DECOMP(
batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0],
false, "Batch size mismatch");

// Check scale size
if (graph_inport[2] != -1) {
Expand Down Expand Up @@ -451,9 +455,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between mm1 and scale");
// find scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -478,8 +482,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
mm2 = cur_op;
}
}
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;

VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph,
"Failed to find mm1 or mm2");
int src1_id = find_graph_inport(mm1->get_input_value(0));
graph_inport.emplace_back(src1_id);
int wei1_id = find_graph_inport(mm1->get_input_value(1));
Expand Down Expand Up @@ -534,7 +538,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
auto post_op = get_post_op(cur_op);
if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue;
auto ppost_op = get_post_op(post_op);
if (!ppost_op) return status::invalid_graph;
VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph,
"Failed to find post post op");

op_ptr reorder1;
op_ptr reorder2;
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
&& res->find_value_mem_map(
cfg_.v_zero_points_.get(), mem_storage[9]);

if (!ok) return status::runtime_error;
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
status::runtime_error,
"sdp_primitive_kernel get_prim_exec_args failed");

memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};
Expand Down
60 changes: 33 additions & 27 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

#include "common/compiler_workarounds.hpp"

#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_primitive, (cond), status, msg, \
##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
// Locate mm1 and all post ops(scale and mask) here.
// 1. locate mm1
if (mm1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented,
"Multiple mm1 found");
mm1 = cur_op;
// At least one of scale and mask exists
if (post_op->get_kind() == op_kind::dnnl_binary) {
Expand All @@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
}
}
} else {
if (mm2) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
"Multiple mm2 found");
mm2 = cur_op;
}
}

// Locate input/outputs: Q, K, V, dst, scale, mask
mm1_ = mm1;
mm2_ = mm2;
if (!mm1 || !mm2 || !final_op) return status::unimplemented;
VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented,
"Not all ops are found");

q_ = mm1->get_input_value(0);
k_ = mm1->get_input_value(1);
v_ = mm2->get_input_value(1);
Expand Down Expand Up @@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// At least 3 inputs: Q, K, V
if (inputs.size() < 3) return status::invalid_arguments;
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");

// step1(pattern check): Not support sdpa variants with select as mask
// We already have a pattern matcher to ensure that the sdpa patterns
Expand Down Expand Up @@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
mm1 = cur_op;
// Not support select between mm1 and scale(optional)
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between mm1 and scale(optional)");
// scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(

// Not support select after scale(optional) and mask(optional)
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select after scale(optional) and "
"mask(optional)");
} else {
mm2 = cur_op;
}
Expand All @@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
return -1;
};

if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
VCHECK_SDP_PRIMITIVE(
mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found");

// step3(dims check): only support 4-dims now.
int q_id = find_graph_inport(mm1->get_input_value(0));
int k_id = find_graph_inport(mm1->get_input_value(1));
int v_id = find_graph_inport(mm2->get_input_value(1));

bool ok = true;
ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1);
if (!ok) return status::unimplemented;
ok = ok && ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4;
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
status::unimplemented, "Q, K, V are not found");
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4,
status::unimplemented, "Q, K, V should be 4-dims");

// sdp_primitive only supports single scale value.
if (scale) {
const auto &s = scale->get_input_value(1)->get_logical_tensor();
if (ltw(s).nelems() != 1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented,
"Scale should be single value");
}

return ok ? status::success : status::unimplemented;
return status::success;
}

status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
Expand Down Expand Up @@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,

auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());

if (status != status::success) {
if (get_verbose(verbose_t::create_dispatch, component_t::graph)) {
verbose_printf(
"graph,create:dispatch,sdpa,could not create primitive, "
"falling back\n");
}
}

VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status,
"could not create primitive, falling back\n");
return status;
}

Expand Down
Loading