Skip to content

Commit

Permalink
Align/add rocm stream headers
Browse files Browse the repository at this point in the history
  • Loading branch information
acoskunses-AMD committed Sep 19, 2024
2 parents 863f596 + ebbebd4 commit e94d9e9
Show file tree
Hide file tree
Showing 63 changed files with 3,541 additions and 1,374 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/fbgemm_gpu_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ jobs:
run: . $PRELUDE; install_fbgemm_gpu_pip $BUILD_ENV ${{ github.event.inputs.fbgemm_gpu_channel_version || 'nightly' }} cuda/${{ matrix.cuda-version }}

- name: Test with PyTest
timeout-minutes: 20
timeout-minutes: 40
run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV


Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
from torch import Tensor

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import torch
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
# pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`.
from torch.profiler import profile, ProfilerActivity

logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/bench/quantize_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# pyre-ignore[21]
from torch.profiler import profile, ProfilerActivity


logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/split_embeddings_cache_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

from torch import nn, Tensor

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
from torch import Tensor
from torch.profiler import profile

logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

haveAIBench = False
try:
from aibench_observer.utils.observer import emitMetric
Expand Down
5 changes: 2 additions & 3 deletions fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@
from torch.autograd.profiler import record_function
from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings",
)

logging.basicConfig(level=logging.DEBUG)


@click.group()
def cli() -> None:
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/bench/stride_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

logging.basicConfig(level=logging.DEBUG)
logger: logging.Logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

try:
# pyre-ignore[21]
Expand Down
50 changes: 25 additions & 25 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,38 +163,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
# has_gpu_support=True
if not kwargs.get("dense"):
# Generate CUDA autograd
template_filepath = (
"training/backward/embedding_backward_split_host_template.cpp"
)

for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]:
template_filepath = (
"training/backward/embedding_backward_split_host_template.cpp"
)
desc = "ssd" if ssd else "split"
sdesc = "_ssd" if ssd else ""
filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp"
CodeTemplate.load(template_filepath).write(
filename, is_forward=False, ssd=ssd, **kwargs
)

# Generate PT2 unified autograd, and PT2 backward wrapper
for template_filepath, filename in [
(
"training/pt2/embedding_split_host_pt2_autograd_template.cpp",
f"gen_embedding_split_{optimizer}_pt2_autograd.cpp",
),
(
"training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp",
f"gen_embedding_backward_split_{optimizer}_pt2_cuda_wrapper.cpp",
),
]:
CodeTemplate.load(template_filepath).write(
filename, is_forward=False, **kwargs
)

if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"):
# Generates Python invoker for CUDA + CPU, and PT2
template = CodeTemplate.load(
"training/python/split_embedding_codegen_lookup_invoker.template"
)
for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]:
sdesc = "_ssd" if ssd else ""
# Generate PT2 unified autograd, and PT2 backward wrapper for all optimizers
for template_filepath, filename in [
(
"training/pt2/embedding_split_host_pt2_autograd_template.cpp",
f"gen_embedding_{desc}_{optimizer}_pt2_autograd.cpp",
),
(
"training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp",
f"gen_embedding_backward_{desc}_{optimizer}_pt2_cuda_wrapper.cpp",
),
]:
CodeTemplate.load(template_filepath).write(
filename, is_forward=False, ssd=ssd, **kwargs
)

if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"):
# Generates Python invoker for CUDA + CPU, and PT2
template = CodeTemplate.load(
"training/python/split_embedding_codegen_lookup_invoker.template"
)
for filename in [
f"lookup_{optimizer}{sdesc}.py",
f"lookup_{optimizer}{sdesc}_pt2.py",
Expand Down
11 changes: 11 additions & 0 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def generate_pt2_wrappers() -> None:
is_forward=True,
)

# Generate PT2 forward wrapper (CUDA)
CodeTemplate.load(
"training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp",
).write(
f"gen_embedding_forward_ssd_pt2_cuda_wrapper.cpp",
has_gpu_support=True,
is_forward=True,
has_vbe_support=True,
ssd=True,
)

@staticmethod
def generate_small_kernels() -> None:
# Generate the small kernels (for nobag only) for the forward splits
Expand Down
102 changes: 89 additions & 13 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
from torch_type_utils import arg_type_to_tensor_type, ArgType, TensorType


