-
Notifications
You must be signed in to change notification settings - Fork 2
/
system_test.go
115 lines (100 loc) · 3 KB
/
system_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package sentencepiece
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"slices"
"strconv"
"testing"
)
// "System" test for comparing our Procesor with the canonical sentencepiece
// Python package (officially distributed with the original C++ implementation
// of the algorithm).
// It also runs Decode for a round-trip test to ensure we get the original
// text back.
//
// This test will only run if python3 is available and is able to successfully
// load the sentencepiece library. Typically this means that 'go test' will
// have to run from an activated Python virtual environment where the library
// was installed.
func TestVsSentencepiecePython(t *testing.T) {
proc := createProcessor(t)
if _, err := exec.Command("python3", "-c", "import sentencepiece").Output(); err != nil {
t.Skip("This test only runs when python3 with sentencepiece is available")
}
pyProgramPath := filepath.Join("test", "sp-dump-ids.py")
paths, err := filepath.Glob(filepath.Join("test", "*.txt"))
if err != nil {
t.Fatal(err)
}
for _, path := range paths {
_, filename := filepath.Split(path)
testname := filename[:len(filename)-len(filepath.Ext(path))]
t.Run(testname, func(t *testing.T) {
// Step 1: run the Python program to tokenize path into IDs.
pyOut, err := exec.Command("python3", pyProgramPath, path).Output()
if err != nil {
t.Fatalf("while running %v on %v: %v", pyProgramPath, path, err)
}
pyIDs := pyOutToIDs(pyOut)
// Step 2: use our Processor to tokenize path into IDs.
buf, err := ioutil.ReadFile(path)
if err != nil {
log.Fatal(err)
}
text := string(buf)
var goIDs []int
goTokens := proc.Encode(text)
for _, t := range goTokens {
goIDs = append(goIDs, t.ID)
}
// Step 3: compare the two; dump IDs to temp files for debugging in case
// of a mismatch.
if !slices.Equal(pyIDs, goIDs) {
tmppy := dumpIDsToTempFile(testname+"-py-", pyIDs)
tmpgo := dumpIDsToTempFile(testname+"-go-", goIDs)
t.Errorf("IDs mismatch; dumped to %q and %q", tmppy, tmpgo)
}
// Step 4: round-trip Decode to get original text back
newText := proc.Decode(goIDs)
if text != newText {
t.Errorf("text mismatch after Decode")
}
})
}
}
// pyOutToIDs takes the entire stdout output of the Python program and parses
// it into a list of integer IDs.
func pyOutToIDs(pyOut []byte) []int {
var IDs []int
scanner := bufio.NewScanner(bytes.NewReader(pyOut))
for scanner.Scan() {
i, err := strconv.Atoi(scanner.Text())
if err != nil {
log.Fatal(err)
}
IDs = append(IDs, i)
}
if err := scanner.Err(); err != nil {
log.Fatal(err)
}
return IDs
}
// dumpIDsToTempFile dumps the given IDs (one per line) to a temporary file with
// the given prefix, and returns the name of the temporary file.
func dumpIDsToTempFile(prefix string, IDs []int) string {
tf, err := os.CreateTemp("", prefix)
if err != nil {
log.Fatal(err)
}
defer tf.Close()
for _, id := range IDs {
fmt.Fprintf(tf, "%d\n", id)
}
return tf.Name()
}