Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

upgrade guard with buckets #415

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 87 additions & 41 deletions sot/opcode_translator/executor/executor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import traceback
import types
from typing import List, Tuple
from typing import Any, List, Tuple

import paddle

from ...infer_meta import MetaInfo
from ...profiler import EventGuard, event_register
from ...psdb import NO_FALLBACK_CODES
from ...utils import (
Expand All @@ -27,6 +30,8 @@
dummy_guard.expr = "lambda frame: True"
dummy_guard.lambda_expr = "lambda frame: True"

ConstTypes = (int, float, str, bool, type(None))


@Singleton
class OpcodeExecutorCache:
Expand All @@ -39,9 +44,15 @@ class OpcodeExecutorCache:
translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits.
"""

MAX_CACHE_SIZE = 20
cache: dict[types.CodeType, GuardedFunctions]
class _PlaceHolder:
def __str__(self):
return "PlaceHolder"

MAX_CACHE_SIZE = 10
MAX_BUCKET_SIZE = 20
cache: dict[types.CodeType, dict[tuple[Any], GuardedFunctions]]
translate_count: int
place_holder = _PlaceHolder()

def __init__(self):
self.cache = {}
Expand All @@ -56,14 +67,45 @@ def clear(self):

def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode:
code: types.CodeType = frame.f_code
code_key = self.get_key(frame)

if code not in self.cache:
log(2, f"[Cache]: Firstly call {code}\n")
log(
2, f"[Cache]: First time call {code} with code_key {code_key}\n"
)
new_custom_code, guard_fn = self.translate(frame, **kwargs)
self.cache[code] = {code_key: [(new_custom_code, guard_fn)]}
return new_custom_code
elif code_key not in self.cache[code]:
if len(self.cache[code]) >= self.MAX_BUCKET_SIZE:
log(2, "[Cache]: Exceed max bucket size, skip it\n")
return CustomCode(None, False)
log(2, f"[Cache]: Firstly call {code} with code_key {code_key}\n")
new_custom_code, guard_fn = self.translate(frame, **kwargs)
self.cache[code] = [(new_custom_code, guard_fn)]
self.cache[code][code_key] = [(new_custom_code, guard_fn)]
return new_custom_code
guarded_fns = self.cache[code]

guarded_fns = self.cache[code][code_key]
return self.lookup(frame, guarded_fns, **kwargs)

def get_key(self, frame):
def get_code_key(name):
var = frame.f_locals[name]
if isinstance(var, ConstTypes):
return var
elif isinstance(var, paddle.Tensor):
return str(MetaInfo.from_tensor(var))
elif isinstance(var, paddle.nn.Layer):
return id(var)
else:
return self.place_holder

code = frame.f_code
n_args = code.co_argcount
input_names = code.co_varnames[0:n_args]
code_key = tuple(map(get_code_key, input_names))
return code_key

@event_register("lookup")
def lookup(
self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs
Expand Down Expand Up @@ -96,15 +138,15 @@ def lookup(
else:
log_do(
4,
self.analyse_guard_global_object(guard_fn),
analyse_guard_global_object(guard_fn),
)
log(
2,
f"[Cache]: Cache miss, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n",
)
log_do(
2,
self.analyse_guard_error(guard_fn, frame),
analyse_guard_error(guard_fn, frame),
)
except Exception as e:
log(2, f"[Cache]: Guard function error: {e}\n")
Expand All @@ -127,43 +169,10 @@ def translate(
Returns:
tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object.
"""
code: types.CodeType = frame.f_code
self.translate_count += 1
custom_new_code, guard_fn = start_translate(frame, **kwargs)
return custom_new_code, guard_fn

def analyse_guard_global_object(self, guard_fn):
def inner():
for key in guard_fn.__globals__.keys():
if key.startswith("__object"):
print(
f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}",
)

return inner

def analyse_guard_error(self, guard_fn, frame):
def inner():
guard_expr = guard_fn.lambda_expr
lambda_head = "lambda frame: "
guard_expr = guard_expr.replace(lambda_head, "")
guards = guard_expr.split(" and ")
for guard_str in guards:
guard = eval(lambda_head + guard_str, guard_fn.__globals__)
result = False
try:
result = guard(frame)
except Exception as e:
print(
f"[Cache]: skip checking {guard_str}\n because error occured {e}"
)
if result is False:
print(f"[Cache]: missed at {guard_str}")
return
print("[Cache]: missed guard not found.")

