Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

update call compile fn #409

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def start_compile(self, *ret_vars: VariableBase):
found = False
for variable in self.input_variables:
if (
isinstance(variable, (TensorVariable, PaddleLayerVariable))
isinstance(variable, TensorVariable)
and variable.get_symbol().name == name
):
variable.tracker.gen_instructions(self.pycode_gen)
Expand Down Expand Up @@ -426,15 +426,12 @@ def call_layer(
"""

def infer_meta_fn(layer, *metas, **kwmetas):
metas = metas[1:]
metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas)
return metas

def compute_fn(layer, inputs, outputs, stacks):
inputs = (layer.get_symbol(), *inputs)
inputs = inputs[1:]
self.sir_ctx.call_LAYER(
layer.value.__class__.__name__,
layer.value,
inputs=inputs,
outputs=outputs,
stacks=stacks,
Expand All @@ -444,7 +441,7 @@ def message_handler(*args, **kwargs):
return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

@event_register("symbolic_call", event_level=2)
Expand Down
9 changes: 1 addition & 8 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import paddle

from .... import psdb
from ....symbolic.statement_ir import Symbol
from ....utils import (
EventGuard,
NameGenerator,
is_break_graph_api,
is_break_graph_tensor_methods,
is_builtin_fn,
Expand Down Expand Up @@ -503,18 +501,13 @@ class PaddleLayerVariable(LayerVariable):
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

layer_name_generator = NameGenerator("layer_")

def __init__(
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker
):
super().__init__(layer, graph, tracker)
self.name = self.layer_name_generator.next()

def get_symbol(self) -> Symbol:
return Symbol(self.name)

def call_function(self, /, *args, **kwargs):
self.graph.add_global_guarded_variable(self)
return self.graph.call_layer(self, *args, **kwargs)

def make_stringify_guard(self) -> list[StringifyExpression]:
Expand Down
14 changes: 6 additions & 8 deletions sot/symbolic/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,25 @@ def _set(v, s):
return replace_symbol(SIR.outputs, state)

def call(self, stmt: Statement, inputs):
SIR = self.get_sir(stmt.name)
SIR = self.get_sir(stmt.sir_name)
state = prepare_state(SIR, inputs)
return self.run_sir(stmt.name, state)
return self.run_sir(stmt.sir_name, state)

def api(self, stmt, inputs):
args, kwargs = inputs
return stmt.name(*args, **kwargs)
return stmt.api(*args, **kwargs)

def method(self, stmt, inputs):
args, kwargs = inputs
var = args[0]
return getattr(var, stmt.name)(*args[1:], **kwargs)
return getattr(var, stmt.method)(*args[1:], **kwargs)

def layer(self, stmt, inputs):
args, kwargs = inputs
layer, args = args[0], args[1:]
layer = stmt.layer()
assert layer is not None, "SIR bound layer is None."
return layer(*args, **kwargs)

def delete(self, stmt, inputs):
pass


def compile_sir(context: SymbolicTraceContext, name: str):
"""
Expand Down
63 changes: 57 additions & 6 deletions sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""
from __future__ import annotations

import weakref
from typing import Callable

import paddle
from paddle.utils import is_sequence, map_structure

from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend
Expand Down Expand Up @@ -69,22 +73,69 @@ def to_string(inps):
inps = (x.__str__() for x in inps)
return ", ".join(inps)

name = (
self.name
if isinstance(self.name, str)
else "paddle." + self.name.__name__
)
return "{} || {} = {} ({}) ".format(
self.type + " " * (10 - len(self.type)),
to_string(self.outputs),
name,
self.name,
to_string(self.inputs),
)

def __repr__(self):
return self.__str__()


class CallStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("call", name, inputs, outputs, stacks)
self.sir_name = name


class ApiStatement(Statement):
def __init__(
self,
api: Callable,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"api", "paddle." + api.__name__, inputs, outputs, stacks
)
self.api = api


class MethodStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("method", name, inputs, outputs, stacks)
self.method = name


class LayerStatement(Statement):
def __init__(
self,
layer: paddle.nn.Layer,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"layer", layer.__class__.__name__, inputs, outputs, stacks
)
self.layer = weakref.ref(layer)


class StatementIR:
"""
StatementIR is the carrier that records the code for building the neural network model.It is
Expand Down
20 changes: 14 additions & 6 deletions sot/symbolic/symbolic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from ..utils import event_register, log
from .compile_cache import CompileSIRCache
from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol
from .statement_ir import (
ApiStatement,
CallStatement,
LayerStatement,
MethodStatement,
StatementIR,
StatementIRFactory,
Symbol,
)


class SymbolicTraceContext:
Expand Down Expand Up @@ -41,7 +49,7 @@ def call_SIR(self, sirname, inputs, outputs, stacks):
Call a SIR, which is a subgraph.
"""

stmt = Statement("call", sirname, inputs, outputs, stacks)
stmt = CallStatement(sirname, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_API", event_level=2)
Expand All @@ -51,7 +59,7 @@ def call_API(self, api, inputs, outputs, stacks):
"""

assert callable(api), "call_API must receive a paddle api."
stmt = Statement("api", api, inputs, outputs, stacks)
stmt = ApiStatement(api, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_METHOD", event_level=2)
Expand All @@ -65,15 +73,15 @@ def call_METHOD(self, method_name, inputs, outputs, stacks):
assert isinstance(
inputs[0][0], Symbol
), "call_METHOD must first augument must be Symbol Variable."
stmt = Statement("method", method_name, inputs, outputs, stacks)
stmt = MethodStatement(method_name, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_LAYER", event_level=2)
def call_LAYER(self, layer_name, inputs, outputs, stacks):
def call_LAYER(self, layer, inputs, outputs, stacks):
"""
Call a layer of a api.
"""
stmt = Statement("layer", layer_name, inputs, outputs, stacks)
stmt = LayerStatement(layer, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

def get_sir(self, name: str):
Expand Down
Loading