Fine-tune Model to calculate CTC loss in Inference part #2991
-
In the inference part, the model used is a fine-tuned stt-en-citrinet model. Now I want to move one step back that only calculates the loss (CTC loss). Is there any function and how to do that? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
If you have the ground truth labels, you can follow the implementation of the training_step() in EncDecCTCModel and see how we call forward() and then pass the logits to the loss function. Normally you could use model.transcribe() with logprobs=True to get the logits to pass to the loss function, however that doesn't provide the length of the actual encoded audio segment to pass to CTC loss. You can approximate it with original acoustic length after preprocessing // model stride (dependent on each model) and pass that to the CTC loss. We will look into more useful ways of storing this information and providing it to users via transcribe(). But this approach should work in the mean time. Needless to say, you will need the ground truth labels to undergo the exact same preprocessing - basically tokenization and detokenization. It would make things easier to try to follow the setup data methods and then use the created dataloader directly |
Beta Was this translation helpful? Give feedback.
If you have the ground truth labels, you can follow the implementation of the training_step() in EncDecCTCModel and see how we call forward() and then pass the logits to the loss function.
Normally you could use model.transcribe() with logprobs=True to get the logits to pass to the loss function, however that doesn't provide the length of the actual encoded audio segment to pass to CTC loss. You can approximate it with original acoustic length after preprocessing // model stride (dependent on each model) and pass that to the CTC loss.
We will look into more useful ways of storing this information and providing it to users via transcribe(). But this approach should work in the mean time.
N…