return inner


def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
"""
Expand Down Expand Up @@ -214,3 +223,40 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction:
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e
finally:
simulator.cleanup()


# log utils


def analyse_guard_global_object(guard_fn):
def inner():
for key in guard_fn.__globals__.keys():
if key.startswith("__object"):
print(
f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}",
)

return inner


def analyse_guard_error(guard_fn, frame):
def inner():
guard_expr = guard_fn.lambda_expr
lambda_head = "lambda frame: "
guard_expr = guard_expr.replace(lambda_head, "")
guards = guard_expr.split(" and ")
for guard_str in guards:
guard = eval(lambda_head + guard_str, guard_fn.__globals__)
result = False
try:
result = guard(frame)
except Exception as e:
print(
f"[Cache]: skip checking {guard_str}\n because error occured {e}"
)
if result is False:
print(f"[Cache]: missed at {guard_str}")
return
print("[Cache]: missed guard not found.")

return inner
6 changes: 5 additions & 1 deletion sot/opcode_translator/instruction_utils/opcode_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,11 @@ def walk(state: SpaceState, start: int) -> SpaceState:
assert instr.jump_to is not None
target_idx = instructions.index(instr.jump_to)
# Fork to two branches, jump or not
jump_branch = fork(state, i, True, target_idx)
jump_branch = (
fork(state, i, True, target_idx)
if target_idx >= start_instr_idx and target_idx < end
else state
)
not_jump_branch = (
fork(state, i, False, target_idx)
if instr.opname not in UNCONDITIONAL_JUMP
Expand Down
4 changes: 2 additions & 2 deletions sot/opcode_translator/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode:

# just check those codes which need open eval_frame
if (
custom_code.disable_eval_frame is False
and CodeStatus().is_code_without_graph(new_code)
CodeStatus().is_code_without_graph(new_code)
and custom_code.disable_eval_frame is False
):
log(
3,
Expand Down
22 changes: 12 additions & 10 deletions tests/test_code_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,29 @@ class SimpleNet1(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layers = paddle.nn.LayerList(
[paddle.nn.Linear(10, 10) for _ in range(30)]
[paddle.nn.Linear(10, 10) for _ in range(20)]
)

def forward(self, x):
for i in range(len(self.layers)):
sot.psdb.breakgraph()
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
return x


class SimpleNet2(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layers = paddle.nn.LayerList(
[paddle.nn.Linear(10, 10) for _ in range(30)]
[paddle.nn.Linear(10, 10) for _ in range(20)]
)

def forward(self, x):
sot.psdb.fallback()
for i in range(len(self.layers)):
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
return x


Expand All @@ -53,6 +49,7 @@ def test_case_1(self):
CodeStatus().clear()
net = SimpleNet1()
inp = paddle.rand((10, 10))
inp.stop_gradient = False
self.assert_results(run_net, net, inp)
code_map = CodeStatus().code_map
states = []
Expand All @@ -63,18 +60,23 @@ def test_case_1(self):
assert v.state == CodeState.WITH_GRAPH
else:
assert v.state == CodeState.WITHOUT_GRAPH
# run_net, forward, loop body, resumed part2 in loop body
assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4
# resumed part1 in loop body
# run_net, loop_body in run_net, forward => 3
# (forward loop_body + resumed part in loop_body) * 20 => 40
assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 43
# part after loop in forward
# resumed part in loop_body of run_net
assert (
len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1
len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 2
)
# part after loop in run_net, it is only called once, so UNKNOW
assert len([v for v in states if v.state == CodeState.UNKNOW]) == 1

def test_case_2(self):
with strict_mode_guard(0):
CodeStatus().clear()
net = SimpleNet2()
inp = paddle.rand((10, 10))
inp.stop_gradient = False
self.assert_results(run_net, net, inp)
code_map = CodeStatus().code_map
states = []
Expand Down
Loading