Skip to content

Commit

Permalink
Remove a number of expired deprecations.
Browse files Browse the repository at this point in the history
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
  • Loading branch information
jakevdp committed Oct 31, 2024
1 parent 7af7a60 commit 2b9c73d
Show file tree
Hide file tree
Showing 8 changed files with 0 additions and 125 deletions.
7 changes: 0 additions & 7 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,6 @@
"jax.clear_backends is deprecated.",
_deprecated_clear_backends
),
# Remove after jax 0.4.35 release.
"xla_computation": (
"jax.xla_computation is deleted. Please use the AOT APIs; see "
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
"CHANGELOG.md for 0.4.30 for more examples.", None
),
}

import typing as _typing
Expand Down
22 changes: 0 additions & 22 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,6 @@
"pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None),
"pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None),
"pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None),
# Finalized 2024-05-13; remove after 2024-08-13
"DimSize": (
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
None,
),
"Shape": (
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
None,
),
# Finalized 2024-06-24; remove after 2024-09-24
"canonicalize_shape": (
"jax.core.canonicalize_shape is deprecated.", None,
),
"dimension_as_value": (
"jax.core.dimension_as_value is deprecated. Use jnp.array.", None,
),
"definitely_equal": (
"jax.core.definitely_equal is deprecated. Use ==.", None,
),
"symbolic_equal_dim": (
"jax.core.symbolic_equal_dim is deprecated. Use ==.", None,
),
# Added Jan 8, 2024
"non_negative_dim": (
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim,
Expand Down
11 changes: 0 additions & 11 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,6 @@
zeros_like_p as zeros_like_p,
)

_deprecations = {
# Finalized Mar 18, 2024; remove after June 18, 2024
"config": (
"jax.interpreters.ad.config is deprecated. Use jax.config directly.",
None,
),
"source_info_util": (
"jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.",
None,
),
}

def backward_pass(jaxpr, reduce_axes, transform_stack,
consts, primals_in, cotangents_in):
Expand Down
36 changes: 0 additions & 36 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,42 +42,6 @@
("jax.interpreters.xla.xe was removed in JAX v0.4.36. "
"Use jax.lib.xla_extension instead."), None
),
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"translations": (
"jax.interpreters.xla.translations is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"register_translation": (
"jax.interpreters.xla.register_translation is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"xla_destructure": (
"jax.interpreters.xla.xla_destructure is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"TranslationRule": (
"jax.interpreters.xla.TranslationRule is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"TranslationContext": (
"jax.interpreters.xla.TranslationContext is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
"XlaOp": (
"jax.interpreters.xla.XlaOp is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
None,
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
Expand Down
13 changes: 0 additions & 13 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,16 +377,3 @@
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p


_deprecations = {
# Finalized 2024-05-13; remove after 2024-08-13
"tie_in": (
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
"Replace z = tie_in(x, y) with z = y.", None,
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
14 changes: 0 additions & 14 deletions jax/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,3 @@
squareplus as squareplus,
mish as mish,
)

# Deprecations

_deprecations = {
# Finalized 2024-05-13; remove after 2024-08-13
"normalize": (
"jax.nn.normalize is deprecated. Use jax.nn.standardize instead.",
None,
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
5 changes: 0 additions & 5 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,6 @@
"jnp.round_ is deprecated; use jnp.round instead.",
round
),
# Deprecated 18 Sept 2023 and removed 06 Feb 2024
"trapz": (
"jnp.trapz is deprecated; use jnp.trapezoid instead.",
None
),
}

import typing
Expand Down
17 changes: 0 additions & 17 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,3 @@
weibull_min as weibull_min,
wrap_key_data as wrap_key_data,
)

_deprecations = {
# Finalized Jul 26 2024; remove after Nov 2024.
"shuffle": (
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
None,
)
}

import typing
if typing.TYPE_CHECKING:
pass
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

0 comments on commit 2b9c73d

Please sign in to comment.