From 0f52340eee8edb7b6ccfa53ad10275d9c67ace43 Mon Sep 17 00:00:00 2001 From: jackzipu <74961298+jackzipu@users.noreply.github.com> Date: Tue, 1 Mar 2022 18:15:00 +0800 Subject: [PATCH] Move computation init to odla_BindToArgument to avoid problem caused by ASYNC call (#840) * callback function definition changes, and handle the error raised by poplar SDK * Move computation init to odla_BindToArgument to avoid early call odla_CreteContext before compuation graph constructed in ASYN call --- ODLA/platforms/odla_popart/odla_compute.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ODLA/platforms/odla_popart/odla_compute.cc b/ODLA/platforms/odla_popart/odla_compute.cc index a0da576e0..bc8fbacce 100644 --- a/ODLA/platforms/odla_popart/odla_compute.cc +++ b/ODLA/platforms/odla_popart/odla_compute.cc @@ -197,14 +197,6 @@ odla_status odla_CreateComputation(odla_computation* comp) { } odla_status odla_CreateContext(odla_context* context) { - odla_status status = - _odla_computation::instance(false) - ->init(); // Place the init here to avoid long execution problem - if (status != ODLA_SUCCESS && - _odla_computation::instance()->session == nullptr) { - popart::logging::err("init computation item in CreateContext failed."); - return status; - } if (PopartConfig::instance()->execution_mode() == PIPELINE_ASYNC) *context = new _odla_pipeline_async_context(_odla_computation::instance()); else @@ -299,6 +291,14 @@ odla_value odla_CreateConstant(odla_value_type type, const void* data_ptr, odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr, odla_context context) { + odla_status status = + _odla_computation::instance(false) + ->init(); // to avoid long execution and async context problem + if (status != ODLA_SUCCESS && + _odla_computation::instance()->session == nullptr) { + popart::logging::err("init computation item in CreateContext failed."); + return status; + } if (!context->hold("odla_BindToArgument")) return ODLA_FAILURE; std::vector shape = context->comp->builder->getTensorShape(value->tensor_id);