Skip to content

Commit

Permalink
Remove lattice system from JAX, especially raise_to_shaped (except as…
Browse files Browse the repository at this point in the history
… a no-op for backwards compat)

PiperOrigin-RevId: 692557993
  • Loading branch information
dougalm authored and Google-ML-Automation committed Nov 3, 2024
1 parent d679c0a commit ec39b59
Show file tree
Hide file tree
Showing 24 changed files with 96 additions and 211 deletions.
3 changes: 2 additions & 1 deletion jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def add_impl(x, y):

@add_jaxvals_p.def_abstract_eval
def add_abstract(x, y):
return core.lattice_join(x, y)
assert core.typematch(x, y)
return x

def zeros_like_aval(aval: core.AbstractValue) -> Array:
return aval_zeros_likers[type(aval)](aval)
Expand Down
112 changes: 29 additions & 83 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class Var:
def __init__(self, suffix: str, aval: AbstractValue):
self.count = next(_var_counter)
self.suffix = suffix
self.aval = raise_to_shaped(aval)
self.aval = aval

# TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not
# care about variable ordering, but the downstream package kfac_jax does.
Expand Down Expand Up @@ -662,7 +662,7 @@ def __init__(self, trace: Trace):
def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
return f"traced array with shape {raise_to_shaped(self.aval).str_short()}"
return f"traced array with shape {self.aval.str_short()}"

def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
Expand Down Expand Up @@ -1302,19 +1302,18 @@ def __repr__(self):
except AttributeError:
return self.__class__.__name__

def strip_weak_type(self) -> AbstractValue:
def update_weak_type(self, weak_type):
return self

def join(self, other):
raise NotImplementedError("must override")
def strip_weak_type(self) -> AbstractValue:
return self.update_weak_type(False)

def update(self, **kwargs):
raise NotImplementedError("must override")

def str_short(self, short_dtypes=False):
return str(self)


# For type signatures involving dynamic shapes, we use lists of abstract values
# which may contain (reverse) de Bruijn indices in their shapes.
class DBIdx(NamedTuple):
Expand Down Expand Up @@ -1348,26 +1347,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
for v in jaxpr.invars]
return tuple(out)

class Bot(AbstractValue): pass
bot = Bot()


def lattice_join(x: AbstractValue | None,
y: AbstractValue | None) -> AbstractValue:
if x is None:
assert y is not None
return y
elif y is None:
return x
elif isinstance(x, type(y)):
return y.join(x)
elif isinstance(y, type(x)):
return x.join(y)
elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray):
# TODO(mattjj): remove this special case after dynamic shapes are integrated
return x.join(y)
else:
raise TypeError(x, y)
# TODO(dougalm): Deprecate. This is here for backwards compat.
def lattice_join(x, y):
assert typematch(x, y)
return x

# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
Expand Down Expand Up @@ -1530,9 +1513,8 @@ def __repr__(self):
def str_short(self, short_dtypes=False) -> str:
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name

def strip_weak_type(self):
"""Returns a copy of the aval with weak_type=False."""
return self.update(weak_type=False)
def update_weak_type(self, weak_type):
return self.update(weak_type=weak_type)

def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
Expand Down Expand Up @@ -1656,13 +1638,6 @@ def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

def join(self, other):
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
else:
raise TypeError(self, other)

def str_short(self, short_dtypes=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
Expand Down Expand Up @@ -1762,14 +1737,6 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))

def join(self, other):
if (definitely_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
else:
raise TypeError(self, other)

def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
Expand Down Expand Up @@ -1881,16 +1848,11 @@ def mutable_array_abstract_eval(init_aval):
@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = raise_to_shaped(get_aval(init_val))
aval = get_aval(init_val)
return MutableArray(AbstractRef(aval), init_val)


class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def to_tangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
Expand All @@ -1910,30 +1872,9 @@ def block_until_ready(self):
pytype_aval_mappings[Token] = lambda _: abstract_token


def raise_to_shaped(aval: AbstractValue, weak_type=None):
aval_type = type(aval)
if aval_type is ShapedArray and weak_type is None:
return aval
if aval_type is DShapedArray and weak_type is None:
return aval
if weak_type is None:
weak_type = getattr(aval, 'weak_type', False)
for typ in aval_type.__mro__:
handler = raise_to_shaped_mappings.get(typ)
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))

def _shaped_array_mapping(aval, weak_type):
if config.sharding_in_types.value:
return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding)
return ShapedArray(aval.shape, aval.dtype, weak_type)

raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
ShapedArray: _shaped_array_mapping,
DShapedArray: lambda aval, _: aval
}
# TODO(dougalm): Deprecate. This is just here for backwards compat.
def raise_to_shaped(aval):
return aval

