Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Oct 10, 2023
1 parent 719cbc6 commit b0a6b10
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 51 deletions.
3 changes: 1 addition & 2 deletions sot/opcode_translator/executor/executor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import types
from typing import List, Tuple

from ...profiler import EventGuard, event_register
from ...psdb import NO_FALLBACK_CODES
from ...utils import (
BreakGraphError,
EventGuard,
FallbackError,
InnerError,
Singleton,
event_register,
is_strict_mode,
log,
log_do,
Expand Down
33 changes: 12 additions & 21 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from typing import Any, Callable

from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo
from ...profiler import EventGuard, event_register
from ...symbolic.statement_ir import Symbol
from ...symbolic.symbolic_context import SymbolicTraceContext
from ...utils import (
EventGuard,
NameGenerator,
OrderedSet,
event_register,
inner_error_default_handler,
is_inplace_api,
is_paddle_api,
Expand Down Expand Up @@ -162,19 +161,16 @@ def save_memo(self) -> FunctionGraph.Memo:
NOTE:
Why don't use __deepcopy__, because memo is not a deepcopy, i.e inner_out is only a shallow copy, SIR is a deepcopy.
"""
with EventGuard(
f"Save SIR Checkpoint: len({len(self.sir_ctx.TOS)})", event_level=2
):
saved_stmt_ir = deepcopy(self.sir_ctx.TOS)
return FunctionGraph.Memo(
inner_out=set(self.inner_out),
input_variables=list(self.input_variables),
stmt_ir=saved_stmt_ir,
global_guards=OrderedSet(self._global_guarded_variables),
side_effects_state=self.side_effects.get_state(),
print_variables=list(self._print_variables),
inplace_tensors=OrderedSet(self._inplace_tensors),
)
saved_stmt_ir = deepcopy(self.sir_ctx.TOS)
return FunctionGraph.Memo(
inner_out=set(self.inner_out),
input_variables=list(self.input_variables),
stmt_ir=saved_stmt_ir,
global_guards=OrderedSet(self._global_guarded_variables),
side_effects_state=self.side_effects.get_state(),
print_variables=list(self._print_variables),
inplace_tensors=OrderedSet(self._inplace_tensors),
)

def restore_memo(self, memo: FunctionGraph.Memo):
"""
Expand Down Expand Up @@ -333,7 +329,6 @@ def start_compile(self, *ret_vars: VariableBase):

view_tracker(list(ret_vars), tracker_output_path, format="png")

@event_register("call_paddle_api", event_level=2)
def call_paddle_api(
self,
func: Callable[..., Any],
Expand All @@ -359,7 +354,6 @@ def message_handler(*args, **kwargs):
InferMetaCache(), self.sir_ctx.call_API, func, *args, **kwargs
)

@event_register("call_tensor_method", event_level=2)
def call_tensor_method(
self, method_name: str, *args: VariableBase, **kwargs
):
Expand Down Expand Up @@ -411,7 +405,6 @@ def get_opcode_executor_stack():
stack.append(f' {code_line}')
return stack

@event_register("call_layer", event_level=2)
def call_layer(
self,
layer: PaddleLayerVariable,
Expand Down Expand Up @@ -444,7 +437,6 @@ def message_handler(*args, **kwargs):
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

@event_register("symbolic_call", event_level=2)
def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
"""
Using infer_meta_fn and compute_fn convert func to symbolic function.
Expand All @@ -459,8 +451,7 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
metas = convert_to_meta(args)
kwmetas = convert_to_meta(kwargs)

with EventGuard("infer_meta"):
out_metas = infer_meta_fn(func, *metas, **kwmetas)
out_metas = infer_meta_fn(func, *metas, **kwmetas)
inputs_symbols = (
convert_to_symbol(args),
convert_to_symbol(kwargs),
Expand Down
11 changes: 3 additions & 8 deletions sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@
import weakref
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from ...utils import (
EventGuard,
InnerError,
current_tmp_name_records,
log,
log_do,
)
from ...profiler import EventGuard
from ...utils import InnerError, current_tmp_name_records, log, log_do

Guard = Callable[[types.FrameType], bool]

Expand Down Expand Up @@ -71,7 +66,7 @@ def make_guard(stringify_guards: list[StringifyExpression]) -> Guard:
Args:
stringify_guards: a list of StringifyExpression.
"""
with EventGuard(f"make_guard: ({len(stringify_guards)})"):
with EventGuard("make_guard"):
num_guards = len(stringify_guards)
if not num_guards:
guard = lambda frame: True
Expand Down
7 changes: 1 addition & 6 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@

import opcode

from ...profiler import EventGuard, event_register
from ...psdb import NO_BREAKGRAPH_CODES
from ...utils import (
BreakGraphError,
EventGuard,
FallbackError,
InnerError,
OrderedSet,
SotUndefinedVar,
event_register,
log,
log_do,
min_graph_size,
Expand Down Expand Up @@ -1509,7 +1508,6 @@ def _create_resume_fn(self, index, stack_size=0):
fn, inputs = pycode_gen.gen_resume_fn_at(index, stack_size)
return fn, inputs

@event_register("_break_graph_in_jump")
@fallback_when_occur_error
def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
"""
Expand Down Expand Up @@ -1596,7 +1594,6 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
self.new_code = self._graph.pycode_gen.gen_pycode()
self.guard_fn = self._graph.guard_fn

@event_register("_break_graph_in_call")
@fallback_when_occur_error
def _break_graph_in_call(
self,
Expand Down Expand Up @@ -1766,7 +1763,6 @@ def _gen_loop_body_between(
pycode_gen.gen_outputs_and_return(inputs)
return pycode_gen.create_fn_with_inputs(inputs)

@event_register("_break_graph_in_for_loop")
@fallback_when_occur_error
def _break_graph_in_for_loop(
self, iterator: VariableBase, for_iter: Instruction
Expand Down Expand Up @@ -1942,7 +1938,6 @@ def _break_graph_in_for_loop(
self.new_code = self._graph.pycode_gen.gen_pycode()
self.guard_fn = self._graph.guard_fn

@event_register("_inline_call_for_loop")
def _inline_call_for_loop(
self, iterator: VariableBase, for_iter: Instruction
):
Expand Down
3 changes: 2 additions & 1 deletion sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import re
from typing import TYPE_CHECKING

from ...utils import BreakGraphError, event_register, log
from ...profiler import event_register
from ...utils import BreakGraphError, log
from ..instruction_utils import Instruction
from .guard import StringifyExpression, union_free_vars
from .opcode_executor import OpcodeExecutorBase, Stop
Expand Down
3 changes: 2 additions & 1 deletion sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import paddle

from ....utils import NameGenerator, event_register, get_unbound_method, log
from ....profiler import event_register
from ....utils import NameGenerator, get_unbound_method, log
from ....utils.exceptions import FallbackError, HasNoAttributeError
from ..dispatcher import Dispatcher
from ..guard import StringifyExpression, check_guard, union_free_vars
Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import paddle

from .... import psdb
from ....profiler import EventGuard
from ....utils import (
EventGuard,
is_break_graph_api,
is_break_graph_tensor_methods,
is_builtin_fn,
Expand Down
3 changes: 2 additions & 1 deletion sot/opcode_translator/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import dis
from functools import partial

from ..utils import CodeStatus, EventGuard, log, log_do
from ..profiler import EventGuard
from ..utils import CodeStatus, log, log_do
from .custom_code import CustomCode
from .executor.executor_cache import OpcodeExecutorCache
from .skip_files import need_skip
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import paddle

from ..profiler import EventGuard
from ..utils import (
Cache,
CodeStatus,
EventGuard,
GraphLogger,
Singleton,
StepInfoManager,
Expand Down
6 changes: 1 addition & 5 deletions sot/symbolic/symbolic_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from ..utils import event_register, log
from ..utils import log
from .compile_cache import CompileSIRCache
from .statement_ir import (
ApiStatement,
Expand Down Expand Up @@ -43,7 +43,6 @@ def TOS(self):

return self.sir_stack[-1]

@event_register("call_SIR", event_level=2)
def call_SIR(self, sirname, inputs, outputs, stacks):
"""
Call a SIR, which is a subgraph.
Expand All @@ -52,7 +51,6 @@ def call_SIR(self, sirname, inputs, outputs, stacks):
stmt = CallStatement(sirname, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_API", event_level=2)
def call_API(self, api, inputs, outputs, stacks):
"""
Call a paddle api.
Expand All @@ -62,7 +60,6 @@ def call_API(self, api, inputs, outputs, stacks):
stmt = ApiStatement(api, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_METHOD", event_level=2)
def call_METHOD(self, method_name, inputs, outputs, stacks):
"""
Call a method of a api. The API here can be python or Paddle
Expand All @@ -76,7 +73,6 @@ def call_METHOD(self, method_name, inputs, outputs, stacks):
stmt = MethodStatement(method_name, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_LAYER", event_level=2)
def call_LAYER(self, layer, inputs, outputs, stacks):
"""
Call a layer of a api.
Expand Down
4 changes: 0 additions & 4 deletions sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
is_inplace_api,
paddle_tensor_methods,
)
from .SotProfiler import EventGuard, SotProfiler, event_register
from .utils import (
Cache,
GraphLogger,
Expand Down Expand Up @@ -82,9 +81,6 @@
"get_unbound_method",
"GraphLogger",
"SotUndefinedVar",
"event_register",
"EventGuard",
"SotProfiler",
"hashable",
"is_inplace_api",
"sotprof_range",
Expand Down

0 comments on commit b0a6b10

Please sign in to comment.