Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I got the error when I run run_event.py #6

Open
lauht opened this issue Jan 2, 2022 · 1 comment
Open

I got the error when I run run_event.py #6

lauht opened this issue Jan 2, 2022 · 1 comment

Comments

@lauht
Copy link

lauht commented Jan 2, 2022

Traceback (most recent call last):
File "run_event.py", line 489, in
main()
File "run_event.py", line 463, in main
outputs = model(input_ids, attention_mask=attention_mask, seq_labels=seq_labels, ner_labels=ner_labels)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data1/lht/TradeTheEvent/utils/model.py", line 170, in forward
seq_logits = self.final_classifier1(seq_logits)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 94, in forward
return F.linear(input, self.weight, self.bias)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 dim 1 must match mat2 dim 0

Is the error caused by networks, or the version of torch is not valid?

@jeremytanjianle
Copy link

run_event.py seems untested.
Amongst other issues, the config.num_labels is not fixed according to the argparser but to BERT's default config.
Insert the following code in run_event.py roughly around line 397:

    logger.info(
        'Total training batch size: {}'.format(args.per_gpu_batch_size * args.gradient_accumulation_steps * args.n_gpu))

    config = BertConfig.from_pretrained(args.model_type)
    # config.num_labels = 12                                     
    config.num_labels = args.num_labels                   # insert this line        
    config.max_seq_length = args.max_seq_length           # insert this line
    model = MODEL_CLASS.from_pretrained(args.model_type, config=config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants