-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
testdata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package colbert | ||
|
||
import ( | ||
"fmt" | ||
"path" | ||
"path/filepath" | ||
"strings" | ||
|
||
"github.com/nlpodyssey/cybertron/pkg/models/bert" | ||
"github.com/nlpodyssey/cybertron/pkg/tokenizers" | ||
"github.com/nlpodyssey/cybertron/pkg/tokenizers/wordpiecetokenizer" | ||
"github.com/nlpodyssey/cybertron/pkg/vocabulary" | ||
"github.com/nlpodyssey/spago/ag" | ||
"github.com/nlpodyssey/spago/embeddings/store/diskstore" | ||
"github.com/nlpodyssey/spago/nn" | ||
) | ||
|
||
const SpecialDocumentMarker = "[unused1]" | ||
|
||
const SpecialQueryMarker = "[unused0]" | ||
|
||
const maxSentenceLength = 509 // 512 minus special tokens | ||
|
||
type DocumentScorer struct { | ||
Model *bert.ColbertModel | ||
Tokenizer *wordpiecetokenizer.WordPieceTokenizer | ||
} | ||
|
||
func LoadDocumentScorer(modelPath string) (*DocumentScorer, error) { | ||
vocab, err := vocabulary.NewFromFile(filepath.Join(modelPath, "vocab.txt")) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to load vocabulary: %w", err) | ||
} | ||
|
||
tokenizer := wordpiecetokenizer.New(vocab) | ||
|
||
embeddingsRepo, err := diskstore.NewRepository(filepath.Join(modelPath, "repo"), diskstore.ReadOnlyMode) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to load embeddings repository: %w", err) | ||
} | ||
|
||
m, err := nn.LoadFromFile[*bert.ColbertModel](path.Join(modelPath, "spago_model.bin")) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to load colbert model: %w", err) | ||
} | ||
|
||
err = m.Bert.SetEmbeddings(embeddingsRepo) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to set embeddings: %w", err) | ||
} | ||
return &DocumentScorer{ | ||
Model: m, | ||
Tokenizer: tokenizer, | ||
}, nil | ||
} | ||
|
||
func (r *DocumentScorer) encode(text string, specialMarker string) []ag.Node { | ||
tokens := r.Tokenizer.Tokenize(strings.ToLower(text)) | ||
|
||
stringTokens := tokenizers.GetStrings(tokens) | ||
stringTokens = append([]string{wordpiecetokenizer.DefaultClassToken, specialMarker}, stringTokens...) | ||
stringTokens = append(stringTokens, wordpiecetokenizer.DefaultSequenceSeparator) | ||
embeddings := normalizeEmbeddings(r.Model.Forward(stringTokens)) | ||
return filterEmbeddings(embeddings, stringTokens) | ||
} | ||
|
||
func (r *DocumentScorer) EncodeDocument(text string) []ag.Node { | ||
return r.encode(text, SpecialDocumentMarker) | ||
} | ||
|
||
func (r *DocumentScorer) EncodeQuery(text string) []ag.Node { | ||
return r.encode(text, SpecialQueryMarker) | ||
} | ||
|
||
func (r *DocumentScorer) ScoreDocument(query []ag.Node, document []ag.Node) ag.Node { | ||
var score ag.Node | ||
score = ag.Scalar(0.0) | ||
for i, q := range query { | ||
if i < 3 || i > len(query)-1 { | ||
continue // don't take special tokens into consideration | ||
} | ||
score = ag.Add(score, r.maxSimilarity(q, document)) | ||
} | ||
return score | ||
} | ||
|
||
func (r *DocumentScorer) maxSimilarity(query ag.Node, document []ag.Node) ag.Node { | ||
var max ag.Node | ||
max = ag.Scalar(0.0) | ||
for i, d := range document { | ||
if i < 3 || i > len(document)-1 { | ||
continue // don't take special tokens into consideration | ||
} | ||
sim := ag.Dot(query, d) | ||
max = ag.Max(max, sim) | ||
} | ||
return max | ||
} | ||
|
||
func normalizeEmbeddings(embeddings []ag.Node) []ag.Node { | ||
// Perform l2 normalization of each embedding | ||
normalized := make([]ag.Node, len(embeddings)) | ||
for i, e := range embeddings { | ||
normalized[i] = ag.DivScalar(e, ag.Sqrt(ag.ReduceSum(ag.Square(e)))) | ||
} | ||
return normalized | ||
} | ||
|
||
func isPunctuation(token string) bool { | ||
return token == "." || token == "," || token == "!" || token == "?" || | ||
token == ":" || token == ";" || token == "-" || token == "'" || | ||
token == "\"" || token == "(" || token == ")" || token == "[" || | ||
token == "]" || token == "{" || token == "}" || token == "*" || | ||
token == "&" || token == "%" || token == "$" || token == "#" || | ||
token == "@" || token == "=" || token == "+" || | ||
token == "_" || token == "~" || token == "/" || token == "\\" || | ||
token == "|" || token == "`" || token == "^" || token == ">" || | ||
token == "<" | ||
} | ||
|
||
func filterEmbeddings(embeddings []ag.Node, tokens []string) []ag.Node { | ||
filtered := make([]ag.Node, 0, len(embeddings)) | ||
for i, e := range embeddings { | ||
if isPunctuation(tokens[i]) { | ||
continue | ||
} | ||
filtered = append(filtered, e) | ||
} | ||
return filtered | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package colbert | ||
|
||
import ( | ||
"sort" | ||
"testing" | ||
|
||
"github.com/nlpodyssey/spago/ag" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestDocumentScorer_ScoreDocument(t *testing.T) { | ||
|
||
tests := []struct { | ||
name string | ||
query string | ||
documents []string | ||
wantRanking []int | ||
wantScores []float64 | ||
}{ | ||
{ | ||
name: "test1", | ||
query: "hello world", | ||
documents: []string{"hello world"}, | ||
wantRanking: []int{0}, | ||
wantScores: []float64{1.0}, | ||
}, | ||
{ | ||
name: "test2", | ||
query: "In which year was the first iPhone released?", | ||
documents: []string{"The first Nokia phone was released in 1987.", | ||
"The iPhone 3G was released in 2008.", | ||
"The original iPhone was first sold in 2007."}, | ||
wantRanking: []int{2, 0, 1}, | ||
}, | ||
} | ||
// Set the directory where the colbert model is stored here: | ||
ColbertModelDir := "testdata/colbert" | ||
|
||
scorer, err := LoadDocumentScorer(ColbertModelDir) | ||
require.NoError(t, err) | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
query := scorer.EncodeQuery(tt.query) | ||
var scores []float64 | ||
for i, doc := range tt.documents { | ||
document := scorer.EncodeDocument(doc) | ||
score := scorer.ScoreDocument(query, document) | ||
// Normalize the score by the length of the non-special tokens of query | ||
// (this is not in the original paper btw, but it makes sense to me) | ||
score = ag.Div(score, ag.Scalar(float64(len(query)-3))) | ||
if tt.wantScores != nil { | ||
require.InDelta(t, tt.wantScores[i], score.Value().Data().F64()[0], 0.01) | ||
} | ||
scores = append(scores, score.Value().Data().F64()[0]) | ||
} | ||
ranking := rank(scores) | ||
require.Equal(t, tt.wantRanking, ranking) | ||
}) | ||
} | ||
} | ||
|
||
func rank(scores []float64) []int { | ||
var ranking []int | ||
for i := range scores { | ||
ranking = append(ranking, i) | ||
} | ||
sort.SliceStable(ranking, func(i, j int) bool { | ||
return scores[ranking[i]] > scores[ranking[j]] | ||
}) | ||
return ranking | ||
} |