diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index f290e0f62..5db6eaf69 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -39,6 +39,7 @@ class OptimizerArgsSetItem: name: str default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL ph_tys: Optional[List[ArgType]] = None # placeholder types + is_optional: bool = False # optional variable # Alias b/c the name is too long @@ -192,6 +193,42 @@ def schema_tensor_list_arg_no_default(name: str) -> str: return f"Tensor[] {name}" +def bool_arg(name: str, default: bool = False) -> str: + return f"bool {name} = {'true' if default else 'false'}" + + +def bool_arg_no_default(name: str) -> str: + return f"bool {name}" + + +def schema_bool_arg(name: str, default: bool = False) -> str: + return f"bool {name} = {default}" + + +def optional_tensor_arg(name: str) -> str: + return f"std::optional {name} = std::nullopt" + + +def optional_tensor_arg_no_default(name: str) -> str: + return f"std::optional {name}" + + +def schema_optional_tensor_arg(name: str) -> str: + return f"Tensor? {name} = None" + + +def optional_tensorlist_arg(name: str) -> str: + return f"std::optional {name} = std::nullopt" + + +def optional_tensorlist_arg_no_default(name: str) -> str: + return f"std::optional {name}" + + +def schema_optional_tensorlist_arg(name: str) -> str: + return f"Tensor[]? {name} = None" + + def make_kernel_arg( # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, @@ -199,9 +236,6 @@ def make_kernel_arg( default: Union[int, float, None], pass_by_ref: bool = False, ) -> str: - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref), ArgType.INT_TENSOR: lambda x: int_tensor_arg(x, pass_by_ref=pass_by_ref), @@ -224,14 +258,15 @@ def make_kernel_arg( if default is not None else float_arg_no_default ), + ArgType.BOOL: ( + (lambda x: bool_arg(x, default=bool(default))) + if default is not None + else bool_arg_no_default + ), }[ty](name) def make_kernel_arg_constructor(ty: ArgType, name: str) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: acc_cache_tensor_arg_constructor, ArgType.INT_TENSOR: int_tensor_arg_constructor, @@ -240,14 +275,11 @@ def make_kernel_arg_constructor(ty: ArgType, name: str) -> str: ArgType.INT: lambda x: x, ArgType.FLOAT: lambda x: x, ArgType.SYM_INT: lambda x: x, + ArgType.BOOL: lambda x: x, }[ty](name) def make_cpu_kernel_arg(ty: ArgType, name: str, default: Union[int, float]) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False), ArgType.INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False), @@ -256,14 +288,11 @@ def make_cpu_kernel_arg(ty: ArgType, name: str, default: Union[int, float]) -> s ArgType.INT: lambda x: int64_arg(x, default=int(default)), ArgType.FLOAT: lambda x: float_arg(x, default=default), ArgType.SYM_INT: lambda x: sym_int_arg(x, default=int(default)), + ArgType.BOOL: lambda x: bool_arg(x, default=bool(default)), }[ty](name) def make_cpu_kernel_arg_constructor(ty: ArgType, name: str) -> str: - # learning_rate is a float in kernels - if name == "learning_rate_tensor": - ty = ArgType.FLOAT - name = "learning_rate" return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False), ArgType.INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False), @@ -274,17 +303,53 @@ def make_cpu_kernel_arg_constructor(ty: ArgType, name: str) -> str: ArgType.INT: lambda x: x, ArgType.FLOAT: lambda x: x, ArgType.SYM_INT: lambda x: x, + ArgType.BOOL: lambda x: x, }[ty](name) def make_function_arg( - ty: ArgType, name: str, default: Optional[Union[int, float]] + ty: ArgType, + name: str, + default: Optional[Union[int, float]], + is_optional: bool = False, ) -> str: return { - ArgType.TENSOR: tensor_arg, - ArgType.INT_TENSOR: tensor_arg, - ArgType.LONG_TENSOR: tensor_arg, - ArgType.PLACEHOLDER_TENSOR: tensor_arg, + ArgType.TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.INT_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.LONG_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), + ArgType.PLACEHOLDER_TENSOR: ( + (lambda x: tensor_arg(x)) + if not is_optional + else ( + optional_tensor_arg + if default is not None + else optional_tensor_arg_no_default + ) + ), ArgType.INT: ( (lambda x: int64_arg(x, default=int(default))) if default is not None @@ -300,6 +365,11 @@ def make_function_arg( if default is not None else sym_int_arg_no_default ), + ArgType.BOOL: ( + (lambda x: bool_arg(x, default=bool(default))) + if default is not None + else bool_arg_no_default + ), }[ty](name) @@ -313,10 +383,11 @@ def make_function_schema_arg(ty: ArgType, name: str, default: Union[int, float]) ArgType.FLOAT: lambda x: float_arg(x, default=default), # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[float, int]`. ArgType.SYM_INT: lambda x: schema_sym_int_arg(x, default=default), + ArgType.BOOL: lambda x: schema_bool_arg(x, default=bool(default)), }[ty](name) -def _extend_tensor_str(name: str, is_cuda: bool) -> str: +def _extend_tensor_str(name: str, is_cuda: bool, optional: bool) -> str: """ Take a tensor name and extend for cpu or cuda @@ -327,10 +398,12 @@ def _extend_tensor_str(name: str, is_cuda: bool) -> str: Returns: String of extended tensors """ + opt = "?" if optional else "" + default = " = None" if optional else "" if is_cuda: - return f"Tensor {name}_dev, Tensor {name}_uvm, Tensor {name}_placements, Tensor {name}_offsets" + return f"Tensor{opt} {name}_dev {default}, Tensor{opt} {name}_uvm {default}, Tensor{opt} {name}_placements {default}, Tensor{opt} {name}_offsets {default}" else: - return f"Tensor {name}_host, Tensor {name}_placements, Tensor {name}_offsets" + return f"Tensor{opt} {name}_host {default}, Tensor{opt} {name}_placements {default}, Tensor{opt} {name}_offsets {default}" def extend_tensors_args_from_str(args_str: str, example_tensor: str) -> str: @@ -350,13 +423,18 @@ def extend_tensors_args_from_str(args_str: str, example_tensor: str) -> str: num_tensors = args_str.count("Tensor") if num_tensors > 0: is_cuda = "_dev" in example_tensor - args = args_str.split(", ", num_tensors) - tensors_args = args[:num_tensors] - non_tensors_args = args[-1] - extended_tensors_args = [ - _extend_tensor_str(t.split(" ")[1], is_cuda) for t in tensors_args - ] - return ", ".join(extended_tensors_args + [non_tensors_args]) + args = args_str.split(", ") + extended_tensors_args = [] + for arg in args: + ty = arg.split(" ")[0] + name = arg.split(" ")[1] + if ty == "Tensor": + extended_tensors_args.append(_extend_tensor_str(name, is_cuda, False)) + elif ty == "Tensor?": + extended_tensors_args.append(_extend_tensor_str(name, is_cuda, True)) + else: + extended_tensors_args.append(arg) + return ", ".join(extended_tensors_args) else: return args_str @@ -378,6 +456,9 @@ def make_split_function_args_v1(args_str: str) -> str: args_str.replace("int", "int64_t") .replace("SymInt", "c10::SymInt") .replace("float", "double") + .replace("Tensor?", "std::optional") + .replace("None", "std::nullopt") + .replace("False", "false") ) @@ -386,20 +467,49 @@ def make_ivalue_cast(ty: ArgType) -> str: ArgType.INT: "toInt", ArgType.FLOAT: "toDouble", ArgType.SYM_INT: "toSymInt", + ArgType.BOOL: "toBool", }[ty] +def reorder_args(split_arg_spec: List[OptimItem]) -> List[OptimItem]: + """ + Reorder such that tensor arguments come first. This is used in backend, wrapper and kernels where tensors are no longer optional. + We need to pass tensor arguments before other types which have default arguments. + + Parameters: + split_arg_spec (List[OptimItem]): List of argument items + + Return: + reordered of split_arg_spec + """ + tensor_args = [] + non_tensor_args = [] + for s in split_arg_spec: + if s.ty in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ): + tensor_args.append(s) + else: + non_tensor_args.append(s) + + return tensor_args + non_tensor_args + + @dataclass class PT2ArgsSet: split_function_args: List[str] split_function_arg_names: List[str] split_function_schemas: List[str] - split_saved_tensor_list: List[str] + split_saved_tensorlist: List[str] + split_saved_tensorlist_optional: List[str] @staticmethod # pyre-ignore[3] def create( - split_arg_spec: List[OptimItem], + arg_spec: List[OptimItem], ): """ PT2ArgsSet.create() is a method that creates different formats given the optimization arguments @@ -410,24 +520,28 @@ def create( e.g., instead of passing `momentum_host, `momentum_dev`, etc, we pass `momentum` Parameters: - split_arg_spec: List[OptimItem] - list of argument specs + arg_spec: List[OptimItem] - list of argument specs Returns: PT2ArgsSet object with the following attributes: - split_function_args: List[str] - List of function arguments + split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions + Tensors will be packed and pass as TensorList e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. - split_function_arg_names: List[str] - List of argument names + split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions e.g., ['momentum1', 'eps', 'weight_decay']. - split_function_schemas: List[str] - List of arguments in the schema format + split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. - split_saved_tensor_list: List[str] - List of saved tensors for the split function - e.g., ['momentum1']. + split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in + PT2 autograd function. e.g., ['momentum1']. + split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional + and will be unpacked in PT2 autograd function e.g., ['row_counter']. """ split_function_arg_names = [] split_function_args = [] split_function_schemas = [] - split_saved_tensor_list = [] - for s in split_arg_spec: + split_saved_tensorlist = [] + split_saved_tensorlist_optional = [] + for s in arg_spec: if s.name == "learning_rate_tensor": split_function_arg_names.append(s.name) split_function_args.append(tensor_arg(s.name)) @@ -438,16 +552,20 @@ def create( ArgType.LONG_TENSOR, ArgType.PLACEHOLDER_TENSOR, ): - name = s.name.rsplit("_", 1)[0] - if name not in split_function_arg_names: - split_function_arg_names.append(name) - split_saved_tensor_list.append(name) + name = s.name + split_function_arg_names.append(name) + if s.is_optional: + split_function_args.append(optional_tensorlist_arg(name)) + split_function_schemas.append(schema_optional_tensorlist_arg(name)) + split_saved_tensorlist_optional.append(name) + else: split_function_args.append( tensor_list_arg_no_default(name, pass_by_ref=False) ) split_function_schemas.append( schema_tensor_list_arg_no_default(name) ) + split_saved_tensorlist.append(name) else: split_function_arg_names.append(s.name) split_function_args.append(make_function_arg(s.ty, s.name, s.default)) @@ -458,7 +576,8 @@ def create( split_function_args=split_function_args, split_function_arg_names=split_function_arg_names, split_function_schemas=split_function_schemas, - split_saved_tensor_list=split_saved_tensor_list, + split_saved_tensorlist=split_saved_tensorlist, + split_saved_tensorlist_optional=split_saved_tensorlist_optional, ) @@ -489,6 +608,9 @@ class OptimizerArgs: placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]] unified_pt2: PT2ArgsSet split_kernel_arg_names: List[str] + split_function_args_autograd: List[str] + split_function_arg_names_autograd: List[str] + split_saved_tensors_optional: List[str] split_function_args_v1: Optional[str] = None split_function_schemas_v1: Optional[str] = None @@ -499,6 +621,30 @@ def create( arg_spec: List[OptimItem], additional_spec: Optional[dict[str, Any]] = None, ): + # Keep the argument order for forward/backward compatibility + # Arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors + # This is used in lookup and autograd functions + frontend_split_arg_spec = split_arg_spec.copy() + + # Create another spec for kernels where learning_rate is float + # This is used in kernels + kernel_split_arg_spec = split_arg_spec.copy() + for i, s in enumerate(kernel_split_arg_spec): + if s.name == "learning_rate_tensor": + kernel_split_arg_spec[i] = OptimItem(ArgType.FLOAT, "learning_rate") + break + + # Optional tensors are converted to tensor in autograd functions + # Reorganize arguments for wrapper, backend and kernel functions + if additional_spec is not None and "has_optional_tensors" in additional_spec: + assert additional_spec[ + "has_optional_tensors" + ], "`has_optional_tensors` should be set to True, otherwise please remove it from additional_spec" + # Arg order: non-optional tensors, learning_rate_tensor, optional tensors as tensors, non-tensors, + split_arg_spec = reorder_args(split_arg_spec) + # Arg order: non-optional tensors, optional tensors as tensors, learning rate (float), non-tensors + kernel_split_arg_spec = reorder_args(kernel_split_arg_spec) + # Compute placeholder tensor combinations ph_tensor_names = [ s.name for s in arg_spec if s.ty == ArgType.PLACEHOLDER_TENSOR @@ -529,6 +675,33 @@ def create( ArgType.PLACEHOLDER_TENSOR, ) ] + # Create empty tensors based on weights + # weights name convention is different between v1 and pt2 unified interface (v2) + # i.e., host_weights, dev_weights uvm_weights, weights_placements, weights_offsets in v1 and weights_{} in v2 + # This is only used in v1, so we fix the name based on v1 + create_empty_tensor = { + "host": ".value_or(at::empty({0}, host_weights.options()))", + "dev": ".value_or(at::empty({0}, dev_weights.options()))", + "uvm": ".value_or(at::empty({0}, uvm_weights.options()))", + "placements": ".value_or(at::empty({0}, weights_placements.options()))", + "offsets": ".value_or(at::empty({0}, weights_offsets.options()))", + } + split_saved_tensors_optional = [ + ( + s.name + create_empty_tensor[s.name.rsplit("_", 1)[1]] + if s.is_optional + else s.name + ) + for s in split_arg_spec + if s.ty + in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ) + ] + # Create function args and schemas for V1 interface for backward compatibility # V1 interface refers to separate CPU/CUDA lookup functions # e.g., split_embedding_codegen_lookup_{}_funtion and split_embedding_codegen_lookup_{}_funtion_cpu) @@ -548,20 +721,22 @@ def create( return OptimizerArgs( # GPU kernel args split_kernel_args=[ - make_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec + make_kernel_arg(s.ty, s.name, s.default) for s in kernel_split_arg_spec ], split_kernel_args_no_defaults=[ - make_kernel_arg(s.ty, s.name, None) for s in split_arg_spec + make_kernel_arg(s.ty, s.name, None) for s in kernel_split_arg_spec ], split_kernel_arg_constructors=[ - make_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec + make_kernel_arg_constructor(s.ty, s.name) for s in kernel_split_arg_spec ], # CPU kernel args split_cpu_kernel_args=[ - make_cpu_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec + make_cpu_kernel_arg(s.ty, s.name, s.default) + for s in kernel_split_arg_spec ], split_cpu_kernel_arg_constructors=[ - make_cpu_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec + make_cpu_kernel_arg_constructor(s.ty, s.name) + for s in kernel_split_arg_spec ], # Function args split_function_args=[ @@ -574,8 +749,11 @@ def create( split_tensors=[ s.name for s in arg_spec - if (s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR)) - and s.name != "learning_rate_tensor" + if ( + s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + and s.name != "learning_rate_tensor" + and not s.is_optional + ) ], split_tensor_types={ s.name: ( @@ -584,7 +762,11 @@ def create( else (s.name + "_ph_t") ) for s in arg_spec - if s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + if ( + s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR) + and s.name != "learning_rate_tensor" + and not s.is_optional + ) }, split_saved_tensors=split_saved_tensors, saved_data=[ @@ -600,16 +782,22 @@ def create( split_variables=["Variable()" for _ in split_arg_spec], split_ref_kernel_args=[ make_kernel_arg(s.ty, s.name, s.default, pass_by_ref=True) - for s in split_arg_spec + for s in kernel_split_arg_spec ], placeholder_tensor_names=ph_tensor_names, placeholder_type_combos=ph_combos, - unified_pt2=PT2ArgsSet.create(split_arg_spec), + unified_pt2=PT2ArgsSet.create(arg_spec), # learning rate remains float in kernels split_kernel_arg_names=[ "learning_rate" if s.name == "learning_rate_tensor" else s.name - for s in split_arg_spec + for s in kernel_split_arg_spec ], + split_function_args_autograd=[ + make_function_arg(s.ty, s.name, s.default, s.is_optional) + for s in frontend_split_arg_spec + ], + split_function_arg_names_autograd=[s.name for s in frontend_split_arg_spec], + split_saved_tensors_optional=split_saved_tensors_optional, split_function_args_v1=split_function_args_v1, split_function_schemas_v1=split_function_schemas_v1, ) @@ -636,7 +824,7 @@ def create_optim_args( for s in arg_spec: # no cpu/cuda extension for learning_rate if ( - s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT) + s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT, ArgType.BOOL) or s.name == "learning_rate_tensor" ): # pyre-fixme[19]: Expected 1 positional argument. @@ -651,13 +839,21 @@ def create_optim_args( def extend_for_cpu(spec: OptimItem) -> List[OptimItem]: name = spec.name default = spec.default + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.TENSOR, f"{name}_host", default), + OptimItem(ArgType.TENSOR, f"{name}_host", default, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod @@ -666,15 +862,23 @@ def extend_for_cuda(spec: OptimItem) -> List[OptimItem]: default = spec.default ty = spec.ty ph_tys = spec.ph_tys + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_dev", default, ph_tys), + OptimItem(ty, f"{name}_dev", default, ph_tys, is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_uvm", default, ph_tys), + OptimItem(ty, f"{name}_uvm", default, ph_tys, is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod @@ -683,17 +887,25 @@ def extend_for_any(spec: OptimItem) -> List[OptimItem]: default = spec.default ty = spec.ty ph_tys = spec.ph_tys + is_optional = spec.is_optional return [ # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.TENSOR, f"{name}_host", default), + OptimItem(ArgType.TENSOR, f"{name}_host", default, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_dev", default, ph_tys), + OptimItem(ty, f"{name}_dev", default, ph_tys, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ty, f"{name}_uvm", default, ph_tys), + OptimItem(ty, f"{name}_uvm", default, ph_tys, is_optional=is_optional), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default), + OptimItem( + ArgType.INT_TENSOR, + f"{name}_placements", + default, + is_optional=is_optional, + ), # pyre-fixme[19]: Expected 1 positional argument. - OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default), + OptimItem( + ArgType.LONG_TENSOR, f"{name}_offsets", default, is_optional=is_optional + ), ] @staticmethod diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index a17506131..9988855da 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1001,6 +1001,33 @@ def partial_rowwise_lamb() -> Dict[str, Any]: def adam() -> Dict[str, Any]: + split_precomputation = """ + at::acc_type* __restrict__ row_counter; + if (use_rowwise_bias_correction) { + const auto row_counter_placement = static_cast(row_counter_placements[t]); + const int64_t row_counter_offset = row_counter_offsets[t]; + if (row_counter_placement == PlacementType::DEVICE) { + row_counter = &row_counter_dev[row_counter_offset]; + } else { + row_counter = &row_counter_uvm[row_counter_offset]; + } + } + at::acc_type _row_counter = 0; + + // need to compute bias correction for each row + if (threadIdx.x == 0 && use_rowwise_bias_correction) { + _row_counter = row_counter[idx] + 1; + row_counter[idx] = _row_counter; + } + // broadcast bias correction to all threads + if (use_rowwise_bias_correction){ + _row_counter = SHFL_SYNC(_row_counter, 0); + } + else{ + _row_counter = iter; + } + """ + split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x *= beta1; @@ -1023,10 +1050,10 @@ def adam() -> Dict[str, Any]: v_t.fma_(grad, 1.0 - beta2); v_t.store(&momentum2[idx * D + d]); - weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x); - weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y); - weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z); - weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w); + weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.x); + weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.y); + weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.z); + weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.w); """ split_weight_update_cpu = "" # TODO @@ -1043,12 +1070,15 @@ def adam() -> Dict[str, Any]: OptimItem(ArgType.FLOAT, "beta2"), OptimItem(ArgType.FLOAT, "weight_decay"), OptimItem(ArgType.INT, "iter"), + OptimItem(ArgType.BOOL, "use_rowwise_bias_correction"), + OptimItem(ArgType.TENSOR, "row_counter", is_optional=True), ], { - "v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0" + "v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0, bool use_rowwise_bias_correction = False, Tensor? row_counter = None", + "has_optional_tensors": True, }, ), - "split_precomputation": "", + "split_precomputation": split_precomputation, "split_weight_update": split_weight_update, "split_post_update": "", "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/genscript/torch_type_utils.py b/fbgemm_gpu/codegen/genscript/torch_type_utils.py index ebd4e0220..aa442ad37 100644 --- a/fbgemm_gpu/codegen/genscript/torch_type_utils.py +++ b/fbgemm_gpu/codegen/genscript/torch_type_utils.py @@ -26,6 +26,7 @@ class ArgType(IntEnum): INT = 7 FLOAT = 8 SYM_INT = 9 + BOOL = 10 @dataclass diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp index e3b459a7b..fc7d8a58f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp @@ -64,14 +64,14 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< bool gradient_clipping, double max_gradient, bool stochastic_rounding, - {{ args.split_function_args | join(", ") }}, + {{ args.split_function_args_autograd | join(", ") }}, int64_t output_dtype = static_cast(SparseType::FP32)) { Tensor indice_weights_value = indice_weights.value_or(Tensor()); Tensor feature_requires_grad_value = feature_requires_grad.value_or(Tensor()); ctx->save_for_backward({ host_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, - indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors | join(", ") }} }); + indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors_optional | join(", ") }} }); ctx->saved_data["total_D"] = total_D; ctx->saved_data["max_D"] = max_D; @@ -242,7 +242,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( gradient_clipping, max_gradient, stochastic_rounding, - {{ args.split_function_arg_names | join(", ") }}, + {{ args.split_function_arg_names_autograd | join(", ") }}, output_dtype)[0]; {% else %} TORCH_CHECK(false, "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail."); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index fb5c6f0e7..adf8f19bb 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -360,7 +360,7 @@ enum SSDTensor { {%- if ssd %} ssd_tensors.value(), {%- endif %} - {{ args.split_function_arg_names | join(", ") }} + {{ args.split_function_arg_names_autograd | join(", ") }} {%- endif %} )[0]; {%- endmacro %} @@ -618,7 +618,7 @@ class {{ autograd_func }} : {%- if ssd %} const at::TensorList& ssd_tensors, {%- endif %} - {{ args.split_function_args | join(", ") }} + {{ args.split_function_args_autograd | join(", ") }} {%- else %} {%- if vbe %} const std::optional& B_offsets, @@ -757,7 +757,7 @@ class {{ autograd_func }} : ssd_tensors[SSDTensor::{{ tensor | upper }}], {%- endfor %} {%- endif %} - {{ args.split_saved_tensors | join(", ") }} + {{ args.split_saved_tensors_optional | join(", ") }} }); {%- if not nobag %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index dedcd5f91..d8697f366 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -476,12 +476,35 @@ enum SSDTensor { /* This macro generates a code blob for unpacking the tensor list */ -{%- macro unpack_tensor_list(tensor_list) %} - const Tensor {{ tensor_list }}_host = {{ tensor_list }}[0]; - const Tensor {{ tensor_list }}_dev = {{ tensor_list }}[1]; - const Tensor {{ tensor_list }}_uvm = {{ tensor_list }}[2]; - const Tensor {{ tensor_list }}_placements = {{ tensor_list }}[3]; - const Tensor {{ tensor_list }}_offsets = {{ tensor_list }}[4]; +{%- macro unpack_tensorlist(name) %} + const Tensor {{ name }}_host = {{ name }}[0]; + const Tensor {{ name }}_dev = {{ name }}[1]; + const Tensor {{ name }}_uvm = {{ name }}[2]; + const Tensor {{ name }}_placements = {{ name }}[3]; + const Tensor {{ name }}_offsets = {{ name }}[4]; +{%- endmacro %} + +{%- macro unpack_tensorlist_optional(name) %} + Tensor {{ name }}_host; + Tensor {{ name }}_dev; + Tensor {{ name }}_uvm; + Tensor {{ name }}_placements; + Tensor {{ name }}_offsets; + if ({{ name }}.has_value()) { + at::TensorList _{{ name }} = {{ name }}.value(); + {{ name }}_host = _{{ name }}[0]; + {{ name }}_dev = _{{ name }}[1]; + {{ name }}_uvm = _{{ name }}[2]; + {{ name }}_placements = _{{ name }}[3]; + {{ name }}_offsets = _{{ name }}[4]; + } + else{ + {{ name }}_host = at::empty({0}, weights_host.options()); + {{ name }}_dev = at::empty({0}, weights_dev.options()); + {{ name }}_uvm = at::empty({0}, weights_uvm.options()); + {{ name }}_placements = at::empty({0}, weights_placements.options()); + {{ name }}_offsets = at::empty({0}, weights_offsets.options()); + } {%- endmacro %} @@ -581,9 +604,12 @@ class {{ autograd_func }} : {{ args_pt2.unified_pt2.split_function_args | join(", ") }}) { // unpack Tensor lists - {{ unpack_tensor_list("weights") }} - {%- for arg_name in args_pt2.unified_pt2.split_saved_tensor_list %} - {{ unpack_tensor_list(arg_name) }} + {{ unpack_tensorlist("weights") }} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist %} + {{ unpack_tensorlist(arg_name) }} + {%- endfor %} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist_optional %} + {{ unpack_tensorlist_optional(arg_name) }} {%- endfor %} const auto T = weights_offsets.sym_numel(); diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e86b27b2d..c69837291 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -60,7 +60,7 @@ def invoke( {%- if "prev_iter_dev" in args.split_function_arg_names %} prev_iter: Momentum, {%- endif %} - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter: Momentum, {%- endif %} {%- if "iter" in args.split_function_arg_names %} @@ -209,7 +209,7 @@ def invoke( prev_iter_placements=prev_iter.placements, {%- endif %} # row_counter - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter_host=row_counter.host, row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, @@ -387,7 +387,7 @@ def invoke( prev_iter_dev=prev_iter_dev, {%- endif %} # row_counter - {%- if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %} row_counter_dev=row_counter.dev, row_counter_uvm=row_counter.uvm, row_counter_offsets=row_counter.offsets, diff --git a/fbgemm_gpu/test/tbe/training/forward_test.py b/fbgemm_gpu/test/tbe/training/forward_test.py index 5ea2ff723..e4e54fc99 100644 --- a/fbgemm_gpu/test/tbe/training/forward_test.py +++ b/fbgemm_gpu/test/tbe/training/forward_test.py @@ -76,6 +76,13 @@ "test_faketensor__test_forward_gpu_uvm_cache_int8": [ unittest.skip("Operator not implemented for Meta tensors"), ], + # learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel + # this fails fake_tensor test as the test expects all tensors to be on the same device + "test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [ + unittest.skip( + "Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors" + ), + ], } )