Skip to content

Commit

Permalink
Get some fbgemm::split_embedding_codegen_lookup_{} functions pt2_comp…
Browse files Browse the repository at this point in the history
…liant (pytorch#2198)

Summary:
Pull Request resolved: pytorch#2198

In particular `split_embedding_codegen_lookup_rowwise_adagrad_function`. But a bunch of similar ops can be marked as pt2_compliant by fixing some templates and bugs.

Reviewed By: zou3519

Differential Revision: D51960321

fbshipit-source-id: 0d15eedd0aa4fb78d8d7ecc88fc170b648a26a13
  • Loading branch information
williamwen42 authored and facebook-github-bot committed Dec 11, 2023
1 parent f0ade47 commit 8f20c5d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 105 deletions.
42 changes: 42 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,42 @@ Tensor split_embedding_codegen_forward_cpu(
return output;
}

Tensor split_embedding_codegen_forward_cpu_meta(
Tensor weights,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
Tensor hash_size_cumsum,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Tensor indice_weights,
int64_t output_dtype) {
c10::SymInt T = D_offsets.sym_numel() - 1;
TORCH_CHECK_GT(T, 0);
// offsets = [T x B + 1]
c10::SymInt B = (offsets.sym_size(0) - 1) / T;
TORCH_CHECK_GE(B, 0);

Tensor output;
if (output_dtype == static_cast<int64_t>(SparseType::FP32)) {
output =
at::empty_symint({B, total_D}, weights.options().dtype(at::kFloat));
} else if (output_dtype == static_cast<int64_t>(SparseType::FP16)) {
output = at::empty_symint({B, total_D}, weights.options().dtype(at::kHalf));
} else if (output_dtype == static_cast<int64_t>(SparseType::BF16)) {
output =
at::empty_symint({B, total_D}, weights.options().dtype(at::kBFloat16));
} else {
output = at::empty_symint({B, total_D}, weights.options());
}

// It is assumed that the indice_weights will always be float
TORCH_CHECK(
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
return output;
}

template <typename weights_t, typename grad_t>
void split_embedding_grad_indice_weights_cpu_kernel(
Tensor grad_output,
Expand Down Expand Up @@ -632,4 +668,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
split_embedding_codegen_forward_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"split_embedding_codegen_forward_cpu",
&split_embedding_codegen_forward_cpu_meta);
}

} // namespace
4 changes: 4 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ Tensor
return output;
}

{%- if not nobag and vbe %}
output = output.reshape({-1});
{%- endif %}

return output;
}

Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,10 @@
TYPE, NAME, FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16_CASE(__VA_ARGS__))

// We can cleanup the following once fbgemm uses PyTorch 2.2 in January 2024.
#ifndef PT2_COMPLIANT_TAG
#ifdef HAS_PT2_COMPLIANT_TAG
#define PT2_COMPLIANT_TAG at::Tag::pt2_compliant_tag
#else
#define PT2_COMPLIANT_TAG
#endif
#endif
30 changes: 29 additions & 1 deletion fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@
#include <ATen/ATen.h>
#include <torch/library.h>

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {

std::tuple<Tensor /*row_output_offsets*/, Tensor /*b_t_map*/>
generate_vbe_metadata_meta(
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& output_offsets_feature_rank,
const Tensor& D_offsets,
const int64_t D,
const bool nobag,
const int64_t max_B_feature_rank,
const int64_t info_B_num_bits,
const c10::SymInt total_B) {
Tensor row_output_offsets =
at::empty_symint({total_B}, output_offsets_feature_rank.options());
Tensor b_t_map = at::empty_symint({total_B}, B_offsets.options());
return {row_output_offsets, b_t_map};
}

} // namespace

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"transpose_embedding_input("
Expand Down Expand Up @@ -40,9 +64,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" bool nobag, "
" int max_B_feature_rank, "
" int info_B_num_bits, "
" int total_B"
" SymInt total_B"
") -> (Tensor, Tensor)");
DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input);
DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata);
DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("generate_vbe_metadata", &generate_vbe_metadata_meta);
}
110 changes: 6 additions & 104 deletions fbgemm_gpu/test/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -487,115 +487,17 @@
}
},
"fbgemm::split_embedding_codegen_lookup_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_adam_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adam": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": {},
"fbgemm::split_embedding_codegen_lookup_adam_function": {},
"fbgemm::split_embedding_codegen_lookup_lamb_function": {},
"fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {},
"fbgemm::split_embedding_codegen_lookup_none_function": {},
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {},
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
"status": "skip"
}
},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": {
"comment": "",
"status": "skip"
}
},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_sgd_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": {
"comment": "",
"status": "skip"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": {
"comment": "",
"status": "skip"
}
},
"fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": {
"comment": "",
"status": "xfail"
}
}
"fbgemm::split_embedding_codegen_lookup_sgd_function": {},
"fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {}
}
}

0 comments on commit 8f20c5d

Please sign in to comment.