From 8fb2f176aac2d9c70adbee7eb9c1569361e32299 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 11 Oct 2023 11:33:28 +0000 Subject: [PATCH] fix stop-gradient not put in SIR problem --- .../executor/function_graph.py | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 1bbfa969..cda7302f 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -475,30 +475,35 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): FunctionGraph.get_opcode_executor_stack() ), ) - if outputs is not None: - if is_inplace_api(func): - # if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation) - # just set it back in SIR, and return outputs to replace tensor meta (it might changes?) - # in this case, the output will not exactly be used - compute_fn( - func, - inputs_symbols, - convert_to_symbol(args[0]), - stmt_stacks, - ) - else: - compute_fn( - func, - inputs_symbols, - convert_to_symbol(outputs), - stmt_stacks, - ) # symbolic only contain symbols. - self._put_inner(outputs) + if is_inplace_api(func): + # if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation) + # just set it back in SIR, and return outputs to replace tensor meta (it might changes?) + # in this case, the output will not exactly be used + compute_fn( + func, + inputs_symbols, + convert_to_symbol(args[0]), + stmt_stacks, + ) + elif outputs is not None: + compute_fn( + func, + inputs_symbols, + convert_to_symbol(outputs), + stmt_stacks, + ) # symbolic only contain symbols. + self._put_inner(outputs) return VariableFactory.from_value( outputs, self, DummyTracker(list(args) + list(kwargs.values())) ) - else: - return None + elif outputs is None: + # tensor.stop_gradient=True + compute_fn( + func, + inputs_symbols, + None, + stmt_stacks, + ) def _put_inner(self, vars: VariableBase): """