Skip to content

Commit

Permalink
TBE UVM cache line locking - backend (pytorch#1883)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1883

This diff is to support cache prefetch pipeline, where cache insert can execute in parallel with embedding table forward/backward. As cache prefetch may evict cache lines, we must make sure that cache lines that are used by forward/backward won't be evicted.

The implementation here targets at training kernel and LRU cache policy. We create a `lxu_cache_locking_counter` of size `(cache_sets, warp_size)` to indicate whether a cache slot is in use (`counter > 0`) or not (`counter = 0`).

Operations on `lxu_cache_locking_counter`:

In `lru_cache_find_uncached_cuda`, if an index is already in cache, the `lxu_cache_locking_counter` of the corresponding cache_slot is incremented.

In `lru_cache_insert_cuda`, we first sort the cache slots based on timestamp within a cache set as the original LRU implementation. When inserting, we check whether the `lxu_cache_locking_counter` of each cache slot to insert is positive of not. If the counter of a cache slot is positive, we skip inserting and move on to next cache slot. If a cache slot is inserted, the `lxu_cache_locking_counter` of that slot is incremented.

After the backward pass is done, we call `lxu_cache_locking_counter_decrement` through a backward hook. For any cache_slot in lxu_cache_locations, the counter of that cache_slot is decremented by 1. Duplicate cache_slots only get decrement once.

With pipeline,  in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i.

Example of the issue is as follows:

```
        idx is in batch_i, batch_{i+1}
        prefetch(batch_i)
          - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss)
        forward(batch_i)
        prefetch(batch_{i+1})
          - insert idx into cache, cache is loaded from host memory
        backward(batch_i)
          - cache_locations_batch_i of idx is -1, the host memory is updated
        forward(batch_{i+1})
          - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated.
```

The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE.

Reviewed By: sryap

Differential Revision: D46172802

fbshipit-source-id: fcd4948f9933f6de28a34063081380ccef321574
  • Loading branch information
yuguo68 authored and facebook-github-bot committed Jul 20, 2023
1 parent 8ce4fde commit 4096e8d
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 18 deletions.
22 changes: 20 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
int64_t time_stamp,
at::Tensor lru_state,
bool gather_cache_stats,
at::Tensor uvm_cache_stats);
at::Tensor uvm_cache_stats,
bool lock_cache_line,
at::Tensor lxu_cache_locking_counter);

///@ingroup table-batched-embed-cuda
/// Map index to cache_set. h_in: linear_indices; C: #cache_sets.
Expand Down Expand Up @@ -71,7 +73,9 @@ void lru_cache_populate_cuda(
at::Tensor lru_state,
bool stochastic_rounding,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);
c10::optional<at::Tensor> uvm_cache_stats,
bool lock_cache_line,
c10::optional<at::Tensor> lxu_cache_locking_counter);

///@ingroup table-batched-embed-cuda
/// LRU cache: fetch the rows corresponding to `linear_cache_indices` from
Expand Down Expand Up @@ -206,3 +210,17 @@ void reset_weight_momentum_cuda(
at::Tensor cache_hash_size_cumsum,
at::Tensor lxu_cache_state,
int64_t total_cache_hash_size);

///@ingroup table-batched-embed-cuda
/// Decrement the LRU/LFU cache counter based on lxu_cache_locations.
void lxu_cache_locking_counter_decrement_cuda(
at::Tensor lxu_cache_locking_counter,
at::Tensor lxu_cache_locations);

///@ingroup table-batched-embed-cuda
/// Inplace update lxu_cache_locations to the new one
/// should only update if lxu_cache_locations[i] == -1
/// and lxu_cache_locations_new[i] >= 0
void lxu_cache_locations_update_cuda(
at::Tensor lxu_cache_locations,
at::Tensor lxu_cache_locations_new);
Loading

0 comments on commit 4096e8d

Please sign in to comment.