Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent results when recreating model instance #77

Open
b-nils opened this issue Jan 20, 2024 · 2 comments
Open

Inconsistent results when recreating model instance #77

b-nils opened this issue Jan 20, 2024 · 2 comments

Comments

@b-nils
Copy link

b-nils commented Jan 20, 2024

Hello there,

First of all, thanks a lot for publishing your code and actively answering questions on GitHub!!

I ran into problems when I instantiated a new instance of the FEDformer model class and loading weights of a previous run. I could not reproduce the same scores. Moreover, the scores substantially differed each time I reloaded the weights.

Hence, I checked whether model outputs stayed the same when reloading previous model weights while feeding consistent inputs. Therefore, I have created 3 model outputs, which should all be the same:

  1. creating a model instance and running inference
  2. creating a NEW model instance, loading the model weights from 1) and running inference
  3. loading the model instance from 1) as a whole and running inference

1., 2., 3. should be the same. This is true if model version is Wavelets. However, if the model version is Fourier this leads to 1. != 2.. As a result, I am not able to use the weights of a Fourier model I trained. I could save and load the model instance as a whole, but I am still wondering why this is the case. Or did I miss something?

This is the code to reproduce my findings:

class Configs(object):
    ab = 0
    modes = 32
    mode_select = 'random'
    # version = 'Wavelets'
    version = 'Fourier'
    moving_avg = [12, 24]
    L = 1
    base = 'legendre'
    cross_activation = 'tanh'
    seq_len = 96
    label_len = 48
    pred_len = 96
    output_attention = False
    enc_in = 7
    dec_in = 7
    d_model = 16
    embed = 'timeF'
    dropout = 0.05
    freq = 'h'
    factor = 1
    n_heads = 8
    d_ff = 16
    e_layers = 2
    d_layers = 1
    c_out = 7
    activation = 'gelu'
    wavelet = 0
    
# consistent input of just ones
configs = Configs()
enc = torch.ones([3, configs.seq_len, 7])
enc_mark = torch.ones([3, configs.seq_len, 4])
dec = torch.ones([3, configs.seq_len//2+configs.pred_len, 7])
dec_mark = torch.ones([3, configs.seq_len//2+configs.pred_len, 4])

# 1) creating a model instance and running inference
model = Model(configs)
model.eval()
out_1 = model.forward(enc, enc_mark, dec, dec_mark)
# saving only the model weights like it is done in the training
torch.save(model.state_dict(), "./model_weights.pt")
# saving the class instance as a whole
torch.save(model, "./model_class_instance.pt")

# 2) creating a new model instance and loading previous model weights
model = Model(configs)
model.load_state_dict(torch.load("./model_weights.pt"))
model.eval()
out_2 = model.forward(enc, enc_mark, dec, dec_mark)  # <-- this leads to inconistent results when 'version' is "Fourier"

# 3) loading the whole model instance
model = torch.load("./model_class_instance.pt")
model.eval()
out_3 = model.forward(enc, enc_mark, dec, dec_mark)

# this should always output: True, True
print(torch.equal(out_1, out_2), torch.equal(out_1, out_3))

I appreciate any help on this matter!!

@b-nils
Copy link
Author

b-nils commented Jan 20, 2024

Update: This is due to the mode_select set to random.

However, how do you intend to reload a model if mode_select is set to random? Should I save the whole class instance with torch.save(model, "./model.pt") then?

When only saving the model's state dict, I lose the information on which modes were selected during model creation...

@efg001
Copy link

efg001 commented Jul 6, 2024

Can you elaborate on your problem with reloading random model?
The repo uses random mode by default and the code does not just work for reloading random mode. It's already saving weights, I just had to change this line in run.py

    setting = '{}_{}_{}_modes{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(args.task_id,
                                                                                                  args.model,
                                                                                                  args.mode_select,
                                                                                                  args.modes,
                                                                                                  args.data,
                                                                                                  args.features,
                                                                                                  args.seq_len,
                                                                                                  args.label_len,
                                                                                                  args.pred_len,
                                                                                                  args.d_model,
                                                                                                  args.n_heads,
                                                                                                  args.e_layers,
                                                                                                  args.d_layers,
                                                                                                  args.d_ff,
                                                                                                  args.factor,
                                                                                                  args.embed,
                                                                                                  args.distil,
                                                                                                  args.des, ii)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants