From 3b3fc396eba49f08fbd8a541ecdce82773d71a8a Mon Sep 17 00:00:00 2001 From: Juliano Viana Date: Fri, 23 Sep 2022 14:42:21 -0400 Subject: [PATCH] implementing scoring task --- pkg/tasks/scoring/colbert/.gitignore | 1 + pkg/tasks/scoring/colbert/ranking.go | 130 ++++++++++++++++++++++ pkg/tasks/scoring/colbert/ranking_test.go | 72 ++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 pkg/tasks/scoring/colbert/.gitignore create mode 100644 pkg/tasks/scoring/colbert/ranking.go create mode 100644 pkg/tasks/scoring/colbert/ranking_test.go diff --git a/pkg/tasks/scoring/colbert/.gitignore b/pkg/tasks/scoring/colbert/.gitignore new file mode 100644 index 0000000..d383c56 --- /dev/null +++ b/pkg/tasks/scoring/colbert/.gitignore @@ -0,0 +1 @@ +testdata diff --git a/pkg/tasks/scoring/colbert/ranking.go b/pkg/tasks/scoring/colbert/ranking.go new file mode 100644 index 0000000..fcc0832 --- /dev/null +++ b/pkg/tasks/scoring/colbert/ranking.go @@ -0,0 +1,130 @@ +package colbert + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/nlpodyssey/cybertron/pkg/models/bert" + "github.com/nlpodyssey/cybertron/pkg/tokenizers" + "github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer" + "github.com/nlpodyssey/cybertron/pkg/vocabulary" + "github.com/nlpodyssey/spago/ag" + "github.com/nlpodyssey/spago/embeddings/store/diskstore" + "github.com/nlpodyssey/spago/nn" +) + +const SpecialDocumentMarker = "[unused1]" + +const SpecialQueryMarker = "[unused0]" + +const maxSentenceLength = 509 // 512 minus special tokens + +type DocumentScorer struct { + Model *bert.ColbertModel + Tokenizer *wordpiecetokenizer.WordPieceTokenizer +} + +func LoadDocumentScorer(modelPath string) (*DocumentScorer, error) { + vocab, err := vocabulary.NewFromFile(filepath.Join(modelPath, "vocab.txt")) + if err != nil { + return nil, fmt.Errorf("failed to load vocabulary: %w", err) + } + + tokenizer := wordpiecetokenizer.New(vocab) + + embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelPath, "repo"), diskstore.ReadOnlyMode) + if err != nil { + return nil, fmt.Errorf("failed to load embeddings repository: %w", err) + } + + m, err := nn.LoadFromFile[*bert.ColbertModel](path.Join(modelPath, "spago_model.bin")) + if err != nil { + return nil, fmt.Errorf("failed to load colbert model: %w", err) + } + + err = m.Bert.SetEmbeddings(embeddingsRepo) + if err != nil { + return nil, fmt.Errorf("failed to set embeddings: %w", err) + } + return &DocumentScorer{ + Model: m, + Tokenizer: tokenizer, + }, nil +} + +func (r *DocumentScorer) encode(text string, specialMarker string) []ag.Node { + tokens := r.Tokenizer.Tokenize(strings.ToLower(text)) + + stringTokens := tokenizers.GetStrings(tokens) + stringTokens = append([]string{wordpiecetokenizer.DefaultClassToken, specialMarker}, stringTokens...) + stringTokens = append(stringTokens, wordpiecetokenizer.DefaultSequenceSeparator) + embeddings := normalizeEmbeddings(r.Model.Forward(stringTokens)) + return filterEmbeddings(embeddings, stringTokens) +} + +func (r *DocumentScorer) EncodeDocument(text string) []ag.Node { + return r.encode(text, SpecialDocumentMarker) +} + +func (r *DocumentScorer) EncodeQuery(text string) []ag.Node { + return r.encode(text, SpecialQueryMarker) +} + +func (r *DocumentScorer) ScoreDocument(query []ag.Node, document []ag.Node) ag.Node { + var score ag.Node + score = ag.Scalar(0.0) + for i, q := range query { + if i < 3 || i > len(query)-1 { + continue // don't take special tokens into consideration + } + score = ag.Add(score, r.maxSimilarity(q, document)) + } + return score +} + +func (r *DocumentScorer) maxSimilarity(query ag.Node, document []ag.Node) ag.Node { + var max ag.Node + max = ag.Scalar(0.0) + for i, d := range document { + if i < 3 || i > len(document)-1 { + continue // don't take special tokens into consideration + } + sim := ag.Dot(query, d) + max = ag.Max(max, sim) + } + return max +} + +func normalizeEmbeddings(embeddings []ag.Node) []ag.Node { + // Perform l2 normalization of each embedding + normalized := make([]ag.Node, len(embeddings)) + for i, e := range embeddings { + normalized[i] = ag.DivScalar(e, ag.Sqrt(ag.ReduceSum(ag.Square(e)))) + } + return normalized +} + +func isPunctuation(token string) bool { + return token == "." || token == "," || token == "!" || token == "?" || + token == ":" || token == ";" || token == "-" || token == "'" || + token == "\"" || token == "(" || token == ")" || token == "[" || + token == "]" || token == "{" || token == "}" || token == "*" || + token == "&" || token == "%" || token == "$" || token == "#" || + token == "@" || token == "=" || token == "+" || + token == "_" || token == "~" || token == "/" || token == "\\" || + token == "|" || token == "`" || token == "^" || token == ">" || + token == "<" +} + +func filterEmbeddings(embeddings []ag.Node, tokens []string) []ag.Node { + filtered := make([]ag.Node, 0, len(embeddings)) + for i, e := range embeddings { + if isPunctuation(tokens[i]) { + continue + } + filtered = append(filtered, e) + } + return filtered +} diff --git a/pkg/tasks/scoring/colbert/ranking_test.go b/pkg/tasks/scoring/colbert/ranking_test.go new file mode 100644 index 0000000..2fc7cbf --- /dev/null +++ b/pkg/tasks/scoring/colbert/ranking_test.go @@ -0,0 +1,72 @@ +package colbert + +import ( + "sort" + "testing" + + "github.com/nlpodyssey/spago/ag" + "github.com/stretchr/testify/require" +) + +func TestDocumentScorer_ScoreDocument(t *testing.T) { + + tests := []struct { + name string + query string + documents []string + wantRanking []int + wantScores []float64 + }{ + { + name: "test1", + query: "hello world", + documents: []string{"hello world"}, + wantRanking: []int{0}, + wantScores: []float64{1.0}, + }, + { + name: "test2", + query: "In which year was the first iPhone released?", + documents: []string{"The first Nokia phone was released in 1987.", + "The iPhone 3G was released in 2008.", + "The original iPhone was first sold in 2007."}, + wantRanking: []int{2, 0, 1}, + }, + } + // Set the directory where the colbert model is stored here: + ColbertModelDir := "testdata/colbert" + + scorer, err := LoadDocumentScorer(ColbertModelDir) + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := scorer.EncodeQuery(tt.query) + var scores []float64 + for i, doc := range tt.documents { + document := scorer.EncodeDocument(doc) + score := scorer.ScoreDocument(query, document) + // Normalize the score by the length of the non-special tokens of query + // (this is not in the original paper btw, but it makes sense to me) + score = ag.Div(score, ag.Scalar(float64(len(query)-3))) + if tt.wantScores != nil { + require.InDelta(t, tt.wantScores[i], score.Value().Data().F64()[0], 0.01) + } + scores = append(scores, score.Value().Data().F64()[0]) + } + ranking := rank(scores) + require.Equal(t, tt.wantRanking, ranking) + }) + } +} + +func rank(scores []float64) []int { + var ranking []int + for i := range scores { + ranking = append(ranking, i) + } + sort.SliceStable(ranking, func(i, j int) bool { + return scores[ranking[i]] > scores[ranking[j]] + }) + return ranking +}