diff --git a/sot/opcode_translator/executor/executor_cache.py b/sot/opcode_translator/executor/executor_cache.py index 79c66b78..12220ab6 100644 --- a/sot/opcode_translator/executor/executor_cache.py +++ b/sot/opcode_translator/executor/executor_cache.py @@ -2,8 +2,11 @@ import traceback import types -from typing import List, Tuple +from typing import Any, List, Tuple +import paddle + +from ...infer_meta import MetaInfo from ...profiler import EventGuard, event_register from ...psdb import NO_FALLBACK_CODES from ...utils import ( @@ -27,6 +30,8 @@ dummy_guard.expr = "lambda frame: True" dummy_guard.lambda_expr = "lambda frame: True" +ConstTypes = (int, float, str, bool, type(None)) + @Singleton class OpcodeExecutorCache: @@ -39,9 +44,15 @@ class OpcodeExecutorCache: translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits. """ - MAX_CACHE_SIZE = 20 - cache: dict[types.CodeType, GuardedFunctions] + class _PlaceHolder: + def __str__(self): + return "PlaceHolder" + + MAX_CACHE_SIZE = 10 + MAX_BUCKET_SIZE = 20 + cache: dict[types.CodeType, dict[tuple[Any], GuardedFunctions]] translate_count: int + place_holder = _PlaceHolder() def __init__(self): self.cache = {} @@ -56,14 +67,45 @@ def clear(self): def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: code: types.CodeType = frame.f_code + code_key = self.get_key(frame) + if code not in self.cache: - log(2, f"[Cache]: Firstly call {code}\n") + log( + 2, f"[Cache]: First time call {code} with code_key {code_key}\n" + ) + new_custom_code, guard_fn = self.translate(frame, **kwargs) + self.cache[code] = {code_key: [(new_custom_code, guard_fn)]} + return new_custom_code + elif code_key not in self.cache[code]: + if len(self.cache[code]) >= self.MAX_BUCKET_SIZE: + log(2, "[Cache]: Exceed max bucket size, skip it\n") + return CustomCode(None, False) + log(2, f"[Cache]: Firstly call {code} with code_key {code_key}\n") new_custom_code, guard_fn = self.translate(frame, **kwargs) - self.cache[code] = [(new_custom_code, guard_fn)] + self.cache[code][code_key] = [(new_custom_code, guard_fn)] return new_custom_code - guarded_fns = self.cache[code] + + guarded_fns = self.cache[code][code_key] return self.lookup(frame, guarded_fns, **kwargs) + def get_key(self, frame): + def get_code_key(name): + var = frame.f_locals[name] + if isinstance(var, ConstTypes): + return var + elif isinstance(var, paddle.Tensor): + return str(MetaInfo.from_tensor(var)) + elif isinstance(var, paddle.nn.Layer): + return id(var) + else: + return self.place_holder + + code = frame.f_code + n_args = code.co_argcount + input_names = code.co_varnames[0:n_args] + code_key = tuple(map(get_code_key, input_names)) + return code_key + @event_register("lookup") def lookup( self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs @@ -96,7 +138,7 @@ def lookup( else: log_do( 4, - self.analyse_guard_global_object(guard_fn), + analyse_guard_global_object(guard_fn), ) log( 2, @@ -104,7 +146,7 @@ def lookup( ) log_do( 2, - self.analyse_guard_error(guard_fn, frame), + analyse_guard_error(guard_fn, frame), ) except Exception as e: log(2, f"[Cache]: Guard function error: {e}\n") @@ -127,43 +169,10 @@ def translate( Returns: tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object. """ - code: types.CodeType = frame.f_code self.translate_count += 1 custom_new_code, guard_fn = start_translate(frame, **kwargs) return custom_new_code, guard_fn - def analyse_guard_global_object(self, guard_fn): - def inner(): - for key in guard_fn.__globals__.keys(): - if key.startswith("__object"): - print( - f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", - ) - - return inner - - def analyse_guard_error(self, guard_fn, frame): - def inner(): - guard_expr = guard_fn.lambda_expr - lambda_head = "lambda frame: " - guard_expr = guard_expr.replace(lambda_head, "") - guards = guard_expr.split(" and ") - for guard_str in guards: - guard = eval(lambda_head + guard_str, guard_fn.__globals__) - result = False - try: - result = guard(frame) - except Exception as e: - print( - f"[Cache]: skip checking {guard_str}\n because error occured {e}" - ) - if result is False: - print(f"[Cache]: missed at {guard_str}") - return - print("[Cache]: missed guard not found.") - - return inner - def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: """ @@ -214,3 +223,40 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e finally: simulator.cleanup() + + +# log utils + + +def analyse_guard_global_object(guard_fn): + def inner(): + for key in guard_fn.__globals__.keys(): + if key.startswith("__object"): + print( + f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", + ) + + return inner + + +def analyse_guard_error(guard_fn, frame): + def inner(): + guard_expr = guard_fn.lambda_expr + lambda_head = "lambda frame: " + guard_expr = guard_expr.replace(lambda_head, "") + guards = guard_expr.split(" and ") + for guard_str in guards: + guard = eval(lambda_head + guard_str, guard_fn.__globals__) + result = False + try: + result = guard(frame) + except Exception as e: + print( + f"[Cache]: skip checking {guard_str}\n because error occured {e}" + ) + if result is False: + print(f"[Cache]: missed at {guard_str}") + return + print("[Cache]: missed guard not found.") + + return inner diff --git a/sot/opcode_translator/instruction_utils/opcode_analysis.py b/sot/opcode_translator/instruction_utils/opcode_analysis.py index e4e635ba..3f36a6c3 100644 --- a/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -185,7 +185,11 @@ def walk(state: SpaceState, start: int) -> SpaceState: assert instr.jump_to is not None target_idx = instructions.index(instr.jump_to) # Fork to two branches, jump or not - jump_branch = fork(state, i, True, target_idx) + jump_branch = ( + fork(state, i, True, target_idx) + if target_idx >= start_instr_idx and target_idx < end + else state + ) not_jump_branch = ( fork(state, i, False, target_idx) if instr.opname not in UNCONDITIONAL_JUMP diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 0c4710d7..e91b18f4 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -86,8 +86,8 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode: # just check those codes which need open eval_frame if ( - custom_code.disable_eval_frame is False - and CodeStatus().is_code_without_graph(new_code) + CodeStatus().is_code_without_graph(new_code) + and custom_code.disable_eval_frame is False ): log( 3, diff --git a/tests/test_code_status.py b/tests/test_code_status.py index 4a24f305..065901a4 100644 --- a/tests/test_code_status.py +++ b/tests/test_code_status.py @@ -12,7 +12,7 @@ class SimpleNet1(paddle.nn.Layer): def __init__(self): super().__init__() self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] + [paddle.nn.Linear(10, 10) for _ in range(20)] ) def forward(self, x): @@ -20,8 +20,6 @@ def forward(self, x): sot.psdb.breakgraph() x = self.layers[i](x) x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) return x @@ -29,7 +27,7 @@ class SimpleNet2(paddle.nn.Layer): def __init__(self): super().__init__() self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] + [paddle.nn.Linear(10, 10) for _ in range(20)] ) def forward(self, x): @@ -37,8 +35,6 @@ def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) return x @@ -53,6 +49,7 @@ def test_case_1(self): CodeStatus().clear() net = SimpleNet1() inp = paddle.rand((10, 10)) + inp.stop_gradient = False self.assert_results(run_net, net, inp) code_map = CodeStatus().code_map states = [] @@ -63,18 +60,23 @@ def test_case_1(self): assert v.state == CodeState.WITH_GRAPH else: assert v.state == CodeState.WITHOUT_GRAPH - # run_net, forward, loop body, resumed part2 in loop body - assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 - # resumed part1 in loop body + # run_net, loop_body in run_net, forward => 3 + # (forward loop_body + resumed part in loop_body) * 20 => 40 + assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 43 + # part after loop in forward + # resumed part in loop_body of run_net assert ( - len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 + len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 2 ) + # part after loop in run_net, it is only called once, so UNKNOW + assert len([v for v in states if v.state == CodeState.UNKNOW]) == 1 def test_case_2(self): with strict_mode_guard(0): CodeStatus().clear() net = SimpleNet2() inp = paddle.rand((10, 10)) + inp.stop_gradient = False self.assert_results(run_net, net, inp) code_map = CodeStatus().code_map states = []