Skip to content

Commit

Permalink
upstream static
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jun 30, 2023
1 parent 039bb9f commit 1f36b92
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 17 deletions.
6 changes: 3 additions & 3 deletions static/csrc/model_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ void ModelContainer::SetConstantImpl(
". Check that the provided tensor's shape is correct.");
}
} else {
throw std::runtime_error(
std::string("Called SetConstant on ") + name +
std::string(" but can't find in either bound or unbound constant set"));
LOG(WARNING) << "Called SetConstant on " << name
<< " but can't find in either bound or unbound constant set";
return;
}

auto* src = tensor.ptr;
Expand Down
2 changes: 1 addition & 1 deletion static/include/cuda_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ inline DeviceError QueryEvent(EventType event) {
return cudaEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return cudaGetErrorString(err);
}

Expand Down
1 change: 0 additions & 1 deletion static/include/debug_utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "device_functions-generated.h"

Expand Down
2 changes: 1 addition & 1 deletion static/include/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#define DEVICE_CHECK(call) \
if ((call) != GetDeviceSuccess()) { \
throw std::runtime_error( \
#call " API call failed: " + GetLastErrorString() + " at " + \
#call " API call failed: " + GetErrorString(call) + " at " + \
__FILE__ + ", line" + std::to_string(__LINE__)); \
}

Expand Down
29 changes: 25 additions & 4 deletions static/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ namespace ait {
inline void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") +
cudaGetErrorString(device_error) +
std::string msg = std::string("Got error: ") + GetErrorString(device_error) +
" enum: " + std::to_string(device_error) + " at " + file + ": " +
std::to_string(line);
LOG(ERROR) << msg;
Expand Down Expand Up @@ -217,6 +216,29 @@ class ModelBase {
}

void RunAsGraph(StreamType stream) {
#ifdef __HIP_PLATFORM_HCC__
if (graph_exec_ == nullptr) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING)
<< "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#else
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
static_cast<ModelType*>(this)->RunImpl(graph_capture_stream_);
Expand All @@ -230,13 +252,11 @@ class ModelBase {
}
throw;
}

// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr) { return EndCapture(graph_ptr); });

if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (
Expand All @@ -247,6 +267,7 @@ class ModelBase {
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
#endif

DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
Expand Down
15 changes: 8 additions & 7 deletions static/include/rocm_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace ait {

inline thread_local bool target_has_graph_mode = false;
inline thread_local bool target_has_graph_mode = true;

using DeviceError = hipError_t;
using DevicePropertyType = hipDeviceProp_t;
Expand Down Expand Up @@ -57,7 +57,7 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has 32-bit integer atomics for shared memory: "
<< (arch.hasSharedInt32Atomics ? "yes" : "no")
<< "\n Has 32-bit float atomic exch for shared memory: "
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no"
<< (arch.hasSharedFloatAtomicExch ? "yes" : "no")
<< "\n Has 32-bit float atomic add in global and shared memory: "
<< (arch.hasFloatAtomicAdd ? "yes" : "no")
<< "\n Has 64-bit integer atomics for global memory: "
Expand All @@ -67,9 +67,9 @@ inline std::string PrintArchFeatureFlags(const hipDeviceArch_t& arch) {
<< "\n Has double-precision floating point: "
<< (arch.hasDoubles ? "yes" : "no")
<< "\n Has warp vote instructions (__any, __all): "
<< (arch.hasWarpVote: ? "yes" : "no")
<< (arch.hasWarpVote ? "yes" : "no")
<< "\n Has warp ballot instructions (__ballot): "
<< (arch.hasWarpBallot: ? "yes" : "no")
<< (arch.hasWarpBallot ? "yes" : "no")
<< "\n Has warp shuffle operations. (__shfl_*): "
<< (arch.hasWarpShuffle ? "yes" : "no")
<< "\n Has funnel two words into one with shift&mask caps: "
Expand Down Expand Up @@ -187,7 +187,7 @@ inline DeviceError StreamDestroy(StreamType stream) {
}

inline DeviceError StreamWaitEvent(StreamType stream, EventType event) {
return hipStreamWaitEvent(stream, event);
return hipStreamWaitEvent(stream, event, 0);
}

inline DeviceError GraphInstantiate(
Expand All @@ -202,7 +202,8 @@ inline DeviceError GraphDestroy(GraphType graph) {

inline DeviceError GraphExecUpdate(GraphExecType graph_exec, GraphType graph) {
// We don't have hipGraphExecUpdate in some versions of rocm
return hipErrorUnknown;
hipGraphExecUpdateResult update;
return hipGraphExecUpdate(graph_exec, graph, nullptr, &update);
}

inline DeviceError GraphExecDestroy(GraphExecType graph_exec) {
Expand Down Expand Up @@ -314,7 +315,7 @@ inline DeviceError QueryEvent(EventType event) {
return hipEventQuery(event);
}

inline const char* GetErrorString(DeviceError err) {
inline std::string GetErrorString(DeviceError err) {
return hipGetErrorString(err);
}

Expand Down

0 comments on commit 1f36b92

Please sign in to comment.