From 51f62a1c82a81580182ba909f72997b71cc5b5d4 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 19 Sep 2024 11:35:04 +0000 Subject: [PATCH] 2024-09-19 nightly release (46e309d58b0e2026311480e2763c91591196694e) --- ...histogram_binning_calibration_benchmark.py | 3 +- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 3 +- .../bench/merge_embeddings_benchmark.py | 3 + fbgemm_gpu/bench/quantize_ops_benchmark.py | 4 +- fbgemm_gpu/bench/sparse_ops_benchmark.py | 3 +- .../bench/split_embeddings_cache_benchmark.py | 3 +- ...plit_table_batched_embeddings_benchmark.py | 3 + .../ssd_table_batched_embeddings_benchmark.py | 5 +- fbgemm_gpu/bench/stride_gemm_benchmark.py | 3 +- fbgemm_gpu/docs/requirements.txt | 2 +- .../gen_ai/src/kv_cache/kv_cache.cpp | 7 +- .../gen_ai/src/kv_cache/kv_cache.cu | 204 ++++++++-- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 17 +- fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py | 3 +- .../ssd_table_batched_embeddings.h | 355 ++++++++++++------ fbgemm_gpu/test/jagged/common.py | 5 +- fbgemm_gpu/test/quantize/common.py | 31 +- fbgemm_gpu/test/sparse/misc_ops_test.py | 26 ++ fbgemm_gpu/test/sparse/pack_segments_test.py | 5 +- .../inference/nbit_split_embeddings_test.py | 56 +++ 20 files changed, 569 insertions(+), 172 deletions(-) diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 592f11496..c919199ee 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -14,7 +14,8 @@ import torch from torch import Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 51c231ad0..acbe22fb2 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -16,7 +16,8 @@ import torch from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 2c0b62664..95ce71d27 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -32,6 +32,9 @@ # pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`. from torch.profiler import profile, ProfilerActivity +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index b4e596f4d..9ffbd9911 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -22,8 +22,8 @@ # pyre-ignore[21] from torch.profiler import profile, ProfilerActivity - -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index fdd051909..2ef9abe8f 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -20,7 +20,8 @@ from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index 432ef3f4d..d3169ca81 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -26,7 +26,8 @@ from torch import nn, Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index a809b5e9b..177c79508 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -48,6 +48,9 @@ from torch import Tensor from torch.profiler import profile +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + haveAIBench = False try: from aibench_observer.utils.observer import emitMetric diff --git a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py index 430087c4a..25540c190 100644 --- a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py @@ -40,14 +40,13 @@ from torch.autograd.profiler import record_function from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) load_torch_module( "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings", ) -logging.basicConfig(level=logging.DEBUG) - @click.group() def cli() -> None: diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 3c70d734f..2609f7fbf 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -13,7 +13,8 @@ import torch from fbgemm_gpu.bench.bench_utils import benchmark_torch_function -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/docs/requirements.txt b/fbgemm_gpu/docs/requirements.txt index 9f3bca439..533232736 100644 --- a/fbgemm_gpu/docs/requirements.txt +++ b/fbgemm_gpu/docs/requirements.txt @@ -14,7 +14,7 @@ sphinx<7 breathe bs4 -docutils +docutils<0.20,>=0.18.1 lxml myst-parser sphinx-lint diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp index 3d5d337ec..38ed3ec6b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -133,7 +133,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v); + std::optional qparam_v, + std::optional block_tables, + int64_t page_size); at::Tensor mqa_attn( at::Tensor XQ, @@ -162,7 +164,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)"); m.impl("dequantize_int4_cache", dequantize_int4_cache); m.def( - "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)"); + "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)"); m.impl("dequantize_fp8_cache", dequantize_fp8_cache); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 0728e6b9a..787c0547c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( } else { __half2* qparam_row = nullptr; auto T = cache_K.size(1); - auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; - if (qkv == QKV::K) { - qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + if (block_tables == nullptr) { + auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } else { - qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + // This is duplicate computation with get_dst_row above. + // TODO: Maybe clean up and merge later. + int page_logical_idx = cache_loc_t / page_size; + int page_offset = cache_loc_t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + auto idx = physical_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } quantize_fp8_kv(dst, dst_row_q, qparam_row); } @@ -1477,16 +1493,113 @@ __global__ void dequantize_fp8_cache_kernel( *reinterpret_cast(&kv_dq.vals[2]); } } + +// Cloned from dequantize_fp8_cache_kernel because +// branching inside the original kernel runs into +// "too many resources requested for launch" which +// necessitates decreasing the number of warps per block, +// which might have performance implications. Also we +// might have more diverging behaviors for paged kernel +// as noted in the comment below so we will keep a separate +// kernel for now. +__global__ void dequantize_fp8_cache_kernel_paged( + // This code currently represents FP8 version not int4 + at::PackedTensorAccessor64 + cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G] + at::PackedTensorAccessor32 kv_seqlen, + at::PackedTensorAccessor64 + cache_K_dq, // [1][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V_dq, // [1][MAX_T][N_KVH][D_H] + int32_t* qparam_k_ptr, + int32_t* qparam_v_ptr, + int32_t* block_tables, + int32_t block_tables_b_stride, + int32_t page_size) { + auto N_KVH = cache_K.size(2); + auto MAX_T = cache_K.size(1); + auto D_H = cache_K_dq.size(3); + auto D_H_q = cache_K.size(3); + CUDA_KERNEL_ASSERT(D_H == 128); + + auto b = blockIdx.x; + // only need to dequantize this far. + auto max_t = kv_seqlen[b]; + + // one warp per T/H + for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + t_h += blockDim.y * gridDim.y) { + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + int page_logical_idx = t / page_size; + int page_offset = t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + + uint8_t* row_k = &cache_K[0][physical_t][h][0]; + uint8_t* row_v = &cache_V[0][physical_t][h][0]; + + bfx8 kv_dq; + uint8_t qparam_offset_bytes; + __half2* qparam_k_src; + __half2* qparam_v_src; + if (qparam_k_ptr) { + // read from standalone qparam tensor + qparam_offset_bytes = 0; + auto idx = physical_t * N_KVH + h; + qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } else { + // read from first row + qparam_offset_bytes = 4; + qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); + qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); + } + // Assert the quantized row dim is as expected + CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes); + if (4 * threadIdx.x >= D_H) { + continue; + } + // each thread reads 4 x 8 bits + + uint64_t kq = *reinterpret_cast( + &row_k[threadIdx.x * 4 + qparam_offset_bytes]); + uint64_t vq = *reinterpret_cast( + &row_v[threadIdx.x * 4 + qparam_offset_bytes]); + + uint64_t packed = kq | (vq << 32); + + kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); + + // now, write our outputs + auto* row_k_dq = &cache_K_dq[0][physical_t][h][0]; + auto* row_v_dq = &cache_V_dq[0][physical_t][h][0]; + // each thread writes 4 elements of type bf16 + *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[0]); + *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[2]); + } +} std::tuple dequantize_fp8_cache( at::Tensor cache_K, at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { TORCH_CHECK(cache_K.is_cuda()); TORCH_CHECK(cache_V.is_cuda()); TORCH_CHECK(kv_seqlen.is_cuda()); - auto B = cache_K.size(0); + auto B = kv_seqlen.size(0); + // vanilla: B_KV = B, paged: B_KV = 1 + auto B_KV = cache_K.size(0); + // vanilla: MAX_T = MAX_T, paged: MAX_T = MAX_PAGE * PAGE_SIZE auto MAX_T = cache_K.size(1); auto N_KVH = cache_K.size(2); auto D_HQ = cache_K.size(3); @@ -1500,31 +1613,72 @@ std::tuple dequantize_fp8_cache( } auto D_H = (D_HQ - fp8_qparam_offset); - auto cache_K_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); - auto cache_V_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + // TODO: + // The below allocates Tensors that have the same shape as cache_K and cache_V + // to store their dequantize results. For paged KV cache, this can be a bit + // inefficient because it has the shape of [1 x (MAX_PAGES * PAGE_SIZE) x + // N_KVH x D_H] to accommodate pages globally across batch instances, and + // if we have very large MAX_PAGES then we are essentially allocating a very + // huge Tensor here. The benefit is that the following users of this + // dequantized results can reuse the existing block_tables to access their + // elements. If we want to be more efficient, there are two possible + // approaches: (1) Allocate a shorter Tensor here and store the dequantize + // results in a more compact manner, but that requires creating a new + // block_tables here and making sure the following users all use the + // correct block_tables. (2) From outside, keep a persistent buffer that has a + // matching shape with the original paged KV and feed the same buffer + // into this function at every layer to reuse it and prevent allocation. + auto cache_K_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + auto cache_V_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); if (B == 0) { return {cache_K_dq, cache_V_dq}; } + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + constexpr int32_t kMaxBlocks = 256; dim3 blocks(B, std::max(1, kMaxBlocks / B)); dim3 threads(kThreadsPerWarp, kWarpsPerBlock); - dequantize_fp8_cache_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_K.packed_accessor64(), - cache_V.packed_accessor64(), - kv_seqlen.packed_accessor32(), - cache_K_dq.packed_accessor64(), - cache_V_dq.packed_accessor64(), - qparam_k_ptr, - qparam_v_ptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (block_tables_ptr == nullptr) { + dequantize_fp8_cache_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dequantize_fp8_cache_kernel_paged<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr, + block_tables_ptr, + block_tables_b_stride, + page_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } return {cache_K_dq, cache_V_dq}; } @@ -1606,7 +1760,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 024170a68..21e858e1a 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -7,7 +7,7 @@ # pyre-strict import math -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -1113,3 +1113,18 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None _setup() + + +@torch.library.register_fake("fbgemm::lengths_range") +def lengths_range_abstract( + lengths: Tensor, + output_shape: Optional[Sequence[int]] = None, +) -> Tensor: + torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor") + output_size = 0 + if output_shape is not None: + output_size = math.prod(output_shape) + else: + ctx = torch.library.get_ctx() + output_size = ctx.new_dynamic_size() + return lengths.new_empty([output_size], dtype=lengths.dtype) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py index fbccebd9a..19927a5ea 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py @@ -11,6 +11,7 @@ from typing import List, Optional, Tuple import numpy as np +import numpy.typing as npt import torch # pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). @@ -135,7 +136,7 @@ def generate_int_data_from_stats( sigma: int, size: int, distribution: str, -) -> np.ndarray: +) -> npt.NDArray: """ Generate integer data based on stats """ diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 77588b5b9..3d1f2c977 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -8,11 +8,14 @@ #pragma once +#include +#include + #include +#include #include #include #include -#include #ifdef FBGEMM_FBCODE #include "common/strings/UUID.h" #include "common/time/Time.h" @@ -126,7 +129,50 @@ class Initializer { /// @brief An implementation of EmbeddingKVDB for RocksDB /// class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { + using snapshot_ptr_t = const rocksdb::Snapshot*; + public: + class SnapshotHandle { + public: + explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { + auto num_shards = db->num_shards(); + CHECK_GT(num_shards, 0); + shard_snapshots_.reserve(num_shards); + for (auto shard = 0; shard < num_shards; ++shard) { + const auto* snapshot = db->dbs_[shard]->GetSnapshot(); + CHECK(snapshot != nullptr) + << "ERROR: create_snapshot fails to create a snapshot " + << "for db shard " << shard << ". Please make sure that " + << "inplace_update_support is set to false" << std::endl; + shard_snapshots_.push_back(snapshot); + } + } + + ~SnapshotHandle() { + for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { + snapshot_ptr_t snapshot = shard_snapshots_[shard]; + CHECK(snapshot != nullptr) + << "Unexpected nullptr for snapshot " << shard; + db_->dbs_[shard]->ReleaseSnapshot(snapshot); + } + } + + void release() { + db_->release_snapshot(this); + } + + snapshot_ptr_t get_snapshot_for_shard(size_t shard) const { + CHECK_LE(shard, shard_snapshots_.size()); + return shard_snapshots_[shard]; + } + + private: + friend class EmbeddingRocksDB; + + EmbeddingRocksDB* db_; + std::vector shard_snapshots_; + }; + explicit EmbeddingRocksDB( std::string path, int64_t num_shards, @@ -152,7 +198,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { max_D, l2_cache_size_gb, tbe_unqiue_id, - row_storage_bitwidth / 8) { + row_storage_bitwidth / 8), + max_D_(max_D) { // TODO: lots of tunables. NNI or something for this? rocksdb::Options options; options.create_if_missing = true; @@ -193,7 +240,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { // causing flush set this to true to make update on the existing key // allow_concurrent_memtable_write is toggled in pair with // inplace_update_support - options.inplace_update_support = true; + options.inplace_update_support = false; options.avoid_unnecessary_blocking_io = true; options.use_direct_reads = true; @@ -341,6 +388,30 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } } + ~EmbeddingRocksDB() override { + // clear all the snapshots if not released + if (snapshots_.size() > 0) { + LOG(WARNING) + << snapshots_.size() + << " snapshots have not been released when db is closing. Releasing them now."; + } + snapshots_.clear(); + for (auto shard = 0; shard < dbs_.size(); ++shard) { + dbs_[shard]->Close(); + } + } + + folly::coro::Task get_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) override { + co_await get_kv_db_async_impl( + indices, + weights, + count, + /*snapshot_handle=*/nullptr); + } + folly::coro::Task set_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, @@ -416,10 +487,162 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { #endif } - folly::coro::Task get_kv_db_async( + bool is_valid_snapshot(const SnapshotHandle* snapshot_handle) const { + return snapshots_.find(snapshot_handle) != snapshots_.end(); + } + + const SnapshotHandle* create_snapshot() { + const auto num_snapshots = snapshots_.size(); + if (num_snapshots > 0) { + std::cerr << "WARNING: create_snapshot found " << num_snapshots + << " other snapshots" << std::endl; + } + + auto handle = std::make_unique(this); + auto handlePtr = handle.get(); + snapshots_[handlePtr] = std::move(handle); + return handlePtr; + } + + void release_snapshot(const SnapshotHandle* snapshot_handle) { + snapshots_.erase(snapshot_handle); + } + + void get_range_from_snapshot( + const at::Tensor& weights, + const int64_t start, + const int64_t length, + const SnapshotHandle* snapshot_handle) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + int64_t* count_ = new int64_t[1]; + count_[0] = length; + const auto count = at::from_blob(count_, {1}, at::kLong); + folly::coro::blockingWait( + get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); + } + + int64_t get_max_D() { + return max_D_; + } + + // collect mem usage on all db shards, checkout rocks_db_mem_properties + std::vector get_mem_usage() { + int num_mem_component = rocks_db_mem_properties.size(); + std::vector mem_usages(num_mem_component); + for (auto& db : dbs_) { + for (int i = 0; i < num_mem_component; i++) { + std::string property = rocks_db_mem_properties[i]; + std::string val; + db->GetProperty(property, &val); + if (val != "") { + if (i != 0) { + mem_usages[i] += folly::to(val); + } else { + mem_usages[i] = folly::to(val); + } + } + } + } + return mem_usages; + } + + std::vector get_rocksdb_io_duration( + const int64_t step, + const int64_t interval) { + std::vector ret; + ret.reserve(5); + if (step > 0 && step % interval == 0) { + int64_t reset_val = 0; + auto read_dur = read_total_duration_.exchange(reset_val); + + auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); + auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); + auto bwd_l1_cnflct_miss_write_back_dur = + bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); + auto flush_write_dur = flush_write_dur_.exchange(reset_val); + + ret.push_back(double(read_dur) / interval); + ret.push_back(double(fwd_rocksdb_read_dur) / interval); + ret.push_back(double(fwd_l1_eviction_dur) / interval); + ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); + ret.push_back(double(flush_write_dur) / interval); + } + return ret; + } + + void compact() override { + for (auto& db : dbs_) { + db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr); + } + } + + void flush() { + kv_db::EmbeddingKVDB::flush(); + for (auto& db : dbs_) { + db->Flush(rocksdb::FlushOptions()); + } + } + + int64_t num_shards() const { + return dbs_.size(); + } + + private: + void flush_or_compact(const int64_t timestep) override { + // Only do manual Flush/Compactions if enabled + if (memtable_flush_period_ > 0) { + { + RECORD_USER_SCOPE("FlushCompactIfNecessary"); + if (!done_staggered_flushes_) { + flush_if_necessary(timestep); + } else { + compact_if_necessary(timestep); + } + } + } + } + + void flush_if_necessary(const int64_t timestep) { + for (int64_t i = 0; i < dbs_.size(); i++) { + if (shard_flush_compaction_deadlines_[i] == timestep) { + rocksdb::FlushOptions fo; + fo.wait = false; + fo.allow_write_stall = false; + dbs_[i]->Flush(fo); + if (i == dbs_.size() - 1) { + done_staggered_flushes_ = true; + int64_t period_per_shard = compaction_period_ / dbs_.size(); + int64_t offset = memtable_flush_offset_ + compaction_period_; + for (int64_t j = 0; j < dbs_.size(); j++) { + shard_flush_compaction_deadlines_[j] = + offset + (j * period_per_shard); + } + } + } + } + } + + void compact_if_necessary(const int64_t timestep) { + for (int64_t i = 0; i < dbs_.size(); i++) { + if (shard_flush_compaction_deadlines_[i] == timestep) { + rocksdb::ColumnFamilyMetaData meta; + dbs_[i]->GetColumnFamilyMetaData(&meta); + int32_t num_level0 = meta.levels[0].files.size(); + if (num_level0 >= l0_files_per_compact_) { + dbs_[i]->CompactRange( + rocksdb::CompactRangeOptions(), nullptr, nullptr); + } + shard_flush_compaction_deadlines_[i] += compaction_period_; + } + } + } + + folly::coro::Task get_kv_db_async_impl( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) override { + const at::Tensor& count, + const SnapshotHandle* snapshot_handle) { RECORD_USER_SCOPE("EmbeddingRocksDB::get"); #ifdef FBGEMM_FBCODE auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -428,9 +651,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { auto count_ = count.item().toLong(); for (auto shard = 0; shard < dbs_.size(); ++shard) { + // Get a snapshot for the shard + snapshot_ptr_t snapshot = snapshot_handle == nullptr + ? nullptr + : snapshot_handle->get_snapshot_for_shard(shard); tasks.emplace_back( folly::coro::co_invoke( - [this, &indices, &weights, count_, shard]() mutable + [this, &indices, &weights, count_, shard, snapshot]() mutable -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { @@ -487,6 +714,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { values.resize(keys.size()); statuses.resize(keys.size()); + // Set a snapshot if it is available + ro_.snapshot = snapshot; dbs_[shard]->MultiGet( ro_, keys.size(), @@ -545,114 +774,6 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { #endif } - // collect mem usage on all db shards, checkout rocks_db_mem_properties - std::vector get_mem_usage() { - int num_mem_component = rocks_db_mem_properties.size(); - std::vector mem_usages(num_mem_component); - for (auto& db : dbs_) { - for (int i = 0; i < num_mem_component; i++) { - std::string property = rocks_db_mem_properties[i]; - std::string val; - db->GetProperty(property, &val); - if (val != "") { - if (i != 0) { - mem_usages[i] += folly::to(val); - } else { - mem_usages[i] = folly::to(val); - } - } - } - } - return mem_usages; - } - - std::vector get_rocksdb_io_duration( - const int64_t step, - const int64_t interval) { - std::vector ret; - ret.reserve(5); - if (step > 0 && step % interval == 0) { - int64_t reset_val = 0; - auto read_dur = read_total_duration_.exchange(reset_val); - - auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); - auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); - auto bwd_l1_cnflct_miss_write_back_dur = - bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); - auto flush_write_dur = flush_write_dur_.exchange(reset_val); - - ret.push_back(double(read_dur) / interval); - ret.push_back(double(fwd_rocksdb_read_dur) / interval); - ret.push_back(double(fwd_l1_eviction_dur) / interval); - ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); - ret.push_back(double(flush_write_dur) / interval); - } - return ret; - } - - void compact() override { - for (auto& db : dbs_) { - db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr); - } - } - - void flush() { - kv_db::EmbeddingKVDB::flush(); - for (auto& db : dbs_) { - db->Flush(rocksdb::FlushOptions()); - } - } - - private: - void flush_or_compact(const int64_t timestep) override { - // Only do manual Flush/Compactions if enabled - if (memtable_flush_period_ > 0) { - { - RECORD_USER_SCOPE("FlushCompactIfNecessary"); - if (!done_staggered_flushes_) { - flush_if_necessary(timestep); - } else { - compact_if_necessary(timestep); - } - } - } - } - - void flush_if_necessary(const int64_t timestep) { - for (int64_t i = 0; i < dbs_.size(); i++) { - if (shard_flush_compaction_deadlines_[i] == timestep) { - rocksdb::FlushOptions fo; - fo.wait = false; - fo.allow_write_stall = false; - dbs_[i]->Flush(fo); - if (i == dbs_.size() - 1) { - done_staggered_flushes_ = true; - int64_t period_per_shard = compaction_period_ / dbs_.size(); - int64_t offset = memtable_flush_offset_ + compaction_period_; - for (int64_t j = 0; j < dbs_.size(); j++) { - shard_flush_compaction_deadlines_[j] = - offset + (j * period_per_shard); - } - } - } - } - } - - void compact_if_necessary(const int64_t timestep) { - for (int64_t i = 0; i < dbs_.size(); i++) { - if (shard_flush_compaction_deadlines_[i] == timestep) { - rocksdb::ColumnFamilyMetaData meta; - dbs_[i]->GetColumnFamilyMetaData(&meta); - int32_t num_level0 = meta.levels[0].files.size(); - if (num_level0 >= l0_files_per_compact_) { - dbs_[i]->CompactRange( - rocksdb::CompactRangeOptions(), nullptr, nullptr); - } - shard_flush_compaction_deadlines_[i] += compaction_period_; - } - } - } - std::vector> dbs_; std::vector> initializers_; std::unique_ptr executor_; @@ -672,6 +793,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::atomic fwd_l1_eviction_dur_{0}; std::atomic bwd_l1_cnflct_miss_write_back_dur_{0}; std::atomic flush_write_dur_{0}; -}; // class EmbeddingKVDB + + std::unordered_map> + snapshots_; + int64_t max_D_; +}; // class EmbeddingRocksDB } // namespace ssd diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 6cd60d96c..3bdcacf98 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -16,6 +16,7 @@ import fbgemm_gpu import fbgemm_gpu.sparse_ops import numpy as np +import numpy.typing as npt import torch from hypothesis import HealthCheck, settings @@ -122,7 +123,7 @@ def generate_jagged_tensor( # dynamo to mark the input as dynamic shape to make sure symbolic # shape is generated mark_dynamic: bool = False, -) -> Tuple[torch.Tensor, List[torch.LongTensor], np.ndarray]: +) -> Tuple[torch.Tensor, List[torch.LongTensor], npt.NDArray]: max_lengths = np.random.randint(low=1, high=10, size=(num_jagged_dim,)) x_offsets: List[torch.LongTensor] = [] num_lengths = outer_dense_size @@ -167,7 +168,7 @@ def generate_jagged_tensor( def to_padded_dense( values: torch.Tensor, offsets: List[torch.LongTensor], - max_lengths: np.ndarray, + max_lengths: npt.NDArray, padding_value: float = 0, ) -> torch.Tensor: outer_dense_size = len(offsets[0]) - 1 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 392bbfcac..5333cc893 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -12,6 +12,7 @@ import fbgemm_gpu import numpy as np +import numpy.typing as npt import torch # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. @@ -30,17 +31,17 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") # Eigen/Python round 0.5 away from 0, Numpy rounds to even -round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round) +round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) -def bytes_to_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32) for i, byte_values in enumerate(byte_matrix): (floats[i],) = struct.unpack("f", bytearray(byte_values)) return floats -def floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float32), (value, floats) @@ -53,7 +54,7 @@ def floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_half_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float16) for i, byte_values in enumerate(byte_matrix): (floats[i],) = np.frombuffer( @@ -62,7 +63,7 @@ def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: return floats -def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def half_floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 2], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float16), (value, floats) @@ -72,7 +73,7 @@ def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_quantize_reference(data: npt.NDArray) -> npt.NDArray: minimum = np.min(data, axis=-1, keepdims=True) maximum = np.max(data, axis=-1, keepdims=True) span = maximum - minimum @@ -87,7 +88,9 @@ def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=-1) -def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_dequantize_reference( + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_floats(fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4)) scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],)) bias = bytes_to_floats(fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4)) @@ -97,8 +100,8 @@ def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.n def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., 0:2].astype(np.uint8).reshape(-1, 2) ) @@ -112,8 +115,8 @@ def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( def fused_rowwise_8bit_dequantize_reference_half( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4) ) @@ -126,7 +129,7 @@ def fused_rowwise_8bit_dequantize_reference_half( return quantized_data * scale + bias -def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndarray: +def fused_rowwise_nbit_quantize_reference(data: npt.NDArray, bit: int) -> npt.NDArray: minimum = np.min(data, axis=1).astype(np.float16).astype(np.float32) maximum = np.max(data, axis=1) span = maximum - minimum @@ -165,8 +168,8 @@ def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndar def fused_rowwise_nbit_quantize_dequantize_reference( - data: np.ndarray, bit: int -) -> np.ndarray: + data: npt.NDArray, bit: int +) -> npt.NDArray: fused_quantized = fused_rowwise_nbit_quantize_reference(data, bit) scale = bytes_to_half_floats(fused_quantized[:, -4:-2].astype(np.uint8)).astype( np.float32 diff --git a/fbgemm_gpu/test/sparse/misc_ops_test.py b/fbgemm_gpu/test/sparse/misc_ops_test.py index fb9e29c81..41187b502 100644 --- a/fbgemm_gpu/test/sparse/misc_ops_test.py +++ b/fbgemm_gpu/test/sparse/misc_ops_test.py @@ -18,6 +18,7 @@ import numpy as np import torch from hypothesis import given, settings, Verbosity +from torch.fx.experimental.symbolic_shapes import ShapeEnv from .common import extend_test_class, open_source @@ -261,6 +262,31 @@ def test_bottom_unique_k_per_row( all_indices_deduped_ref = torch.as_tensor(all_indices[:, :, :L]) torch.testing.assert_close(all_indices_deduped, all_indices_deduped_ref) + def test_lengths_range(self) -> None: + # When 'output_shape' is None, the function will return a tensor with dynamic shape. + with self.assertRaisesRegex( + torch._subclasses.fake_tensor.DynamicOutputShapeException, + "fbgemm.lengths_range.default", + ): + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + _ = torch.ops.fbgemm.lengths_range(lengths, None) + + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + output_shape = [1, 2, 4, 4] + actual_result = torch.ops.fbgemm.lengths_range(lengths, output_shape) + + self.assertEqual(actual_result.shape, (1 * 2 * 4 * 4,)) + extend_test_class(MiscOpsTest) diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index c6383017a..d6b40328e 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -15,6 +15,7 @@ import hypothesis.strategies as st import numpy as np +import numpy.typing as npt import torch from hypothesis import given, settings @@ -27,7 +28,7 @@ from fbgemm_gpu.test.test_utils import gpu_available -def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray: +def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: """Get a list of `n` integers which collectively sum to `k`, drawn uniformly from the set of all such lists. @@ -58,7 +59,7 @@ def _pack_segments_ref( lengths: torch.Tensor, tensor: torch.Tensor, max_length: Optional[int] = None, - ) -> np.ndarray: + ) -> npt.NDArray: lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index ed4245dbd..439797688 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -347,6 +347,62 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( ) torch.testing.assert_close(output_uvm, output_ref, equal_nan=True) + @given( + weights_ty=st.sampled_from( + [ + SparseType.FP32, + SparseType.FP16, + SparseType.INT8, + SparseType.INT4, + SparseType.INT2, + ] + ), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_int_nbit_split_embedding_cpu_mixed_indices_offsets_dtypes( + self, + weights_ty: SparseType, + ) -> None: + T = random.randint(1, 5) + B = random.randint(1, 128) + L = random.randint(1, 20) + D = random.randint(2, 256) + log_E = random.randint(3, 5) + + iters = 4 + E = int(10**log_E) + + D_alignment = ( + 1 if weights_ty.bit_rate() % 8 == 0 else int(8 / weights_ty.bit_rate()) + ) + D = round_up(D, D_alignment) + + Ds = [D] * T + Es = [E] * T + cpu_locations = [EmbeddingLocation.HOST] * T + + cc = IntNBitTableBatchedEmbeddingBagsCodegen( + [("", E, D, weights_ty, M) for (E, D, M) in zip(Es, Ds, cpu_locations)], + device=torch.device("cpu"), + ) + cc.fill_random_weights() + + requests = generate_requests( + iters, B, T, L, min(Es), reuse=0.1, emulate_pruning=False, use_cpu=True + ) + dtypes_combo = [ + (torch.int64, torch.int64), + (torch.int32, torch.int32), + (torch.int32, torch.int64), + (torch.int64, torch.int32), + ] + for i, req in enumerate(requests): + indices, offsets = req.unpack_2() + indices_dtype, offsets_dtype = dtypes_combo[i] + indices = indices.to(indices_dtype) + offsets = offsets.to(offsets_dtype) + _ = cc(indices, offsets) + if __name__ == "__main__": unittest.main()