-
Notifications
You must be signed in to change notification settings - Fork 16
Sarkars/flib ngenc compute #529
base: master
Are you sure you want to change the base?
Conversation
@@ -34,7 +34,7 @@ int NGraphClusterManager::NewCluster() { | |||
|
|||
GraphDef* NGraphClusterManager::GetClusterGraph(int idx) { | |||
std::lock_guard<std::mutex> guard(s_cluster_graphs_mutex); | |||
return s_cluster_graphs[idx]; | |||
return idx < s_cluster_graphs.size() ? s_cluster_graphs[idx] : nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need this change since now we might have nothing stored in NGraphClusterManager
, so need to safely avoid accessing out-of-bounds requests
@@ -18,6 +18,7 @@ | |||
#include <utility> | |||
|
|||
#include "tensorflow/core/common_runtime/dma_helper.h" | |||
#include "tensorflow/core/common_runtime/function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Included for FunctionBody
} | ||
} | ||
library { | ||
function { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The library of graphdefs
} | ||
node { | ||
name: "Sigmoid" | ||
op: "IdentityN" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IdentityN
due to Kanvi's change
# Comparing with expected value | ||
assert np.isclose(res1, exp).all() | ||
|
||
@pytest.mark.skip(reason="Not passing through grappler") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kanvi-nervana : I think this is a minimal repro of your squeeze-net issue
test/python/test_flib.py
Outdated
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import pytest, pdb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove pdb
test/python/common.py
Outdated
@@ -50,7 +50,7 @@ def with_ngraph(self, l, config=tf.ConfigProto()): | |||
|
|||
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1' | |||
ngraph_bridge.enable() | |||
with tf.Session(config=config) as sess: | |||
with tf.Session(graph=graph, config=config) as sess: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this change? Just Curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we write python code like a = tf.constant(...); b = tf.constant(...); c = a+b
, this underlying graph is added to the default graph. When we read it from a pbtxt (like in this case), the graph is not added to default graph. So was passing the graph along to common.py
to use during session construction.
But found a way to set the graph read from pbtxt as default graph (as_default
)
const FunctionLibraryDefinition flib = | ||
*ctx->function_library()->GetFunctionLibraryDefinition(); | ||
const FunctionDef* fdef = | ||
flib.Find("Enc_" + to_string(m_ngraph_cluster) + "_native_segment"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the node name (eg. ngraph_cluster_251) instead of this "Enc_" + to_string(m_ngraph_cluster)" ?
…naSystems/ngraph-tf into sarkars/flib_ngenc_compute
@@ -64,7 +64,8 @@ Status NgraphOptimizer::Optimize(tensorflow::grappler::Cluster* cluster, | |||
// we will not do anything; all subsequent | |||
// passes become a no-op. | |||
if (config::IsEnabled() == false || | |||
std::getenv("NGRAPH_TF_DISABLE") != nullptr) { | |||
std::getenv("NGRAPH_TF_DISABLE") != nullptr || | |||
IsProcessedByNgraphPass(&graph)) { | |||
NGRAPH_VLOG(0) << "NGTF_OPTIMIZER: Ngraph is disabled "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include the or condition in the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Requested minor change.
https://github.com/NervanaSystems/ngraph-tf/pull/529/files#diff-4d55399542be3550bdfca644b849afa2R110
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.