######################################################################
# Optimizer Args Set Item
######################################################################


@dataclass
class OptimizerArgsSetItem:
ty: ArgType # type
name: str
default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL
ph_tys: Optional[List[ArgType]] = None # placeholder types


# Alias b/c the name is too long
OptimItem = OptimizerArgsSetItem


######################################################################
## Helper functions for the code generator script ##
######################################################################
Expand Down Expand Up @@ -117,6 +134,11 @@ def int_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> st
return _arg("int32_t", name, gpu=gpu, pass_by_ref=pass_by_ref)


def tensor_list_arg_no_default(name: str, pass_by_ref: bool) -> str:
ref = "&" if pass_by_ref else ""
return f"at::TensorList{ref} {name}"


def tensor_arg(name: str) -> str:
return f"Tensor {name}"

Expand Down Expand Up @@ -165,6 +187,10 @@ def schema_sym_int_arg_no_default(name: str) -> str:
return f"SymInt {name}"


def schema_tensor_list_arg_no_default(name: str) -> str:
return f"Tensor[] {name}"


def make_kernel_arg(
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
ty: ArgType,
Expand Down Expand Up @@ -282,21 +308,69 @@ def make_ivalue_cast(ty: ArgType) -> str:
}[ty]


######################################################################
# Optimizer Args Set Item
######################################################################


@dataclass
class OptimizerArgsSetItem:
ty: ArgType # type
name: str
default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL
ph_tys: Optional[List[ArgType]] = None # placeholder types

class PT2ArgsSet:
split_function_args: List[str]
split_function_arg_names: List[str]
split_function_schemas: List[str]
split_saved_tensor_list: List[str]

# Alias b/c the name is too long
OptimItem = OptimizerArgsSetItem
@staticmethod
# pyre-ignore[3]
def create(
split_arg_spec: List[OptimItem],
):
"""
PT2ArgsSet.create() is a method that creates different formats given the optimization arguments
to be used in TBE codegen PT2 templates.
Parameters:
split_arg_spec: List[OptimItem] - list of argument specs
Returns:
PT2ArgsSet object with the following attributes:
split_function_args: List[str] - List of function arguments
e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay'].
split_function_arg_names: List[str] - List of argument names
e.g., ['momentum1', 'eps', 'weight_decay'].
split_function_schemas: List[str] - List of arguments in the schema format
e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay'].
split_saved_tensor_list: List[str] - List of saved tensors for the split function
e.g., ['momentum1'].
"""
split_function_arg_names = []
split_function_args = []
split_function_schemas = []
split_saved_tensor_list = []
for s in split_arg_spec:
if s.ty in (
ArgType.TENSOR,
ArgType.INT_TENSOR,
ArgType.LONG_TENSOR,
ArgType.PLACEHOLDER_TENSOR,
):
name = s.name.rsplit("_", 1)[0]
if name not in split_function_arg_names:
split_function_arg_names.append(name)
split_saved_tensor_list.append(name)
split_function_args.append(
tensor_list_arg_no_default(name, pass_by_ref=False)
)
split_function_schemas.append(
schema_tensor_list_arg_no_default(name)
)
else:
split_function_arg_names.append(s.name)
split_function_args.append(make_function_arg(s.ty, s.name, s.default))
split_function_schemas.append(
make_function_schema_arg(s.ty, s.name, s.default)
)
return PT2ArgsSet(
split_function_args=split_function_args,
split_function_arg_names=split_function_arg_names,
split_function_schemas=split_function_schemas,
split_saved_tensor_list=split_saved_tensor_list,
)


######################################################################
Expand Down Expand Up @@ -324,6 +398,7 @@ class OptimizerArgs:
placeholder_tensor_names: List[str]
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]]
unified_pt2: PT2ArgsSet

@staticmethod
# pyre-ignore[3]
Expand Down Expand Up @@ -419,6 +494,7 @@ def create(
],
placeholder_tensor_names=ph_tensor_names,
placeholder_type_combos=ph_combos,
unified_pt2=PT2ArgsSet.create(split_arg_spec),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
std::vector<int64_t> max_D_list{
max_int2_D,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
{%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %}

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name_{{ emb_weight_type }} = func_name_{{ emb_weight_type }};
const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}";
#endif

#ifdef X
Expand Down
Loading

0 comments on commit e94d9e9

Please sign in to comment.