Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting a 'nan' Loss after about 12 training epochs #85

Open
marc-martini opened this issue Jul 15, 2024 · 4 comments
Open

Getting a 'nan' Loss after about 12 training epochs #85

marc-martini opened this issue Jul 15, 2024 · 4 comments

Comments

@marc-martini
Copy link

Hi,

Well done on this amazing work and thank you so much for putting this on github and sharing.

My apologies if this is a silly question.
I have been struggling to figure out where i am going wrong.

I am trying to recreate the results on the Electricity data set.
The training runs perfectly, however up till about epoch 12 or 13. At this point i get a 'nan' loss.

Please help me understand where i am going wrong.

thank you

@tianzhou2011
Copy link
Collaborator

I am not sure what happened here....maybe add a breakpoint using pdb.set_trace() can help you identify the problem. Just add a if loss==nan: pdb.set_trace() to check the intermediate variable or input values.

@efg001
Copy link

efg001 commented Jul 16, 2024

Was going to open an issue on this...
@marc-martini in additional what Tianzhou shared, also make sure you are running on an A100 equivalent GPU that has good support for FP64 for training(lmk what device you ran it on : )
If you are running on A100, continue reading....

I ran into similar issues and have isolated the issue at the line applying activation function on the result of frequency domain attention which contains complex number (this line xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) . All numbers in the input matrix are 'good'[not nan or inf ] but I got some inf entries after it)
I 'fixed' it by using softmax activation instead of tanh can you give it a try?
(First set softmax as the activation function in run config then add activation=configs.cross_activation to

encoder_self_att = FourierBlock(in_channels=configs.d_model,
out_channels=configs.d_model,
seq_len=self.seq_len,
modes=configs.modes,
mode_select_method=configs.mode_select)
decoder_self_att = FourierBlock(in_channels=configs.d_model,
out_channels=configs.d_model,
seq_len=self.seq_len//2+self.pred_len,
modes=configs.modes,
mode_select_method=configs.mode_select)
decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model,
out_channels=configs.d_model,
seq_len_q=self.seq_len//2+self.pred_len,
seq_len_kv=self.seq_len,
modes=configs.modes,
mode_select_method=configs.mode_select)
)
(detail below)

First of all I dont know if PyTorch's support for complex number has been through the test of time
see
1.pytorch/pytorch#47052
2.https://pytorch.org/docs/stable/complex_numbers.html (Complex tensors is a beta feature and subject to change. )

Second, I believe the default activation function tanh is not a good fit for complex number
image

I swapped out tanh with softmax and no longer see nan weight/loss/gradient
image

I haven't seen this when running the code with some of the ETT and I haven't tried running it on electricity data either. I got nan when trying to run the model on my own dataset.

This is just of something I found. I haven't found enough evidence to support my theory for opening an issue: feel free to ignore it TianZhou -- I am only sharing it now because Marc just ran into the same issue. I am working on something else at the moment will loop back to this.

@marc-martini
Copy link
Author

Thank you for the guidance. I tried changing to sofmax, however with no change.
What i have got it down to is that the weights of the weights of the Conv1d and the Linear layers in the TokenEmbedding and TemporalEmbedding Layers become 'nan' after some training time.
Any ideas?

thank you

@efg001
Copy link

efg001 commented Jul 20, 2024

Weight become nans are the result of invalid gradient update, invalid gradients are calculated from invalid layer output using SGD.

Whatever code you added to capture the nan, I'm guessing it either 1. missed the initial nan output from layer or 2. did not capture an invalid layer output for example inf. I think we want to know which layer is the root cause of the nan weight

        def nan_hook(module, input, output, model_name):
            # Check if the output is a tuple and handle accordingly
            if isinstance(output, tuple):
                outputs = output
            else:
                outputs = (output,)

            for out in outputs:
                if  isinstance(out, list) or out is None: continue # that attention list is prob for debugger skip for now todo
                if torch.isnan(out).any():
                    print(f"NaN detected in {module}, model {model_name}")
                    for name, param in module.named_parameters():
                        if torch.isnan(param).any():
                            print(f"NaN detected in parameter: {name}")
                    pdb.set_trace()
                    #raise ValueError("NaN detected during forward pass")

        model = model_dict[self.args.model].Model(self.args).float()
        if(self.args.detect_nan):
            def register_hooks(module):
                # Register hook on the current module regardless of whether it has children
                module.register_forward_hook(lambda m, i, o: nan_hook(m, i, o, self.args))

                # Recurse through children modules to register hooks on them as well
                for child_name, child_module in module.named_children():
                    register_hooks(child_module)

            register_hooks(model)
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model
...
this is what I used

Also try to print loss for every iteration and check if loss and gradient are within reasonable range

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants