Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update call compile fn (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Oct 9, 2023
1 parent 4371abe commit 994e37e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 34 deletions.
9 changes: 3 additions & 6 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def start_compile(self, *ret_vars: VariableBase):
found = False
for variable in self.input_variables:
if (
isinstance(variable, (TensorVariable, PaddleLayerVariable))
isinstance(variable, TensorVariable)
and variable.get_symbol().name == name
):
variable.tracker.gen_instructions(self.pycode_gen)
Expand Down Expand Up @@ -426,15 +426,12 @@ def call_layer(
"""

def infer_meta_fn(layer, *metas, **kwmetas):
metas = metas[1:]
metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas)
return metas

def compute_fn(layer, inputs, outputs, stacks):
inputs = (layer.get_symbol(), *inputs)
inputs = inputs[1:]
self.sir_ctx.call_LAYER(
layer.value.__class__.__name__,
layer.value,
inputs=inputs,
outputs=outputs,
stacks=stacks,
Expand All @@ -444,7 +441,7 @@ def message_handler(*args, **kwargs):
return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?"

return inner_error_default_handler(self.symbolic_call, message_handler)(
infer_meta_fn, compute_fn, layer, *[layer, *args], **kwargs
infer_meta_fn, compute_fn, layer, *args, **kwargs
)

@event_register("symbolic_call", event_level=2)
Expand Down
9 changes: 1 addition & 8 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import paddle

from .... import psdb
from ....symbolic.statement_ir import Symbol
from ....utils import (
EventGuard,
NameGenerator,
is_break_graph_api,
is_break_graph_tensor_methods,
is_builtin_fn,
Expand Down Expand Up @@ -503,18 +501,13 @@ class PaddleLayerVariable(LayerVariable):
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

layer_name_generator = NameGenerator("layer_")

def __init__(
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker
):
super().__init__(layer, graph, tracker)
self.name = self.layer_name_generator.next()

def get_symbol(self) -> Symbol:
return Symbol(self.name)

def call_function(self, /, *args, **kwargs):
self.graph.add_global_guarded_variable(self)
return self.graph.call_layer(self, *args, **kwargs)

def make_stringify_guard(self) -> list[StringifyExpression]:
Expand Down
14 changes: 6 additions & 8 deletions sot/symbolic/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,25 @@ def _set(v, s):
return replace_symbol(SIR.outputs, state)

def call(self, stmt: Statement, inputs):
SIR = self.get_sir(stmt.name)
SIR = self.get_sir(stmt.sir_name)
state = prepare_state(SIR, inputs)
return self.run_sir(stmt.name, state)
return self.run_sir(stmt.sir_name, state)

def api(self, stmt, inputs):
args, kwargs = inputs
return stmt.name(*args, **kwargs)
return stmt.api(*args, **kwargs)

def method(self, stmt, inputs):
args, kwargs = inputs
var = args[0]
return getattr(var, stmt.name)(*args[1:], **kwargs)
return getattr(var, stmt.method)(*args[1:], **kwargs)

def layer(self, stmt, inputs):
args, kwargs = inputs
layer, args = args[0], args[1:]
layer = stmt.layer()
assert layer is not None, "SIR bound layer is None."
return layer(*args, **kwargs)

def delete(self, stmt, inputs):
pass


def compile_sir(context: SymbolicTraceContext, name: str):
"""
Expand Down
63 changes: 57 additions & 6 deletions sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""
from __future__ import annotations

import weakref
from typing import Callable

import paddle
from paddle.utils import is_sequence, map_structure

from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend
Expand Down Expand Up @@ -69,22 +73,69 @@ def to_string(inps):
inps = (x.__str__() for x in inps)
return ", ".join(inps)

name = (
self.name
if isinstance(self.name, str)
else "paddle." + self.name.__name__
)
return "{} || {} = {} ({}) ".format(
self.type + " " * (10 - len(self.type)),
to_string(self.outputs),
name,
self.name,
to_string(self.inputs),
)

def __repr__(self):
return self.__str__()


class CallStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("call", name, inputs, outputs, stacks)
self.sir_name = name


class ApiStatement(Statement):
def __init__(
self,
api: Callable,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"api", "paddle." + api.__name__, inputs, outputs, stacks
)
self.api = api


class MethodStatement(Statement):
def __init__(
self,
name: str,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__("method", name, inputs, outputs, stacks)
self.method = name


class LayerStatement(Statement):
def __init__(
self,
layer: paddle.nn.Layer,
inputs: list[Symbol],
outputs: list[Symbol],
stacks: list[str],
):
super().__init__(
"layer", layer.__class__.__name__, inputs, outputs, stacks
)
self.layer = weakref.ref(layer)


class StatementIR:
"""
StatementIR is the carrier that records the code for building the neural network model.It is
Expand Down
20 changes: 14 additions & 6 deletions sot/symbolic/symbolic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from ..utils import event_register, log
from .compile_cache import CompileSIRCache
from .statement_ir import Statement, StatementIR, StatementIRFactory, Symbol
from .statement_ir import (
ApiStatement,
CallStatement,
LayerStatement,
MethodStatement,
StatementIR,
StatementIRFactory,
Symbol,
)


class SymbolicTraceContext:
Expand Down Expand Up @@ -41,7 +49,7 @@ def call_SIR(self, sirname, inputs, outputs, stacks):
Call a SIR, which is a subgraph.
"""

stmt = Statement("call", sirname, inputs, outputs, stacks)
stmt = CallStatement(sirname, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_API", event_level=2)
Expand All @@ -51,7 +59,7 @@ def call_API(self, api, inputs, outputs, stacks):
"""

assert callable(api), "call_API must receive a paddle api."
stmt = Statement("api", api, inputs, outputs, stacks)
stmt = ApiStatement(api, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_METHOD", event_level=2)
Expand All @@ -65,15 +73,15 @@ def call_METHOD(self, method_name, inputs, outputs, stacks):
assert isinstance(
inputs[0][0], Symbol
), "call_METHOD must first augument must be Symbol Variable."
stmt = Statement("method", method_name, inputs, outputs, stacks)
stmt = MethodStatement(method_name, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

@event_register("call_LAYER", event_level=2)
def call_LAYER(self, layer_name, inputs, outputs, stacks):
def call_LAYER(self, layer, inputs, outputs, stacks):
"""
Call a layer of a api.
"""
stmt = Statement("layer", layer_name, inputs, outputs, stacks)
stmt = LayerStatement(layer, inputs, outputs, stacks)
self.TOS.add_statement(stmt)

def get_sir(self, name: str):
Expand Down

0 comments on commit 994e37e

Please sign in to comment.