diff --git a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h index 1617b01..4369441 100644 --- a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h +++ b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h @@ -109,7 +109,7 @@ namespace XUSG enum class TensorFlag { NONE, - MANAGED, + MANAGED }; DEFINE_ENUM_FLAG_OPERATORS(TensorFlag); @@ -119,7 +119,7 @@ namespace XUSG NONE, ALLOW_HALF_PRECISION_COMPUTATION, DISABLE_META_COMMANDS, - DESCRIPTORS_VOLATILE, + DESCRIPTORS_VOLATILE }; DEFINE_ENUM_FLAG_OPERATORS(ExecutionFlag); @@ -521,6 +521,54 @@ namespace XUSG uint32_t K; }; + struct BatchNormalization + + { + const Tensor* pInput; + const Tensor* pMean; + const Tensor* pVariance; + const Tensor* pScale; + const Tensor* pBias; + const Tensor* pOutput; + bool Spatial; + float Epsilon; + OperatorType FusedActivationType; + const void* pFusedActivation; + }; + + struct MeanVarianceNormalization + { + const Tensor* pInput; + const Tensor* pScale; + const Tensor* pBias; + const Tensor* pOutput; + bool CrossChannel; + bool NormalizeVariance; + float Epsilon; + OperatorType FusedActivationType; + const void* pFusedActivation; + }; + + struct LocalResponseNormalization + { + const Tensor* pInput; + const Tensor* pOutput; + bool CrossChannel; + uint32_t LocalSize; + float Alpha; + float Beta; + float Bias; + }; + + struct LPNormalization + { + const Tensor* pInput; + const Tensor* pOutput; + uint32_t Axis; + float Epsilon; + uint32_t P; + }; + //-------------------------------------------------------------------------------------- // Device //-------------------------------------------------------------------------------------- diff --git a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp index 660a99c..7e21b63 100644 --- a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp +++ b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp @@ -692,6 +692,100 @@ void ML::GetDMLTypedOperator(vector& dmlTypedOpDesc, OperatorType type, pDMLDesc->K = desc.K; }; + static const auto getDMLBatchNormalization = [](vector& dmlTypedOpDesc, const void* pOpDesc) + { + const auto& desc = *static_cast(pOpDesc); + + vector typedFused(0); + if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation); + + dmlTypedOpDesc.resize(sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC) + + (desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0)); + const auto pDMLDesc = reinterpret_cast(dmlTypedOpDesc.data()); + const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast( + &dmlTypedOpDesc[sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC)]) : nullptr; + + pDMLDesc->InputTensor = desc.pInput ? static_cast(desc.pInput->GetHandle()) : nullptr; + pDMLDesc->MeanTensor = desc.pMean ? static_cast(desc.pMean->GetHandle()) : nullptr; + pDMLDesc->VarianceTensor = desc.pVariance ? static_cast(desc.pVariance->GetHandle()) : nullptr; + pDMLDesc->ScaleTensor = desc.pScale ? static_cast(desc.pScale->GetHandle()) : nullptr; + pDMLDesc->BiasTensor = desc.pBias ? static_cast(desc.pBias->GetHandle()) : nullptr; + pDMLDesc->OutputTensor = desc.pOutput ? static_cast(desc.pOutput->GetHandle()) : nullptr; + pDMLDesc->Spatial = desc.Spatial; + pDMLDesc->Epsilon = desc.Epsilon; + pDMLDesc->FusedActivation = pDMLFused; + + if (pDMLFused) + { + assert(desc.pFusedActivation); + const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC); + pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType); + pDMLFused->Desc = &dmlTypedOpDesc[offset]; + memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size()); + } + }; + + static const auto getDMLMeanVarianceNormalization = [](vector& dmlTypedOpDesc, const void* pOpDesc) + { + const auto& desc = *static_cast(pOpDesc); + + vector typedFused(0); + if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation); + + dmlTypedOpDesc.resize(sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC) + + (desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0)); + const auto pDMLDesc = reinterpret_cast(dmlTypedOpDesc.data()); + const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast( + &dmlTypedOpDesc[sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC)]) : nullptr; + + pDMLDesc->InputTensor = desc.pInput ? static_cast(desc.pInput->GetHandle()) : nullptr; + pDMLDesc->ScaleTensor = desc.pScale ? static_cast(desc.pScale->GetHandle()) : nullptr; + pDMLDesc->BiasTensor = desc.pBias ? static_cast(desc.pBias->GetHandle()) : nullptr; + pDMLDesc->OutputTensor = desc.pOutput ? static_cast(desc.pOutput->GetHandle()) : nullptr; + pDMLDesc->CrossChannel = desc.CrossChannel; + pDMLDesc->NormalizeVariance = desc.NormalizeVariance; + pDMLDesc->Epsilon = desc.Epsilon; + pDMLDesc->FusedActivation = pDMLFused; + + if (pDMLFused) + { + assert(desc.pFusedActivation); + const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC); + pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType); + pDMLFused->Desc = &dmlTypedOpDesc[offset]; + memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size()); + } + }; + + static const auto getDMLLocalResponseNormalization = [](vector& dmlTypedOpDesc, const void* pOpDesc) + { + dmlTypedOpDesc.resize(sizeof(DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC)); + const auto pDMLDesc = reinterpret_cast(dmlTypedOpDesc.data()); + const auto& desc = *static_cast(pOpDesc); + + pDMLDesc->InputTensor = desc.pInput ? static_cast(desc.pInput->GetHandle()) : nullptr; + pDMLDesc->OutputTensor = desc.pOutput ? static_cast(desc.pOutput->GetHandle()) : nullptr; + pDMLDesc->CrossChannel = desc.CrossChannel; + pDMLDesc->LocalSize = desc.LocalSize; + pDMLDesc->Alpha = desc.Alpha; + pDMLDesc->Beta = desc.Beta; + pDMLDesc->Bias = desc.Bias; + }; + + static const auto getDMLLPNormalization = [](vector& dmlTypedOpDesc, const void* pOpDesc) + { + dmlTypedOpDesc.resize(sizeof(DML_LP_NORMALIZATION_OPERATOR_DESC)); + const auto pDMLDesc = reinterpret_cast(dmlTypedOpDesc.data()); + const auto& desc = *static_cast(pOpDesc); + + pDMLDesc->InputTensor = desc.pInput ? static_cast(desc.pInput->GetHandle()) : nullptr; + pDMLDesc->OutputTensor = desc.pOutput ? static_cast(desc.pOutput->GetHandle()) : nullptr; + pDMLDesc->Axis = desc.Axis; + pDMLDesc->Epsilon = desc.Epsilon; + pDMLDesc->P = desc.P; + }; + + static const function&, const void*)> pfnGetDMLOps[] = { nullptr, // INVALID @@ -768,6 +862,11 @@ void ML::GetDMLTypedOperator(vector& dmlTypedOpDesc, OperatorType type, getDMLSpaceDepth, // DEPTH_TO_SPACE getDMLTile, // TILE getDMLTopK, // TOP_K + + getDMLBatchNormalization, // BATCH_NORMALIZATION + getDMLMeanVarianceNormalization, // MEAN_VARIANCE_NORMALIZATION + getDMLLocalResponseNormalization, // LOCAL_RESPONSE_NORMALIZATION + getDMLLPNormalization, // LP_NORMALIZATION }; pfnGetDMLOps[static_cast(type)](dmlTypedOpDesc, pOpDesc); diff --git a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h index a6cb979..e456eb1 100644 --- a/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h +++ b/XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h @@ -47,11 +47,6 @@ namespace XUSG com_ptr m_device; }; - using BatchNormalization = DML_BATCH_NORMALIZATION_OPERATOR_DESC; - using MeanVarianceNormalization = DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC; - using LocalResponseNormalization = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC; - using LPNormalization = DML_LP_NORMALIZATION_OPERATOR_DESC; - using RNNOperator = DML_RNN_OPERATOR_DESC; using LSTMOperator = DML_LSTM_OPERATOR_DESC; using GRUOperator = DML_GRU_OPERATOR_DESC;