From 1732581b56eb4e1d2d0f3b2deb3bb8a5ae43d59c Mon Sep 17 00:00:00 2001 From: "augusto.yjh" Date: Tue, 6 Aug 2024 16:46:03 +0800 Subject: [PATCH] add function ensureCollectTraceDone to wait and cleanup collectTraceThread --- libkineto/src/CuptiActivityProfiler.cpp | 13 ++++++++----- libkineto/src/CuptiActivityProfiler.h | 2 ++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/libkineto/src/CuptiActivityProfiler.cpp b/libkineto/src/CuptiActivityProfiler.cpp index 36682ae04..9ef6b91b3 100644 --- a/libkineto/src/CuptiActivityProfiler.cpp +++ b/libkineto/src/CuptiActivityProfiler.cpp @@ -1096,6 +1096,13 @@ void CuptiActivityProfiler::collectTrace(bool collection_done, UST_LOGGER_MARK_COMPLETED(kCollectionStage); } +void CuptiActivityProfiler::ensureCollectTraceDone() { + if (collectTraceThread && collectTraceThread->joinable()) { + std::lock_guard guard(mutex_); + collectTraceThread->join(); + collectTraceThread.reset(nullptr); + } +} void CuptiActivityProfiler::startTraceInternal( const time_point& now) { captureWindowStartTime_ = libkineto::timeSinceEpoch(now); @@ -1251,11 +1258,7 @@ const time_point CuptiActivityProfiler::performRunLoopStep( } // Before processing, we should wait for collectTrace thread to be done. - if (collectTraceThread && collectTraceThread->joinable()) { - std::lock_guard guard(mutex_); - collectTraceThread->join(); - collectTraceThread.reset(nullptr); - } + ensureCollectTraceDone(); // FIXME: Probably want to allow interruption here // for quickly handling trace request via synchronous API diff --git a/libkineto/src/CuptiActivityProfiler.h b/libkineto/src/CuptiActivityProfiler.h index 0b309e0e8..b865e768c 100644 --- a/libkineto/src/CuptiActivityProfiler.h +++ b/libkineto/src/CuptiActivityProfiler.h @@ -167,6 +167,8 @@ class CuptiActivityProfiler { // Collect CPU and GPU traces void collectTrace(bool collectionDone, const std::chrono::time_point& now ); + // Ensure collectTrace is done + void ensureCollectTraceDone(); // Process CPU and GPU traces void processTrace(ActivityLogger& logger) { std::lock_guard guard(mutex_);