Skip to content

Commit

Permalink
simplying the engine
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Dec 10, 2024
1 parent ef19a99 commit 2b4a359
Showing 1 changed file with 30 additions and 48 deletions.
78 changes: 30 additions & 48 deletions viscy/representation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,22 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
key, grid, self.current_epoch, dataformats="HWC"
)

def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]):
"""Common method for logging step samples"""
if batch_idx < self.log_batches_per_epoch:
output_list = (
self.training_step_outputs
if stage == "train"
else self.validation_step_outputs
)
output_list.extend(detach_sample(samples, self.log_samples_per_batch))

def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
anchor_img = batch["anchor"]
pos_img = batch["positive"]
anchor_projection = self(anchor_img)
positive_projection = self(pos_img)
negative_projection = None
if isinstance(self.loss_function, NTXentLoss):
indices = torch.arange(
0, anchor_projection.size(0), device=anchor_projection.device
Expand All @@ -121,36 +132,21 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
# Note: we assume the two augmented views are the anchor and positive samples
embeddings = torch.cat((anchor_projection, positive_projection))
loss = self.loss_function(embeddings, labels)
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=None,
stage="train",
)
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
detach_sample((anchor_img, pos_img), self.log_samples_per_batch)
)
self._log_step_samples(batch_idx, (anchor_img, pos_img), "train")
else:
neg_img = batch["negative"]
negative_projection = self(neg_img)
loss = self.loss_function(
anchor_projection, positive_projection, negative_projection
)
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=negative_projection,
stage="train",
)
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
detach_sample(
(anchor_img, pos_img, neg_img), self.log_samples_per_batch
)
)
self._log_step_samples(batch_idx, (anchor_img, pos_img, neg_img), "train")
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=negative_projection,
stage="train",
)
return loss

def on_train_epoch_end(self) -> None:
Expand All @@ -164,6 +160,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
pos_img = batch["positive"]
anchor_projection = self(anchor)
positive_projection = self(pos_img)
negative_projection = None
if isinstance(self.loss_function, NTXentLoss):
indices = torch.arange(
0, anchor_projection.size(0), device=anchor_projection.device
Expand All @@ -172,36 +169,21 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
# Note: we assume the two augmented views are the anchor and positive samples
embeddings = torch.cat((anchor_projection, positive_projection))
loss = self.loss_function(embeddings, labels)
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=None,
stage="val",
)
if batch_idx < self.log_batches_per_epoch:
self.validation_step_outputs.extend(
detach_sample((anchor, pos_img), self.log_samples_per_batch)
)
self._log_step_samples(batch_idx, (anchor, pos_img), "val")
else:
neg_img = batch["negative"]
negative_projection = self(neg_img)
loss = self.loss_function(
anchor_projection, positive_projection, negative_projection
)
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=negative_projection,
stage="val",
)
if batch_idx < self.log_batches_per_epoch:
self.validation_step_outputs.extend(
detach_sample(
(anchor, pos_img, neg_img), self.log_samples_per_batch
)
)
self._log_step_samples(batch_idx, (anchor, pos_img, neg_img), "val")
self._log_metrics(
loss=loss,
anchor=anchor_projection,
positive=positive_projection,
negative=negative_projection,
stage="val",
)
return loss

def on_validation_epoch_end(self) -> None:
Expand Down

0 comments on commit 2b4a359

Please sign in to comment.