Skip to content

Commit

Permalink
Add support for constant-initializing int4 variables on CPU/GPU to …
Browse files Browse the repository at this point in the history
…praxis.

PiperOrigin-RevId: 603501597
  • Loading branch information
phoenix-meadowlark authored and pax authors committed Feb 1, 2024
1 parent fb41168 commit dd07dd6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
15 changes: 12 additions & 3 deletions praxis/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ class WeightInit:
scale: Initialization scale.
"""
method: str
scale: float
scale: float | int

@pax_fiddle.auto_config
@staticmethod
Expand Down Expand Up @@ -415,7 +415,7 @@ def XavierWithFixupParams(

@pax_fiddle.auto_config
@staticmethod
def Constant(scale: float | bool = 1.0) -> WeightInit:
def Constant(scale: float | int = 1.0) -> WeightInit:
"""scale."""
return WeightInit('constant', scale)

Expand Down Expand Up @@ -750,7 +750,16 @@ def init_var(
return scale * jrandom.truncated_normal(
prng_key, lower=-2.0, upper=2.0, shape=shape, dtype=init_dtype)
elif method in ['constant']:
return scale + jnp.zeros(shape=shape, dtype=init_dtype)
if jnp.issubdtype(init_dtype, jnp.integer) and not isinstance(scale, int):
raise ValueError(
'An integer scale must be provided when initializing an '
f'integer variable (of type {init_dtype}), but got {scale=}'
)
if init_dtype in [jnp.int4, jnp.uint4]:
# jnp.zeros(dtype=int4) is not currently supported.
return (scale + jnp.zeros(shape=shape, dtype=jnp.int8)).astype(init_dtype)
else:
return scale + jnp.zeros(shape=shape, dtype=init_dtype)
elif method in ['xavier']:
fan_in, fan_out = get_fan_in_fan_out(shape, fan_in_axes, fan_out_axes)
limit = scale * math.sqrt(6. / (fan_in + fan_out))
Expand Down
22 changes: 22 additions & 0 deletions praxis/layers/quantization/linears_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,28 @@ def test_linear_step_count_in_train(self):
updated_vars[SUMMARIES]['step_count_scalar'], np.array([1])
)

def test_int4_weight_init(self):
p = pax_fiddle.Config(
qlinears.Linear,
name='linear',
input_dims=16,
output_dims=32,
quantization=QuantizationParams(
mode=QuantizationMode.INFERENCE,
weight_params=quantization_hparams.WeightQuantizationParams(
precision=4,
dtype=jnp.int4,
use_int4_packed_weights=False,
),
act_params=quantization_hparams.ActQuantizationParams(precision=4),
),
)
linear = instantiate(p)
with base_layer.JaxContext.new_context():
inputs = jnp.zeros([1, p.input_dims], dtype=jnp.float32)
linear_vars = linear.init(jax.random.PRNGKey(123), inputs)
self.assertEqual(linear_vars['params']['w'].dtype, jnp.int4)

@parameterized.product(
input_dim=[64, 256, 1024],
apply_jit=[False, True],
Expand Down

0 comments on commit dd07dd6

Please sign in to comment.