forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FusedAdamKernel.cu
91 lines (84 loc) · 3.98 KB
/
FusedAdamKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TypeDefault.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/fused_adam_amsgrad_impl.cuh>
#include <ATen/native/cuda/fused_adam_impl.cuh>
#include <c10/util/Exception.h>
namespace at::native {
// note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile, defensively split instantiations into _impl files.
// this is only for CUDA 11.3 for which it took about 20 minutes and 28 minutes in my workstation and CI, respectively.
// As a data point, it took about 20 seconds for CUDA 11.7 installed in my environment.
// See https://github.com/pytorch/pytorch/pull/81705 for details.
void _fused_adam_kernel_cuda_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
) {
if (amsgrad) {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_amsgrad_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
} else {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
}
}
// The following overload simply has a Tensor lr
void _fused_adam_kernel_cuda_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf
) {
if (lr.is_cpu()) {
_fused_adam_kernel_cuda_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
return;
}
// Manually check devices since we specify no device check in native_functions.yaml
Device param_device = params[0].device();
if (grad_scale != c10::nullopt) {
TORCH_CHECK(grad_scale->device() == param_device, "grad_scale must be on the same GPU device as the params");
}
if (found_inf != c10::nullopt) {
TORCH_CHECK(found_inf->device() == param_device, "found_inf must be on the same GPU device as the params");
}
TORCH_CHECK(lr.device() == param_device, "lr must be on the same GPU device as the params");
if (amsgrad) {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_amsgrad_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
} else {
TORCH_CHECK(
at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_cuda_impl_(params, grads, exp_avgs, exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, maximize, grad_scale, found_inf);
}
}
} // namespace at::native