diff --git a/sot/symbolic/statement_ir.py b/sot/symbolic/statement_ir.py index 542eb71d..3f3d0823 100644 --- a/sot/symbolic/statement_ir.py +++ b/sot/symbolic/statement_ir.py @@ -6,7 +6,7 @@ from __future__ import annotations import weakref -from typing import Callable +from typing import Any, Callable import paddle from paddle.utils import is_sequence, map_structure @@ -251,7 +251,8 @@ class SIRRuntimeCache: """ def __init__(self): - self.cache = {} # { name : (inputs, outputs, free_vars) } + self.cache = {} + # { name : (inputs, outputs, free_vars) } # inputs : can be used when call_SIR, if free_vars exist # outputs : used for generator new ProxyTensor output before fallback # free_vars: (name, function) @@ -265,7 +266,7 @@ def has_key(self, key: str) -> bool: """ return key in self.cache.keys() - def set_origin_inputs(self, key: str, inputs: any): + def set_origin_inputs(self, key: str, inputs: Any): """ Set Cache origin Inputs of the StatementIR """ @@ -275,7 +276,7 @@ def set_origin_inputs(self, key: str, inputs: any): else: self.cache[key] = (inputs, None, None) - def set_origin_outputs(self, key: str, outputs: any): + def set_origin_outputs(self, key: str, outputs: Any): """ Set Cache origin outputs of the StatementIR """ @@ -285,7 +286,7 @@ def set_origin_outputs(self, key: str, outputs: any): else: self.cache[key] = (None, outputs, None) - def set_free_vars(self, key: str, free_vars: any): + def set_free_vars(self, key: str, free_vars: Any): """ Set Cache free variables of the StatementIR """