diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 18390d0d735..93bfb2e814b 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -20,7 +20,7 @@ import bitarray.util from dataclasses import dataclass, field import sys -from typing import ByteString, Iterable +from typing import ByteString, Iterable, Optional import absl.app import absl.flags @@ -107,7 +107,7 @@ def _add_subgraph(self): @dataclass class _LutCompressedArray: - compression_axis: int = 0 + compression_axis: Optional[int] = None lookup_tables: list[np.ndarray] = field(default_factory=list) indices: np.ndarray = field(default_factory=lambda: np.array([])) @@ -121,27 +121,46 @@ def index_bitwidth(self) -> int: return max_index.bit_length() or 1 -def _lut_compress_array(tensor: np.ndarray, axis: int) -> _LutCompressedArray: - """Compresses using a lookup table per subarray along the given axis. +def _lut_compress_array(tensor: np.ndarray, + axis: Optional[int]) -> _LutCompressedArray: + """Compresses the given tensor using lookup tables. - Compressing a tensor with a lookup table per subarray along a particular axis - is analogous to quantizing a tensor with different quantization parameters - per subarray along a particular axis (dimension). + Args: + tensor (np.ndarray): The tensor to be compressed. + + axis (Optional[int]): The axis along which to compress the tensor. If an + axis is given, a lookup table is created for each slice along the + axis. If axis is None, a single lookup table is used for the entire + tensor. + + Compressing a tensor with a lookup table per slice along a + particular axis is analogous to quantizing a tensor with different + quantization parameters per slice along a particular axis (dimension). + + Returns: + _LutCompressedArray: An object containing the compressed tensor data, + including the lookup tables and indices. """ compressed = _LutCompressedArray() compressed.compression_axis = axis - # Iterate over subarrays along the compression axis - subarray_indices = [] - for subarray in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(subarray, return_inverse=True) + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) compressed.lookup_tables.append(values) - indices = indices.reshape(subarray.shape) - subarray_indices.append(indices) - - # Reconstruct a tensor of indices from the subarrays - stacked = np.stack(subarray_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) return compressed @@ -155,18 +174,34 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression: return compression[0] -def _identify_compression_axis(tensor: model_facade._Tensor) -> int: - """Finds the axis along which to compress. +def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. - Use the quantization axis, else the NWHC channel dimension. If necessary, - an user-specified override could be added to the compression spec schema. + Raises: + CompressionError: If the axis cannot be determined. """ - if tensor.quantization is not None: - axis = tensor.quantization.quantizedDimension - else: - axis = tensor.array.ndim - 1 + q = tensor.quantization + if q is not None \ + and q.scale is not None \ + and q.quantizedDimension < len(tensor.shape): + quantization_channels = len(q.scale) + if quantization_channels == 1: + # Use one value table for the entire tensor + return None + + if quantization_channels == tensor.shape[q.quantizedDimension]: + return q.quantizedDimension - return axis + raise CompressionError( + f"Invalid or no quanitzation parameters from which to " + f"infer the axis along which tensor should be compressed.") def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): @@ -204,7 +239,7 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: Pack the value tables of a LutCompressedArray into a bytes object in the format writable to a value_table buffer in the .tflite flatbuffer. The - tables, one per subarray, are concatinated. + tables are concatinated. """ buffer = bytearray() for t in tables: diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 4c9d2fa3d01..1167c421f6e 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -200,31 +200,61 @@ def test_multiple_tables_with_padding(self): "shape": (16, 1), "type": tflite.TensorType.UINT8, "buffer": 1, + "quantization": { + "quantized_dimension": 1, + "scale": (1,), + "zero_point": (0,), + }, }, 1: { "shape": (16, 1), "type": tflite.TensorType.INT8, "buffer": 2, + "quantization": { + "quantized_dimension": 1, + "scale": (1,), + "zero_point": (0,), + }, }, 2: { "shape": (16, 1), "type": tflite.TensorType.INT16, "buffer": 3, + "quantization": { + "quantized_dimension": 1, + "scale": (1,), + "zero_point": (0,), + }, }, 3: { "shape": (16, 1), "type": tflite.TensorType.INT32, "buffer": 4, + "quantization": { + "quantized_dimension": 1, + "scale": (1,), + "zero_point": (0,), + }, }, 4: { "shape": (16, 1), "type": tflite.TensorType.INT32, "buffer": 5, + "quantization": { + "quantized_dimension": 1, + "scale": (1,), + "zero_point": (0,), + }, }, 5: { "shape": (4, 5), "type": tflite.TensorType.INT16, "buffer": 6, + "quantization": { + "quantized_dimension": 1, + "scale": (1, 1, 1, 1, 1), + "zero_point": (0, 0, 0, 0, 0), + }, }, 6: { "shape": (5, 4), @@ -232,8 +262,25 @@ def test_multiple_tables_with_padding(self): "buffer": 7, "quantization": { "quantized_dimension": 0, + "scale": (1, 1, 1, 1, 1), + "zero_point": (0, 0, 0, 0, 0), }, }, + 7: { + "shape": (5, 4), + "type": tflite.TensorType.INT16, + "buffer": 8, + "quantization": { + "quantized_dimension": 0, + "scale": (1,), + "zero_point": (0,), + }, + }, + 8: { + "shape": (16, 1), + "type": tflite.TensorType.UINT8, + "buffer": 9, + }, }, }, }, @@ -260,6 +307,14 @@ def test_multiple_tables_with_padding(self): (9, 10, 11, 12), (13, 14, 15, 16), (17, 18, 19, 20)), dtype=np.dtype(" bytearray: tensor_t.type = tensor["type"] tensor_t.buffer = tensor["buffer"] - try: - d = tensor["quantization"]["quantized_dimension"] + if "quantization" in tensor: tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = d - except KeyError: - tensor_t.quantization = None + tensor_t.quantization.quantizedDimension = \ + tensor["quantization"].get("quantized_dimension", None) + tensor_t.quantization.scale = \ + tensor["quantization"].get("scale", None) + tensor_t.quantization.zeroPoint = \ + tensor["quantization"].get("zero_point", None) subgraph_t.tensors.append(tensor_t)