Skip to content

Commit

Permalink
Add SSDScratchPadIndicesQueue lookup in frontend (pytorch#2948)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2948

X-link: facebookresearch/FBGEMM#50

This diff updates SSD-TBE frontend to use
`torch.classes.fbgemm.SSDScratchPadIndicesQueue` for scratch pad
lookup (added in D60363607).  `SSDScratchPadIndicesQueue` is for
storing scratch pad indices (conflict missed indices) from previous
iterations. It is used during the L1 cache prefetching step: instead
of fetching the missing indices directly from SSD, TBE will lookup the
scatch pad index queue first to check whether the missing data is in
the scratch pad from the previous iteration.

The high-level workflow of the prefetch step in SSD-TBE is shown in
the figure below:

 {F1795380801}

https://internalfb.com/excalidraw/EX264055

Reviewed By: ehsanardestani

Differential Revision: D60413116

fbshipit-source-id: 95be2f372e7ee5270288ed7f6a643350969b67d6
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 13, 2024
1 parent 9f17b23 commit 4ae45b7
Showing 1 changed file with 107 additions and 20 deletions.
127 changes: 107 additions & 20 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,22 @@ def __init__(
self.ssd_event_evict = torch.cuda.Event()
# SSD backward completion event
self.ssd_event_backward = torch.cuda.Event()
# SSD scratch pad eviction completion event
self.ssd_event_evict_sp = torch.cuda.Event()
# SSD get's input copy completion event
self.ssd_event_get_inputs_cpy = torch.cuda.Event()
# SSD scratch pad index queue insert completion event
self.ssd_event_sp_idxq_insert = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
self.ssd_prefetch_data = []

# Scratch pad value queue
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# pyre-ignore[4]
# Scratch pad index queue
self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue(-1)

if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization:
raise AssertionError(
"weight_decay_mode = WeightDecayMode.COUNTER is not supported for SSD TBE."
Expand Down Expand Up @@ -427,10 +432,6 @@ def __init__(
torch.zeros(0, device=self.current_device, dtype=torch.float)
)

# Register backward hook for evicting rows from a scratch pad to SSD
# post backward
self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad)

assert optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
Expand Down Expand Up @@ -578,8 +579,8 @@ def evict(
indices_cpu: Tensor,
actions_count_cpu: Tensor,
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
pre_event: Optional[torch.cuda.Event],
post_event: Optional[torch.cuda.Event],
is_rows_uvm: bool,
name: Optional[str] = "",
) -> None:
Expand Down Expand Up @@ -607,7 +608,8 @@ def evict(
"""
with record_function(f"## ssd_evict_{name} ##"):
with torch.cuda.stream(stream):
stream.wait_event(pre_event)
if pre_event is not None:
stream.wait_event(pre_event)

rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)

Expand All @@ -622,13 +624,17 @@ def evict(
self.timestep,
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_eviction_stream)
stream.record_event(post_event)
if post_event is not None:
stream.record_event(post_event)

def _evict_from_scratch_pad(self, return_on_empty: bool) -> None:
scratch_pad_len = len(self.ssd_scratch_pads)

if not return_on_empty:
assert scratch_pad_len > 0, "There must be at least one scratch pad"
elif scratch_pad_len == 0:
return

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = (
self.ssd_scratch_pads.pop(0)
)
Expand All @@ -640,7 +646,7 @@ def _evict_from_scratch_pad(self, grad: Tensor) -> None:
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)
Expand Down Expand Up @@ -683,6 +689,20 @@ def _compute_cache_ptrs(
)
)

# Insert conflict miss indices in the index queue for future lookup
# post_bwd_evicted_indices_cpu is transferred on the ssd_eviction_stream stream
# actions_count_cpu is transferred on the ssd_memcpy_stream stream
with torch.cuda.stream(self.ssd_eviction_stream):
# Ensure that actions_count_cpu transfer is done
self.ssd_eviction_stream.wait_event(self.ssd_event_get_inputs_cpy)
self.record_function_via_dummy_profile(
"## ssd_scratch_pad_idx_queue_insert ##",
self.scratch_pad_idx_queue.insert_cuda,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
)
self.ssd_eviction_stream.record_event(self.ssd_event_sp_idxq_insert)

with record_function("## ssd_scratch_pads ##"):
# Store scratch pad info for post backward eviction
self.ssd_scratch_pads.append(
Expand Down Expand Up @@ -778,12 +798,76 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
)

current_stream = torch.cuda.current_stream()

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)
if len(self.ssd_scratch_pads) > 0:
with record_function("## ssd_lookup_scratch_pad ##"):
current_stream.wait_event(self.ssd_event_sp_idxq_insert)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

(
inserted_rows_prev,
post_bwd_evicted_indices_cpu_prev,
actions_count_cpu_prev,
do_evict_prev,
) = self.ssd_scratch_pads.pop(0)

# Inserted indices that are found in the scratch pad
# from the previous iteration
sp_locations_cpu = torch.empty(
inserted_indices_cpu.shape,
dtype=inserted_indices_cpu.dtype,
pin_memory=True,
)

# Before entering this function: inserted_indices_cpu
# contains all linear indices that are missed from the
# L1 cache
#
# After this function: inserted indices that are found
# in the scratch pad from the previous iteration are
# stored in sp_locations_cpu, while the rests are
# stored in inserted_indices_cpu
#
# An invalid index is -1 or its position >
# actions_count_cpu
self.record_function_via_dummy_profile(
"## ssd_lookup_mask_and_pop_front ##",
self.scratch_pad_idx_queue.lookup_mask_and_pop_front_cuda,
sp_locations_cpu,
post_bwd_evicted_indices_cpu_prev,
inserted_indices_cpu,
actions_count_cpu,
)

# Transfer sp_locations_cpu to GPU
sp_locations_gpu = sp_locations_cpu.cuda(non_blocking=True)

# Copy data from the previous iteration's scratch pad to
# the current iteration's scratch pad
torch.ops.fbgemm.masked_index_select(
inserted_rows,
sp_locations_gpu,
inserted_rows_prev,
actions_count_gpu,
)

# Evict from scratch pad
if do_evict_prev:
torch.cuda.current_stream().record_event(
self.ssd_event_backward
)
self.evict(
rows=inserted_rows_prev,
indices_cpu=post_bwd_evicted_indices_cpu_prev,
actions_count_cpu=actions_count_cpu_prev,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=None,
is_rows_uvm=True,
name="scratch_pad",
)

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

if linear_cache_indices.numel() > 0:
Expand Down Expand Up @@ -1030,6 +1114,9 @@ def flush(self) -> None:
active_slots_mask_cpu.view(-1)
)

# Evict data from scratch pad if there is scratch pad in the queue
self._evict_from_scratch_pad(return_on_empty=True)

torch.cuda.current_stream().wait_stream(self.ssd_eviction_stream)

self.ssd_db.set_cuda(
Expand Down

0 comments on commit 4ae45b7

Please sign in to comment.