Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TBE UVM cache line locking - backend (pytorch#1883)
Summary: Pull Request resolved: pytorch#1883 This diff is to support cache prefetch pipeline, where cache insert can execute in parallel with embedding table forward/backward. As cache prefetch may evict cache lines, we must make sure that cache lines that are used by forward/backward won't be evicted. The implementation here targets at training kernel and LRU cache policy. We create a `lxu_cache_locking_counter` of size `(cache_sets, warp_size)` to indicate whether a cache slot is in use (`counter > 0`) or not (`counter = 0`). Operations on `lxu_cache_locking_counter`: In `lru_cache_find_uncached_cuda`, if an index is already in cache, the `lxu_cache_locking_counter` of the corresponding cache_slot is incremented. In `lru_cache_insert_cuda`, we first sort the cache slots based on timestamp within a cache set as the original LRU implementation. When inserting, we check whether the `lxu_cache_locking_counter` of each cache slot to insert is positive of not. If the counter of a cache slot is positive, we skip inserting and move on to next cache slot. If a cache slot is inserted, the `lxu_cache_locking_counter` of that slot is incremented. After the backward pass is done, we call `lxu_cache_locking_counter_decrement` through a backward hook. For any cache_slot in lxu_cache_locations, the counter of that cache_slot is decremented by 1. Duplicate cache_slots only get decrement once. With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: ``` idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. ``` The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D46172802 fbshipit-source-id: fcd4948f9933f6de28a34063081380ccef321574
- Loading branch information