Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add template info into generated files #3325

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions fbgemm_gpu/codegen/genscript/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,35 @@

@dataclass
class CodeTemplate:
relative_path: str
template: jinja2.Template

@staticmethod
# pyre-ignore[3]
def load(relative_path: str):
return CodeTemplate(env.get_template(relative_path))
return CodeTemplate(relative_path, env.get_template(relative_path))

def write(self, filename: str, **kwargs: Any) -> None:
# Render the generated file header
comment = (
"##"
if (
self.relative_path.endswith(".py")
or self.relative_path.endswith(".template")
)
else "//"
)
generated_file_header = (
f"{comment * 40}\n"
f"{comment} GENERATED FILE INFO\n"
f"{comment}\n"
f"{comment} Template Source: {self.relative_path}\n"
f"{comment * 40}\n"
"\n"
)

# Render the template
output = self.template.render(**kwargs)
output = generated_file_header + self.template.render(**kwargs)

# All generated files are written to the specified install directory.
with open(os.path.join(args.install_dir, filename), "w") as f:
Expand Down
63 changes: 41 additions & 22 deletions fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "fbgemm_gpu/utils/ops_utils.h"
#ifdef FBCODE_CAFFE2
#include <libdivide.h>
#include "folly/container/F14Map.h"
#else
#include <omp.h>
#endif
Expand All @@ -29,7 +28,12 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

