Skip to content

Commit

Permalink
implementing scoring task
Browse files Browse the repository at this point in the history
  • Loading branch information
jjviana committed Sep 23, 2022
1 parent 3b2ae8f commit 3b3fc39
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
1 change: 1 addition & 0 deletions pkg/tasks/scoring/colbert/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
testdata
130 changes: 130 additions & 0 deletions pkg/tasks/scoring/colbert/ranking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package colbert

import (
"fmt"
"path"
"path/filepath"
"strings"

"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tokenizers"
"github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer"
"github.com/nlpodyssey/cybertron/pkg/vocabulary"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/embeddings/store/diskstore"
"github.com/nlpodyssey/spago/nn"
)

const SpecialDocumentMarker = "[unused1]"

const SpecialQueryMarker = "[unused0]"

const maxSentenceLength = 509 // 512 minus special tokens

type DocumentScorer struct {
Model *bert.ColbertModel
Tokenizer *wordpiecetokenizer.WordPieceTokenizer
}

func LoadDocumentScorer(modelPath string) (*DocumentScorer, error) {
vocab, err := vocabulary.NewFromFile(filepath.Join(modelPath, "vocab.txt"))
if err != nil {
return nil, fmt.Errorf("failed to load vocabulary: %w", err)
}

tokenizer := wordpiecetokenizer.New(vocab)

embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelPath, "repo"), diskstore.ReadOnlyMode)
if err != nil {
return nil, fmt.Errorf("failed to load embeddings repository: %w", err)
}

m, err := nn.LoadFromFile[*bert.ColbertModel](path.Join(modelPath, "spago_model.bin"))
if err != nil {
return nil, fmt.Errorf("failed to load colbert model: %w", err)
}

err = m.Bert.SetEmbeddings(embeddingsRepo)
if err != nil {
return nil, fmt.Errorf("failed to set embeddings: %w", err)
}
return &DocumentScorer{
Model: m,
Tokenizer: tokenizer,
}, nil
}

func (r *DocumentScorer) encode(text string, specialMarker string) []ag.Node {
tokens := r.Tokenizer.Tokenize(strings.ToLower(text))

stringTokens := tokenizers.GetStrings(tokens)
stringTokens = append([]string{wordpiecetokenizer.DefaultClassToken, specialMarker}, stringTokens...)
stringTokens = append(stringTokens, wordpiecetokenizer.DefaultSequenceSeparator)
embeddings := normalizeEmbeddings(r.Model.Forward(stringTokens))
return filterEmbeddings(embeddings, stringTokens)
}

func (r *DocumentScorer) EncodeDocument(text string) []ag.Node {
return r.encode(text, SpecialDocumentMarker)
}

func (r *DocumentScorer) EncodeQuery(text string) []ag.Node {
return r.encode(text, SpecialQueryMarker)
}

func (r *DocumentScorer) ScoreDocument(query []ag.Node, document []ag.Node) ag.Node {
var score ag.Node
score = ag.Scalar(0.0)
for i, q := range query {
if i < 3 || i > len(query)-1 {
continue // don't take special tokens into consideration
}
score = ag.Add(score, r.maxSimilarity(q, document))
}
return score
}

func (r *DocumentScorer) maxSimilarity(query ag.Node, document []ag.Node) ag.Node {
var max ag.Node
max = ag.Scalar(0.0)
for i, d := range document {
if i < 3 || i > len(document)-1 {
continue // don't take special tokens into consideration
}
sim := ag.Dot(query, d)
max = ag.Max(max, sim)
}
return max
}

func normalizeEmbeddings(embeddings []ag.Node) []ag.Node {
// Perform l2 normalization of each embedding
normalized := make([]ag.Node, len(embeddings))
for i, e := range embeddings {
normalized[i] = ag.DivScalar(e, ag.Sqrt(ag.ReduceSum(ag.Square(e))))
}
return normalized
}

func isPunctuation(token string) bool {
return token == "." || token == "," || token == "!" || token == "?" ||
token == ":" || token == ";" || token == "-" || token == "'" ||
token == "\"" || token == "(" || token == ")" || token == "[" ||
token == "]" || token == "{" || token == "}" || token == "*" ||
token == "&" || token == "%" || token == "$" || token == "#" ||
token == "@" || token == "=" || token == "+" ||
token == "_" || token == "~" || token == "/" || token == "\\" ||
token == "|" || token == "`" || token == "^" || token == ">" ||
token == "<"
}

func filterEmbeddings(embeddings []ag.Node, tokens []string) []ag.Node {
filtered := make([]ag.Node, 0, len(embeddings))
for i, e := range embeddings {
if isPunctuation(tokens[i]) {
continue
}
filtered = append(filtered, e)
}
return filtered
}
72 changes: 72 additions & 0 deletions pkg/tasks/scoring/colbert/ranking_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package colbert

import (
"sort"
"testing"

"github.com/nlpodyssey/spago/ag"
"github.com/stretchr/testify/require"
)

func TestDocumentScorer_ScoreDocument(t *testing.T) {

tests := []struct {
name string
query string
documents []string
wantRanking []int
wantScores []float64
}{
{
name: "test1",
query: "hello world",
documents: []string{"hello world"},
wantRanking: []int{0},
wantScores: []float64{1.0},
},
{
name: "test2",
query: "In which year was the first iPhone released?",
documents: []string{"The first Nokia phone was released in 1987.",
"The iPhone 3G was released in 2008.",
"The original iPhone was first sold in 2007."},
wantRanking: []int{2, 0, 1},
},
}
// Set the directory where the colbert model is stored here:
ColbertModelDir := "testdata/colbert"

scorer, err := LoadDocumentScorer(ColbertModelDir)
require.NoError(t, err)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := scorer.EncodeQuery(tt.query)
var scores []float64
for i, doc := range tt.documents {
document := scorer.EncodeDocument(doc)
score := scorer.ScoreDocument(query, document)
// Normalize the score by the length of the non-special tokens of query
// (this is not in the original paper btw, but it makes sense to me)
score = ag.Div(score, ag.Scalar(float64(len(query)-3)))
if tt.wantScores != nil {
require.InDelta(t, tt.wantScores[i], score.Value().Data().F64()[0], 0.01)
}
scores = append(scores, score.Value().Data().F64()[0])
}
ranking := rank(scores)
require.Equal(t, tt.wantRanking, ranking)
})
}
}

func rank(scores []float64) []int {
var ranking []int
for i := range scores {
ranking = append(ranking, i)
}
sort.SliceStable(ranking, func(i, j int) bool {
return scores[ranking[i]] > scores[ranking[j]]
})
return ranking
}

0 comments on commit 3b3fc39

Please sign in to comment.