diff --git a/include/ffconst.h b/include/ffconst.h index 2df3b4585a..886c6d97e1 100644 --- a/include/ffconst.h +++ b/include/ffconst.h @@ -136,6 +136,7 @@ enum OperatorType { OP_GELU, OP_MULTIHEAD_ATTENTION, OP_FUSED, // Fused operator type for internal fusion optimizations + OP_RSQRT, //https://pytorch.org/docs/stable/generated/torch.rsqrt.html }; #endif // _FLEXFLOW_CONST_H_ diff --git a/include/model.h b/include/model.h index d5e9b20845..8ed764723b 100644 --- a/include/model.h +++ b/include/model.h @@ -314,6 +314,10 @@ class FFModel { const Tensor& y, bool inplace_a = false, char const *name = NULL); + // Add a rsqrt layer + Tensor rsqrt(const Tensor& x, + bool inplace = true, + char const *name = NULL); // Add a scalar operation layer Tensor scalar_multiply(const Tensor& x, const float scalar, diff --git a/src/ops/element_unary.cu b/src/ops/element_unary.cu index 73a1df4e72..a763902381 100644 --- a/src/ops/element_unary.cu +++ b/src/ops/element_unary.cu @@ -85,6 +85,11 @@ Tensor FFModel::elu(const Tensor& x, bool inplace, const char *name) return this->unary(OP_ELU, x, inplace, name); } +Tensor FFModel::rsqrt(const Tensor& x, bool inplace, const char *name) +{ + return this->unary(OP_RSQRT, x, inplace, name); +} + ElementUnary::ElementUnary(FFModel& model, OperatorType _op_type, const Tensor& x, @@ -342,6 +347,11 @@ void elewise_unary_forward_kernel(coord_t volume, out[i] = in[i] * 0.5 * erfc(-in[i]*M_SQRT1_2); break; } + case OP_RSQRT: + { + out[i] = 1.0f / sqrt(in[i]); + break; + } default: assert(false); } @@ -459,6 +469,7 @@ void elewise_unary_backward_kernel(coord_t volume, const float beta, const float scalar, OperatorType type, + const float* output, const float* output_grad, const float* input, float* input_grad) @@ -502,6 +513,11 @@ void elewise_unary_backward_kernel(coord_t volume, input_grad[i] = output_grad[i]*(0.5 * erfc(-input[i]*M_SQRT1_2)-0.5*M_SQRT1_2*input[i]*exp(-input[i]*input[i]*0.5)); break; } + case OP_RSQRT: + { + input_grad[i] = -1.0f * output_grad[i] * output[i] * output[i] * output[i]; + break; + } default: assert(false); } @@ -526,7 +542,7 @@ void ElementUnary::backward_kernel(const ElementUnaryMeta* m, m->inputTensor, input_ptr, &alpha, m->inputTensor, input_grad_ptr)); } else { elewise_unary_backward_kernel<<>>( - num_elements, alpha, alpha, m->scalar, m->op_type, output_grad_ptr, input_ptr, input_grad_ptr); + num_elements, alpha, alpha, m->scalar, m->op_type, output_ptr, output_grad_ptr, input_ptr, input_grad_ptr); } }