template <typename weights_t, typename ind_weights_t, typename output_t>
template <
typename weights_t,
typename ind_weights_t,
typename index_t,
typename offset_t,
typename output_t>
void split_embedding_forward_cpu_kernel(
Tensor weights,
Tensor weights_offsets,
Expand All @@ -56,8 +60,8 @@ void split_embedding_forward_cpu_kernel(

const auto D_offsets_data = D_offsets.accessor<int, 1>();
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
const auto indices_data = indices.data_ptr<int64_t>();
const auto offsets_data = offsets.data_ptr<int64_t>();
const auto indices_data = indices.data_ptr<index_t>();
const auto offsets_data = offsets.data_ptr<offset_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();

const auto weights_data = weights.data_ptr<weights_t>();
Expand Down Expand Up @@ -97,8 +101,8 @@ void split_embedding_forward_cpu_kernel(
weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int64_t,
/*OffsetType=*/int64_t>(
/*IndexType=*/index_t,
/*OffsetType=*/offset_t>(
D,
indice_weights.defined(),
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
Expand Down Expand Up @@ -203,29 +207,44 @@ Tensor split_embedding_codegen_forward_cpu(
// It is assumed that the indice_weights will always be float
TORCH_CHECK(
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);

FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "split_embedding_cpu_forward", [&]() {
output.scalar_type(), "split_embedding_cpu_forward_1", [&]() {
using output_t = scalar_t;

FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
weights.scalar_type(), "split_embedding_cpu_forward", [&] {
weights.scalar_type(), "split_embedding_cpu_forward_2", [&] {
using ind_weights_t = std::conditional<
std::is_same<scalar_t, double>::value,
double,
float>::type;
split_embedding_forward_cpu_kernel<
scalar_t,
ind_weights_t,
output_t>(
weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output);

AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "split_embedding_cpu_forward_3", [&] {
using offset_t = index_t;

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"split_embedding_cpu_forward_4",
[&] {
split_embedding_forward_cpu_kernel<
scalar_t,
ind_weights_t,
index_t,
offset_t,
output_t>(
weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output);
});
});
});
});
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ batch_index_select_dim0_codegen_forward_small_kernel(
indices_start = total_L_start + L_start;
L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start);
{%- else %}
index_t indices_start = offsets[b_t];
int32_t L = offsets[b_t + 1] - indices_start;
const auto indices_start = offsets[b_t];
const auto L = offsets[b_t + 1] - indices_start;
{%- endif %}

{%- if is_index_select %}
const int32_t D_start = D_offsets[t];
const int32_t D_end = D_offsets[t + 1];
const int32_t D = D_end - D_start;
const auto D_start = D_offsets[t];
const auto D_end = D_offsets[t + 1];
const auto D = D_end - D_start;

// Check D in the kernel to avoid iterating through the list on host
CUDA_KERNEL_ASSERT(D % 4 == 0 && "The column size must be multiple of 4");
Expand Down Expand Up @@ -221,7 +221,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for kEmbeddingSize in [4, 8, 16, 32] %}
{%- set index_type = 'int64_t' %}
{%- for index_type in ['int32_t', 'int64_t'] %}

template __launch_bounds__(kForwardMaxThreads) __global__ void
{%- if is_index_select %}
Expand Down Expand Up @@ -268,3 +268,4 @@ batch_index_select_dim0_codegen_forward_small_kernel
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ batch_index_select_dim0_codegen_forward_kernel(
emb_type,
cache_type,
output_type,
index_type,
use_cache,
kMaxVecsPerThread,
kThreadGroupSize)
Expand All @@ -577,7 +578,7 @@ batch_index_select_dim0_codegen_forward_kernel
{%- if not dense %}
{{ use_cache }},
{%- endif %}
int64_t,
{{ index_type }},
{%- if not nobag %}
{{ kMaxVecsPerThread }},
{%- endif %}
Expand All @@ -603,9 +604,9 @@ batch_index_select_dim0_codegen_forward_kernel
{%- else %}
FixedDivisor fd_B,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> indices,
{%- if not is_index_select %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets,
{%- endif %}
{%- if not nobag %}
int64_t pooling_mode,
Expand Down Expand Up @@ -638,17 +639,20 @@ batch_index_select_dim0_codegen_forward_kernel
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{{ template_instantiation(
emb_type,
cache_type,
output_type,
index_type,
use_cache,
kMaxVecsPerThread,
kThreadGroupSize)
}}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}

{%- macro instantiate_templates(use_subwarp_shuffle) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
*/

{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for use_cache in ['true', 'false'] %}
Expand All @@ -996,7 +997,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
{{ emb_type }},
{{ cache_type }},
{{ output_type }},
int64_t, // index_t
{{ index_type }},
{{ use_cache }}
> (
const {{ emb_type }}* __restrict__ const dev_weights,
Expand All @@ -1008,11 +1009,11 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
const bool mean_pooling,
const uint32_t max_D_cache,
const FixedDivisor fd_num_warps_per_table,
const int64_t* __restrict__ const indices,
const {{ index_type }}* __restrict__ const indices,
{%- if weighted %}
const float* __restrict__ const index_weights,
{%- endif %}
const int64_t* __restrict__ const offsets,
const {{ index_type }}* __restrict__ const offsets,
const uint32_t* __restrict__ const D_offsets,
const int64_t* __restrict__ const weights_offsets,
const int32_t* __restrict__ const lxu_cache_locations,
Expand All @@ -1022,3 +1023,4 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,8 @@ batch_index_select_dim0_codegen_forward_cuda(
return output;
}


AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "batched_embedding{{ ndesc }}_forward_kernel_1", [&] {
DISPATCH_EMB_CACHE_OUTPUT_TYPES(
dev_weights.scalar_type(),
{%- if not dense %}
Expand Down Expand Up @@ -590,7 +592,7 @@ batch_index_select_dim0_codegen_forward_cuda(
emb_t,
cache_t,
output_t,
int64_t,
index_t,
kEmbeddingSize / 4>
<<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
Expand All @@ -611,9 +613,9 @@ batch_index_select_dim0_codegen_forward_cuda(
D,
{%- endif %}
FixedDivisor(B),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
{%- if not is_index_select %}
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- endif %}
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
Expand Down Expand Up @@ -644,9 +646,9 @@ batch_index_select_dim0_codegen_forward_cuda(

{{ nobag_kernel }}
{%- if dense or is_index_select %}
<emb_t, cache_t, output_t, int64_t>
<emb_t, cache_t, output_t, index_t>
{%- else %}
<emb_t, cache_t, output_t, use_cache_t, int64_t>
<emb_t, cache_t, output_t, use_cache_t, index_t>
{%- endif %}
<<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
Expand All @@ -667,9 +669,9 @@ batch_index_select_dim0_codegen_forward_cuda(
D,
{%- endif %}
FixedDivisor(B),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
{%- if not is_index_select %}
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- endif %}
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
Expand Down Expand Up @@ -717,7 +719,7 @@ batch_index_select_dim0_codegen_forward_cuda(
{%- if not dense%}
use_cache_t,
{%- endif %}
int64_t,
index_t,
kMaxVecsPerThread,
kThreadGroupSize>
<<<
Expand All @@ -742,8 +744,8 @@ batch_index_select_dim0_codegen_forward_cuda(
{%- else %}
FixedDivisor(B),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
pooling_mode,
{%- if weighted %}
MAKE_PTA_ACC_WITH_NAME(func_name, indice_weights, cache_t, 1, 32),
Expand Down Expand Up @@ -784,9 +786,9 @@ batch_index_select_dim0_codegen_forward_cuda(

const auto kernel_func =
(use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel<
emb_t, cache_t, output_t, int64_t, true>
emb_t, cache_t, output_t, index_t, true>
: split_embedding_codegen_forward_{{ wdesc }}_v2_kernel<
emb_t, cache_t, output_t, int64_t, false>);
emb_t, cache_t, output_t, index_t, false>);

kernel_func
<<<
Expand All @@ -804,12 +806,12 @@ batch_index_select_dim0_codegen_forward_cuda(
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
use_lxu_cache ? lxu_cache_weights.size(1) : 0,
FixedDivisor(num_warps_per_table),
indices.data_ptr<int64_t>(),
indices.data_ptr<index_t>(),
{%- if weighted %}
// TODO: update indice_weights type
indice_weights.data_ptr<float>(),
{%- endif %}
offsets.data_ptr<int64_t>(),
offsets.data_ptr<index_t>(),
reinterpret_cast<uint32_t*>(D_offsets.data_ptr<int32_t>()),
weights_offsets.data_ptr<int64_t>(),
lxu_cache_locations.data_ptr<int32_t>(),
Expand All @@ -819,7 +821,7 @@ batch_index_select_dim0_codegen_forward_cuda(
}
{%- endif %} // if has_experimental
});

});
return output;
}

Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ namespace fbgemm_gpu {
* scale_bias_last == false that can take -1 indices (output from
* pruned embedding id mapping)
*/
template <typename IndexType>
template <typename IndexType, typename OffsetType>
void report_embedding_error(
int t,
int B,
int b_begin,
int b_end,
const IndexType* offsets_data,
const OffsetType* offsets_data,
const IndexType* indices_data,
int64_t hash_size,
bool allow_minus_one = false) {
Expand Down
Loading