### Operations on shapes and dimension sizes.

Expand Down Expand Up @@ -2341,18 +2282,23 @@ def typecheck(aval: AbstractValue, x) -> bool:
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
try:
return typematch(aval_ref, lattice_join(aval_ref, aval))
return typematch(aval_ref, aval)
except TypeError:
return False

def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags aren't considered
# part of the type
return (raise_to_shaped(aval1, weak_type=False) ==
raise_to_shaped(aval2, weak_type=False))
def typematch(t1: AbstractValue, t2: AbstractValue) -> bool:
"""Determine whether `t1` and `t2` are equivalent. Ignores weak_type."""
t1 = t1.strip_weak_type()
t2 = t2.strip_weak_type()
if t1 == t2:
return True
elif (isinstance(t1, (ShapedArray, DShapedArray)) and
isinstance(t2, (ShapedArray, DShapedArray))):
# This case handles DShapedArray and shape polynomials. Alternatively we
# could try normalizing first and then doing simple equality.
return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape)
else:
return False

class JaxprTypeError(TypeError): pass

Expand Down
27 changes: 13 additions & 14 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -81,7 +80,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
ans_avals = [core.get_aval(x) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)


Expand Down Expand Up @@ -287,7 +286,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
py_primals_out, py_tangents_out = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
primal_avals = [core.get_aval(x) for x in primals_out]
if out_tree != out_tree2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce primal and tangent outputs with equal container (pytree) "
Expand Down Expand Up @@ -327,11 +326,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out]
expected_tangent_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
core.get_aval(x).strip_weak_type().to_tangent_aval()
for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
tangent_avals_out = [core.get_aval(t).strip_weak_type()
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if expected_tangent_avals_out != tangent_avals_out:
Expand Down Expand Up @@ -606,7 +605,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
f_, dyn_args = lu.wrap_init(self.fun), args
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
Expand Down Expand Up @@ -674,7 +673,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
py_primals_out, res = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
res, res_tree = tree_flatten(res)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
primal_avals = [core.get_aval(x) for x in primals_out]
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
Expand Down Expand Up @@ -772,7 +771,7 @@ def append(x, d):
msg = ("Custom VJP bwd rule must produce an output with the same "
"shape/dtypes as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding "
f"shape/dtype {a_.str_short()} corresponding "
f"to an input of shape/dtype {a.str_short()}.")
raise ValueError(msg)
results.append(ct)
Expand Down Expand Up @@ -831,7 +830,7 @@ def _custom_vjp_call_jaxpr_jvp(
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
Expand Down Expand Up @@ -1110,7 +1109,7 @@ def merge(l1, l2):
return out, merge

def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
return core.get_aval(x)


### Custom transposition
Expand Down Expand Up @@ -1211,7 +1210,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
lin_avals = map(abstractify, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
out_avals = f_jaxpr.out_avals

t_in_tree = treedef_tuple((res_tree, out_tree()))
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
Expand Down Expand Up @@ -1265,7 +1264,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose,
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out

def _linear_call_abstract_eval(*args, **kwargs):
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
return kwargs['callee'].out_avals

linear_call_p = core.Primitive('linear_call')
linear_call_p.multiple_results = True
Expand Down Expand Up @@ -1398,7 +1397,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)

in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
prim_tree, res_tree = out_trees()
Expand Down
9 changes: 4 additions & 5 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
Expand Down Expand Up @@ -362,7 +361,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,

_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
with core.set_current_trace(self.parent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
Expand Down Expand Up @@ -434,8 +433,8 @@ def to_concrete_value(self):

def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
primal_aval = get_aval(primal).strip_weak_type()
tangent_aval = get_aval(tangent).strip_weak_type()
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros,
add_jaxvals, add_jaxvals_p)
from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName
from jax._src.core import Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(self, a): self.a = a
for d in a.shape))
if type(a) is core.DShapedArray else a for a, e in orig_type if e]

new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
new_avals = [core.get_aval(s) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, RaggedAxis):
Expand Down Expand Up @@ -387,7 +387,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
if config.enable_checks.value:
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
aval = core.get_aval(val)
assert 0 <= batch_dim < len(aval.shape)
self._trace = trace
self.val = val
Expand All @@ -396,7 +396,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,

@property
def aval(self):
aval = raise_to_shaped(core.get_aval(self.val))
aval = core.get_aval(self.val)
if self.batch_dim is not_mapped:
return aval
elif type(self.batch_dim) is int:
Expand Down
Loading

0 comments on commit ec39b59

Please sign in to comment.