Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I do train grayscale-image on MUNIT ?? #170

Open
edwardcho opened this issue Mar 4, 2022 · 2 comments
Open

Can I do train grayscale-image on MUNIT ?? #170

edwardcho opened this issue Mar 4, 2022 · 2 comments

Comments

@edwardcho
Copy link

Hello Sir,

I have interesting image-to-image translation.
So I tried to your code using my-datasets.

My-datasets are as follows :

  1. grayscale (1 channel)
  2. 256 x 256

When start training, I met some error.

Namespace(b1=0.5, b2=0.999, batch_size=4, channels=1, checkpoint_interval=-1, dataset_name='noise2clip', decay_epoch=2, dim=64, epoch=0, img_height=256, img_width=256, lr=0.0001, n_cpu=8, n_downsample=2, n_epochs=4, n_residual=3, sample_interval=400, style_dim=8)
/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py:288: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  "Argument interpolation should be of type InterpolationMode instead of int. "
../../data/noise2clip/trainA
../../data/noise2clip/valA
Traceback (most recent call last):
  File "munit.py", line 171, in <module>
    for i, batch in enumerate(dataloader):
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/data/TESTBOARD/additional_networks/generation/PyTorch-GAN_eriklindernoren/implementations/munit/datasets.py", line 40, in __getitem__
    img_A = self.transform(img_A)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 226, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 351, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 256, 256] doesn't match the broadcast shape [3, 256, 256]

How to train on my-case (using grayscale-datam MUNIT) ?

Thanks.
Edward Cho.

@Ignoramus-Sage
Copy link

This is my hypothesis: The error stems from that your images are of dimension 1 along the channels dimension. What you can do is add two channels (stack the same image along the channel axis duplicating the values you already have, R=G=B=Gray). It should work however I am not sure about performance as I have no idea on the paper or the code.

@algocompretto
Copy link

You are trying to pass an RGB image to the model, instead, apply grayscale tranformation to your data loading pipeline.

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder


transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

dataloader = torch.utils.data.DataLoader(
    ImageFolder("dataset_folder_path/",
     transform=transform),
    batch_size=16, shuffle=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants