From 1f36b92c879c5f0193e9a2a4efcbb9a58b74d391 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 30 Jun 2023 13:39:57 +0800 Subject: [PATCH] upstream static --- static/csrc/model_container.cpp | 6 +++--- static/include/cuda_device_functions.h | 2 +- static/include/debug_utility.h | 1 - static/include/macros.h | 2 +- static/include/model.h | 29 ++++++++++++++++++++++---- static/include/rocm_device_functions.h | 15 ++++++------- 6 files changed, 38 insertions(+), 17 deletions(-) diff --git a/static/csrc/model_container.cpp b/static/csrc/model_container.cpp index 5548a97f0..085ad0b58 100644 --- a/static/csrc/model_container.cpp +++ b/static/csrc/model_container.cpp @@ -403,9 +403,9 @@ void ModelContainer::SetConstantImpl( ". Check that the provided tensor's shape is correct."); } } else { - throw std::runtime_error( - std::string("Called SetConstant on ") + name + - std::string(" but can't find in either bound or unbound constant set")); + LOG(WARNING) << "Called SetConstant on " << name + << " but can't find in either bound or unbound constant set"; + return; } auto* src = tensor.ptr; diff --git a/static/include/cuda_device_functions.h b/static/include/cuda_device_functions.h index b03e25f78..4d2c3f463 100644 --- a/static/include/cuda_device_functions.h +++ b/static/include/cuda_device_functions.h @@ -403,7 +403,7 @@ inline DeviceError QueryEvent(EventType event) { return cudaEventQuery(event); } -inline const char* GetErrorString(DeviceError err) { +inline std::string GetErrorString(DeviceError err) { return cudaGetErrorString(err); } diff --git a/static/include/debug_utility.h b/static/include/debug_utility.h index 332f07890..d5f7ce65c 100644 --- a/static/include/debug_utility.h +++ b/static/include/debug_utility.h @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "device_functions-generated.h" diff --git a/static/include/macros.h b/static/include/macros.h index 59fcde94b..462b25255 100644 --- a/static/include/macros.h +++ b/static/include/macros.h @@ -22,7 +22,7 @@ #define DEVICE_CHECK(call) \ if ((call) != GetDeviceSuccess()) { \ throw std::runtime_error( \ - #call " API call failed: " + GetLastErrorString() + " at " + \ + #call " API call failed: " + GetErrorString(call) + " at " + \ __FILE__ + ", line" + std::to_string(__LINE__)); \ } diff --git a/static/include/model.h b/static/include/model.h index 759fd95c7..3963724df 100644 --- a/static/include/model.h +++ b/static/include/model.h @@ -22,8 +22,7 @@ namespace ait { inline void DeviceCheckLastError(const char* file, int line) { auto device_error = GetLastError(); if (device_error != GetDeviceSuccess()) { - std::string msg = std::string("Got error: ") + - cudaGetErrorString(device_error) + + std::string msg = std::string("Got error: ") + GetErrorString(device_error) + " enum: " + std::to_string(device_error) + " at " + file + ": " + std::to_string(line); LOG(ERROR) << msg; @@ -217,6 +216,29 @@ class ModelBase { } void RunAsGraph(StreamType stream) { +#ifdef __HIP_PLATFORM_HCC__ + if (graph_exec_ == nullptr) { + DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false)); + try { + static_cast(this)->RunImpl(graph_capture_stream_); + } catch (...) { + GraphType graph; + // No need to DEVICE_CHECK here, we want to see the original exception. + EndCapture(&graph); + if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) { + LOG(WARNING) + << "Graph destruction failed while handling exception! Memory will be leaked."; + } + throw; + } + // The following function ends the capture and creates a graph + // inside a unique_ptr that cleans up it when it goes out of scope. + // Note that it throws an exception if EndCapture fails. + auto graph = RAII_EndCaptureAndCreateGraph( + [this](GraphType* graph_ptr) { return EndCapture(graph_ptr); }); + DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get())); + } +#else DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false)); try { static_cast(this)->RunImpl(graph_capture_stream_); @@ -230,13 +252,11 @@ class ModelBase { } throw; } - // The following function ends the capture and creates a graph // inside a unique_ptr that cleans up it when it goes out of scope. // Note that it throws an exception if EndCapture fails. auto graph = RAII_EndCaptureAndCreateGraph( [this](GraphType* graph_ptr) { return EndCapture(graph_ptr); }); - if (graph_exec_ == nullptr) { DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get())); } else if ( @@ -247,6 +267,7 @@ class ModelBase { DEVICE_CHECK(GraphExecDestroy(graph_exec_)); DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get())); } +#endif DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream)); } diff --git a/static/include/rocm_device_functions.h b/static/include/rocm_device_functions.h index 8fc7adf3c..18d3aa297 100644 --- a/static/include/rocm_device_functions.h +++ b/static/include/rocm_device_functions.h @@ -28,7 +28,7 @@ namespace ait { -inline thread_local bool target_has_graph_mode = false; +inline thread_local bool target_has_graph_mode = true; using DeviceError = hipError_t; using DevicePropertyType = hipDeviceProp_t; @@ -57,7 +57,7 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) { << "\n Has 32-bit integer atomics for shared memory: " << (arch.hasSharedInt32Atomics ? "yes" : "no") << "\n Has 32-bit float atomic exch for shared memory: " - << (arch.hasSharedFloatAtomicExch ? "yes" : "no" + << (arch.hasSharedFloatAtomicExch ? "yes" : "no") << "\n Has 32-bit float atomic add in global and shared memory: " << (arch.hasFloatAtomicAdd ? "yes" : "no") << "\n Has 64-bit integer atomics for global memory: " @@ -67,9 +67,9 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) { << "\n Has double-precision floating point: " << (arch.hasDoubles ? "yes" : "no") << "\n Has warp vote instructions (__any, __all): " - << (arch.hasWarpVote: ? "yes" : "no") + << (arch.hasWarpVote ? "yes" : "no") << "\n Has warp ballot instructions (__ballot): " - << (arch.hasWarpBallot: ? "yes" : "no") + << (arch.hasWarpBallot ? "yes" : "no") << "\n Has warp shuffle operations. (__shfl_*): " << (arch.hasWarpShuffle ? "yes" : "no") << "\n Has funnel two words into one with shift&mask caps: " @@ -187,7 +187,7 @@ inline DeviceError StreamDestroy(StreamType stream) { } inline DeviceError StreamWaitEvent(StreamType stream, EventType event) { - return hipStreamWaitEvent(stream, event); + return hipStreamWaitEvent(stream, event, 0); } inline DeviceError GraphInstantiate( @@ -202,7 +202,8 @@ inline DeviceError GraphDestroy(GraphType graph) { inline DeviceError GraphExecUpdate(GraphExecType graph_exec, GraphType graph) { // We don't have hipGraphExecUpdate in some versions of rocm - return hipErrorUnknown; + hipGraphExecUpdateResult update; + return hipGraphExecUpdate(graph_exec, graph, nullptr, &update); } inline DeviceError GraphExecDestroy(GraphExecType graph_exec) { @@ -314,7 +315,7 @@ inline DeviceError QueryEvent(EventType event) { return hipEventQuery(event); } -inline const char* GetErrorString(DeviceError err) { +inline std::string GetErrorString(DeviceError err) { return hipGetErrorString(err); }