Skip to content

Commit

Permalink
Return a custom error from BART text2text if sequence is too long
Browse files Browse the repository at this point in the history
instead of panicking
  • Loading branch information
marco-nicola committed Oct 17, 2022
1 parent d993d4b commit 0876860
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pkg/tasks/text2text/bart/text2text.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ func (m *Text2Text) Generate(ctx context.Context, text string, opts *text2text.O
if err != nil {
return text2text.Response{}, err
}

if l, max := len(tokenized), m.Model.Bart.Config.MaxLength; l > max {
return text2text.Response{}, fmt.Errorf("%w: %d > %d", text2text.ErrInputSequenceTooLong, l, max)
}

sequences, scores := m.process(ctx, tokenized, *opts)
result := text2text.Response{
Texts: make([]string, len(sequences)),
Expand Down
5 changes: 5 additions & 0 deletions pkg/tasks/text2text/text2text.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package text2text

import (
"context"
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -76,6 +77,10 @@ type Response struct {
Scores []float64
}

// ErrInputSequenceTooLong means that pre-processing the input text
// produced a sequence that exceeds the maximum allowed length.
var ErrInputSequenceTooLong = errors.New("sequence too long")

// DefaultOptions returns the default options for generating text.
func DefaultOptions() *Options {
return &Options{
Expand Down

0 comments on commit 0876860

Please sign in to comment.