Skip to content

Commit

Permalink
Fix the bug in asymmetric activation quantization inference logic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573866807
  • Loading branch information
The praxis Authors authored and chandrasekhard2 committed Oct 16, 2023
1 parent a31b34c commit 1a51bb8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
18 changes: 13 additions & 5 deletions praxis/layers/quantization/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ def einsum(
Returns:
A JTensor.
"""
# Non performent equation for inference testing purposes
# TODO: b/305735188 - Improve the performance by using the integer einsum op.
if zp_act is not None:
dequantized_x = jnp.multiply(x, scale_act) - zp_act
# explicit broadcast if necessary.
if w.ndim == 3 and scale.ndim == 1:
scale = jnp.expand_dims(scale, (1, 2))
dequantized_w = jnp.multiply(w, scale)
if zp is not None:
dequantized_w = dequantized_w - zp
return jnp.einsum(eqn, dequantized_x, dequantized_w)

use_int_dot_general = (
x.dtype in QUANTIZED_TYPES and w.dtype in QUANTIZED_TYPES
Expand Down Expand Up @@ -302,11 +313,6 @@ def einsum(
offset = compute_offset(x, zp, eqn)
ret = ret - offset

if zp_act is not None:
# Non performent equation for inference testing purposes
dequantized_x = scale_act * x - zp_act
dequantized_w = scale * w - zp
ret = jnp.einsum(eqn, dequantized_x, dequantized_w)
return ret


Expand Down Expand Up @@ -623,6 +629,8 @@ def reduce_einsum_activation_precision(

if squeeze:
scale = jnp.squeeze(scale, axis=contract_dims)
if zp is not None:
zp = jnp.squeeze(zp, axis=contract_dims)
return t, scale, zp


Expand Down
19 changes: 18 additions & 1 deletion praxis/layers/quantization/operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,24 @@ def test_quantized_einsum_with_asym_weight_act(self, eqn):

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.1, atol=0.5)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.named_parameters(
('eqn_with_dot', '...y,yz->...z'),
)
def test_quantized_einsum_with_aym_weight_asym_act(self, eqn):
w = jax.random.uniform(jax.random.PRNGKey(0), (4, 3))
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4))
qw, sw, zpw = operations.reduce_einsum_weight_precision(
eqn, w, use_symmetric=True
)
qx, sx, zpx = operations.reduce_einsum_activation_precision(
eqn, x, symmetric=False
)

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.parameters(
('ab,bc->ac', (10, 4), (4, 5)),
Expand Down

0 comments on commit 1a51bb8

Please sign in to comment.