diff --git a/.travis.yml b/.travis.yml index 19b7c1c..87176e3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,7 @@ language: go go: - "1.14.x" +- "1.15.x" services: - docker diff --git a/cmd/README.md b/cmd/README.md deleted file mode 100644 index 12fbd87..0000000 --- a/cmd/README.md +++ /dev/null @@ -1,187 +0,0 @@ -# Model - -## Word2Vec - -Word2Vec is composed of the following modules: - -Model: -- Skip-Gram -- CBOW - -Optimizer: -- Hierarchical Softmax -- Negative Sampling - -### Usage - -``` -Word2Vec: Continuous Bag-of-Words and Skip-gram model - -Usage: - wego word2vec [flags] - -Flags: - --batchSize int interval word size to update learning rate (default 10000) - -d, --dimension int dimension of word vector (default 10) - -h, --help help for word2vec - --initlr float initial learning rate (default 0.025) - -i, --inputFile string input file path for corpus (default "example/input.txt") - --iter int number of iteration (default 15) - --lower whether the words on corpus convert to lowercase or not - --maxDepth int times to track huffman tree, max-depth=0 means to track full path from root to word (for hierarchical softmax only) - --min-count int lower limit to filter rare words (default 5) - --model string which model does it use? one of: cbow|skip-gram (default "cbow") - --optimizer string which optimizer does it use? one of: hs|ns (default "hs") - -o, --outputFile string output file path to save word vectors (default "example/word_vectors.txt") - --prof profiling mode to check the performances - --sample int negative sample size(for negative sampling only) (default 5) - --theta float lower limit of learning rate (lr >= initlr * theta) (default 0.0001) - --thread int number of goroutine (default 8) - --threshold float threshold for subsampling (default 0.001) - --verbose verbose mode - -w, --window int context window size (default 5) -``` - -## GloVe - -GloVe is weighted matrix factorization model for co-occurrence map between words. - -### Usage - -``` -GloVe: Global Vectors for Word Representation - -Usage: - wego glove [flags] - -Flags: - --alpha float exponent of weighting function (default 0.75) - -d, --dimension int dimension of word vector (default 10) - -h, --help help for glove - --initlr float initial learning rate (default 0.025) - -i, --inputFile string input file path for corpus (default "example/input.txt") - --iter int number of iteration (default 15) - --lower whether the words on corpus convert to lowercase or not - --min-count int lower limit to filter rare words (default 5) - -o, --outputFile string output file path to save word vectors (default "example/word_vectors.txt") - --prof profiling mode to check the performances - --solver string solver for GloVe objective. One of: sgd|adagrad (default "sgd") - --thread int number of goroutine (default 8) - --verbose verbose mode - -w, --window int context window size (default 5) - --xmax int specifying cutoff in weighting function (default 100) -``` - -## Lexvec - -### Usage - -``` -Lexvec: Matrix Factorization using Window Sampling and Negative Sampling for Improved Word Representations - -Usage: - wego lexvec [flags] - -Flags: - --batchSize int interval word size to update learning rate (default 10000) - -d, --dimension int dimension of word vector (default 10) - -h, --help help for lexvec - --initlr float initial learning rate (default 0.025) - -i, --inputFile string input file path for corpus (default "example/input.txt") - --iter int number of iteration (default 15) - --lower whether the words on corpus convert to lowercase or not - --min-count int lower limit to filter rare words (default 5) - -o, --outputFile string output file path to save word vectors (default "example/word_vectors.txt") - --prof profiling mode to check the performances - --rel string relation type for counting co-occurrence. One of ppmi|pmi|co|logco (default "ppmi") - --sample int negative sample size(for negative sampling only) (default 5) - --save-vec string save vector type. One of: normal|add (default "normal") - --smooth float smoothing value (default 0.75) - --theta float lower limit of learning rate (lr >= initlr * theta) (default 0.0001) - --thread int number of goroutine (default 12) - --verbose verbose mode - -w, --window int context window size (default 5) -``` - -# Search - -Similarity search between word vectors. - -## Usage - -``` -Search similar words - -Usage: - wego search [flags] - -Examples: - wego search -i example/word_vectors.txt microsoft - -Flags: - -h, --help help for search - -i, --inputFile string input file path for trained word vector (default "example/input.txt") - -r, --rank int how many the most similar words will be displayed (default 10) -``` - -## Example - -``` -$ go run wego.go search -i example/word_vectors_sg.txt microsoft - RANK | WORD | SIMILARITY -+------+------------+------------+ - 1 | apple | 0.994008 - 2 | operating | 0.992855 - 3 | versions | 0.992800 - 4 | ibm | 0.992232 - 5 | os | 0.989174 - 6 | computers | 0.988998 - 7 | machines | 0.988804 - 8 | dvd | 0.988732 - 9 | cd | 0.988259 - 10 | compatible | 0.988200 -``` - -# REPL for search - -Similarity search between word vectors with REPL mode. - -## Usage - -``` -Search similar words with REPL mode - -Usage: - wego repl [flags] - -Examples: - wego repl -i example/word_vectors.txt - >> apple + banana - ... - -Flags: - -h, --help help for repl - -i, --inputFile string input file path for trained word vector (default "example/word_vectors.txt") - -r, --rank int how many the most similar words will be displayed (default 10) -``` - -## Example - -Now, it is able to use `+`, `-` for arithmetic operations. - -``` -$ go run wego.go repl -i example/word_vectors_sg.txt ->> a + b - RANK | WORD | SIMILARITY -+------+---------+------------+ - 1 | phi | 0.907975 - 2 | q | 0.904593 - 3 | mathbf | 0.903066 - 4 | cdot | 0.902205 - 5 | b | 0.901952 - 6 | becomes | 0.900346 - 7 | int | 0.898680 - 8 | z | 0.897895 - 9 | named | 0.896480 - 10 | v | 0.895456 -``` diff --git a/cmd/model/glove/glove.go b/cmd/model/glove/glove.go index 9a80ca0..e40fe60 100644 --- a/cmd/model/glove/glove.go +++ b/cmd/model/glove/glove.go @@ -23,9 +23,6 @@ import ( "github.com/spf13/cobra" "github.com/ynqa/wego/cmd/model/cmdutil" - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/glove" "github.com/ynqa/wego/pkg/model/modelutil/save" ) @@ -51,10 +48,6 @@ func New() *cobra.Command { cmdutil.AddOutputFlags(cmd, &outputFile) cmdutil.AddProfFlags(cmd, &prof) cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) - - corpus.LoadForCmd(cmd, &opts.CorpusOptions) - pairwise.LoadForCmd(cmd, &opts.PairwiseOptions) - model.LoadForCmd(cmd, &opts.ModelOptions) glove.LoadForCmd(cmd, &opts) return cmd } diff --git a/cmd/model/lexvec/lexvec.go b/cmd/model/lexvec/lexvec.go index cae534b..75dcd57 100644 --- a/cmd/model/lexvec/lexvec.go +++ b/cmd/model/lexvec/lexvec.go @@ -23,8 +23,6 @@ import ( "github.com/spf13/cobra" "github.com/ynqa/wego/cmd/model/cmdutil" - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/lexvec" "github.com/ynqa/wego/pkg/model/modelutil/save" ) @@ -50,9 +48,6 @@ func New() *cobra.Command { cmdutil.AddOutputFlags(cmd, &outputFile) cmdutil.AddProfFlags(cmd, &prof) cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) - - corpus.LoadForCmd(cmd, &opts.CorpusOptions) - model.LoadForCmd(cmd, &opts.ModelOptions) lexvec.LoadForCmd(cmd, &opts) return cmd } diff --git a/cmd/model/word2vec/word2vec.go b/cmd/model/word2vec/word2vec.go index c63e617..11d7829 100644 --- a/cmd/model/word2vec/word2vec.go +++ b/cmd/model/word2vec/word2vec.go @@ -23,8 +23,6 @@ import ( "github.com/spf13/cobra" "github.com/ynqa/wego/cmd/model/cmdutil" - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/modelutil/save" "github.com/ynqa/wego/pkg/model/word2vec" ) @@ -50,9 +48,6 @@ func New() *cobra.Command { cmdutil.AddOutputFlags(cmd, &outputFile) cmdutil.AddProfFlags(cmd, &prof) cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) - - corpus.LoadForCmd(cmd, &opts.CorpusOptions) - model.LoadForCmd(cmd, &opts.ModelOptions) word2vec.LoadForCmd(cmd, &opts) return cmd } diff --git a/cmd/search/cmdutil/cmdutil.go b/cmd/query/cmdutil/cmdutil.go similarity index 100% rename from cmd/search/cmdutil/cmdutil.go rename to cmd/query/cmdutil/cmdutil.go diff --git a/cmd/search/repl/repl.go b/cmd/query/console/console.go similarity index 81% rename from cmd/search/repl/repl.go rename to cmd/query/console/console.go index 80fa2b6..19f1823 100644 --- a/cmd/search/repl/repl.go +++ b/cmd/query/console/console.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package repl +package console import ( "os" @@ -20,10 +20,10 @@ import ( "github.com/pkg/errors" "github.com/spf13/cobra" - "github.com/ynqa/wego/cmd/search/cmdutil" + "github.com/ynqa/wego/cmd/query/cmdutil" "github.com/ynqa/wego/pkg/embedding" "github.com/ynqa/wego/pkg/search" - "github.com/ynqa/wego/pkg/search/repl" + "github.com/ynqa/wego/pkg/search/console" ) var ( @@ -33,10 +33,9 @@ var ( func New() *cobra.Command { cmd := &cobra.Command{ - Use: "search-repl", - Short: "Search similar words (REPL mode)", - Long: "Search similar words (REPL mode)", - Example: " wego search-repl -i example/word_vectors.txt\n" + + Use: "console", + Short: "Console to investigate word vectors", + Example: " wego console -i example/word_vectors.txt\n" + " >> apple + banana\n" + " ...", RunE: func(cmd *cobra.Command, args []string) error { @@ -70,9 +69,9 @@ func execute() error { if err != nil { return err } - repl, err := repl.New(searcher, rank) + console, err := console.New(searcher, rank) if err != nil { return err } - return repl.Run() + return console.Run() } diff --git a/cmd/search/search.go b/cmd/query/query.go similarity index 88% rename from cmd/search/search.go rename to cmd/query/query.go index 834c72f..c4493c6 100644 --- a/cmd/search/search.go +++ b/cmd/query/query.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package search +package query import ( "os" @@ -20,7 +20,7 @@ import ( "github.com/pkg/errors" "github.com/spf13/cobra" - "github.com/ynqa/wego/cmd/search/cmdutil" + "github.com/ynqa/wego/cmd/query/cmdutil" "github.com/ynqa/wego/pkg/embedding" "github.com/ynqa/wego/pkg/search" ) @@ -32,10 +32,9 @@ var ( func New() *cobra.Command { cmd := &cobra.Command{ - Use: "search", - Short: "Search similar words", - Long: "Search similar words", - Example: " wego search -i example/word_vectors.txt microsoft", + Use: "query", + Short: "Query similar words", + Example: " wego query -i example/word_vectors.txt microsoft", RunE: func(cmd *cobra.Command, args []string) error { return execute(args) }, diff --git a/examples/word2vec/main.go b/examples/word2vec/main.go index c160229..696033a 100644 --- a/examples/word2vec/main.go +++ b/examples/word2vec/main.go @@ -23,10 +23,10 @@ import ( func main() { model, err := word2vec.New( - word2vec.WithWindow(5), - word2vec.WithModel(word2vec.Cbow), - word2vec.WithOptimizer(word2vec.NegativeSampling), - word2vec.WithNegativeSampleSize(5), + word2vec.Window(5), + word2vec.Model(word2vec.Cbow), + word2vec.Optimizer(word2vec.NegativeSampling), + word2vec.NegativeSampleSize(5), word2vec.Verbose(), ) if err != nil { diff --git a/go.mod b/go.mod index 80cbc9c..7e6f35d 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,15 @@ module github.com/ynqa/wego -go 1.13 +go 1.15 require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/inconshreveable/mousetrap v1.0.0 // indirect - github.com/kr/pretty v0.1.0 // indirect - github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57 - github.com/peterh/liner v1.1.0 - github.com/pkg/errors v0.8.0 - github.com/spf13/cobra v0.0.1 - github.com/spf13/pflag v1.0.3 // indirect - github.com/stretchr/testify v1.4.0 - golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect - gopkg.in/yaml.v2 v2.2.4 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect + github.com/olekukonko/tablewriter v0.0.4 + github.com/peterh/liner v1.2.0 + github.com/pkg/errors v0.9.1 + github.com/spf13/cobra v1.0.0 + github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.6.1 + golang.org/x/sync v0.0.0-20200930132711-30421366ff76 + gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect ) diff --git a/go.sum b/go.sum index 78dfae3..50a6310 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,161 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= +github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mattn/go-runewidth v0.0.3 h1:a+kO+98RDGEfo6asOGMmpodZq4FNtnGP54yps8BzLR4= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57 h1:c6g+iEoim6VD2DGy2utQoryQMVNndSvYm/YfGjc5A/o= -github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/peterh/liner v1.1.0 h1:f+aAedNJA6uk7+6rXsYBnhdo4Xux7ESLe+kcuVUF5os= -github.com/peterh/liner v1.1.0/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= +github.com/mattn/go-runewidth v0.0.7 h1:Ei8KR0497xHyKJPAv59M1dkC+rOZCMBJ+t3fZ+twI54= +github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/olekukonko/tablewriter v0.0.4 h1:vHD/YYe1Wolo78koG299f7V/VAS08c6IpCLn+Ejf/w8= +github.com/olekukonko/tablewriter v0.0.4/go.mod h1:zq6QwlOf5SlnkVbMSr5EoBv3636FWnp+qbPhuoO21uA= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/peterh/liner v1.2.0 h1:w/UPXyl5GfahFxcTOz2j9wCIHNI+pUPr2laqpojKNCg= +github.com/peterh/liner v1.2.0/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spf13/cobra v0.0.1 h1:zZh3X5aZbdnoj+4XkaBxKfhO4ot82icYdhhREIAXIj8= -github.com/spf13/cobra v0.0.1/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= +github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v1.0.0 h1:6m/oheQuQ13N9ks4hubMG6BnvwOeaJrqSPLahSnczz8= +github.com/spf13/cobra v1.0.0/go.mod h1:/6GTrnGXV9HjY+aR4k0oJ5tcvakLuG6EuKReYlHNrgE= +github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200930132711-30421366ff76 h1:JnxiSYT3Nm0BT2a8CyvYyM6cnrWpidecD1UuSYbhKm0= +golang.org/x/sync v0.0.0-20200930132711-30421366ff76/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/corpus/cooccurrence/cooccurrence.go b/pkg/corpus/cooccurrence/cooccurrence.go new file mode 100644 index 0000000..c44fe5c --- /dev/null +++ b/pkg/corpus/cooccurrence/cooccurrence.go @@ -0,0 +1,99 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package co + +import ( + "fmt" + "math" + + "github.com/pkg/errors" + + "github.com/ynqa/wego/pkg/corpus/cooccurrence/encode" +) + +const ( + DefaultCountType = Increment +) + +func invalidCountTypeError(typ CountType) error { + return fmt.Errorf("invalid relation type: %s not in %s|%s", typ, Increment, Proximity) +} + +type CountType string + +func (t *CountType) String() string { + if *t == CountType("") { + *t = DefaultCountType + } + return string(*t) +} + +func (t *CountType) Set(name string) error { + typ := CountType(name) + if typ == Increment || typ == Proximity { + *t = typ + return nil + } + return invalidCountTypeError(typ) +} + +func (t *CountType) Type() string { + return t.String() +} + +const ( + Increment CountType = "inc" + Proximity CountType = "prox" +) + +type Cooccurrence struct { + typ CountType + + ma map[uint64]float64 +} + +func New(typ CountType) (*Cooccurrence, error) { + if typ != Increment && typ != Proximity { + return nil, invalidCountTypeError(typ) + } + return &Cooccurrence{ + typ: typ, + + ma: make(map[uint64]float64), + }, nil +} + +func (c *Cooccurrence) EncodedMatrix() map[uint64]float64 { + return c.ma +} + +func (c *Cooccurrence) Add(left, right int) error { + enc := encode.EncodeBigram(uint64(left), uint64(right)) + var val float64 + switch c.typ { + case Increment: + val = 1 + case Proximity: + div := left - right + if div == 0 { + return errors.Errorf("Divide by zero on counting co-occurrence") + } + val = 1. / math.Abs(float64(div)) + default: + return invalidCountTypeError(c.typ) + } + c.ma[enc] += val + return nil +} diff --git a/pkg/corpus/cooccurrence/cooccurrence_test.go b/pkg/corpus/cooccurrence/cooccurrence_test.go new file mode 100644 index 0000000..e955da9 --- /dev/null +++ b/pkg/corpus/cooccurrence/cooccurrence_test.go @@ -0,0 +1,33 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package co + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCooccurrence(t *testing.T) { + pw, err := New(Increment) + assert.NoError(t, err) + assert.NoError(t, pw.Add(1, 2)) + assert.Equal(t, 1, len(pw.EncodedMatrix())) +} + +func TestCooccurrenceWithInvalidCountType(t *testing.T) { + _, err := New(CountType("invalid type")) + assert.Error(t, err) +} diff --git a/pkg/corpus/cooccurrence/encode/encode.go b/pkg/corpus/cooccurrence/encode/encode.go new file mode 100644 index 0000000..1948f82 --- /dev/null +++ b/pkg/corpus/cooccurrence/encode/encode.go @@ -0,0 +1,37 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encode + +// data structure for co-occurrence mapping: +// - https://blog.chewxy.com/2017/07/12/21-bits-english/ + +// EncodeBigram creates id between two words. +func EncodeBigram(l1, l2 uint64) uint64 { + if l1 < l2 { + return encode(l1, l2) + } else { + return encode(l2, l1) + } +} + +func encode(l1, l2 uint64) uint64 { + return l1 | (l2 << 32) +} + +// DecodeBigram reverts pair id to two word ids. +func DecodeBigram(pid uint64) (uint64, uint64) { + f := pid >> 32 + return pid - (f << 32), f +} diff --git a/pkg/corpus/corpus.go b/pkg/corpus/corpus.go index cc3483a..4236b81 100644 --- a/pkg/corpus/corpus.go +++ b/pkg/corpus/corpus.go @@ -1,123 +1,30 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package corpus import ( - "bufio" - "fmt" - "io" - "strings" - - "github.com/pkg/errors" - - "github.com/ynqa/wego/pkg/clock" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" "github.com/ynqa/wego/pkg/corpus/dictionary" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/verbose" ) -type Corpus struct { - opts Options - dic *dictionary.Dictionary - pair *pairwise.Pairwise - maxLen int - - doc []int - - verbose *verbose.Verbose -} - -func New( - opts Options, - verbose *verbose.Verbose, -) *Corpus { - return &Corpus{ - opts: opts, - dic: dictionary.New(), - - doc: make([]int, 0), - - verbose: verbose, - } -} - -func (c *Corpus) Doc() []int { - return c.doc -} - -func (c *Corpus) Dictionary() *dictionary.Dictionary { - return c.dic -} - -func (c *Corpus) Pairwise() *pairwise.Pairwise { - return c.pair -} - -func (c *Corpus) Len() int { - return c.maxLen -} - -func (c *Corpus) Build(r io.Reader) error { - scanner := bufio.NewScanner(r) - scanner.Split(bufio.ScanWords) - - clk := clock.New() - for scanner.Scan() { - word := scanner.Text() - if c.opts.ToLower { - word = strings.ToLower(word) - } - - c.dic.Add(word) - id, _ := c.dic.ID(word) - c.doc = append(c.doc, id) - c.maxLen++ - - c.verbose.Do(func() { - if c.maxLen%100000 == 0 { - fmt.Printf("read %d words %v\r", c.maxLen, clk.AllElapsed()) - } - }) - } - if err := scanner.Err(); err != nil && err != io.EOF { - return errors.Wrap(err, "failed to scan") - } - - c.verbose.Do(func() { - fmt.Printf("read %d words %v\r\n", c.maxLen, clk.AllElapsed()) - }) - - if err := c.filter(); err != nil { - return err - } - c.verbose.Do(func() { - fmt.Printf("filtered to %d words %v\r\n", len(c.doc), clk.AllElapsed()) - }) - return nil -} - -func (c *Corpus) filter() error { - dst := make([]int, 0) - for _, id := range c.doc { - if c.dic.IDFreq(id) > c.opts.MinCount { - dst = append(dst, id) - } - } - c.doc = dst - return nil -} - -func (c *Corpus) BuildForPairwise(r io.Reader, opts pairwise.Options) error { - c.pair = pairwise.New(opts) - - if err := c.Build(r); err != nil { - return err - } - - for i := 0; i < len(c.doc); i++ { - for j := i + 1; j < len(c.doc) && j <= i+opts.Window; j++ { - if err := c.pair.Add(c.doc[i], c.doc[j]); err != nil { - return err - } - } - } - return nil +type Corpus interface { + IndexedDoc() []int + BatchWords(chan []int, int) error + Dictionary() *dictionary.Dictionary + Cooccurrence() *co.Cooccurrence + Len() int + LoadForDictionary() error + LoadForCooccurrence(co.CountType, int) error } diff --git a/pkg/corpus/cpsutil/cpsutil.go b/pkg/corpus/cpsutil/cpsutil.go new file mode 100644 index 0000000..f15090c --- /dev/null +++ b/pkg/corpus/cpsutil/cpsutil.go @@ -0,0 +1,116 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cpsutil + +import ( + "bufio" + "io" + + "github.com/ynqa/wego/pkg/corpus/dictionary" +) + +func scanner(r io.Reader) *bufio.Scanner { + s := bufio.NewScanner(r) + s.Split(bufio.ScanWords) + return s +} + +func ReadWord(r io.Reader, fn func(string) error) error { + scanner := scanner(r) + for scanner.Scan() { + if err := fn(scanner.Text()); err != nil { + return err + } + } + + if err := scanner.Err(); err != nil && err != io.EOF { + return err + } + + return nil +} + +func ReadWordWithForwardContext(r io.Reader, n int, fn func(string, string) error) error { + scanner := scanner(r) + var ( + axis string + ws []string = make([]string, n) + ) + postFn := func() error { + for _, w := range ws { + if err := fn(axis, w); err != nil { + return err + } + } + return nil + } + for { + if axis == "" { + if !scanner.Scan() { + break + } + axis = scanner.Text() + for i := 0; i < n; i++ { + if !scanner.Scan() { + break + } + ws[i] = scanner.Text() + } + } else { + axis = ws[0] + ws = ws[1:] + if !scanner.Scan() { + break + } + ws = append(ws, scanner.Text()) + } + if err := postFn(); err != nil { + return err + } + } + if err := postFn(); err != nil { + return err + } + + if err := scanner.Err(); err != nil && err != io.EOF { + return err + } + + return nil +} + +type Filters []FilterFn + +func (f Filters) Any(id int, dic *dictionary.Dictionary) bool { + var b bool + for _, fn := range f { + b = b || fn(id, dic) + } + return b +} + +type FilterFn func(int, *dictionary.Dictionary) bool + +func MaxCount(v int) FilterFn { + return FilterFn(func(id int, dic *dictionary.Dictionary) bool { + return 0 < v && v < dic.IDFreq(id) + }) +} + +func MinCount(v int) FilterFn { + return FilterFn(func(id int, dic *dictionary.Dictionary) bool { + return 0 <= v && dic.IDFreq(id) < v + }) +} diff --git a/pkg/corpus/cpsutil/cpsutil_test.go b/pkg/corpus/cpsutil/cpsutil_test.go new file mode 100644 index 0000000..a43a28b --- /dev/null +++ b/pkg/corpus/cpsutil/cpsutil_test.go @@ -0,0 +1,48 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cpsutil + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReadWord(t *testing.T) { + var dic []string + fn := func(w string) (err error) { + dic = append(dic, w) + return + } + + r := strings.NewReader("a bc def") + expected := []string{"a", "bc", "def"} + assert.NoError(t, ReadWord(r, fn)) + assert.Equal(t, expected, dic) +} + +func TestReadWordWithForwardContext(t *testing.T) { + var dic []string + fn := func(w1, w2 string) (err error) { + dic = append(dic, w1+w2) + return + } + + r := strings.NewReader("a b c d e") + expected := []string{"ab", "ac", "bc", "bd", "cd", "ce", "de"} + assert.NoError(t, ReadWordWithForwardContext(r, 2, fn)) + assert.Equal(t, expected, dic) +} diff --git a/pkg/corpus/dictionary/dictionary.go b/pkg/corpus/dictionary/dictionary.go index c3ebec9..f91ad5a 100644 --- a/pkg/corpus/dictionary/dictionary.go +++ b/pkg/corpus/dictionary/dictionary.go @@ -1,3 +1,17 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package dictionary // inspired by diff --git a/pkg/corpus/dictionary/huffman.go b/pkg/corpus/dictionary/huffman.go index 37ce52d..646de0d 100644 --- a/pkg/corpus/dictionary/huffman.go +++ b/pkg/corpus/dictionary/huffman.go @@ -1,3 +1,17 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package dictionary import ( diff --git a/pkg/corpus/dictionary/node/node.go b/pkg/corpus/dictionary/node/node.go index 3d3c328..506f48b 100644 --- a/pkg/corpus/dictionary/node/node.go +++ b/pkg/corpus/dictionary/node/node.go @@ -1,3 +1,17 @@ +// Copyright © 2020 wego authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package node type Node struct { diff --git a/pkg/corpus/fs/fs.go b/pkg/corpus/fs/fs.go new file mode 100644 index 0000000..9d6afcb --- /dev/null +++ b/pkg/corpus/fs/fs.go @@ -0,0 +1,119 @@ +package fs + +import ( + "io" + "strings" + + "github.com/ynqa/wego/pkg/corpus" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" + "github.com/ynqa/wego/pkg/corpus/cpsutil" + "github.com/ynqa/wego/pkg/corpus/dictionary" +) + +type Corpus struct { + doc io.Reader + + dic *dictionary.Dictionary + cooc *co.Cooccurrence + maxLen int + indexedDoc []int + + toLower bool + filters cpsutil.Filters +} + +func New(r io.Reader, toLower bool, maxCount, minCount int) corpus.Corpus { + return &Corpus{ + doc: r, + dic: dictionary.New(), + indexedDoc: make([]int, 0), + + toLower: toLower, + filters: cpsutil.Filters{ + cpsutil.MaxCount(maxCount), + cpsutil.MinCount(minCount), + }, + } +} + +func (c *Corpus) IndexedDoc() []int { + return nil +} + +func (c *Corpus) BatchWords(ch chan []int, batchSize int) error { + cursor, ids := 0, make([]int, batchSize) + err := cpsutil.ReadWord(c.doc, func(word string) error { + if c.toLower { + word = strings.ToLower(word) + } + + id, _ := c.dic.ID(word) + if c.filters.Any(id, c.dic) { + return nil + } + + ids[cursor] = id + cursor++ + if len(ids) == batchSize { + ch <- ids + cursor, ids = 0, make([]int, batchSize) + } + return nil + }) + if err != nil { + return err + } + + // send left words + ch <- ids[:cursor] + return nil +} + +func (c *Corpus) Dictionary() *dictionary.Dictionary { + return c.dic +} + +func (c *Corpus) Cooccurrence() *co.Cooccurrence { + return c.cooc +} + +func (c *Corpus) Len() int { + return c.maxLen +} + +func (c *Corpus) LoadForDictionary() error { + if err := cpsutil.ReadWord(c.doc, func(word string) error { + if c.toLower { + word = strings.ToLower(word) + } + + c.dic.Add(word) + c.maxLen++ + + return nil + }); err != nil { + return err + } + + return nil +} + +func (c *Corpus) LoadForCooccurrence(typ co.CountType, window int) (err error) { + c.cooc, err = co.New(typ) + if err != nil { + return + } + + if err = cpsutil.ReadWordWithForwardContext( + c.doc, window, func(w1, w2 string) error { + id1, _ := c.dic.ID(w1) + id2, _ := c.dic.ID(w2) + if err := c.cooc.Add(id1, id2); err != nil { + return err + } + return nil + }); err != nil { + return + } + return +} diff --git a/pkg/corpus/memory/memory.go b/pkg/corpus/memory/memory.go new file mode 100644 index 0000000..28697c5 --- /dev/null +++ b/pkg/corpus/memory/memory.go @@ -0,0 +1,100 @@ +package memory + +import ( + "io" + "strings" + + "github.com/ynqa/wego/pkg/corpus" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" + "github.com/ynqa/wego/pkg/corpus/cpsutil" + "github.com/ynqa/wego/pkg/corpus/dictionary" +) + +type Corpus struct { + doc io.Reader + + dic *dictionary.Dictionary + cooc *co.Cooccurrence + maxLen int + indexedDoc []int + + toLower bool + filters cpsutil.Filters +} + +func New(doc io.Reader, toLower bool, maxCount, minCount int) corpus.Corpus { + return &Corpus{ + doc: doc, + dic: dictionary.New(), + indexedDoc: make([]int, 0), + + toLower: toLower, + filters: cpsutil.Filters{ + cpsutil.MaxCount(maxCount), + cpsutil.MinCount(minCount), + }, + } +} + +func (c *Corpus) IndexedDoc() []int { + var res []int + for _, id := range c.indexedDoc { + if c.filters.Any(id, c.dic) { + continue + } + res = append(res, id) + } + return res +} + +func (c *Corpus) BatchWords(chan []int, int) error { + return nil +} + +func (c *Corpus) Dictionary() *dictionary.Dictionary { + return c.dic +} + +func (c *Corpus) Cooccurrence() *co.Cooccurrence { + return c.cooc +} + +func (c *Corpus) Len() int { + return c.maxLen +} + +func (c *Corpus) LoadForDictionary() error { + if err := cpsutil.ReadWord(c.doc, func(word string) error { + if c.toLower { + word = strings.ToLower(word) + } + + c.dic.Add(word) + id, _ := c.dic.ID(word) + c.maxLen++ + c.indexedDoc = append(c.indexedDoc, id) + + return nil + }); err != nil { + return err + } + + return nil +} + +func (c *Corpus) LoadForCooccurrence(typ co.CountType, window int) (err error) { + c.cooc, err = co.New(typ) + if err != nil { + return + } + + for i := 0; i < len(c.indexedDoc); i++ { + for j := i + 1; j < len(c.indexedDoc) && j <= i+window; j++ { + if err = c.cooc.Add(c.indexedDoc[i], c.indexedDoc[j]); err != nil { + return + } + } + } + + return +} diff --git a/pkg/corpus/options.go b/pkg/corpus/options.go deleted file mode 100644 index 9d89afb..0000000 --- a/pkg/corpus/options.go +++ /dev/null @@ -1,27 +0,0 @@ -package corpus - -import ( - "github.com/spf13/cobra" -) - -const ( - defaultMinCount = 5 - defaultToLower = false -) - -type Options struct { - MinCount int - ToLower bool -} - -func DefaultOptions() Options { - return Options{ - MinCount: defaultMinCount, - ToLower: defaultToLower, - } -} - -func LoadForCmd(cmd *cobra.Command, opts *Options) { - cmd.Flags().IntVar(&opts.MinCount, "min-count", defaultMinCount, "lower limit to filter rare words") - cmd.Flags().BoolVar(&opts.ToLower, "lower", defaultToLower, "whether the words on corpus convert to lowercase or not") -} diff --git a/pkg/corpus/pairwise/encode/encode.go b/pkg/corpus/pairwise/encode/encode.go deleted file mode 100644 index 32abd19..0000000 --- a/pkg/corpus/pairwise/encode/encode.go +++ /dev/null @@ -1,23 +0,0 @@ -package encode - -// data structure for co-occurrence mapping: -// - https://blog.chewxy.com/2017/07/12/21-bits-english/ - -// EncodeBigram creates id between two words. -func EncodeBigram(l1, l2 uint64) uint64 { - if l1 < l2 { - return encode(l1, l2) - } else { - return encode(l2, l1) - } -} - -func encode(l1, l2 uint64) uint64 { - return l1 | (l2 << 32) -} - -// DecodeBigram reverts pair id to two word ids. -func DecodeBigram(pid uint64) (uint64, uint64) { - f := pid >> 32 - return pid - (f << 32), f -} diff --git a/pkg/corpus/pairwise/options.go b/pkg/corpus/pairwise/options.go deleted file mode 100644 index 2e712a8..0000000 --- a/pkg/corpus/pairwise/options.go +++ /dev/null @@ -1,57 +0,0 @@ -package pairwise - -import ( - "fmt" - - "github.com/pkg/errors" - "github.com/spf13/cobra" -) - -func invalidCountTypeError(typ CountType) error { - return errors.Errorf("invalid relation type: %s not in %s|%s", typ, Increment, Distance) -} - -type CountType string - -const ( - Increment CountType = "inc" - Distance CountType = "dis" - defaultCountType = Increment - defaultWindow = 5 -) - -func (t *CountType) String() string { - if *t == CountType("") { - *t = defaultCountType - } - return string(*t) -} - -func (t *CountType) Set(name string) error { - typ := CountType(name) - if typ == Increment || typ == Distance { - *t = typ - return nil - } - return invalidCountTypeError(typ) -} - -func (t *CountType) Type() string { - return t.String() -} - -type Options struct { - CountType CountType - Window int -} - -func DefaultOptions() Options { - return Options{ - CountType: defaultCountType, - Window: defaultWindow, - } -} - -func LoadForCmd(cmd *cobra.Command, opts *Options) { - cmd.Flags().Var(&opts.CountType, "cnt", fmt.Sprintf("count type for co-occurrence words. One of %s|%s", Increment, Distance)) -} diff --git a/pkg/corpus/pairwise/pairwise.go b/pkg/corpus/pairwise/pairwise.go deleted file mode 100644 index ab0635c..0000000 --- a/pkg/corpus/pairwise/pairwise.go +++ /dev/null @@ -1,46 +0,0 @@ -package pairwise - -import ( - "math" - - "github.com/pkg/errors" - - "github.com/ynqa/wego/pkg/corpus/pairwise/encode" -) - -type Pairwise struct { - opts Options - - pm map[uint64]float64 -} - -func New(opts Options) *Pairwise { - return &Pairwise{ - opts: opts, - - pm: make(map[uint64]float64), - } -} - -func (p *Pairwise) PairMap() map[uint64]float64 { - return p.pm -} - -func (p *Pairwise) Add(left, right int) error { - enc := encode.EncodeBigram(uint64(left), uint64(right)) - var val float64 - switch p.opts.CountType { - case Increment: - val = 1 - case Distance: - div := left - right - if div == 0 { - return errors.Errorf("Divide by zero on counting co-occurrence") - } - val = 1. / math.Abs(float64(div)) - default: - return invalidCountTypeError(p.opts.CountType) - } - p.pm[enc] += val - return nil -} diff --git a/pkg/model/glove/glove.go b/pkg/model/glove/glove.go index 4d213e2..708853d 100644 --- a/pkg/model/glove/glove.go +++ b/pkg/model/glove/glove.go @@ -27,7 +27,8 @@ import ( "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/corpus/fs" + "github.com/ynqa/wego/pkg/corpus/memory" "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/modelutil" "github.com/ynqa/wego/pkg/model/modelutil/matrix" @@ -38,7 +39,7 @@ import ( type glove struct { opts Options - corpus *corpus.Corpus + corpus corpus.Corpus param *matrix.Matrix solver solver @@ -47,16 +48,7 @@ type glove struct { } func New(opts ...ModelOption) (model.Model, error) { - options := Options{ - CorpusOptions: corpus.DefaultOptions(), - PairwiseOptions: pairwise.DefaultOptions(), - ModelOptions: model.DefaultOptions(), - - Alpha: defaultAlpha, - SolverType: defaultSolverType, - Xmax: defaultXmax, - } - + options := DefaultOptions() for _, fn := range opts { fn(&options) } @@ -66,26 +58,29 @@ func New(opts ...ModelOption) (model.Model, error) { func NewForOptions(opts Options) (model.Model, error) { // TODO: validate Options - v := verbose.New(opts.ModelOptions.Verbose) + v := verbose.New(opts.Verbose) return &glove{ opts: opts, - corpus: corpus.New(opts.CorpusOptions, v), - verbose: v, }, nil } func (g *glove) preTrain(r io.Reader) error { - g.opts.PairwiseOptions.Window = g.opts.ModelOptions.Window - if err := g.corpus.BuildForPairwise( - r, - g.opts.PairwiseOptions, - ); err != nil { + if g.opts.DocInMemory { + g.corpus = memory.New(r, g.opts.ToLower, g.opts.MaxCount, g.opts.MinCount) + } else { + g.corpus = fs.New(r, g.opts.ToLower, g.opts.MaxCount, g.opts.MinCount) + } + + if err := g.corpus.LoadForDictionary(); err != nil { + return err + } + if err := g.corpus.LoadForCooccurrence(g.opts.CountType, g.opts.Window); err != nil { return err } - dic, dim := g.corpus.Dictionary(), g.opts.ModelOptions.Dim + dic, dim := g.corpus.Dictionary(), g.opts.Dim g.param = matrix.New( dic.Len()*2, @@ -99,9 +94,9 @@ func (g *glove) preTrain(r io.Reader) error { switch g.opts.SolverType { case Stochastic: - g.solver = newStochastic(g.opts.ModelOptions) + g.solver = newStochastic(g.opts) case AdaGrad: - g.solver = newAdaGrad(dic, g.opts.ModelOptions) + g.solver = newAdaGrad(dic, g.opts) default: return invalidSolverTypeError(g.opts.SolverType) } @@ -113,21 +108,21 @@ func (g *glove) Train(r io.Reader) error { return err } - items := g.makeItems(g.corpus.Pairwise()) + items := g.makeItems(g.corpus.Cooccurrence()) itemSize := len(items) indexPerThread := modelutil.IndexPerThread( - g.opts.ModelOptions.ThreadSize, + g.opts.Goroutines, itemSize, ) - for i := 0; i < g.opts.ModelOptions.Iter; i++ { + for i := 0; i < g.opts.Iter; i++ { trained, clk := make(chan struct{}), clock.New() go g.observe(trained, clk) - sem := semaphore.NewWeighted(int64(g.opts.ModelOptions.ThreadSize)) + sem := semaphore.NewWeighted(int64(g.opts.Goroutines)) wg := &sync.WaitGroup{} - for i := 0; i < g.opts.ModelOptions.ThreadSize; i++ { + for i := 0; i < g.opts.Goroutines; i++ { wg.Add(1) s, e := indexPerThread[i], indexPerThread[i+1] go g.trainPerThread(items[s:e], trained, sem, wg) @@ -169,7 +164,7 @@ func (g *glove) observe(trained chan struct{}, clk *clock.Clock) { for range trained { g.verbose.Do(func() { cnt++ - if cnt%g.opts.ModelOptions.BatchSize == 0 { + if cnt%g.opts.BatchSize == 0 { fmt.Printf("trained %d items %v\r", cnt, clk.AllElapsed()) } }) @@ -190,7 +185,7 @@ func (g *glove) Save(f io.Writer, typ save.VectorType) error { for i := 0; i < dic.Len(); i++ { word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < g.opts.ModelOptions.Dim; j++ { + for j := 0; j < g.opts.Dim; j++ { var v float64 switch { case typ == save.Aggregated: diff --git a/pkg/model/glove/item.go b/pkg/model/glove/item.go index ed67fbd..c83abe5 100644 --- a/pkg/model/glove/item.go +++ b/pkg/model/glove/item.go @@ -19,8 +19,8 @@ import ( "math" "github.com/ynqa/wego/pkg/clock" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/corpus/pairwise/encode" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" + "github.com/ynqa/wego/pkg/corpus/cooccurrence/encode" ) type item struct { @@ -29,10 +29,10 @@ type item struct { coef float64 } -func (g *glove) makeItems(pairwise *pairwise.Pairwise) []item { - pm := pairwise.PairMap() - res, idx, clk := make([]item, len(pm)), 0, clock.New() - for enc, f := range pm { +func (g *glove) makeItems(cooc *co.Cooccurrence) []item { + em := cooc.EncodedMatrix() + res, idx, clk := make([]item, len(em)), 0, clock.New() + for enc, f := range em { u1, u2 := encode.DecodeBigram(enc) l1, l2 := int(u1), int(u2) coef := 1. diff --git a/pkg/model/glove/options.go b/pkg/model/glove/options.go index 025aeed..045a379 100644 --- a/pkg/model/glove/options.go +++ b/pkg/model/glove/options.go @@ -14,12 +14,12 @@ package glove import ( + "fmt" + "runtime" + "github.com/pkg/errors" "github.com/spf13/cobra" - - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/model" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" ) func invalidSolverTypeError(typ SolverType) error { @@ -29,9 +29,8 @@ func invalidSolverTypeError(typ SolverType) error { type SolverType string const ( - Stochastic SolverType = "sgd" - AdaGrad SolverType = "adagrad" - defaultSolverType = Stochastic + Stochastic SolverType = "sgd" + AdaGrad SolverType = "adagrad" ) func (t *SolverType) String() string { @@ -54,115 +53,172 @@ func (t *SolverType) Type() string { return t.String() } -const ( +var ( defaultAlpha = 0.75 + defaultBatchSize = 100000 + defaultDim = 10 + defaultDocInMemory = false + defaultGoroutines = runtime.NumCPU() + defaultInitlr = 0.025 + defaultIter = 15 + defaultMaxCount = -1 + defaultMinCount = 5 + defaultSolverType = Stochastic defaultSubsampleThreshold = 1.0e-3 + defaultToLower = false + defaultVerbose = false + defaultWindow = 5 defaultXmax = 100 ) type Options struct { - CorpusOptions corpus.Options - PairwiseOptions pairwise.Options - ModelOptions model.Options - Alpha float64 + BatchSize int + CountType co.CountType + Dim int + DocInMemory bool + Goroutines int + Initlr float64 + Iter int + MaxCount int + MinCount int SolverType SolverType SubsampleThreshold float64 + ToLower bool + Verbose bool + Window int Xmax int } +func DefaultOptions() Options { + return Options{ + Alpha: defaultAlpha, + BatchSize: defaultBatchSize, + CountType: co.DefaultCountType, + Dim: defaultDim, + DocInMemory: defaultDocInMemory, + Goroutines: defaultGoroutines, + Initlr: defaultInitlr, + Iter: defaultIter, + MaxCount: defaultMaxCount, + MinCount: defaultMinCount, + SolverType: defaultSolverType, + SubsampleThreshold: defaultSubsampleThreshold, + ToLower: defaultToLower, + Verbose: defaultVerbose, + Window: defaultWindow, + Xmax: defaultXmax, + } +} + func LoadForCmd(cmd *cobra.Command, opts *Options) { cmd.Flags().Float64Var(&opts.Alpha, "alpha", defaultAlpha, "exponent of weighting function") + cmd.Flags().IntVar(&opts.BatchSize, "batch", defaultBatchSize, "batch size to train") + cmd.Flags().Var(&opts.CountType, "cnt", fmt.Sprintf("count type for co-occurrence words. One of %s|%s", co.Increment, co.Proximity)) + cmd.Flags().IntVarP(&opts.Dim, "dim", "d", defaultDim, "dimension for word vector") + cmd.Flags().IntVar(&opts.Goroutines, "goroutines", defaultGoroutines, "number of goroutine") + cmd.Flags().BoolVar(&opts.DocInMemory, "in-memory", defaultDocInMemory, "whether to store the doc in memory") + cmd.Flags().Float64Var(&opts.Initlr, "initlr", defaultInitlr, "initial learning rate") + cmd.Flags().IntVar(&opts.Iter, "iter", defaultIter, "number of iteration") + cmd.Flags().IntVar(&opts.MaxCount, "max-count", defaultMaxCount, "upper limit to filter words") + cmd.Flags().IntVar(&opts.MinCount, "min-count", defaultMinCount, "lower limit to filter words") + cmd.Flags().Var(&opts.SolverType, "solver", "solver for GloVe objective. One of: sgd|adagrad") cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") + cmd.Flags().BoolVar(&opts.ToLower, "to-lower", defaultToLower, "whether the words on corpus convert to lowercase or not") + cmd.Flags().BoolVar(&opts.Verbose, "verbose", defaultVerbose, "verbose mode") + cmd.Flags().IntVarP(&opts.Window, "window", "w", defaultWindow, "context window size") + cmd.Flags().IntVar(&opts.Xmax, "xmax", defaultXmax, "specifying cutoff in weighting function") } type ModelOption func(*Options) -// corpus options -func WithMinCount(v int) ModelOption { +func Alpha(v float64) ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.MinCount = v + opts.Alpha = v }) } -func ToLower() ModelOption { +func BatchSize(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.ToLower = true + opts.BatchSize = v }) } -// pairwise options -func WithCountType(typ pairwise.CountType) ModelOption { +func DocInMemory() ModelOption { return ModelOption(func(opts *Options) { - opts.PairwiseOptions.CountType = typ + opts.DocInMemory = true }) } -// model options -func WithBatchSize(v int) ModelOption { +func Goroutines(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.BatchSize = v + opts.Goroutines = v }) } -func WithDimension(v int) ModelOption { +func Dim(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Dim = v + opts.Dim = v }) } -func WithInitLearningRate(v float64) ModelOption { +func Initlr(v float64) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Initlr = v + opts.Initlr = v }) } -func WithIteration(v int) ModelOption { +func Iter(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Iter = v + opts.Iter = v }) } -func WithThreadSize(v int) ModelOption { +func MaxCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.ThreadSize = v + opts.MaxCount = v }) } -func WithWindow(v int) ModelOption { +func MinCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Window = v + opts.MinCount = v }) } -func Verbose() ModelOption { +func Solver(typ SolverType) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Verbose = true + opts.SolverType = typ }) } -// for glove options -func WithAlpha(v float64) ModelOption { +func SubsampleThreshold(v float64) ModelOption { return ModelOption(func(opts *Options) { - opts.Alpha = v + opts.SubsampleThreshold = v }) } -func WithSolver(typ SolverType) ModelOption { +func ToLower() ModelOption { return ModelOption(func(opts *Options) { - opts.SolverType = typ + opts.ToLower = true }) } -func WithSubsampleThreshold(v float64) ModelOption { +func Verbose() ModelOption { return ModelOption(func(opts *Options) { - opts.SubsampleThreshold = v + opts.Verbose = true + }) +} + +func Window(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.Window = v }) } -func WithXmax(v int) ModelOption { +func Xmax(v int) ModelOption { return ModelOption(func(opts *Options) { opts.Xmax = v }) diff --git a/pkg/model/glove/solver.go b/pkg/model/glove/solver.go index fcf8586..bdba5f9 100644 --- a/pkg/model/glove/solver.go +++ b/pkg/model/glove/solver.go @@ -18,7 +18,6 @@ import ( "math" "github.com/ynqa/wego/pkg/corpus/dictionary" - "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/modelutil/matrix" ) @@ -30,7 +29,7 @@ type stochastic struct { initlr float64 } -func newStochastic(opts model.Options) solver { +func newStochastic(opts Options) solver { return &stochastic{ initlr: opts.Initlr, } @@ -58,7 +57,7 @@ type adaGrad struct { gradsq *matrix.Matrix } -func newAdaGrad(dic *dictionary.Dictionary, opts model.Options) solver { +func newAdaGrad(dic *dictionary.Dictionary, opts Options) solver { dimAndBias := opts.Dim + 1 return &adaGrad{ initlr: opts.Initlr, diff --git a/pkg/model/lexvec/item.go b/pkg/model/lexvec/item.go index 5c9f9fb..0026ad4 100644 --- a/pkg/model/lexvec/item.go +++ b/pkg/model/lexvec/item.go @@ -20,15 +20,15 @@ import ( "github.com/pkg/errors" "github.com/ynqa/wego/pkg/clock" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/corpus/pairwise/encode" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" + "github.com/ynqa/wego/pkg/corpus/cooccurrence/encode" ) -func (l *lexvec) makeItems(pairwise *pairwise.Pairwise) (map[uint64]float64, error) { - pm := pairwise.PairMap() +func (l *lexvec) makeItems(cooc *co.Cooccurrence) (map[uint64]float64, error) { + em := cooc.EncodedMatrix() res, idx, clk := make(map[uint64]float64), 0, clock.New() logTotalFreq := math.Log(math.Pow(float64(l.corpus.Len()), l.opts.Smooth)) - for enc, f := range pm { + for enc, f := range em { u1, u2 := encode.DecodeBigram(enc) l1, l2 := int(u1), int(u2) v, err := l.calculateRelation( diff --git a/pkg/model/lexvec/lexvec.go b/pkg/model/lexvec/lexvec.go index aca9553..2ac687e 100644 --- a/pkg/model/lexvec/lexvec.go +++ b/pkg/model/lexvec/lexvec.go @@ -27,8 +27,10 @@ import ( "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/corpus/pairwise" - "github.com/ynqa/wego/pkg/corpus/pairwise/encode" + co "github.com/ynqa/wego/pkg/corpus/cooccurrence" + "github.com/ynqa/wego/pkg/corpus/cooccurrence/encode" + "github.com/ynqa/wego/pkg/corpus/fs" + "github.com/ynqa/wego/pkg/corpus/memory" "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/modelutil" "github.com/ynqa/wego/pkg/model/modelutil/matrix" @@ -40,7 +42,7 @@ import ( type lexvec struct { opts Options - corpus *corpus.Corpus + corpus corpus.Corpus param *matrix.Matrix subsampler *subsample.Subsampler @@ -50,17 +52,7 @@ type lexvec struct { } func New(opts ...ModelOption) (model.Model, error) { - options := Options{ - CorpusOptions: corpus.DefaultOptions(), - ModelOptions: model.DefaultOptions(), - - NegativeSampleSize: defaultNegativeSampleSize, - RelationType: defaultRelationType, - Smooth: defaultSmooth, - SubsampleThreshold: defaultSubsampleThreshold, - Theta: defaultTheta, - } - + options := DefaultOptions() for _, fn := range opts { fn(&options) } @@ -70,30 +62,30 @@ func New(opts ...ModelOption) (model.Model, error) { func NewForOptions(opts Options) (model.Model, error) { // TODO: validate Options - v := verbose.New(opts.ModelOptions.Verbose) + v := verbose.New(opts.Verbose) return &lexvec{ opts: opts, - corpus: corpus.New(opts.CorpusOptions, v), - - currentlr: opts.ModelOptions.Initlr, + currentlr: opts.Initlr, verbose: v, }, nil } func (l *lexvec) preTrain(r io.Reader) error { - if err := l.corpus.BuildForPairwise( - r, - pairwise.Options{ - CountType: pairwise.Increment, - Window: l.opts.ModelOptions.Window, - }, - ); err != nil { + if l.opts.DocInMemory { + l.corpus = memory.New(r, l.opts.ToLower, l.opts.MaxCount, l.opts.MinCount) + } else { + l.corpus = fs.New(r, l.opts.ToLower, l.opts.MaxCount, l.opts.MinCount) + } + if err := l.corpus.LoadForDictionary(); err != nil { + return err + } + if err := l.corpus.LoadForCooccurrence(co.Increment, l.opts.Window); err != nil { return err } - dic, dim := l.corpus.Dictionary(), l.opts.ModelOptions.Dim + dic, dim := l.corpus.Dictionary(), l.opts.Dim l.param = matrix.New( dic.Len()*2, @@ -114,24 +106,24 @@ func (l *lexvec) Train(r io.Reader) error { return err } - items, err := l.makeItems(l.corpus.Pairwise()) + items, err := l.makeItems(l.corpus.Cooccurrence()) if err != nil { return err } - doc := l.corpus.Doc() + doc := l.corpus.IndexedDoc() indexPerThread := modelutil.IndexPerThread( - l.opts.ModelOptions.ThreadSize, + l.opts.Goroutines, len(doc), ) - for i := 1; i <= l.opts.ModelOptions.Iter; i++ { + for i := 1; i <= l.opts.Iter; i++ { trained, clk := make(chan struct{}), clock.New() go l.observe(trained, clk) - sem := semaphore.NewWeighted(int64(l.opts.ModelOptions.ThreadSize)) + sem := semaphore.NewWeighted(int64(l.opts.Goroutines)) wg := &sync.WaitGroup{} - for i := 0; i < l.opts.ModelOptions.ThreadSize; i++ { + for i := 0; i < l.opts.Goroutines; i++ { wg.Add(1) s, e := indexPerThread[i], indexPerThread[i+1] go l.trainPerThread(doc[s:e], items, trained, sem, wg) @@ -171,12 +163,12 @@ func (l *lexvec) trainPerThread( func (l *lexvec) trainOne(doc []int, pos int, items map[uint64]float64) { dic := l.corpus.Dictionary() - del := modelutil.NextRandom(l.opts.ModelOptions.Window) - for a := del; a < l.opts.ModelOptions.Window*2+1-del; a++ { - if a == l.opts.ModelOptions.Window { + del := modelutil.NextRandom(l.opts.Window) + for a := del; a < l.opts.Window*2+1-del; a++ { + if a == l.opts.Window { continue } - c := pos - l.opts.ModelOptions.Window + a + c := pos - l.opts.Window + a if c < 0 || c >= len(doc) { continue } @@ -192,11 +184,11 @@ func (l *lexvec) trainOne(doc []int, pos int, items map[uint64]float64) { func (l *lexvec) update(l1, l2 int, f float64) { var diff float64 - for i := 0; i < l.opts.ModelOptions.Dim; i++ { + for i := 0; i < l.opts.Dim; i++ { diff += l.param.Slice(l1)[i] * l.param.Slice(l2)[i] } diff = (diff - f) * l.currentlr - for i := 0; i < l.opts.ModelOptions.Dim; i++ { + for i := 0; i < l.opts.Dim; i++ { t1 := diff * l.param.Slice(l2)[i] t2 := diff * l.param.Slice(l1)[i] l.param.Slice(l1)[i] -= t1 @@ -208,12 +200,12 @@ func (l *lexvec) observe(trained chan struct{}, clk *clock.Clock) { var cnt int for range trained { cnt++ - if cnt%l.opts.ModelOptions.BatchSize == 0 { - lower := l.opts.ModelOptions.Initlr * l.opts.Theta + if cnt%l.opts.BatchSize == 0 { + lower := l.opts.Initlr * l.opts.Theta if l.currentlr < lower { l.currentlr = lower } else { - l.currentlr = l.opts.ModelOptions.Initlr * (1.0 - float64(cnt)/float64(l.corpus.Len())) + l.currentlr = l.opts.Initlr * (1.0 - float64(cnt)/float64(l.corpus.Len())) } l.verbose.Do(func() { fmt.Printf("trained %d words %v\r", cnt, clk.AllElapsed()) @@ -236,7 +228,7 @@ func (l *lexvec) Save(f io.Writer, typ save.VectorType) error { for i := 0; i < dic.Len(); i++ { word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < l.opts.ModelOptions.Dim; j++ { + for j := 0; j < l.opts.Dim; j++ { var v float64 switch { case typ == save.Aggregated: diff --git a/pkg/model/lexvec/options.go b/pkg/model/lexvec/options.go index 9aea2b5..23b74b8 100644 --- a/pkg/model/lexvec/options.go +++ b/pkg/model/lexvec/options.go @@ -16,12 +16,10 @@ package lexvec import ( "fmt" + "runtime" "github.com/pkg/errors" "github.com/spf13/cobra" - - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" ) func invalidRelationTypeError(typ RelationType) error { @@ -58,117 +56,179 @@ func (t *RelationType) Type() string { return t.String() } -const ( +var ( + defaultBatchSize = 100000 + defaultDim = 10 + defaultDocInMemory = false + defaultGoroutines = runtime.NumCPU() + defaultInitlr = 0.025 + defaultIter = 15 + defaultMaxCount = -1 + defaultMinCount = 5 defaultNegativeSampleSize = 5 defaultSmooth = 0.75 defaultSubsampleThreshold = 1.0e-3 defaultTheta = 1.0e-4 + defaultToLower = false + defaultVerbose = false + + defaultWindow = 5 ) type Options struct { - CorpusOptions corpus.Options - ModelOptions model.Options - + BatchSize int + Dim int + DocInMemory bool + Goroutines int + Initlr float64 + Iter int + MaxCount int + MinCount int NegativeSampleSize int RelationType RelationType Smooth float64 SubsampleThreshold float64 Theta float64 + ToLower bool + Verbose bool + + Window int +} + +func DefaultOptions() Options { + return Options{ + BatchSize: defaultBatchSize, + Dim: defaultDim, + DocInMemory: defaultDocInMemory, + Goroutines: defaultGoroutines, + Initlr: defaultInitlr, + Iter: defaultIter, + MaxCount: defaultMaxCount, + MinCount: defaultMinCount, + NegativeSampleSize: defaultNegativeSampleSize, + RelationType: defaultRelationType, + Smooth: defaultSmooth, + SubsampleThreshold: defaultSubsampleThreshold, + Theta: defaultTheta, + ToLower: defaultToLower, + Verbose: defaultVerbose, + Window: defaultWindow, + } } - func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().IntVar(&opts.BatchSize, "batch", defaultBatchSize, "batch size to train") + cmd.Flags().IntVarP(&opts.Dim, "dim", "d", defaultDim, "dimension for word vector") + cmd.Flags().IntVar(&opts.Goroutines, "goroutines", defaultGoroutines, "number of goroutine") + cmd.Flags().BoolVar(&opts.DocInMemory, "in-memory", defaultDocInMemory, "whether to store the doc in memory") + cmd.Flags().Float64Var(&opts.Initlr, "initlr", defaultInitlr, "initial learning rate") + cmd.Flags().IntVar(&opts.Iter, "iter", defaultIter, "number of iteration") + cmd.Flags().IntVar(&opts.MaxCount, "max-count", defaultMaxCount, "upper limit to filter words") + cmd.Flags().IntVar(&opts.MinCount, "min-count", defaultMinCount, "lower limit to filter words") cmd.Flags().IntVar(&opts.NegativeSampleSize, "sample", defaultNegativeSampleSize, "negative sample size") cmd.Flags().Var(&opts.RelationType, "rel", fmt.Sprintf("relation type for co-occurrence words. One of %s|%s|%s|%s", PPMI, PMI, Collocation, LogCollocation)) cmd.Flags().Float64Var(&opts.Smooth, "smooth", defaultSmooth, "smoothing value for co-occurence value") cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") cmd.Flags().Float64Var(&opts.Theta, "theta", defaultTheta, "lower limit of learning rate (lr >= initlr * theta)") + cmd.Flags().BoolVar(&opts.ToLower, "to-lower", defaultToLower, "whether the words on corpus convert to lowercase or not") + cmd.Flags().BoolVar(&opts.Verbose, "verbose", defaultVerbose, "verbose mode") + cmd.Flags().IntVarP(&opts.Window, "window", "w", defaultWindow, "context window size") + } type ModelOption func(*Options) -// corpus options -func WithMinCount(v int) ModelOption { +func BatchSize(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.MinCount = v + opts.BatchSize = v }) } -func ToLower() ModelOption { +func DocInMemory() ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.ToLower = true + opts.DocInMemory = true }) } -// model options -func WithBatchSize(v int) ModelOption { +func Goroutines(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.BatchSize = v + opts.Goroutines = v }) } -func WithDimension(v int) ModelOption { +func Dim(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Dim = v + opts.Dim = v }) } -func WithInitLearningRate(v float64) ModelOption { +func Initlr(v float64) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Initlr = v + opts.Initlr = v }) } -func WithIteration(v int) ModelOption { +func Iter(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Iter = v + opts.Iter = v }) } -func WithThreadSize(v int) ModelOption { +func MaxCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.ThreadSize = v + opts.MaxCount = v }) } -func WithWindow(v int) ModelOption { +func MinCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Window = v + opts.MinCount = v }) } -func Verbose() ModelOption { - return ModelOption(func(opts *Options) { - opts.ModelOptions.Verbose = true - }) -} - -// for lexvec options -func WithNegativeSampleSize(v int) ModelOption { +func NegativeSampleSize(v int) ModelOption { return ModelOption(func(opts *Options) { opts.NegativeSampleSize = v }) } -func WithRelation(typ RelationType) ModelOption { +func Relation(typ RelationType) ModelOption { return ModelOption(func(opts *Options) { opts.RelationType = typ }) } -func WithSmooth(v float64) ModelOption { +func Smooth(v float64) ModelOption { return ModelOption(func(opts *Options) { opts.Smooth = v }) } -func WithSubsampleThreshold(v float64) ModelOption { +func SubsampleThreshold(v float64) ModelOption { return ModelOption(func(opts *Options) { opts.SubsampleThreshold = v }) } -func WithTheta(v float64) ModelOption { +func Theta(v float64) ModelOption { return ModelOption(func(opts *Options) { opts.Theta = v }) } + +func ToLower() ModelOption { + return ModelOption(func(opts *Options) { + opts.ToLower = true + }) +} + +func Verbose() ModelOption { + return ModelOption(func(opts *Options) { + opts.Verbose = true + }) +} + +func Window(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.Window = v + }) +} diff --git a/pkg/model/model.go b/pkg/model/model.go index ee0dc3f..b4f0e50 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -16,9 +16,7 @@ package model import ( "io" - "runtime" - "github.com/spf13/cobra" "github.com/ynqa/wego/pkg/model/modelutil/save" ) @@ -26,46 +24,3 @@ type Model interface { Train(io.Reader) error Save(io.Writer, save.VectorType) error } - -var ( - defaultBatchSize = 100000 - defaultDim = 10 - defaultInitlr = 0.025 - defaultIter = 15 - defaultThreadSize = runtime.NumCPU() - defaultWindow = 5 - defaultVerbose = false -) - -// Options stores common options for each model. -type Options struct { - BatchSize int - Dim int - Initlr float64 - Iter int - ThreadSize int - Window int - Verbose bool -} - -func DefaultOptions() Options { - return Options{ - BatchSize: defaultBatchSize, - Dim: defaultDim, - Initlr: defaultInitlr, - Iter: defaultIter, - ThreadSize: defaultThreadSize, - Window: defaultWindow, - Verbose: defaultVerbose, - } -} - -func LoadForCmd(cmd *cobra.Command, opts *Options) { - cmd.Flags().IntVar(&opts.BatchSize, "batch", defaultBatchSize, "batch size to train") - cmd.Flags().IntVarP(&opts.Dim, "dim", "d", defaultDim, "dimension for word vector") - cmd.Flags().Float64Var(&opts.Initlr, "initlr", defaultInitlr, "initial learning rate") - cmd.Flags().IntVar(&opts.Iter, "iter", defaultIter, "number of iteration") - cmd.Flags().IntVar(&opts.ThreadSize, "thread", defaultThreadSize, "number of goroutine") - cmd.Flags().IntVarP(&opts.Window, "window", "w", defaultWindow, "context window size") - cmd.Flags().BoolVar(&opts.Verbose, "verbose", defaultVerbose, "verbose mode") -} diff --git a/pkg/model/word2vec/model.go b/pkg/model/word2vec/model.go index 5061915..e0e29c6 100644 --- a/pkg/model/word2vec/model.go +++ b/pkg/model/word2vec/model.go @@ -35,13 +35,13 @@ type skipGram struct { } func newSkipGram(opts Options) mod { - ch := make(chan []float64, opts.ModelOptions.ThreadSize) - for i := 0; i < opts.ModelOptions.ThreadSize; i++ { - ch <- make([]float64, opts.ModelOptions.Dim) + ch := make(chan []float64, opts.Goroutines) + for i := 0; i < opts.Goroutines; i++ { + ch <- make([]float64, opts.Dim) } return &skipGram{ ch: ch, - window: opts.ModelOptions.Window, + window: opts.Window, } } @@ -83,13 +83,13 @@ type cbow struct { } func newCbow(opts Options) mod { - ch := make(chan []float64, opts.ModelOptions.ThreadSize*2) - for i := 0; i < opts.ModelOptions.ThreadSize; i++ { - ch <- make([]float64, opts.ModelOptions.Dim) + ch := make(chan []float64, opts.Goroutines*2) + for i := 0; i < opts.Goroutines; i++ { + ch <- make([]float64, opts.Dim) } return &cbow{ ch: ch, - window: opts.ModelOptions.Window, + window: opts.Window, } } diff --git a/pkg/model/word2vec/optimizer.go b/pkg/model/word2vec/optimizer.go index d40d2ff..012c71c 100644 --- a/pkg/model/word2vec/optimizer.go +++ b/pkg/model/word2vec/optimizer.go @@ -34,14 +34,13 @@ type negativeSampling struct { } func newNegativeSampling(dic *dictionary.Dictionary, opts Options) optimizer { - dim := opts.ModelOptions.Dim return &negativeSampling{ ctx: matrix.New( dic.Len(), - dim, + opts.Dim, func(vec []float64) { - for i := 0; i < dim; i++ { - vec[i] = (rand.Float64() - 0.5) / float64(dim) + for i := 0; i < opts.Dim; i++ { + vec[i] = (rand.Float64() - 0.5) / float64(opts.Dim) } }, ), @@ -100,7 +99,7 @@ type hierarchicalSoftmax struct { func newHierarchicalSoftmax(dic *dictionary.Dictionary, opts Options) optimizer { return &hierarchicalSoftmax{ sigtable: newSigmoidTable(), - nodeset: dic.HuffnamTree(opts.ModelOptions.Dim), + nodeset: dic.HuffnamTree(opts.Dim), maxDepth: opts.MaxDepth, } } diff --git a/pkg/model/word2vec/options.go b/pkg/model/word2vec/options.go index eb8caf0..36d23a8 100644 --- a/pkg/model/word2vec/options.go +++ b/pkg/model/word2vec/options.go @@ -16,12 +16,10 @@ package word2vec import ( "fmt" + "runtime" "github.com/pkg/errors" "github.com/spf13/cobra" - - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" ) func invalidModelTypeError(typ ModelType) error { @@ -87,125 +85,187 @@ func (t *OptimizerType) Type() string { return t.String() } -const ( +var ( + defaultBatchSize = 100000 + defaultDim = 10 + defaultDocInMemory = false + defaultGoroutines = runtime.NumCPU() + defaultInitlr = 0.025 + defaultIter = 15 + defaultMaxCount = -1 defaultMaxDepth = 100 + defaultMinCount = 5 defaultNegativeSampleSize = 5 defaultSubsampleThreshold = 1.0e-3 defaultTheta = 1.0e-4 + defaultToLower = false + defaultWindow = 5 + defaultVerbose = false ) type Options struct { - CorpusOptions corpus.Options - ModelOptions model.Options - + BatchSize int + Dim int + DocInMemory bool + Goroutines int + Initlr float64 + Iter int + MaxCount int MaxDepth int + MinCount int ModelType ModelType NegativeSampleSize int OptimizerType OptimizerType SubsampleThreshold float64 Theta float64 + ToLower bool + Verbose bool + Window int +} + +func DefaultOptions() Options { + return Options{ + BatchSize: defaultBatchSize, + Dim: defaultDim, + DocInMemory: defaultDocInMemory, + Goroutines: defaultGoroutines, + Initlr: defaultInitlr, + Iter: defaultIter, + MaxCount: defaultMaxCount, + MaxDepth: defaultMaxDepth, + MinCount: defaultMinCount, + ModelType: defaultModelType, + NegativeSampleSize: defaultNegativeSampleSize, + OptimizerType: defaultOptimizerType, + SubsampleThreshold: defaultSubsampleThreshold, + Theta: defaultTheta, + ToLower: defaultToLower, + Verbose: defaultVerbose, + Window: defaultWindow, + } } func LoadForCmd(cmd *cobra.Command, opts *Options) { - cmd.Flags().IntVar(&opts.MaxDepth, "maxDepth", defaultMaxDepth, "times to track huffman tree, max-depth=0 means to track full path from root to word (for hierarchical softmax only)") + cmd.Flags().IntVar(&opts.BatchSize, "batch", defaultBatchSize, "batch size to train") + cmd.Flags().IntVarP(&opts.Dim, "dim", "d", defaultDim, "dimension for word vector") + cmd.Flags().IntVar(&opts.Goroutines, "goroutines", defaultGoroutines, "number of goroutine") + cmd.Flags().BoolVar(&opts.DocInMemory, "in-memory", defaultDocInMemory, "whether to store the doc in memory") + cmd.Flags().Float64Var(&opts.Initlr, "initlr", defaultInitlr, "initial learning rate") + cmd.Flags().IntVar(&opts.Iter, "iter", defaultIter, "number of iteration") + cmd.Flags().IntVar(&opts.MaxCount, "max-count", defaultMaxCount, "upper limit to filter words") + cmd.Flags().IntVar(&opts.MaxDepth, "max-depth", defaultMaxDepth, "times to track huffman tree, max-depth=0 means to track full path from root to word (for hierarchical softmax only)") + cmd.Flags().IntVar(&opts.MinCount, "min-count", defaultMinCount, "lower limit to filter words") cmd.Flags().Var(&opts.ModelType, "model", fmt.Sprintf("which model does it use? one of: %s|%s", Cbow, SkipGram)) cmd.Flags().IntVar(&opts.NegativeSampleSize, "sample", defaultNegativeSampleSize, "negative sample size(for negative sampling only)") cmd.Flags().Var(&opts.OptimizerType, "optimizer", fmt.Sprintf("which optimizer does it use? one of: %s|%s", HierarchicalSoftmax, NegativeSampling)) cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") cmd.Flags().Float64Var(&opts.Theta, "theta", defaultTheta, "lower limit of learning rate (lr >= initlr * theta)") + cmd.Flags().BoolVar(&opts.ToLower, "to-lower", defaultToLower, "whether the words on corpus convert to lowercase or not") + cmd.Flags().BoolVar(&opts.Verbose, "verbose", defaultVerbose, "verbose mode") + cmd.Flags().IntVarP(&opts.Window, "window", "w", defaultWindow, "context window size") + } type ModelOption func(*Options) -// corpus options -func WithMinCount(v int) ModelOption { +func BatchSize(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.MinCount = v + opts.BatchSize = v }) } -func ToLower() ModelOption { +func DocInMemory() ModelOption { return ModelOption(func(opts *Options) { - opts.CorpusOptions.ToLower = true + opts.DocInMemory = true }) } -// model options -func WithBatchSize(v int) ModelOption { +func Goroutines(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.BatchSize = v + opts.Goroutines = v }) } -func WithDimension(v int) ModelOption { +func Dim(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Dim = v + opts.Dim = v }) } -func WithInitLearningRate(v float64) ModelOption { +func Initlr(v float64) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Initlr = v + opts.Initlr = v }) } -func WithIteration(v int) ModelOption { +func Iter(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Iter = v + opts.Iter = v }) } -func WithThreadSize(v int) ModelOption { +func MaxCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.ThreadSize = v + opts.MaxCount = v }) } -func WithWindow(v int) ModelOption { +func MaxDepth(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.ModelOptions.Window = v - }) -} - -func Verbose() ModelOption { - return ModelOption(func(opts *Options) { - opts.ModelOptions.Verbose = true + opts.MaxDepth = v }) } -// word2vec options -func WithMaxDepth(v int) ModelOption { +func MinCount(v int) ModelOption { return ModelOption(func(opts *Options) { - opts.MaxDepth = v + opts.MinCount = v }) } -func WithModel(typ ModelType) ModelOption { +func Model(typ ModelType) ModelOption { return ModelOption(func(opts *Options) { opts.ModelType = typ }) } -func WithNegativeSampleSize(v int) ModelOption { +func NegativeSampleSize(v int) ModelOption { return ModelOption(func(opts *Options) { opts.NegativeSampleSize = v }) } -func WithOptimizer(typ OptimizerType) ModelOption { +func Optimizer(typ OptimizerType) ModelOption { return ModelOption(func(opts *Options) { opts.OptimizerType = typ }) } -func WithSubsampleThreshold(v float64) ModelOption { +func SubsampleThreshold(v float64) ModelOption { return ModelOption(func(opts *Options) { opts.SubsampleThreshold = v }) } -func WithTheta(v float64) ModelOption { +func Theta(v float64) ModelOption { return ModelOption(func(opts *Options) { opts.Theta = v }) } + +func ToLower() ModelOption { + return ModelOption(func(opts *Options) { + opts.ToLower = true + }) +} + +func Verbose() ModelOption { + return ModelOption(func(opts *Options) { + opts.Verbose = true + }) +} + +func Window(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.Window = v + }) +} diff --git a/pkg/model/word2vec/word2vec.go b/pkg/model/word2vec/word2vec.go index 89060f0..1ee5597 100644 --- a/pkg/model/word2vec/word2vec.go +++ b/pkg/model/word2vec/word2vec.go @@ -27,6 +27,8 @@ import ( "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/corpus/fs" + "github.com/ynqa/wego/pkg/corpus/memory" "github.com/ynqa/wego/pkg/model" "github.com/ynqa/wego/pkg/model/modelutil" "github.com/ynqa/wego/pkg/model/modelutil/matrix" @@ -38,7 +40,7 @@ import ( type word2vec struct { opts Options - corpus *corpus.Corpus + corpus corpus.Corpus param *matrix.Matrix subsampler *subsample.Subsampler @@ -50,18 +52,7 @@ type word2vec struct { } func New(opts ...ModelOption) (model.Model, error) { - options := Options{ - CorpusOptions: corpus.DefaultOptions(), - ModelOptions: model.DefaultOptions(), - - MaxDepth: defaultMaxDepth, - ModelType: defaultModelType, - NegativeSampleSize: defaultNegativeSampleSize, - OptimizerType: defaultOptimizerType, - SubsampleThreshold: defaultSubsampleThreshold, - Theta: defaultTheta, - } - + options := DefaultOptions() for _, fn := range opts { fn(&options) } @@ -71,24 +62,27 @@ func New(opts ...ModelOption) (model.Model, error) { func NewForOptions(opts Options) (model.Model, error) { // TODO: validate Options - v := verbose.New(opts.ModelOptions.Verbose) + v := verbose.New(opts.Verbose) return &word2vec{ opts: opts, - corpus: corpus.New(opts.CorpusOptions, v), - - currentlr: opts.ModelOptions.Initlr, + currentlr: opts.Initlr, verbose: v, }, nil } func (w *word2vec) preTrain(r io.Reader) error { - if err := w.corpus.Build(r); err != nil { + if w.opts.DocInMemory { + w.corpus = memory.New(r, w.opts.ToLower, w.opts.MaxCount, w.opts.MinCount) + } else { + w.corpus = fs.New(r, w.opts.ToLower, w.opts.MaxCount, w.opts.MinCount) + } + if err := w.corpus.LoadForDictionary(); err != nil { return err } - dic, dim := w.corpus.Dictionary(), w.opts.ModelOptions.Dim + dic, dim := w.corpus.Dictionary(), w.opts.Dim w.param = matrix.New( dic.Len(), @@ -133,20 +127,20 @@ func (w *word2vec) Train(r io.Reader) error { return err } - doc := w.corpus.Doc() + doc := w.corpus.IndexedDoc() indexPerThread := modelutil.IndexPerThread( - w.opts.ModelOptions.ThreadSize, + w.opts.Goroutines, len(doc), ) - for i := 1; i <= w.opts.ModelOptions.Iter; i++ { + for i := 1; i <= w.opts.Iter; i++ { trained, clk := make(chan struct{}), clock.New() go w.observe(trained, clk) - sem := semaphore.NewWeighted(int64(w.opts.ModelOptions.ThreadSize)) + sem := semaphore.NewWeighted(int64(w.opts.Goroutines)) wg := &sync.WaitGroup{} - for i := 0; i < w.opts.ModelOptions.ThreadSize; i++ { + for i := 0; i < w.opts.Goroutines; i++ { wg.Add(1) s, e := indexPerThread[i], indexPerThread[i+1] go w.trainPerThread(doc[s:e], trained, sem, wg) @@ -187,12 +181,12 @@ func (w *word2vec) observe(trained chan struct{}, clk *clock.Clock) { var cnt int for range trained { cnt++ - if cnt%w.opts.ModelOptions.BatchSize == 0 { - lower := w.opts.ModelOptions.Initlr * w.opts.Theta + if cnt%w.opts.BatchSize == 0 { + lower := w.opts.Initlr * w.opts.Theta if w.currentlr < lower { w.currentlr = lower } else { - w.currentlr = w.opts.ModelOptions.Initlr * (1.0 - float64(cnt)/float64(w.corpus.Len())) + w.currentlr = w.opts.Initlr * (1.0 - float64(cnt)/float64(w.corpus.Len())) } w.verbose.Do(func() { fmt.Printf("trained %d words %v\r", cnt, clk.AllElapsed()) @@ -220,7 +214,7 @@ func (w *word2vec) Save(f io.Writer, typ save.VectorType) error { for i := 0; i < dic.Len(); i++ { word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < w.opts.ModelOptions.Dim; j++ { + for j := 0; j < w.opts.Dim; j++ { var v float64 switch { case typ == save.Aggregated && ctx.Row() > i: diff --git a/pkg/search/repl/repl.go b/pkg/search/console/console.go similarity index 69% rename from pkg/search/repl/repl.go rename to pkg/search/console/console.go index 6cb2c98..54fe785 100644 --- a/pkg/search/repl/repl.go +++ b/pkg/search/console/console.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package repl +package console import ( "fmt" @@ -37,18 +37,18 @@ type searchcursor struct { vector []float64 } -type Repl struct { +type Console struct { *liner.State searcher *search.Searcher cursor *searchcursor params *searchparams } -func New(searcher *search.Searcher, k int) (*Repl, error) { +func New(searcher *search.Searcher, k int) (*Console, error) { if searcher.Items.Empty() { return nil, errors.New("Number of items for searcher must be over 0") } - return &Repl{ + return &Console{ State: liner.NewLiner(), searcher: searcher, cursor: &searchcursor{ @@ -61,10 +61,10 @@ func New(searcher *search.Searcher, k int) (*Repl, error) { }, nil } -func (r *Repl) Run() error { - defer r.Close() +func (c *Console) Run() error { + defer c.Close() for { - l, err := r.Prompt(">> ") + l, err := c.Prompt(">> ") if err != nil { fmt.Println("error: ", err) } @@ -74,18 +74,18 @@ func (r *Repl) Run() error { case "": continue default: - if err := r.eval(l); err != nil { + if err := c.eval(l); err != nil { fmt.Println(err) } } } } -func (r *Repl) eval(l string) error { +func (c *Console) eval(l string) error { defer func() { - r.cursor.w1 = "" - r.cursor.w2 = "" - r.cursor.vector = make([]float64, r.params.dim) + c.cursor.w1 = "" + c.cursor.w2 = "" + c.cursor.vector = make([]float64, c.params.dim) }() expr, err := parser.ParseExpr(l) @@ -96,20 +96,20 @@ func (r *Repl) eval(l string) error { var neighbors search.Neighbors switch e := expr.(type) { case *ast.Ident: - neighbors, err = r.searcher.SearchInternal(e.String(), r.params.k) + neighbors, err = c.searcher.SearchInternal(e.String(), c.params.k) if err != nil { fmt.Printf("failed to search with word=%s\n", e.String()) } case *ast.BinaryExpr: - if err := r.evalExpr(expr); err != nil { + if err := c.evalExpr(expr); err != nil { return err } - neighbors, err = r.searcher.Search(embedding.Embedding{ - Vector: r.cursor.vector, - Norm: embutil.Norm(r.cursor.vector), - }, r.params.k, r.cursor.w1, r.cursor.w2) + neighbors, err = c.searcher.Search(embedding.Embedding{ + Vector: c.cursor.vector, + Norm: embutil.Norm(c.cursor.vector), + }, c.params.k, c.cursor.w1, c.cursor.w2) if err != nil { - fmt.Printf("failed to search with vector=%v\n", r.cursor.vector) + fmt.Printf("failed to search with vector=%v\n", c.cursor.vector) } default: return errors.Errorf("invalid type %v", e) @@ -118,10 +118,10 @@ func (r *Repl) eval(l string) error { return nil } -func (r *Repl) evalExpr(expr ast.Expr) error { +func (c *Console) evalExpr(expr ast.Expr) error { switch e := expr.(type) { case *ast.BinaryExpr: - return r.evalBinaryExpr(e) + return c.evalBinaryExpr(e) case *ast.Ident: return nil default: @@ -129,30 +129,30 @@ func (r *Repl) evalExpr(expr ast.Expr) error { } } -func (r *Repl) evalBinaryExpr(expr *ast.BinaryExpr) error { - xi, err := r.evalAsEmbedding(expr.X) +func (c *Console) evalBinaryExpr(expr *ast.BinaryExpr) error { + xi, err := c.evalAsEmbedding(expr.X) if err != nil { return err } - yi, err := r.evalAsEmbedding(expr.Y) + yi, err := c.evalAsEmbedding(expr.Y) if err != nil { return nil } - r.cursor.w1 = xi.Word - r.cursor.w2 = yi.Word - r.cursor.vector, err = arithmetic(xi.Vector, expr.Op, yi.Vector) + c.cursor.w1 = xi.Word + c.cursor.w2 = yi.Word + c.cursor.vector, err = arithmetic(xi.Vector, expr.Op, yi.Vector) return err } -func (r *Repl) evalAsEmbedding(expr ast.Expr) (embedding.Embedding, error) { - if err := r.evalExpr(expr); err != nil { +func (c *Console) evalAsEmbedding(expr ast.Expr) (embedding.Embedding, error) { + if err := c.evalExpr(expr); err != nil { return embedding.Embedding{}, err } v, ok := expr.(*ast.Ident) if !ok { return embedding.Embedding{}, errors.Errorf("failed to parse %v", expr) } - vi, ok := r.searcher.Items.Find(v.String()) + vi, ok := c.searcher.Items.Find(v.String()) if !ok { return embedding.Embedding{}, errors.Errorf("not found word=%s in vector map", v.String()) } else if err := vi.Validate(); err != nil { diff --git a/pkg/search/repl/op.go b/pkg/search/console/op.go similarity index 98% rename from pkg/search/repl/op.go rename to pkg/search/console/op.go index 22b1cdd..edc3d38 100644 --- a/pkg/search/repl/op.go +++ b/pkg/search/console/op.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package repl +package console import ( "github.com/pkg/errors" diff --git a/scripts/e2e.sh b/scripts/e2e.sh index ab294c6..8449f8b 100755 --- a/scripts/e2e.sh +++ b/scripts/e2e.sh @@ -48,55 +48,55 @@ function get_corpus() { function train_word2vec() { echo "train: skipgram with ns" - ./wego word2vec -i text8 -o word2vec_sg_ns.txt \ - --model skipgram --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-type agg --thread 20 --batch 100000 + ./wego word2vec -i text8 -o word2vec_sg_ns.txt --in-memory \ + --model skipgram --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-type agg --goroutines 20 --batch 100000 echo "train: skipgram with hs" - ./wego word2vec -i text8 -o word2vec_sg_hs.txt \ - --model skipgram --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 --batch 100000 + ./wego word2vec -i text8 -o word2vec_sg_hs.txt --in-memory \ + --model skipgram --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --goroutines 20 --batch 100000 echo "train: cbow with ns" - ./wego word2vec -i text8 -o word2vec_cbow_ns.txt \ - --model cbow --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-type agg --thread 20 --batch 100000 + ./wego word2vec -i text8 -o word2vec_cbow_ns.txt --in-memory \ + --model cbow --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-type agg --goroutines 20 --batch 100000 echo "train: cbow with hs" - ./wego word2vec -i text8 -o word2vec_cbow_hs.txt \ - --model cbow --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 --batch 100000 + ./wego word2vec -i text8 -o word2vec_cbow_hs.txt --in-memory \ + --model cbow --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --goroutines 20 --batch 100000 } function train_glove() { echo "train: glove with sgd" - ./wego glove -d 50 -i text8 -o glove_sgd.txt \ - --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver sgd --save-type agg --verbose + ./wego glove -d 50 -i text8 -o glove_sgd.txt --in-memory \ + --iter 5 --goroutines 12 --initlr 0.01 --min-count 5 -w 10 --solver sgd --save-type agg --verbose echo "train: glove with adagrad" - ./wego glove -d 50 -i text8 -o glove_adagrad.txt \ - --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver adagrad --save-type agg --verbose + ./wego glove -d 50 -i text8 -o glove_adagrad.txt --in-memory \ + --iter 5 --goroutines 12 --initlr 0.05 --min-count 5 -w 10 --solver adagrad --save-type agg --verbose } function train_lexvec() { echo "train: lexvec" - ./wego lexvec -d 50 -i text8 -o lexvec.txt \ - --iter 3 --thread 12 --initlr 0.05 --min-count 5 -w 5 --rel ppmi --save-type agg --verbose + ./wego lexvec -d 50 -i text8 -o lexvec.txt --in-memory \ + --iter 3 --goroutines 12 --initlr 0.05 --min-count 5 -w 5 --rel ppmi --save-type agg --verbose } function search_word2vec() { echo "similarity search: skipgram with ns" - ./wego search -i word2vec_sg_ns.txt microsoft + ./wego query -i word2vec_sg_ns.txt microsoft echo "similarity search: skipgram with hs" - ./wego search -i word2vec_sg_hs.txt microsoft + ./wego query -i word2vec_sg_hs.txt microsoft echo "similarity search: cbow with ns" - ./wego search -i word2vec_cbow_ns.txt microsoft + ./wego query -i word2vec_cbow_ns.txt microsoft echo "similarity search: cbow with hs" - ./wego search -i word2vec_cbow_hs.txt microsoft + ./wego query -i word2vec_cbow_hs.txt microsoft } function search_glove() { echo "similarity search: glove with sgd" - ./wego search -i glove_sgd.txt microsoft + ./wego query -i glove_sgd.txt microsoft echo "similarity search: glove with adagrad" - ./wego search -i glove_adagrad.txt microsoft + ./wego query -i glove_adagrad.txt microsoft } function search_lexvec() { echo "similarity search: lexvec" - ./wego search -i lexvec.txt microsoft + ./wego query -i lexvec.txt microsoft } for OPT in "$@"; do diff --git a/wego.go b/wego.go index ce092b3..b550a6a 100644 --- a/wego.go +++ b/wego.go @@ -23,16 +23,16 @@ import ( "github.com/ynqa/wego/cmd/model/glove" "github.com/ynqa/wego/cmd/model/lexvec" "github.com/ynqa/wego/cmd/model/word2vec" - "github.com/ynqa/wego/cmd/search" - "github.com/ynqa/wego/cmd/search/repl" + "github.com/ynqa/wego/cmd/query" + "github.com/ynqa/wego/cmd/query/console" ) func main() { word2vec := word2vec.New() glove := glove.New() lexvec := lexvec.New() - search := search.New() - repl := repl.New() + query := query.New() + console := console.New() cmd := &cobra.Command{ Use: "wego", @@ -42,16 +42,16 @@ func main() { word2vec.Name(), glove.Name(), lexvec.Name(), - search.Name(), - repl.Name(), + query.Name(), + console.Name(), ) }, } cmd.AddCommand(word2vec) cmd.AddCommand(glove) cmd.AddCommand(lexvec) - cmd.AddCommand(search) - cmd.AddCommand(repl) + cmd.AddCommand(query) + cmd.AddCommand(console) if err := cmd.Execute(); err != nil { os.Exit(1)