From f9937a24b0d3664faf7d789a3f8972ba5d631c7a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 25 May 2021 22:01:53 +0800 Subject: [PATCH] [gmap] nhwc bwd (#96) * [gmap] bwd nhwc * fix bug for group conv --- driver/perf/gmap.cpp | 480 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 422 insertions(+), 58 deletions(-) diff --git a/driver/perf/gmap.cpp b/driver/perf/gmap.cpp index a472ab81..94c4b943 100644 --- a/driver/perf/gmap.cpp +++ b/driver/perf/gmap.cpp @@ -74,7 +74,9 @@ void serialize_block_req(const block_req_t * block_req, FILE* fp, std::vectorreq.size(); i++){ const auto & thread_req = block_req->req[i]; - assert(thread_req.tid == i); + //assert(thread_req.tid == i); + if(thread_req.tid != i) + printf("tid:%zu, i:%zu, %s\n",thread_req.tid, i, thread_req.tid == i?"yyy":"nnn"); ss<<"t"< &inp_block_req, std::vector &wei_block_req, std::vector &out_block_req, + std::vector &record_inp, std::vector &record_wei, std::vector &record_out, + linear_tensor_t &tensor_inp, linear_tensor_t &tensor_wei, linear_tensor_t &tensor_out, + FILE *fp_inp, FILE *fp_wei, FILE *fp_out) +{ + // serialize block request + for(auto itr_ibr = inp_block_req.begin(); itr_ibr != inp_block_req.end(); itr_ibr++) + serialize_block_req(&(*itr_ibr), fp_inp, &record_inp); + + for(auto itr_ibr = wei_block_req.begin(); itr_ibr != wei_block_req.end(); itr_ibr++) + serialize_block_req(&(*itr_ibr), fp_wei, &record_wei); + + for(auto itr_ibr = out_block_req.begin(); itr_ibr != out_block_req.end(); itr_ibr++) + serialize_block_req(&(*itr_ibr), fp_out, &record_out); + + // valid all record + std::vector valid_hi, valid_wi; + std::tie(valid_hi, valid_wi) = gmap_get_input_access_map(conv_args); + for(auto it = record_inp.begin(); it != record_inp.end(); it++){ + index_t idx = std::distance(record_inp.begin(), it); + std::vector inp_position = tensor_inp.get(idx); + index_t ihi = tensor_layout == "nhwc" ? inp_position[1] : inp_position[2]; + index_t iwi = tensor_layout == "nhwc" ? inp_position[2] : inp_position[3]; + if(valid_hi[ihi] && valid_wi[iwi]){ + if(!(*it)){ + printf("WARNING! input not touched pixel at %zu\n", idx); + } + } + else{ + if(*it){ + printf("WARNING! input touched unused pixel at %zu\n", idx); + } + } + } + + for(auto it = record_wei.begin(); it != record_wei.end(); it++){ + index_t idx = std::distance(record_wei.begin(), it); + if(!(*it)){ + printf("WARNING! weight not touched pixel at %zu\n", idx); + } + } + + for(auto it = record_out.begin(); it != record_out.end(); it++){ + index_t idx = std::distance(record_out.begin(), it); + if(!(*it)){ + printf("WARNING! output not touched pixel at %zu\n", idx); + } + } +} + +void gmap_dump_bwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tunable, int gks, FILE *fp_inp, FILE *fp_wei, FILE *fp_out) +{ + index_t hi = conv_args->get_int("in_h"); + index_t wi = conv_args->get_int("in_w"); + index_t n = conv_args->get_int("batchsize"); + index_t k = conv_args->get_int("out_channels"); + index_t c = conv_args->get_int("in_channels"); + + index_t stride_h = conv_args->get_int("conv_stride_h"); + index_t stride_w = conv_args->get_int("conv_stride_w"); + index_t dilation_h = conv_args->get_int("dilation_h"); + index_t dilation_w = conv_args->get_int("dilation_w"); + index_t pad_h = conv_args->get_int("pad_h"); + index_t pad_w = conv_args->get_int("pad_w"); + index_t y = conv_args->get_int("fil_h"); + index_t x = conv_args->get_int("fil_w"); + index_t ho = gmap_conv_out_size(hi, pad_h, dilation_h, y, stride_h); + index_t wo = gmap_conv_out_size(wi, pad_w, dilation_w, x, stride_w); + index_t group = conv_args->get_int("group_count"); + + std::string precision = tunable->precision; + index_t data_byte = utility_string_to_data_byte(tunable->precision); + + index_t gcd_stride_dilation_h = utility_gcd(stride_h, dilation_h); + index_t gcd_stride_dilation_w = utility_gcd(stride_w, dilation_w); + + index_t y_tilda = stride_h / gcd_stride_dilation_h; + index_t x_tilda = stride_w / gcd_stride_dilation_w; + + index_t y_dot = utility_integer_divide_ceil(y, y_tilda); + index_t x_dot = utility_integer_divide_ceil(x, x_tilda); + + index_t h_tilda = ho + utility_integer_divide_ceil(dilation_h * (y - 1), stride_h); + index_t w_tilda = wo + utility_integer_divide_ceil(dilation_w * (x - 1), stride_w); + + index_t h_tilda_left = utility_integer_divide_floor( + (index_t)utility_max((int64_t)0, static_cast(pad_h - dilation_h * (y_tilda - 1))), stride_h); + index_t w_tilda_left = utility_integer_divide_floor( + (index_t)utility_max((int64_t)0, static_cast(pad_w - dilation_w * (x_tilda - 1))), stride_w); + + index_t h_tilda_right = utility_min( + h_tilda, utility_integer_divide_ceil(pad_h + hi - 1, stride_h) + 1); + index_t w_tilda_right = utility_min( + w_tilda, utility_integer_divide_ceil(pad_w + wi - 1, stride_w) + 1); + + index_t h_tilda_slice = h_tilda_right - h_tilda_left; + index_t w_tilda_slice = w_tilda_right - w_tilda_left; + index_t num_of_gemm = y_tilda * x_tilda; + + index_t num_global_splits = tunable->gemm_k_global_split ? (1 << gks) : 1; + index_t gemm_m_per_block = tunable->gemm_m_per_block; + index_t gemm_n_per_block = tunable->gemm_n_per_block; + index_t gemm_k_per_block = tunable->gemm_k_per_block; + index_t gemm_m = ((n * h_tilda_slice * w_tilda_slice + gemm_m_per_block - 1) / gemm_m_per_block) * gemm_m_per_block; + index_t gemm_n = ((c / group) + gemm_n_per_block - 1) / gemm_n_per_block * gemm_n_per_block; + //index_t gemm_k = ((k / group) * y * x) / num_global_splits; + + index_t ta_e = tunable->tensor_a_thread_lengths[0]; + index_t ta_k = tunable->tensor_a_thread_lengths[1]; + index_t ta_nb0 = tunable->tensor_a_thread_lengths[2]; + index_t ta_nb1 = tunable->tensor_a_thread_lengths[3]; + + index_t tb_e = tunable->tensor_b_thread_lengths[0]; + index_t tb_k = tunable->tensor_b_thread_lengths[1]; + index_t tb_c0 = tunable->tensor_b_thread_lengths[2]; + index_t tb_c1 = tunable->tensor_b_thread_lengths[3]; + + index_t ca_e = tunable->tensor_a_cluster_lengths[0]; + index_t ca_k = tunable->tensor_a_cluster_lengths[1]; + index_t ca_nb0 = tunable->tensor_a_cluster_lengths[2]; + index_t ca_nb1 = tunable->tensor_a_cluster_lengths[3]; + + index_t cb_e = tunable->tensor_b_cluster_lengths[0]; + index_t cb_k = tunable->tensor_b_cluster_lengths[1]; + index_t cb_c0 = tunable->tensor_b_cluster_lengths[2]; + index_t cb_c1 = tunable->tensor_b_cluster_lengths[3]; + + index_t block_size = ca_e * ca_k * ca_nb0 * ca_nb1; + assert(block_size == (cb_e * cb_k * cb_c0 * cb_c1)); + assert((gemm_m % gemm_m_per_block == 0) && (gemm_n % gemm_n_per_block == 0)); + index_t grid_size = num_global_splits * num_of_gemm * group * (gemm_m / gemm_m_per_block) * (gemm_n / gemm_n_per_block); + linear_tensor_t block_mapping({num_global_splits, num_of_gemm, group, (gemm_m / gemm_m_per_block), (gemm_n / gemm_n_per_block)}); + linear_tensor_t gemm_m_transform({n, h_tilda_slice, w_tilda_slice}); + //linear_tensor_t gemm_k_transform({y, x, k / group}); + + linear_tensor_t tensor_inp({n, hi, wi, group, c/group}); + linear_tensor_t tensor_wei({group, k/group, y, x, c/group}); + linear_tensor_t tensor_out({n, ho, wo, group, k/group}); + std::vector record_inp(n*hi*wi*group*(c/group), false); + std::vector record_wei(group*(k/group)*y*x*(c/group), false); + std::vector record_out(n*ho*wo*group*(k/group), false); + + + index_t ta_nb_per_thread = ta_nb0 != 1 ? ta_nb0 : ta_nb1; + index_t ta_vector_k = utility_gcd(ta_k, 4 * (4 / data_byte)); + index_t ta_nk_per_thread = ta_k / ta_vector_k; + index_t ta_nb_thread_stride = tunable->tensor_a_pass_through ? ca_nb0 * ca_nb1 : ( + ta_nb0 != 1 ? ca_nb1 * ta_nb1 :1); + + index_t tb_vector_c = utility_gcd(tb_c1, 4 * (4 / data_byte)); + index_t tb_nc_per_thread = tb_c0 != 1 ? tb_c0 : tb_c1 / tb_vector_c; + index_t tb_nk_per_thread = tb_k; + index_t tb_nk_thread_stride = 1; + index_t tb_nc_thread_stride = tb_c0 != 1 ? tb_c1 * cb_c1 : tb_vector_c; + + // check get_vector_write_out() + index_t tc_vector_c = 1; // 1 when fp32 + if(tunable->precision == "fp16"){ + if(tunable->gemm_k_global_split) + tc_vector_c = 2; + else{ + if(tb_c1 == 1) + tc_vector_c = 1; + else + tc_vector_c = utility_gcd(gemm_n_per_block, static_cast(tunable->vector_store == 0 ? 8 : tunable->vector_store)); + } + } + else if(tunable->precision == "int8") + { + if(tb_c1 == 1) + tc_vector_c = 1; + else + tc_vector_c = utility_gcd(gemm_n_per_block, static_cast(tunable->vector_store == 0 ? 16 : tunable->vector_store)); + } + + assert(gemm_n_per_block % tc_vector_c == 0); + index_t cc_c = gemm_n_per_block / tc_vector_c; + assert(block_size % cc_c == 0); + index_t cc_nb = block_size / cc_c; + assert(gemm_m_per_block % cc_nb == 0); + index_t tc_nb_per_thread = gemm_m_per_block / cc_nb; + index_t tc_nb_thread_stride = cc_nb; + + std::vector ta_block_req_idx(grid_size, 0); + std::vector tb_block_req_idx(grid_size, 0); + std::vector tc_block_req_idx(grid_size, 0); + + std::vector out_block_req; + std::vector wei_block_req; + std::vector inp_block_req; + + auto cur_block = [&](index_t bid, index_t cur_gks, index_t cur_gemm_id, index_t cur_group, index_t cur_gemm_m, index_t cur_gemm_n, index_t cur_gemm_k){ + index_t i_y_tilda = cur_gemm_id / x_tilda; + index_t i_x_tilda = cur_gemm_id % x_tilda; + index_t y_dot_slice = utility_integer_divide_ceil(y - i_y_tilda, y_tilda); + index_t x_dot_slice = utility_integer_divide_ceil(x - i_x_tilda, x_tilda); + + index_t gemm_k = (k / group) * y_dot_slice * x_dot_slice / num_global_splits; + + linear_tensor_t gemm_k_transform({y_dot_slice, x_dot_slice, k / group}); + + index_t dtile_iy = i_y_tilda; + index_t dtile_ix = i_x_tilda; + index_t dtile_dy = dilation_h / gcd_stride_dilation_h; + index_t dtile_dx = dilation_w / gcd_stride_dilation_w; + index_t dtile_y = y_tilda; + index_t dtile_x = x_tilda; + index_t dtile_h = h_tilda; + index_t dtile_w = w_tilda; + index_t dslice_y = y_dot_slice; + index_t dslice_x = x_dot_slice; + index_t dslice_h = h_tilda_slice; + index_t dslice_w = w_tilda_slice; + index_t dslice_h_left = h_tilda_left; + index_t dslice_w_left = w_tilda_left; + + auto cur_block_out = [&](){ + for(index_t t_inb = 0; t_inb < ta_nb_per_thread; t_inb++){ + for(index_t t_ik = 0; t_ik < ta_nk_per_thread; t_ik++){ + //index_t i_b_req = out_block_req_desc.offset({cur_gks, cur_gemm_id, cur_gemm_m / gemm_m_per_block, cur_gemm_k / gemm_k_per_block, t_inb, t_ik}); + //block_req_t & b_req = out_block_req[i_b_req]; + block_req_t b_req; + b_req.block_size = block_size; + b_req.bid.push_back(bid); + b_req.req_idx = ta_block_req_idx[bid]; + ta_block_req_idx[bid]++; + + if(cur_gemm_n == 0){ + for(index_t tid = 0; tid < block_size; tid++){ + + index_t out_inb, out_ik; + if(tunable->tensor_a_pass_through){ + index_t tmp = tid; index_t tmp1; + out_inb = (tmp % ca_nb1) * ta_nb1; tmp /= ca_nb1; + out_ik = (tmp % ca_k) * ta_vector_k; tmp /= ca_k; + tmp1 = (tmp % ca_nb0) * ta_nb0; + out_inb = tmp1 * (ca_nb1 * ta_nb1) + out_inb; + }else{ + out_ik = (tid % ca_k) * ta_k; + out_inb = (tid / ca_k) * ta_nb1; + } + index_t cur_out_inb = cur_gemm_m + out_inb + t_inb * ta_nb_thread_stride; + + auto out_gemm_m_trans = gemm_m_transform.get(cur_out_inb); + auto out_gemm_k_trans = gemm_k_transform.get(cur_gemm_k + cur_gks * gemm_k); + + index_t cur_out_dslice_iy = out_gemm_k_trans[0]; + index_t cur_out_dslice_ix = out_gemm_k_trans[1]; + index_t cur_out_ik = out_gemm_k_trans[2] + out_ik + t_ik * ta_vector_k * (tunable->tensor_a_pass_through ? ca_k : 1); + + index_t cur_out_in = out_gemm_m_trans[0]; + index_t cur_out_dslice_ih = out_gemm_m_trans[1]; + index_t cur_out_dslice_iw = out_gemm_m_trans[2]; + + // iho = out_dslice_ih + dslice_h_left - dtile_dy * dslice_iy + // iwo = out_dslice_iw + dslice_w_left - dtile_dx * dslice_ix + index_t cur_out_iho = cur_out_dslice_ih + dslice_h_left - dtile_dy * cur_out_dslice_iy; + index_t cur_out_iwo = cur_out_dslice_iw + dslice_w_left - dtile_dx * cur_out_dslice_ix; + + auto cur_out_idx = {cur_out_in, cur_out_iho, cur_out_iwo, cur_group, cur_out_ik}; + bool cur_out_valid = tensor_out.range_check(cur_out_idx); + //printf("out bid:%zu tid:%zu, n:%zu,ho:%zu,wo:%zu,g:%zu,k:%zu, %s dslice_ih:%zu,dslice_iw:%zu,dslice_iy:%zu,dslice_ix:%zu,dtile_dy:%zu,dtile_dx:%zu,dslice_y:%zu,dslice_x:%zu,dslice_h_left:%zu,dslice_w_left:%zu\n", + // bid, tid,cur_out_in, cur_out_iho, cur_out_iwo, cur_group, cur_out_ik,cur_out_valid?"y":"n", + // cur_out_dslice_ih, cur_out_dslice_iw,cur_out_dslice_iy,cur_out_dslice_ix,dtile_dy,dtile_dx,dslice_y,dslice_x,dslice_h_left,dslice_w_left); + index_t cur_out_offset = tensor_out.offset(cur_out_idx) * data_byte; + b_req.req.emplace_back(req_t({tid, data_byte, ta_vector_k, cur_out_offset, cur_out_valid})); + } + out_block_req.push_back(b_req); + } + } + } + }; + auto cur_block_wei = [&](){ + for(index_t t_ik = 0; t_ik < tb_nk_per_thread; t_ik++){ + for(index_t t_ic = 0; t_ic < tb_nc_per_thread; t_ic++){ + //index_t i_b_req = wei_block_req_desc.offset({cur_gks, cur_gemm_id, cur_gemm_n / gemm_n_per_block, cur_gemm_k / gemm_k_per_block, t_ik, t_ic}); + //block_req_t & b_req = wei_block_req[i_b_req]; + block_req_t b_req; + b_req.block_size = block_size; + b_req.bid.push_back(bid); + b_req.req_idx = tb_block_req_idx[bid]; + tb_block_req_idx[bid]++; + + if(cur_gemm_m == 0){ + for(index_t tid = 0; tid < block_size; tid++){ + index_t wei_ik, wei_ic; + + wei_ic = (tid % cb_c1) * tb_c1; + wei_ik = (tid / cb_c1) * tb_k; + + index_t cur_wei_ic = cur_gemm_n + wei_ic + t_ic * tb_nc_thread_stride; + + auto wei_gemm_k_trans = gemm_k_transform.get(cur_gemm_k + cur_gks * gemm_k); + + index_t cur_wei_dslice_iy = wei_gemm_k_trans[0]; + index_t cur_wei_dslice_ix = wei_gemm_k_trans[1]; + index_t cur_wei_ik = wei_gemm_k_trans[2] + (wei_ik + t_ik * tb_nk_thread_stride); + + // iy = dslice_iy * dtile_y + dtile_iy + // ix = dslice_ix * dtile_x + dtile_ix + index_t cur_wei_iy = cur_wei_dslice_iy * dtile_y + dtile_iy; + index_t cur_wei_ix = cur_wei_dslice_ix * dtile_x + dtile_ix; + + auto cur_wei_idx = {cur_group, cur_wei_ik, cur_wei_iy, cur_wei_ix, cur_wei_ic}; + bool cur_wei_valid = tensor_wei.range_check(cur_wei_idx); + + index_t cur_wei_offset = tensor_wei.offset(cur_wei_idx) * data_byte; + b_req.req.emplace_back(req_t({tid, data_byte, tb_vector_c, cur_wei_offset, cur_wei_valid})); + } + wei_block_req.push_back(b_req); + } + } + } + }; + auto cur_block_inp = [&](){ + if(cur_gemm_k == 0){ + for(index_t t_inb = 0 ; t_inb < tc_nb_per_thread; t_inb++){ + //index_t i_b_req = inp_block_req_desc.offset({cur_gemm_id, cur_gemm_m / gemm_m_per_block, cur_gemm_n / gemm_n_per_block, t_inb}); + //block_req_t & b_req = inp_block_req[i_b_req]; + block_req_t b_req; + b_req.block_size = block_size; + b_req.bid.push_back(bid); + b_req.req_idx = tc_block_req_idx[bid]; + tc_block_req_idx[bid]++; + + if(cur_gks == 0){ + for(index_t tid = 0; tid < block_size; tid++){ + index_t in_inb, in_ic; + in_ic = (tid % cc_c) * tc_vector_c; + in_inb = tid / cc_c; + + index_t cur_in_ic = cur_gemm_n + in_ic; + index_t cur_in_inb = cur_gemm_m + in_inb + t_inb * tc_nb_thread_stride; + + auto in_gemm_m_trans = gemm_m_transform.get(cur_in_inb); + + // ihi = (in_dslice_ih + dslice_h_left) * stride_h + dtile_iy * dilation_h - pad_h + // iwi = (in_dslice_iw + dslice_w_left) * stride_w + dtile_ix * dilation_w - pad_w + index_t cur_in_in = in_gemm_m_trans[0]; + index_t cur_in_dslice_ih = in_gemm_m_trans[1]; + index_t cur_in_dslice_iw = in_gemm_m_trans[2]; + + index_t cur_in_ihi = (cur_in_dslice_ih + dslice_h_left) * stride_h + dtile_iy * dilation_h - pad_h; + index_t cur_in_iwi = (cur_in_dslice_iw + dslice_w_left) * stride_w + dtile_ix * dilation_w - pad_w; + + auto cur_in_idx = {cur_in_in, cur_in_ihi, cur_in_iwi, cur_group, cur_in_ic}; + auto cur_in_valid = tensor_inp.range_check(cur_in_idx); + + index_t cur_in_offset = tensor_inp.offset(cur_in_idx) * data_byte; + b_req.req.emplace_back(req_t({tid, data_byte, tc_vector_c, cur_in_offset, cur_in_valid})); + } + inp_block_req.push_back(b_req); + } + } + } + }; + + cur_block_out(); + cur_block_wei(); + cur_block_inp(); + }; + + for(index_t bid = 0; bid < grid_size; bid++){ + auto cur_block_position = block_mapping.get(bid); // position of this block in ndim space + auto cur_gks = cur_block_position[0]; + auto cur_gemm_id = cur_block_position[1]; + auto cur_group = cur_block_position[2]; + auto cur_gemm_m = cur_block_position[3] * gemm_m_per_block; + auto cur_gemm_n = cur_block_position[4] * gemm_n_per_block; + + index_t i_y_tilda = cur_gemm_id / x_tilda; + index_t i_x_tilda = cur_gemm_id % x_tilda; + index_t y_dot_slice = utility_integer_divide_ceil(y - i_y_tilda, y_tilda); + index_t x_dot_slice = utility_integer_divide_ceil(x - i_x_tilda, x_tilda); + + index_t gemm_k = (k / group) * y_dot_slice * x_dot_slice / num_global_splits; + + bool is_gemm_not_empty = (i_y_tilda < y) && (i_x_tilda < x); + + if(!is_gemm_not_empty) + continue; + + for(index_t cur_gemm_k = 0; cur_gemm_k < gemm_k; cur_gemm_k += gemm_k_per_block){ + cur_block(bid, cur_gks, cur_gemm_id, cur_group, cur_gemm_m, cur_gemm_n, cur_gemm_k); + } + } + + gmap_serialize_and_valid(conv_args, tunable->tensor_layout, inp_block_req, wei_block_req, out_block_req, + record_inp, record_wei, record_out, tensor_inp, tensor_wei, tensor_out, fp_inp, fp_wei, fp_out); + +} + void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tunable, int gks, FILE *fp_inp, FILE *fp_wei, FILE *fp_out) { index_t hi = conv_args->get_int("in_h"); @@ -287,17 +683,17 @@ void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tun index_t block_size = ca_e * ca_c * ca_nb0 * ca_nb1; assert(block_size == (cb_e * cb_c * cb_k0 * cb_k1)); assert((gemm_m % gemm_m_per_block == 0) && (gemm_n % gemm_n_per_block == 0)); - index_t grid_size = group * num_global_splits * (gemm_m / gemm_m_per_block) * (gemm_n / gemm_n_per_block); - linear_tensor_t block_mapping({group, num_global_splits, (gemm_m / gemm_m_per_block), (gemm_n / gemm_n_per_block)}); + index_t grid_size = num_global_splits * group * (gemm_m / gemm_m_per_block) * (gemm_n / gemm_n_per_block); + linear_tensor_t block_mapping({num_global_splits, group, (gemm_m / gemm_m_per_block), (gemm_n / gemm_n_per_block)}); linear_tensor_t gemm_m_transform({n, ho, wo}); linear_tensor_t gemm_k_transform({y, x, c / group}); linear_tensor_t tensor_inp({n, hi, wi, group, c/group}); linear_tensor_t tensor_wei({group, k/group, y, x, c/group}); linear_tensor_t tensor_out({n, ho, wo, group, k/group}); - std::vector record_inp(n*hi*wi*(c/group), false); + std::vector record_inp(n*hi*wi*group*(c/group), false); std::vector record_wei(group*(k/group)*y*x*(c/group), false); - std::vector record_out(n*ho*wo*(k/group), false); + std::vector record_out(n*ho*wo*group*(k/group), false); index_t ta_nb_per_thread = ta_nb0 != 1 ? ta_nb0 : ta_nb1; index_t ta_vector_c = utility_gcd(ta_c, 4 * (4 / data_byte)); @@ -343,23 +739,23 @@ void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tun std::vector tc_block_req_idx(grid_size, 0); std::vector inp_block_req; - linear_tensor_t inp_block_req_desc({num_global_splits, gemm_m / gemm_m_per_block, gemm_k / gemm_k_per_block, ta_nb_per_thread, ta_nc_per_thread}); + linear_tensor_t inp_block_req_desc({num_global_splits, group, gemm_m / gemm_m_per_block, gemm_k / gemm_k_per_block, ta_nb_per_thread, ta_nc_per_thread}); inp_block_req.resize(inp_block_req_desc.size()); std::vector wei_block_req; - linear_tensor_t wei_block_req_desc({num_global_splits, gemm_n / gemm_n_per_block, gemm_k / gemm_k_per_block, tb_nk_per_thread, tb_nc_per_thread}); + linear_tensor_t wei_block_req_desc({num_global_splits, group, gemm_n / gemm_n_per_block, gemm_k / gemm_k_per_block, tb_nk_per_thread, tb_nc_per_thread}); wei_block_req.resize(wei_block_req_desc.size()); std::vector out_block_req; - linear_tensor_t out_block_req_desc({gemm_m / gemm_m_per_block, gemm_n / gemm_n_per_block, tc_nb_per_thread}); + linear_tensor_t out_block_req_desc({group, gemm_m / gemm_m_per_block, gemm_n / gemm_n_per_block, tc_nb_per_thread}); out_block_req.resize(out_block_req_desc.size()); - auto cur_block = [&](index_t bid, index_t cur_group, index_t cur_gks, index_t cur_gemm_m, index_t cur_gemm_n, index_t cur_gemm_k){ + auto cur_block = [&](index_t bid, index_t cur_gks, index_t cur_group, index_t cur_gemm_m, index_t cur_gemm_n, index_t cur_gemm_k){ // inp auto cur_block_inp = [&](){ for(index_t t_inb = 0; t_inb < ta_nb_per_thread; t_inb++){ for(index_t t_ic = 0; t_ic < ta_nc_per_thread; t_ic++){ - index_t i_b_req = inp_block_req_desc.offset({cur_gks, cur_gemm_m / gemm_m_per_block, cur_gemm_k / gemm_k_per_block, t_inb, t_ic}); + index_t i_b_req = inp_block_req_desc.offset({cur_gks, cur_group, cur_gemm_m / gemm_m_per_block, cur_gemm_k / gemm_k_per_block, t_inb, t_ic}); block_req_t & b_req = inp_block_req[i_b_req]; b_req.block_size = block_size; b_req.bid.push_back(bid); @@ -413,7 +809,7 @@ void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tun auto cur_block_wei = [&](){ for(index_t t_ik = 0; t_ik < tb_nk_per_thread; t_ik++){ for(index_t t_ic = 0; t_ic < tb_nc_per_thread; t_ic++){ - index_t i_b_req = wei_block_req_desc.offset({cur_gks, cur_gemm_n / gemm_n_per_block, cur_gemm_k / gemm_k_per_block, t_ik, t_ic}); + index_t i_b_req = wei_block_req_desc.offset({cur_gks, cur_group, cur_gemm_n / gemm_n_per_block, cur_gemm_k / gemm_k_per_block, t_ik, t_ic}); block_req_t & b_req = wei_block_req[i_b_req]; b_req.block_size = block_size; b_req.bid.push_back(bid); @@ -452,7 +848,7 @@ void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tun auto cur_block_out = [&](){ if(cur_gemm_k == 0){ for(index_t t_inb = 0 ; t_inb < tc_nb_per_thread; t_inb++){ - index_t i_b_req = out_block_req_desc.offset({cur_gemm_m / gemm_m_per_block, cur_gemm_n / gemm_n_per_block, t_inb}); + index_t i_b_req = out_block_req_desc.offset({cur_group, cur_gemm_m / gemm_m_per_block, cur_gemm_n / gemm_n_per_block, t_inb}); block_req_t & b_req = out_block_req[i_b_req]; b_req.block_size = block_size; b_req.bid.push_back(bid); @@ -492,58 +888,17 @@ void gmap_dump_fwd_nhwc(const args_t *conv_args, const igemm_gtc_tunable_t * tun for(index_t bid = 0; bid < grid_size; bid++){ auto cur_block_position = block_mapping.get(bid); // position of this block in ndim space - auto cur_group = cur_block_position[0]; - auto cur_gks = cur_block_position[1]; + auto cur_gks = cur_block_position[0]; + auto cur_group = cur_block_position[1]; auto cur_gemm_m = cur_block_position[2] * gemm_m_per_block; auto cur_gemm_n = cur_block_position[3] * gemm_n_per_block; for(index_t cur_gemm_k = 0; cur_gemm_k < gemm_k; cur_gemm_k += gemm_k_per_block){ - cur_block(bid, cur_group, cur_gks, cur_gemm_m, cur_gemm_n, cur_gemm_k); + cur_block(bid, cur_gks, cur_group, cur_gemm_m, cur_gemm_n, cur_gemm_k); } } - // serialize block request - for(auto itr_ibr = inp_block_req.begin(); itr_ibr != inp_block_req.end(); itr_ibr++) - serialize_block_req(&(*itr_ibr), fp_inp, &record_inp); - - for(auto itr_ibr = wei_block_req.begin(); itr_ibr != wei_block_req.end(); itr_ibr++) - serialize_block_req(&(*itr_ibr), fp_wei, &record_wei); - - for(auto itr_ibr = out_block_req.begin(); itr_ibr != out_block_req.end(); itr_ibr++) - serialize_block_req(&(*itr_ibr), fp_out, &record_out); - - // valid all record - std::vector valid_hi, valid_wi; - std::tie(valid_hi, valid_wi) = gmap_get_input_access_map(conv_args); - for(auto it = record_inp.begin(); it != record_inp.end(); it++){ - index_t idx = std::distance(record_inp.begin(), it); - std::vector inp_position = tensor_inp.get(idx); - index_t ihi = inp_position[1]; - index_t iwi = inp_position[2]; - if(valid_hi[ihi] && valid_wi[iwi]){ - if(!(*it)){ - printf("WARNING! input not touched pixel at %zu\n", idx); - } - } - else{ - if(*it){ - printf("WARNING! input touched unused pixel at %zu\n", idx); - } - } - } - - for(auto it = record_wei.begin(); it != record_wei.end(); it++){ - index_t idx = std::distance(record_wei.begin(), it); - if(!(*it)){ - printf("WARNING! weight not touched pixel at %zu\n", idx); - } - } - - for(auto it = record_out.begin(); it != record_out.end(); it++){ - index_t idx = std::distance(record_out.begin(), it); - if(!(*it)){ - printf("WARNING! output not touched pixel at %zu\n", idx); - } - } + gmap_serialize_and_valid(conv_args, tunable->tensor_layout, inp_block_req, wei_block_req, out_block_req, + record_inp, record_wei, record_out, tensor_inp, tensor_wei, tensor_out, fp_inp, fp_wei, fp_out); } void gmap_dump_banner(const args_t *conv_args, const igemm_gtc_tunable_t * tunable, FILE *fp_inp, FILE *fp_wei, FILE *fp_out) @@ -645,6 +1000,15 @@ void gmap_dump(const args_t *conv_args, const igemm_gtc_tunable_t * tunable, int assert(0); } } + else if(direction == "bwd"){ + if(tensor_layout == "nchw"){ + + }else if(tensor_layout == "nhwc"){ + gmap_dump_bwd_nhwc(conv_args, tunable, gks, fp_inp, fp_wei, fp_out); + }else{ + assert(0); + } + } fclose(fp_inp); fclose(fp_wei);