From dc8f2d3919035a886d19336cebfe810f7682ca0c Mon Sep 17 00:00:00 2001 From: ynqa Date: Fri, 27 Dec 2019 20:21:20 +0900 Subject: [PATCH] refactoring --- .gitignore | 2 +- README.md | 41 +- {pkg/model => cmd}/README.md | 83 +++++ cmd/glove.go | 98 ----- cmd/glove_test.go | 46 --- cmd/lexvec.go | 101 ----- cmd/model/cmdutil/cmdutil.go | 44 +++ cmd/model/glove/glove.go | 102 +++++ cmd/model/lexvec/lexvec.go | 100 +++++ cmd/model/word2vec/word2vec.go | 100 +++++ cmd/repl.go | 68 ---- cmd/repl_test.go | 34 -- cmd/root.go | 89 ----- cmd/root_test.go | 45 --- cmd/search.go | 76 ---- .../search/cmdutil/cmdutil.go | 24 +- cmd/search/repl/repl.go | 68 ++++ cmd/search/search.go | 73 ++++ cmd/search_test.go | 34 -- cmd/word2vec.go | 108 ------ cmd/word2vec_test.go | 47 --- examples/glove/glove.go | 54 --- examples/lexvec/lexvec.go | 55 --- examples/word2vec/main.go | 29 ++ examples/word2vec/word2vec.go | 56 --- go.mod | 36 +- go.sum | 101 +---- pkg/builder/glove.go | 216 ----------- pkg/builder/glove_test.go | 150 -------- pkg/builder/lexvec.go | 226 ----------- pkg/builder/word2vec.go | 270 -------------- pkg/builder/word2vec_test.go | 203 ---------- pkg/{model/util_test.go => clock/clock.go} | 22 +- pkg/co/co.go | 37 -- pkg/config/config.go | 154 -------- pkg/config/config_test.go | 142 ------- pkg/corpus/core.go | 82 ---- pkg/corpus/corpus.go | 106 ++++++ pkg/corpus/count_model.go | 227 ----------- pkg/corpus/dictionary/dictionary.go | 67 ++++ pkg/corpus/dictionary/huffman.go | 43 +++ pkg/corpus/dictionary/node/node.go | 29 ++ pkg/corpus/options.go | 23 ++ pkg/corpus/pairwise/encode/encode.go | 23 ++ pkg/corpus/pairwise/options.go | 54 +++ pkg/corpus/pairwise/pairwise.go | 46 +++ pkg/corpus/word2vec.go | 47 --- pkg/corpus/word2vec_test.go | 63 ---- pkg/item/item.go | 88 +++++ pkg/item/item_test.go | 68 ++++ pkg/model/glove/adagrad.go | 69 ---- pkg/model/glove/adagrad_test.go | 66 ---- pkg/model/glove/glove.go | 286 +++++++------- pkg/model/glove/item.go | 45 +++ pkg/model/glove/options.go | 170 +++++++++ pkg/model/glove/sgd.go | 55 --- pkg/model/glove/sgd_test.go | 50 --- pkg/model/glove/solver.go | 108 ++++-- pkg/model/lexvec/item.go | 72 ++++ pkg/model/lexvec/lexvec.go | 352 +++++++++--------- pkg/model/lexvec/options.go | 174 +++++++++ pkg/model/model.go | 67 +++- pkg/model/modelutil/matrix/matrix.go | 36 ++ pkg/model/modelutil/modelutil.go | 27 ++ pkg/model/modelutil/save/save.go | 37 ++ pkg/model/option.go | 29 -- pkg/model/subsample/subsample.go | 38 ++ pkg/model/util.go | 39 -- pkg/model/word2vec/cbow.go | 87 ----- pkg/model/word2vec/hs.go | 78 ---- pkg/model/word2vec/hs_test.go | 46 --- pkg/model/word2vec/model.go | 156 ++++++-- pkg/model/word2vec/ns.go | 93 ----- pkg/model/word2vec/ns_test.go | 46 --- pkg/model/word2vec/opt.go | 43 --- pkg/model/word2vec/optimizer.go | 116 ++++++ pkg/model/word2vec/options.go | 211 +++++++++++ pkg/model/word2vec/sigmoid_table.go | 12 +- pkg/model/word2vec/sigmoid_table_test.go | 5 +- pkg/model/word2vec/skipgram.go | 65 ---- pkg/model/word2vec/word2vec.go | 313 ++++++++-------- pkg/node/node.go | 96 ----- pkg/repl/README.md | 43 --- pkg/search/README.md | 38 -- pkg/search/describer.go | 41 -- pkg/search/describer_test.go | 43 --- pkg/search/neighbor.go | 28 -- pkg/search/neighbor_test.go | 77 ---- pkg/search/parser.go | 61 --- pkg/search/parser_test.go | 54 --- pkg/{ => search}/repl/op.go | 34 +- pkg/{ => search}/repl/repl.go | 102 +++-- pkg/search/search.go | 216 ++++++++--- pkg/search/search_test.go | 204 ++++++++-- pkg/search/testing.go | 25 -- pkg/search/util.go | 35 -- pkg/search/util_test.go | 40 -- pkg/timer/timer.go | 73 ---- pkg/validate/validate.go | 25 -- pkg/validate/validate_test.go | 25 -- pkg/verbose/verbose.go | 17 + scripts/demo.sh | 6 +- scripts/e2e.sh | 56 +-- wego.go | 36 +- 104 files changed, 3365 insertions(+), 5101 deletions(-) rename {pkg/model => cmd}/README.md (70%) delete mode 100644 cmd/glove.go delete mode 100644 cmd/glove_test.go delete mode 100644 cmd/lexvec.go create mode 100644 cmd/model/cmdutil/cmdutil.go create mode 100644 cmd/model/glove/glove.go create mode 100644 cmd/model/lexvec/lexvec.go create mode 100644 cmd/model/word2vec/word2vec.go delete mode 100644 cmd/repl.go delete mode 100644 cmd/repl_test.go delete mode 100644 cmd/root.go delete mode 100644 cmd/root_test.go delete mode 100644 cmd/search.go rename pkg/corpus/testing.go => cmd/search/cmdutil/cmdutil.go (58%) create mode 100644 cmd/search/repl/repl.go create mode 100644 cmd/search/search.go delete mode 100644 cmd/search_test.go delete mode 100644 cmd/word2vec.go delete mode 100644 cmd/word2vec_test.go delete mode 100644 examples/glove/glove.go delete mode 100644 examples/lexvec/lexvec.go create mode 100644 examples/word2vec/main.go delete mode 100644 examples/word2vec/word2vec.go delete mode 100644 pkg/builder/glove.go delete mode 100644 pkg/builder/glove_test.go delete mode 100644 pkg/builder/lexvec.go delete mode 100644 pkg/builder/word2vec.go delete mode 100644 pkg/builder/word2vec_test.go rename pkg/{model/util_test.go => clock/clock.go} (73%) delete mode 100644 pkg/co/co.go delete mode 100644 pkg/config/config.go delete mode 100644 pkg/config/config_test.go delete mode 100644 pkg/corpus/core.go create mode 100644 pkg/corpus/corpus.go delete mode 100644 pkg/corpus/count_model.go create mode 100644 pkg/corpus/dictionary/dictionary.go create mode 100644 pkg/corpus/dictionary/huffman.go create mode 100644 pkg/corpus/dictionary/node/node.go create mode 100644 pkg/corpus/options.go create mode 100644 pkg/corpus/pairwise/encode/encode.go create mode 100644 pkg/corpus/pairwise/options.go create mode 100644 pkg/corpus/pairwise/pairwise.go delete mode 100644 pkg/corpus/word2vec.go delete mode 100644 pkg/corpus/word2vec_test.go create mode 100644 pkg/item/item.go create mode 100644 pkg/item/item_test.go delete mode 100644 pkg/model/glove/adagrad.go delete mode 100644 pkg/model/glove/adagrad_test.go create mode 100644 pkg/model/glove/item.go create mode 100644 pkg/model/glove/options.go delete mode 100644 pkg/model/glove/sgd.go delete mode 100644 pkg/model/glove/sgd_test.go create mode 100644 pkg/model/lexvec/item.go create mode 100644 pkg/model/lexvec/options.go create mode 100644 pkg/model/modelutil/matrix/matrix.go create mode 100644 pkg/model/modelutil/modelutil.go create mode 100644 pkg/model/modelutil/save/save.go delete mode 100644 pkg/model/option.go create mode 100644 pkg/model/subsample/subsample.go delete mode 100644 pkg/model/util.go delete mode 100644 pkg/model/word2vec/cbow.go delete mode 100644 pkg/model/word2vec/hs.go delete mode 100644 pkg/model/word2vec/hs_test.go delete mode 100644 pkg/model/word2vec/ns.go delete mode 100644 pkg/model/word2vec/ns_test.go delete mode 100644 pkg/model/word2vec/opt.go create mode 100644 pkg/model/word2vec/optimizer.go create mode 100644 pkg/model/word2vec/options.go delete mode 100644 pkg/model/word2vec/skipgram.go delete mode 100644 pkg/node/node.go delete mode 100644 pkg/repl/README.md delete mode 100644 pkg/search/README.md delete mode 100644 pkg/search/describer.go delete mode 100644 pkg/search/describer_test.go delete mode 100644 pkg/search/neighbor.go delete mode 100644 pkg/search/neighbor_test.go delete mode 100644 pkg/search/parser.go delete mode 100644 pkg/search/parser_test.go rename pkg/{ => search}/repl/op.go (51%) rename pkg/{ => search}/repl/repl.go (60%) delete mode 100644 pkg/search/testing.go delete mode 100644 pkg/search/util.go delete mode 100644 pkg/search/util_test.go delete mode 100644 pkg/timer/timer.go delete mode 100644 pkg/validate/validate.go delete mode 100644 pkg/validate/validate_test.go create mode 100644 pkg/verbose/verbose.go diff --git a/.gitignore b/.gitignore index 3e48cb5..21dea3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ .idea/ vendor/ -example/*.txt +*.txt text8 text8.zip diff --git a/README.md b/README.md index fc7b717..4c94aa0 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ Wego outputs a .txt file that is described word vector is subject to the followi ... ``` -## Example +## API It's also able to train word vectors using wego APIs. Examples are as follows. @@ -78,41 +78,28 @@ package main import ( "os" - "github.com/ynqa/wego/pkg/builder" + "github.com/ynqa/wego/pkg/model/modelutil/save" "github.com/ynqa/wego/pkg/model/word2vec" ) func main() { - b := builder.NewWord2vecBuilder() - - b.Dimension(10). - Window(5). - Model(word2vec.CBOW). - Optimizer(word2vec.NEGATIVE_SAMPLING). - NegativeSampleSize(5). - Verbose() - - m, err := b.Build() + model, err := word2vec.New( + word2vec.WithWindow(5), + word2vec.WithModel(word2vec.Cbow), + word2vec.WithOptimizer(word2vec.NegativeSampling), + word2vec.WithNegativeSampleSize(5), + word2vec.Verbose(), + ) if err != nil { - // Failed to build word2vec. + // failed to create word2vec. } input, _ := os.Open("text8") - - // Start to Train. - if err = m.Train(input); err != nil { - // Failed to train by word2vec. + if err = model.Train(input); err != nil { + // failed to train. } - output, err := os.Create("example.txt") - if err != nil { - // Failed to create output file. - } - - defer func() { - output.Close() - }() - - m.Save(output) + // write word vector. + model.Save(os.Stdin, save.AggregatedVector) } ``` diff --git a/pkg/model/README.md b/cmd/README.md similarity index 70% rename from pkg/model/README.md rename to cmd/README.md index 1ee1670..12fbd87 100644 --- a/pkg/model/README.md +++ b/cmd/README.md @@ -102,3 +102,86 @@ Flags: --verbose verbose mode -w, --window int context window size (default 5) ``` + +# Search + +Similarity search between word vectors. + +## Usage + +``` +Search similar words + +Usage: + wego search [flags] + +Examples: + wego search -i example/word_vectors.txt microsoft + +Flags: + -h, --help help for search + -i, --inputFile string input file path for trained word vector (default "example/input.txt") + -r, --rank int how many the most similar words will be displayed (default 10) +``` + +## Example + +``` +$ go run wego.go search -i example/word_vectors_sg.txt microsoft + RANK | WORD | SIMILARITY ++------+------------+------------+ + 1 | apple | 0.994008 + 2 | operating | 0.992855 + 3 | versions | 0.992800 + 4 | ibm | 0.992232 + 5 | os | 0.989174 + 6 | computers | 0.988998 + 7 | machines | 0.988804 + 8 | dvd | 0.988732 + 9 | cd | 0.988259 + 10 | compatible | 0.988200 +``` + +# REPL for search + +Similarity search between word vectors with REPL mode. + +## Usage + +``` +Search similar words with REPL mode + +Usage: + wego repl [flags] + +Examples: + wego repl -i example/word_vectors.txt + >> apple + banana + ... + +Flags: + -h, --help help for repl + -i, --inputFile string input file path for trained word vector (default "example/word_vectors.txt") + -r, --rank int how many the most similar words will be displayed (default 10) +``` + +## Example + +Now, it is able to use `+`, `-` for arithmetic operations. + +``` +$ go run wego.go repl -i example/word_vectors_sg.txt +>> a + b + RANK | WORD | SIMILARITY ++------+---------+------------+ + 1 | phi | 0.907975 + 2 | q | 0.904593 + 3 | mathbf | 0.903066 + 4 | cdot | 0.902205 + 5 | b | 0.901952 + 6 | becomes | 0.900346 + 7 | int | 0.898680 + 8 | z | 0.897895 + 9 | named | 0.896480 + 10 | v | 0.895456 +``` diff --git a/cmd/glove.go b/cmd/glove.go deleted file mode 100644 index 29e1b2e..0000000 --- a/cmd/glove.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - "runtime/pprof" - - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/validate" -) - -var gloveCmd = &cobra.Command{ - Use: "glove", - Short: "GloVe: Global Vectors for Word Representation", - PreRun: func(cmd *cobra.Command, args []string) { - bindConfig(cmd) - bindGlove(cmd) - }, - RunE: func(cmd *cobra.Command, args []string) error { - if viper.GetBool(config.Prof.String()) { - f, err := os.Create("cpu.prof") - if err != nil { - os.Exit(1) - } - pprof.StartCPUProfile(f) - defer pprof.StopCPUProfile() - } - return runGlove() - }, -} - -func init() { - gloveCmd.Flags().AddFlagSet(configFlagSet()) - gloveCmd.Flags().String(config.Solver.String(), config.DefaultSolver.String(), - "solver for GloVe objective. One of: sgd|adagrad") - gloveCmd.Flags().Int(config.Xmax.String(), config.DefaultXmax, - "specifying cutoff in weighting function") - gloveCmd.Flags().Float64(config.Alpha.String(), config.DefaultAlpha, - "exponent of weighting function") -} - -func bindGlove(cmd *cobra.Command) { - viper.BindPFlag(config.Solver.String(), cmd.Flags().Lookup(config.Solver.String())) - viper.BindPFlag(config.Xmax.String(), cmd.Flags().Lookup(config.Xmax.String())) - viper.BindPFlag(config.Alpha.String(), cmd.Flags().Lookup(config.Alpha.String())) -} - -func runGlove() error { - outputFileName := viper.GetString(config.OutputFile.String()) - if validate.FileExists(outputFileName) { - return errors.Errorf("%s is already existed", outputFileName) - } - glove, err := builder.NewGloveBuilderFromViper() - if err != nil { - return err - } - mod, err := glove.Build() - if err != nil { - return err - } - inputFile := viper.GetString(config.InputFile.String()) - if !validate.FileExists(inputFile) { - return errors.Errorf("Not such a file %s", inputFile) - } - input, err := os.Open(inputFile) - if err != nil { - return err - } - if err := mod.Train(input); err != nil { - return err - } - outputFile, err := os.Create(outputFileName) - if err != nil { - return err - } - defer func() { - outputFile.Close() - }() - return mod.Save(outputFile) -} diff --git a/cmd/glove_test.go b/cmd/glove_test.go deleted file mode 100644 index 36bb718..0000000 --- a/cmd/glove_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "testing" - - "github.com/spf13/viper" -) - -const gloveFlagSize = 3 - -func TestGloveBind(t *testing.T) { - defer viper.Reset() - - bindGlove(gloveCmd) - - if len(viper.AllKeys()) != gloveFlagSize { - t.Errorf("Expected gloveBind maps %v keys: %v", - gloveFlagSize, viper.AllKeys()) - } -} - -func TestGloveCmdPreRun(t *testing.T) { - defer viper.Reset() - - var empty []string - gloveCmd.PreRun(gloveCmd, empty) - - if len(viper.AllKeys()) != gloveFlagSize+configFlagSize { - t.Errorf("Expected PreRun of gloveCmd maps %v keys: %v", - gloveFlagSize+configFlagSize, viper.AllKeys()) - } -} diff --git a/cmd/lexvec.go b/cmd/lexvec.go deleted file mode 100644 index 8411c69..0000000 --- a/cmd/lexvec.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - "runtime/pprof" - - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/validate" -) - -var lexvecCmd = &cobra.Command{ - Use: "lexvec", - Short: "Lexvec: Matrix Factorization using Window Sampling and Negative Sampling for Improved Word Representations", - PreRun: func(cmd *cobra.Command, args []string) { - bindConfig(cmd) - bindLexvec(cmd) - }, - RunE: func(cmd *cobra.Command, args []string) error { - if viper.GetBool(config.Prof.String()) { - f, err := os.Create("cpu.prof") - if err != nil { - os.Exit(1) - } - pprof.StartCPUProfile(f) - defer pprof.StopCPUProfile() - } - return runLexvec() - }, -} - -func init() { - lexvecCmd.Flags().AddFlagSet(configFlagSet()) - lexvecCmd.Flags().Int(config.NegativeSampleSize.String(), config.DefaultNegativeSampleSize, - "negative sample size(for negative sampling only)") - lexvecCmd.Flags().Float64(config.Theta.String(), config.DefaultTheta, - "lower limit of learning rate (lr >= initlr * theta)") - lexvecCmd.Flags().Float64(config.Smooth.String(), config.DefaultSmooth, - "smoothing value for co-occurence value") - lexvecCmd.Flags().String(config.RelationType.String(), config.DefaultRelationType.String(), - "relation type for counting co-occurrence. One of ppmi|pmi|co|logco") -} - -func bindLexvec(cmd *cobra.Command) { - viper.BindPFlag(config.NegativeSampleSize.String(), cmd.Flags().Lookup(config.NegativeSampleSize.String())) - viper.BindPFlag(config.Theta.String(), cmd.Flags().Lookup(config.Theta.String())) - viper.BindPFlag(config.Smooth.String(), cmd.Flags().Lookup(config.Smooth.String())) - viper.BindPFlag(config.RelationType.String(), cmd.Flags().Lookup(config.RelationType.String())) -} - -func runLexvec() error { - outputFileName := viper.GetString(config.OutputFile.String()) - if validate.FileExists(outputFileName) { - return errors.Errorf("%s is already existed", outputFileName) - } - lexvec, err := builder.NewLexvecBuilderFromViper() - if err != nil { - return err - } - mod, err := lexvec.Build() - if err != nil { - return err - } - inputFile := viper.GetString(config.InputFile.String()) - if !validate.FileExists(inputFile) { - return errors.Errorf("Not such a file %s", inputFile) - } - input, err := os.Open(inputFile) - if err != nil { - return err - } - if err := mod.Train(input); err != nil { - return err - } - outputFile, err := os.Create(outputFileName) - if err != nil { - return err - } - defer func() { - outputFile.Close() - }() - return mod.Save(outputFile) -} diff --git a/cmd/model/cmdutil/cmdutil.go b/cmd/model/cmdutil/cmdutil.go new file mode 100644 index 0000000..79065f3 --- /dev/null +++ b/cmd/model/cmdutil/cmdutil.go @@ -0,0 +1,44 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmdutil + +import ( + "fmt" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/pkg/model/modelutil/save" +) + +const ( + defaultInputFile = "example/input.txt" + defaultOutputFile = "example/word_vectors.txt" + defaultProf = false +) + +func AddInputFlags(cmd *cobra.Command, input *string) { + cmd.Flags().StringVarP(input, "input", "i", defaultInputFile, "input file path for corpus") +} + +func AddOutputFlags(cmd *cobra.Command, output *string) { + cmd.Flags().StringVarP(output, "output", "o", defaultOutputFile, "output file path to save word vectors") +} + +func AddProfFlags(cmd *cobra.Command, prof *bool) { + cmd.Flags().BoolVar(prof, "prof", defaultProf, "profiling mode to check the performances") +} + +func AddSaveVectorTypeFlags(cmd *cobra.Command, typ *save.VectorType) { + cmd.Flags().Var(typ, "save-vec", fmt.Sprintf("save vector type. One of: %s|%s", save.SingleVector, save.AggregatedVector)) +} diff --git a/cmd/model/glove/glove.go b/cmd/model/glove/glove.go new file mode 100644 index 0000000..01931b4 --- /dev/null +++ b/cmd/model/glove/glove.go @@ -0,0 +1,102 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package glove + +import ( + "os" + "path/filepath" + "runtime/pprof" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/model/cmdutil" + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/glove" + "github.com/ynqa/wego/pkg/model/modelutil/save" +) + +var ( + prof bool + inputFile string + outputFile string + saveVectorType save.VectorType +) + +func New() *cobra.Command { + var opts glove.Options + cmd := &cobra.Command{ + Use: "glove", + Short: "GloVe: Global Vectors for Word Representation", + RunE: func(cmd *cobra.Command, args []string) error { + return execute(opts) + }, + } + + cmdutil.AddInputFlags(cmd, &inputFile) + cmdutil.AddOutputFlags(cmd, &outputFile) + cmdutil.AddProfFlags(cmd, &prof) + cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) + + corpus.LoadForCmd(cmd, &opts.CorpusOptions) + pairwise.LoadForCmd(cmd, &opts.PairwiseOptions) + model.LoadForCmd(cmd, &opts.ModelOptions) + glove.LoadForCmd(cmd, &opts) + return cmd +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func execute(opts glove.Options) error { + if prof { + f, err := os.Create("cpu.prof") + if err != nil { + return err + } + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + + if fileExists(outputFile) { + return errors.Errorf("%s is already existed", outputFile) + } else if !fileExists(inputFile) { + return errors.Errorf("Not such a file %s", inputFile) + } + if err := os.MkdirAll(filepath.Dir(outputFile), 0777); err != nil { + return err + } + output, err := os.Create(outputFile) + if err != nil { + return err + } + input, err := os.Open(inputFile) + if err != nil { + return err + } + defer input.Close() + mod, err := glove.NewForOptions(opts) + if err != nil { + return err + } + if err := mod.Train(input); err != nil { + return err + } + return mod.Save(output, saveVectorType) +} diff --git a/cmd/model/lexvec/lexvec.go b/cmd/model/lexvec/lexvec.go new file mode 100644 index 0000000..8dd9d52 --- /dev/null +++ b/cmd/model/lexvec/lexvec.go @@ -0,0 +1,100 @@ +// Copyright © 2019 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lexvec + +import ( + "os" + "path/filepath" + "runtime/pprof" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/model/cmdutil" + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/lexvec" + "github.com/ynqa/wego/pkg/model/modelutil/save" +) + +var ( + prof bool + inputFile string + outputFile string + saveVectorType save.VectorType +) + +func New() *cobra.Command { + var opts lexvec.Options + cmd := &cobra.Command{ + Use: "lexvec", + Short: "Lexvec: Matrix Factorization using Window Sampling and Negative Sampling for Improved Word Representations", + RunE: func(cmd *cobra.Command, args []string) error { + return execute(opts) + }, + } + + cmdutil.AddInputFlags(cmd, &inputFile) + cmdutil.AddOutputFlags(cmd, &outputFile) + cmdutil.AddProfFlags(cmd, &prof) + cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) + + corpus.LoadForCmd(cmd, &opts.CorpusOptions) + model.LoadForCmd(cmd, &opts.ModelOptions) + lexvec.LoadForCmd(cmd, &opts) + return cmd +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func execute(opts lexvec.Options) error { + if prof { + f, err := os.Create("cpu.prof") + if err != nil { + return err + } + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + + if fileExists(outputFile) { + return errors.Errorf("%s is already existed", outputFile) + } else if !fileExists(inputFile) { + return errors.Errorf("Not such a file %s", inputFile) + } + if err := os.MkdirAll(filepath.Dir(outputFile), 0777); err != nil { + return err + } + output, err := os.Create(outputFile) + if err != nil { + return err + } + input, err := os.Open(inputFile) + if err != nil { + return err + } + defer input.Close() + mod, err := lexvec.NewForOptions(opts) + if err != nil { + return err + } + if err := mod.Train(input); err != nil { + return err + } + return mod.Save(output, saveVectorType) +} diff --git a/cmd/model/word2vec/word2vec.go b/cmd/model/word2vec/word2vec.go new file mode 100644 index 0000000..9dc7237 --- /dev/null +++ b/cmd/model/word2vec/word2vec.go @@ -0,0 +1,100 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package word2vec + +import ( + "os" + "path/filepath" + "runtime/pprof" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/model/cmdutil" + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/modelutil/save" + "github.com/ynqa/wego/pkg/model/word2vec" +) + +var ( + prof bool + inputFile string + outputFile string + saveVectorType save.VectorType +) + +func New() *cobra.Command { + var opts word2vec.Options + cmd := &cobra.Command{ + Use: "word2vec", + Short: "Word2Vec: Continuous Bag-of-Words and Skip-gram model", + RunE: func(cmd *cobra.Command, args []string) error { + return execute(opts) + }, + } + + cmdutil.AddInputFlags(cmd, &inputFile) + cmdutil.AddOutputFlags(cmd, &outputFile) + cmdutil.AddProfFlags(cmd, &prof) + cmdutil.AddSaveVectorTypeFlags(cmd, &saveVectorType) + + corpus.LoadForCmd(cmd, &opts.CorpusOptions) + model.LoadForCmd(cmd, &opts.ModelOptions) + word2vec.LoadForCmd(cmd, &opts) + return cmd +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func execute(opts word2vec.Options) error { + if prof { + f, err := os.Create("cpu.prof") + if err != nil { + return err + } + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + + if fileExists(outputFile) { + return errors.Errorf("%s is already existed", outputFile) + } else if !fileExists(inputFile) { + return errors.Errorf("%s is not found", inputFile) + } + if err := os.MkdirAll(filepath.Dir(outputFile), 0777); err != nil { + return err + } + output, err := os.Create(outputFile) + if err != nil { + return err + } + input, err := os.Open(inputFile) + if err != nil { + return err + } + defer input.Close() + mod, err := word2vec.NewForOptions(opts) + if err != nil { + return err + } + if err := mod.Train(input); err != nil { + return err + } + return mod.Save(output, saveVectorType) +} diff --git a/cmd/repl.go b/cmd/repl.go deleted file mode 100644 index e9429d7..0000000 --- a/cmd/repl.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/repl" -) - -var replCmd = &cobra.Command{ - Use: "repl", - Short: "Search similar words with REPL mode", - Long: "Search similar words with REPL mode", - Example: " wego repl -i example/word_vectors.txt\n" + - " >> apple + banana\n" + - " ...", - PreRun: func(cmd *cobra.Command, args []string) { - replBind(cmd) - }, - RunE: func(cmd *cobra.Command, args []string) error { - return executeRepl() - }, -} - -func init() { - replCmd.Flags().StringP(config.InputFile.String(), "i", config.DefaultOutputFile, - "input file path for trained word vector") - replCmd.Flags().IntP(config.Rank.String(), "r", config.DefaultRank, - "how many the most similar words will be displayed") -} - -func replBind(cmd *cobra.Command) { - viper.BindPFlag(config.InputFile.String(), cmd.Flags().Lookup(config.InputFile.String())) - viper.BindPFlag(config.Rank.String(), cmd.Flags().Lookup(config.Rank.String())) -} - -func executeRepl() error { - inputFile := viper.GetString(config.InputFile.String()) - f, err := os.Open(inputFile) - if err != nil { - return err - } - defer f.Close() - - k := viper.GetInt(config.Rank.String()) - repl, err := repl.NewRepl(f, k) - if err != nil { - return err - } - return repl.Run() -} diff --git a/cmd/repl_test.go b/cmd/repl_test.go deleted file mode 100644 index 92e71b8..0000000 --- a/cmd/repl_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "testing" - - "github.com/spf13/viper" -) - -const replFlagSize = 2 - -func TestReplBind(t *testing.T) { - defer viper.Reset() - - replBind(replCmd) - - if len(viper.AllKeys()) != replFlagSize { - t.Errorf("Expected replBind maps %v keys: %v", - replFlagSize, viper.AllKeys()) - } -} diff --git a/cmd/root.go b/cmd/root.go deleted file mode 100644 index e164e94..0000000 --- a/cmd/root.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/pflag" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" -) - -// RootCmd is the root command for word embedding. -var RootCmd = &cobra.Command{ - Use: "wego", - Short: "tools for embedding words into vector space", - RunE: func(cmd *cobra.Command, args []string) error { - return errors.Errorf("Set sub-command. One of %s|%s|%s|%s|%s", - word2vecCmd.Name(), gloveCmd.Name(), lexvecCmd.Name(), searchCmd.Name(), replCmd.Name()) - }, -} - -func configFlagSet() *pflag.FlagSet { - fs := pflag.NewFlagSet(RootCmd.Name(), pflag.ExitOnError) - fs.StringP(config.InputFile.String(), "i", config.DefaultInputFile, - "input file path for corpus") - fs.StringP(config.OutputFile.String(), "o", config.DefaultOutputFile, - "output file path to save word vectors") - fs.IntP(config.Dimension.String(), "d", config.DefaultDimension, - "dimension of word vector") - fs.Int(config.Iteration.String(), config.DefaultIteration, - "number of iteration") - fs.Int(config.MinCount.String(), config.DefaultMinCount, - "lower limit to filter rare words") - fs.Int(config.ThreadSize.String(), config.DefaultThreadSize, - "number of goroutine") - fs.Int(config.BatchSize.String(), config.DefaultBatchSize, - "interval word size to preprocess/train") - fs.IntP(config.Window.String(), "w", config.DefaultWindow, - "context window size") - fs.Float64(config.Initlr.String(), config.DefaultInitlr, - "initial learning rate") - fs.Bool(config.Prof.String(), config.DefaultProf, - "profiling mode to check the performances") - fs.Bool(config.ToLower.String(), config.DefaultToLower, - "whether the words on corpus convert to lowercase or not") - fs.Bool(config.Verbose.String(), config.DefaultVerbose, - "verbose mode") - fs.String(config.SaveVectorType.String(), config.DefaultSaveVectorType.String(), - "save vector type. One of: normal|add") - return fs -} - -func bindConfig(cmd *cobra.Command) { - viper.BindPFlag(config.InputFile.String(), cmd.Flags().Lookup(config.InputFile.String())) - viper.BindPFlag(config.OutputFile.String(), cmd.Flags().Lookup(config.OutputFile.String())) - viper.BindPFlag(config.Dimension.String(), cmd.Flags().Lookup(config.Dimension.String())) - viper.BindPFlag(config.Iteration.String(), cmd.Flags().Lookup(config.Iteration.String())) - viper.BindPFlag(config.MinCount.String(), cmd.Flags().Lookup(config.MinCount.String())) - viper.BindPFlag(config.ThreadSize.String(), cmd.Flags().Lookup(config.ThreadSize.String())) - viper.BindPFlag(config.BatchSize.String(), cmd.Flags().Lookup(config.BatchSize.String())) - viper.BindPFlag(config.Window.String(), cmd.Flags().Lookup(config.Window.String())) - viper.BindPFlag(config.Initlr.String(), cmd.Flags().Lookup(config.Initlr.String())) - viper.BindPFlag(config.Prof.String(), cmd.Flags().Lookup(config.Prof.String())) - viper.BindPFlag(config.ToLower.String(), cmd.Flags().Lookup(config.ToLower.String())) - viper.BindPFlag(config.Verbose.String(), cmd.Flags().Lookup(config.Verbose.String())) - viper.BindPFlag(config.SaveVectorType.String(), cmd.Flags().Lookup(config.SaveVectorType.String())) -} - -func init() { - RootCmd.AddCommand(word2vecCmd) - RootCmd.AddCommand(gloveCmd) - RootCmd.AddCommand(lexvecCmd) - RootCmd.AddCommand(searchCmd) - RootCmd.AddCommand(replCmd) -} diff --git a/cmd/root_test.go b/cmd/root_test.go deleted file mode 100644 index 3c3a386..0000000 --- a/cmd/root_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "testing" - - "github.com/spf13/cobra" - "github.com/spf13/viper" -) - -const configFlagSize = 13 - -func TestConfigFlagSet(t *testing.T) { - fs := configFlagSet() - - if !fs.HasAvailableFlags() { - t.Error("Expected that ConfigFlagSet() retruns *pflag.FlagSet without empty") - } -} - -func TestConfigBind(t *testing.T) { - defer viper.Reset() - - config := &cobra.Command{} - config.Flags().AddFlagSet(configFlagSet()) - bindConfig(config) - - if len(viper.AllKeys()) != configFlagSize { - t.Errorf("Expected configBind maps %v keys: %v", - configFlagSize, viper.AllKeys()) - } -} diff --git a/cmd/search.go b/cmd/search.go deleted file mode 100644 index bd4a405..0000000 --- a/cmd/search.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/search" -) - -var searchCmd = &cobra.Command{ - Use: "search", - Short: "Search similar words", - Long: "Search similar words", - Example: " wego search -i example/word_vectors.txt microsoft", - PreRun: func(cmd *cobra.Command, args []string) { - searchBind(cmd) - }, - RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 1 { - return executeSearch(args[0]) - } - return errors.New("Input a single word") - }, -} - -func init() { - searchCmd.Flags().StringP(config.InputFile.String(), "i", config.DefaultOutputFile, - "input file path for trained word vector") - searchCmd.Flags().IntP(config.Rank.String(), "r", config.DefaultRank, - "how many the most similar words will be displayed") -} - -func searchBind(cmd *cobra.Command) { - viper.BindPFlag(config.Rank.String(), cmd.Flags().Lookup(config.Rank.String())) - viper.BindPFlag(config.InputFile.String(), cmd.Flags().Lookup(config.InputFile.String())) -} - -func executeSearch(query string) error { - inputFile := viper.GetString(config.InputFile.String()) - f, err := os.Open(inputFile) - if err != nil { - return err - } - defer f.Close() - - searcher, err := search.NewSearcher(f) - if err != nil { - return err - } - - k := viper.GetInt(config.Rank.String()) - neighbors, err := searcher.SearchWithQuery(query, k) - if err != nil { - return err - } - - return search.Describe(neighbors) -} diff --git a/pkg/corpus/testing.go b/cmd/search/cmdutil/cmdutil.go similarity index 58% rename from pkg/corpus/testing.go rename to cmd/search/cmdutil/cmdutil.go index e2c516f..bb4e3c0 100644 --- a/pkg/corpus/testing.go +++ b/cmd/search/cmdutil/cmdutil.go @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -package corpus +package cmdutil import ( - "bytes" - "io" - "io/ioutil" + "github.com/spf13/cobra" ) -type fakeNopSeeker struct{ io.ReadCloser } +const ( + defaultInputFile = "example/word_vectors.txt" + defaultRank = 10 +) -func (fake fakeNopSeeker) Seek(offset int64, whence int) (int64, error) { return 0, nil } +func AddInputFlags(cmd *cobra.Command, input *string) { + cmd.Flags().StringVarP(input, "input", "i", defaultInputFile, "input file path for trained word vector") +} -var ( - text = "a b b c c c c" - FakeSeeker = fakeNopSeeker{ - ReadCloser: ioutil.NopCloser(bytes.NewReader([]byte(text))), - } -) +func AddRankFlags(cmd *cobra.Command, rank *int) { + cmd.Flags().IntVarP(rank, "rank", "r", defaultRank, "how many similar words will be displayed") +} diff --git a/cmd/search/repl/repl.go b/cmd/search/repl/repl.go new file mode 100644 index 0000000..34940e9 --- /dev/null +++ b/cmd/search/repl/repl.go @@ -0,0 +1,68 @@ +// Copyright © 2019 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repl + +import ( + "os" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/search/cmdutil" + "github.com/ynqa/wego/pkg/search/repl" +) + +var ( + inputFile string + rank int +) + +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "search-repl", + Short: "Search similar words (REPL mode)", + Long: "Search similar words (REPL mode)", + Example: " wego search-repl -i example/word_vectors.txt\n" + + " >> apple + banana\n" + + " ...", + RunE: func(cmd *cobra.Command, args []string) error { + return execute() + }, + } + cmdutil.AddInputFlags(cmd, &inputFile) + cmdutil.AddRankFlags(cmd, &rank) + return cmd +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func execute() error { + if !fileExists(inputFile) { + return errors.Errorf("Not such a file %s", inputFile) + } + input, err := os.Open(inputFile) + if err != nil { + return err + } + defer input.Close() + repl, err := repl.New(input, rank) + if err != nil { + return err + } + return repl.Run() +} diff --git a/cmd/search/search.go b/cmd/search/search.go new file mode 100644 index 0000000..5d032bc --- /dev/null +++ b/cmd/search/search.go @@ -0,0 +1,73 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package search + +import ( + "os" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/search/cmdutil" + "github.com/ynqa/wego/pkg/search" +) + +var ( + inputFile string + rank int +) + +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "search", + Short: "Search similar words", + Long: "Search similar words", + Example: " wego search -i example/word_vectors.txt microsoft", + RunE: func(cmd *cobra.Command, args []string) error { + return execute(args) + }, + } + cmdutil.AddInputFlags(cmd, &inputFile) + cmdutil.AddRankFlags(cmd, &rank) + return cmd +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func execute(args []string) error { + if !fileExists(inputFile) { + return errors.Errorf("Not such a file %s", inputFile) + } else if len(args) != 1 { + return errors.Errorf("Input a single word %v", args) + } + input, err := os.Open(inputFile) + if err != nil { + return err + } + defer input.Close() + searcher, err := search.NewForVectorFile(input) + if err != nil { + return err + } + neighbors, err := searcher.InternalSearch(args[0], rank) + if err != nil { + return err + } + neighbors.Describe() + return nil +} diff --git a/cmd/search_test.go b/cmd/search_test.go deleted file mode 100644 index 4f2672a..0000000 --- a/cmd/search_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "testing" - - "github.com/spf13/viper" -) - -const searchFlagSize = 2 - -func TestSearchBind(t *testing.T) { - defer viper.Reset() - - searchBind(searchCmd) - - if len(viper.AllKeys()) != searchFlagSize { - t.Errorf("Expected searchBind maps %v keys: %v", - searchFlagSize, viper.AllKeys()) - } -} diff --git a/cmd/word2vec.go b/cmd/word2vec.go deleted file mode 100644 index c68c2bf..0000000 --- a/cmd/word2vec.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "os" - "runtime/pprof" - - "github.com/pkg/errors" - "github.com/spf13/cobra" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/validate" -) - -var word2vecCmd = &cobra.Command{ - Use: "word2vec", - Short: "Word2Vec: Continuous Bag-of-Words and Skip-gram model", - PreRun: func(cmd *cobra.Command, args []string) { - bindConfig(cmd) - bindWord2vec(cmd) - }, - RunE: func(cmd *cobra.Command, args []string) error { - if viper.GetBool(config.Prof.String()) { - f, err := os.Create("cpu.prof") - if err != nil { - os.Exit(1) - } - pprof.StartCPUProfile(f) - defer pprof.StopCPUProfile() - } - - return runWord2vec() - }, -} - -func init() { - word2vecCmd.Flags().AddFlagSet(configFlagSet()) - word2vecCmd.Flags().String(config.Model.String(), config.DefaultModel.String(), - "which model does it use? one of: cbow|skip-gram") - word2vecCmd.Flags().String(config.Optimizer.String(), config.DefaultOptimizer.String(), - "which optimizer does it use? one of: hs|ns") - word2vecCmd.Flags().Int(config.MaxDepth.String(), config.DefaultMaxDepth, - "times to track huffman tree, max-depth=0 means to track full path from root to word (for hierarchical softmax only)") - word2vecCmd.Flags().Int(config.NegativeSampleSize.String(), config.DefaultNegativeSampleSize, - "negative sample size(for negative sampling only)") - word2vecCmd.Flags().Float64(config.SubsampleThreshold.String(), config.DefaultSubsampleThreshold, - "threshold for subsampling") - word2vecCmd.Flags().Float64(config.Theta.String(), config.DefaultTheta, - "lower limit of learning rate (lr >= initlr * theta)") -} - -func bindWord2vec(cmd *cobra.Command) { - viper.BindPFlag(config.Model.String(), cmd.Flags().Lookup(config.Model.String())) - viper.BindPFlag(config.Optimizer.String(), cmd.Flags().Lookup(config.Optimizer.String())) - viper.BindPFlag(config.MaxDepth.String(), cmd.Flags().Lookup(config.MaxDepth.String())) - viper.BindPFlag(config.NegativeSampleSize.String(), cmd.Flags().Lookup(config.NegativeSampleSize.String())) - viper.BindPFlag(config.SubsampleThreshold.String(), cmd.Flags().Lookup(config.SubsampleThreshold.String())) - viper.BindPFlag(config.Theta.String(), cmd.Flags().Lookup(config.Theta.String())) -} - -func runWord2vec() error { - outputFileName := viper.GetString(config.OutputFile.String()) - if validate.FileExists(outputFileName) { - return errors.Errorf("%s is already existed", outputFileName) - } - w2v, err := builder.NewWord2vecBuilderFromViper() - if err != nil { - return err - } - mod, err := w2v.Build() - if err != nil { - return err - } - inputFile := viper.GetString(config.InputFile.String()) - if !validate.FileExists(inputFile) { - return errors.Errorf("Not such a file %s", inputFile) - } - input, err := os.Open(inputFile) - if err != nil { - return err - } - if err := mod.Train(input); err != nil { - return err - } - outputFile, err := os.Create(outputFileName) - if err != nil { - return err - } - defer func() { - outputFile.Close() - }() - return mod.Save(outputFile) -} diff --git a/cmd/word2vec_test.go b/cmd/word2vec_test.go deleted file mode 100644 index c0544ca..0000000 --- a/cmd/word2vec_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cmd - -import ( - "testing" - - "github.com/spf13/viper" -) - -const word2vecFlagSize = 6 - -func TestWord2vecBind(t *testing.T) { - defer viper.Reset() - - bindWord2vec(word2vecCmd) - - if len(viper.AllKeys()) != word2vecFlagSize { - t.Errorf("Expected word2vecBind maps %v keys: %v", - word2vecFlagSize, - viper.AllKeys()) - } -} - -func TestWord2vecCmdPreRun(t *testing.T) { - defer viper.Reset() - - var empty []string - word2vecCmd.PreRun(word2vecCmd, empty) - - if len(viper.AllKeys()) != word2vecFlagSize+configFlagSize { - t.Errorf("Expected PreRun of word2vecCmd maps %v keys: %v", - word2vecFlagSize+configFlagSize, viper.AllKeys()) - } -} diff --git a/examples/glove/glove.go b/examples/glove/glove.go deleted file mode 100644 index 1056c5e..0000000 --- a/examples/glove/glove.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/model/glove" -) - -func main() { - b := builder.NewGloveBuilder() - - b.Dimension(10). - Window(5). - Solver(glove.SGD). - Verbose() - - m, err := b.Build() - if err != nil { - // Failed to build word2vec. - } - - input, _ := os.Open("text8") - - // Start to Train. - if err = m.Train(input); err != nil { - // Failed to train by word2vec. - } - - output, err := os.Create("example.txt") - if err != nil { - // Failed to create output file. - } - - defer func() { - output.Close() - }() - - m.Save(output) -} diff --git a/examples/lexvec/lexvec.go b/examples/lexvec/lexvec.go deleted file mode 100644 index 9ef78f4..0000000 --- a/examples/lexvec/lexvec.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/corpus" -) - -func main() { - b := builder.NewLexvecBuilder() - - b.Dimension(10). - Window(5). - RelationType(corpus.PPMI). - NegativeSampleSize(5). - Verbose() - - m, err := b.Build() - if err != nil { - // Failed to build word2vec. - } - - input, _ := os.Open("text8") - - // Start to Train. - if err = m.Train(input); err != nil { - // Failed to train by word2vec. - } - - output, err := os.Create("example.txt") - if err != nil { - // Failed to create output file. - } - - defer func() { - output.Close() - }() - - m.Save(output) -} diff --git a/examples/word2vec/main.go b/examples/word2vec/main.go new file mode 100644 index 0000000..692e157 --- /dev/null +++ b/examples/word2vec/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "os" + + "github.com/ynqa/wego/pkg/model/modelutil/save" + "github.com/ynqa/wego/pkg/model/word2vec" +) + +func main() { + model, err := word2vec.New( + word2vec.WithWindow(5), + word2vec.WithModel(word2vec.Cbow), + word2vec.WithOptimizer(word2vec.NegativeSampling), + word2vec.WithNegativeSampleSize(5), + word2vec.Verbose(), + ) + if err != nil { + // failed to create word2vec. + } + + input, _ := os.Open("text8") + if err = model.Train(input); err != nil { + // failed to train. + } + + // write word vector. + model.Save(os.Stdin, save.AggregatedVector) +} diff --git a/examples/word2vec/word2vec.go b/examples/word2vec/word2vec.go deleted file mode 100644 index 1c7599c..0000000 --- a/examples/word2vec/word2vec.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "os" - - "github.com/ynqa/wego/pkg/builder" - "github.com/ynqa/wego/pkg/model/word2vec" -) - -func main() { - b := builder.NewWord2vecBuilder() - - b.Dimension(10). - Window(5). - Model(word2vec.CBOW). - Optimizer(word2vec.NEGATIVE_SAMPLING). - NegativeSampleSize(5). - Verbose() - - m, err := b.Build() - if err != nil { - // Failed to build word2vec. - } - - input, _ := os.Open("text8") - - // Start to Train. - if err = m.Train(input); err != nil { - // Failed to train by word2vec. - } - - output, err := os.Create("example.txt") - if err != nil { - // Failed to create output file. - } - - defer func() { - output.Close() - }() - - m.Save(output) -} diff --git a/go.mod b/go.mod index 6ee4284..80cbc9c 100644 --- a/go.mod +++ b/go.mod @@ -3,38 +3,16 @@ module github.com/ynqa/wego go 1.13 require ( - github.com/BurntSushi/toml v0.3.1 // indirect - github.com/awalterschulze/gographviz v0.0.0-20170410065617-c84395e536e1 // indirect - github.com/chewxy/hm v1.0.0 // indirect - github.com/chewxy/lingo v0.0.0-20180424035724-8f8059f54389 - github.com/chewxy/math32 v1.0.0 // indirect - github.com/fatih/color v1.7.0 // indirect - github.com/fsnotify/fsnotify v1.4.2 // indirect - github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect - github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 // indirect - github.com/magiconair/properties v1.7.4 // indirect - github.com/mattn/go-colorable v0.1.4 // indirect - github.com/mitchellh/mapstructure v0.0.0-20171017171808-06020f85339e // indirect + github.com/kr/pretty v0.1.0 // indirect github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57 - github.com/pelletier/go-toml v1.0.1 // indirect github.com/peterh/liner v1.1.0 github.com/pkg/errors v0.8.0 - github.com/spf13/afero v1.0.0 // indirect - github.com/spf13/cast v1.1.0 // indirect github.com/spf13/cobra v0.0.1 - github.com/spf13/jwalterweatherman v0.0.0-20170901151539-12bd96e66386 // indirect - github.com/spf13/pflag v1.0.0 - github.com/spf13/viper v1.0.0 - github.com/stretchr/testify v1.4.0 // indirect - github.com/xtgo/set v0.0.0-20151204082305-4431f6b51265 // indirect - golang.org/x/exp v0.0.0-20191224044220-1fea468a75e9 // indirect - golang.org/x/text v0.3.1-0.20171227012246-e19ae1496984 // indirect - gonum.org/v1/gonum v0.0.0-20171227085449-cd47c93d5448 // indirect - gopkg.in/cheggaaa/pb.v1 v1.0.19 - gorgonia.org/cu v0.8.0 // indirect - gorgonia.org/gorgonia v0.8.0 // indirect - gorgonia.org/tensor v0.8.0 // indirect - gorgonia.org/vecf32 v0.7.1-0.20171210012140-9c61466a81d9 // indirect - gorgonia.org/vecf64 v0.7.1-0.20171210012113-92cacded62a7 // indirect + github.com/spf13/pflag v1.0.3 // indirect + github.com/stretchr/testify v1.4.0 + golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v2 v2.2.4 // indirect ) diff --git a/go.sum b/go.sum index cf7836d..78dfae3 100644 --- a/go.sum +++ b/go.sum @@ -1,103 +1,38 @@ -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/awalterschulze/gographviz v0.0.0-20170410065617-c84395e536e1 h1:r2lcIqPAm8+z4sEiWTJW3JR3/tc9WWH95hZFXLd2Y0g= -github.com/awalterschulze/gographviz v0.0.0-20170410065617-c84395e536e1/go.mod h1:GEV5wmg4YquNw7v1kkyoX9etIk8yVmXj+AkDHuuETHs= -github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= -github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= -github.com/chewxy/lingo v0.0.0-20180424035724-8f8059f54389 h1:PJP6KdmGiApIZNsxibkzfljG9mNnVmeR/2G2arHVbeo= -github.com/chewxy/lingo v0.0.0-20180424035724-8f8059f54389/go.mod h1:q5CiNpjdywd0YKLC4xgjZEWe02MRmIfpmpt4aZRf9jg= -github.com/chewxy/math32 v1.0.0 h1:RTt2SACA7BTzvbsAKVQJLZpV6zY2MZw4bW9L2HEKkHg= -github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/fsnotify/fsnotify v1.4.2 h1:v5tKwtf2hNhBV24eNYfQ5UmvFOGlOCmRqk7/P1olxtk= -github.com/fsnotify/fsnotify v1.4.2/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb h1:1OvvPvZkn/yCQ3xBcM8y4020wdkMXPHLB4+NfoGWh4U= -github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 h1:O75p5GUdUfhJqNCMM1ntthjtJCOHVa1lzMSfh5Qsa0Y= -github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21/go.mod h1:N0SVk0uhy+E1PZ3C9ctsPRlvOPAFPkCNlcPBDkt0N3U= -github.com/magiconair/properties v1.7.4 h1:UVo0TkHGd4lQSN1dVDzs9URCIgReuSIcCXpAVB9nZ80= -github.com/magiconair/properties v1.7.4/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-runewidth v0.0.3 h1:a+kO+98RDGEfo6asOGMmpodZq4FNtnGP54yps8BzLR4= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/mitchellh/mapstructure v0.0.0-20171017171808-06020f85339e h1:PtGHLB3CX3TFPcksODQMxncoeQKWwCgTg0bJ40VLJP4= -github.com/mitchellh/mapstructure v0.0.0-20171017171808-06020f85339e/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57 h1:c6g+iEoim6VD2DGy2utQoryQMVNndSvYm/YfGjc5A/o= github.com/olekukonko/tablewriter v0.0.0-20171203151007-65fec0d89a57/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= -github.com/pelletier/go-toml v1.0.1 h1:0nx4vKBl23+hEaCOV1mFhKS9vhhBtFYWC7rQY0vJAyE= -github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterh/liner v1.1.0 h1:f+aAedNJA6uk7+6rXsYBnhdo4Xux7ESLe+kcuVUF5os= github.com/peterh/liner v1.1.0/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spf13/afero v1.0.0 h1:Z005C09nPzwTTsDRJCQBVnpTU0bjTr/NhyWLj1nSPP4= -github.com/spf13/afero v1.0.0/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.1.0 h1:0Rhw4d6C8J9VPu6cjZLIhZ8+aAOHcDvGeKn+cq5Aq3k= -github.com/spf13/cast v1.1.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= github.com/spf13/cobra v0.0.1 h1:zZh3X5aZbdnoj+4XkaBxKfhO4ot82icYdhhREIAXIj8= github.com/spf13/cobra v0.0.1/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= -github.com/spf13/jwalterweatherman v0.0.0-20170901151539-12bd96e66386 h1:zBoLErXXAvWnNsu+pWkRYl6Cx1KXmIfAVsIuYkPN6aY= -github.com/spf13/jwalterweatherman v0.0.0-20170901151539-12bd96e66386/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.0 h1:oaPbdDe/x0UncahuwiPxW1GYJyilRAdsPnq3e1yaPcI= -github.com/spf13/pflag v1.0.0/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.0.0 h1:RUA/ghS2i64rlnn4ydTfblY8Og8QzcPtCcHvgMn+w/I= -github.com/spf13/viper v1.0.0/go.mod h1:A8kyI5cUJhb8N+3pkfONlcEcZbueH6nhAm0Fq7SrnBM= +github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/xtgo/set v0.0.0-20151204082305-4431f6b51265 h1:86yslOGLdMhh3xJR1+1UleoyTbyzmBWAGDdw0qPg1HI= -github.com/xtgo/set v0.0.0-20151204082305-4431f6b51265/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20191224044220-1fea468a75e9 h1:HLuLY2KniBsHW28uXd1i2UZKjifeJUy//P/wTK6AJwI= -golang.org/x/exp v0.0.0-20191224044220-1fea468a75e9/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20171222143536-83801418e1b5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20171227012246-e19ae1496984 h1:4S3Dic2vY09agWhKAjYa6buMB7HsLkVrliEHZclmmSU= -golang.org/x/text v0.3.1-0.20171227012246-e19ae1496984/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.0.0-20171227085449-cd47c93d5448 h1:sTUWU+bhQAWFzZe659aEHnUcwkw6x/chMg4u1+SFSoo= -gonum.org/v1/gonum v0.0.0-20171227085449-cd47c93d5448/go.mod h1:cucAdkem48eM79EG1fdGOGASXorNZIYAO9duTse+1cI= -gonum.org/v1/netlib v0.0.0-20191031114514-eccb95939662 h1:yBPy8lLj+GituDSGQjvXBqT6yTch2BdT9Z/FbX19+to= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/cheggaaa/pb.v1 v1.0.19 h1:FiMbj8xLGIsj8TLj3O+0GkiydM2OLJhyerwuyNozYug= -gopkg.in/cheggaaa/pb.v1 v1.0.19/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= -gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab h1:yZ6iByf7GKeJ3gsd1Dr/xaj1DyJ//wxKX1Cdh8LhoAw= -gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gorgonia.org/cu v0.8.0 h1:XpTkl5IpMlTPNJl6pKQPEXVV/9TnEtiRB7j1gGkrzCI= -gorgonia.org/cu v0.8.0/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8= -gorgonia.org/gorgonia v0.8.0 h1:fCRGlTTr3KF0xjyJaWt46zs1EijQMD9iJyOHn2e6UCE= -gorgonia.org/gorgonia v0.8.0/go.mod h1:qucT7YHm/2OuSHWEw/6Je/LQ5htRJNQJ1+qpB58fY8c= -gorgonia.org/tensor v0.8.0 h1:+JDsIEnx+wQVf9brvrYyU+/uZit2PnzP9Dt4CWNeirQ= -gorgonia.org/tensor v0.8.0/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= -gorgonia.org/vecf32 v0.7.1-0.20171210012140-9c61466a81d9 h1:L4MiT1K2R1dTZEgy9Qo7cbPz3vnYlCXBplZvmdPi4Ko= -gorgonia.org/vecf32 v0.7.1-0.20171210012140-9c61466a81d9/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= -gorgonia.org/vecf64 v0.7.1-0.20171210012113-92cacded62a7 h1:UX4nbW/6w1hRHsWVpYJZjqjlbSvql5ynxioh1z/KLTI= -gorgonia.org/vecf64 v0.7.1-0.20171210012113-92cacded62a7/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q= +gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pkg/builder/glove.go b/pkg/builder/glove.go deleted file mode 100644 index 0340ff5..0000000 --- a/pkg/builder/glove.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package builder - -import ( - "github.com/pkg/errors" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/model" - "github.com/ynqa/wego/pkg/model/glove" -) - -// GloveBuilder manages the members to build Model interface. -type GloveBuilder struct { - // common configs. - dimension int - iteration int - minCount int - threadSize int - batchSize int - window int - initlr float64 - toLower bool - verbose bool - saveVectorType model.SaveVectorType - - // glove configs. - solver glove.SolverType - xmax int - alpha float64 -} - -// NewGloveBuilder creates *GloveBuilder -func NewGloveBuilder() *GloveBuilder { - return &GloveBuilder{ - dimension: config.DefaultDimension, - iteration: config.DefaultIteration, - minCount: config.DefaultMinCount, - threadSize: config.DefaultThreadSize, - batchSize: config.DefaultBatchSize, - window: config.DefaultWindow, - initlr: config.DefaultInitlr, - toLower: config.DefaultToLower, - verbose: config.DefaultVerbose, - saveVectorType: config.DefaultSaveVectorType, - - solver: config.DefaultSolver, - xmax: config.DefaultXmax, - alpha: config.DefaultAlpha, - } -} - -// NewGloveBuilderFromViper creates *GloveBuilder from viper. -func NewGloveBuilderFromViper() (*GloveBuilder, error) { - var saveVectorType model.SaveVectorType - saveVectorTypeStr := viper.GetString(config.SaveVectorType.String()) - switch saveVectorTypeStr { - case model.NORMAL.String(): - saveVectorType = model.NORMAL - case model.ADD.String(): - saveVectorType = model.ADD - default: - return nil, errors.Errorf("Invalid save vector type=%s", saveVectorTypeStr) - } - - var solver glove.SolverType - solverTypeStr := viper.GetString(config.Solver.String()) - switch solverTypeStr { - case glove.SGD.String(): - solver = glove.SGD - case glove.ADAGRAD.String(): - solver = glove.ADAGRAD - default: - return nil, errors.Errorf("Invalid solver type=%s", solverTypeStr) - } - return &GloveBuilder{ - dimension: viper.GetInt(config.Dimension.String()), - iteration: viper.GetInt(config.Iteration.String()), - minCount: viper.GetInt(config.MinCount.String()), - threadSize: viper.GetInt(config.ThreadSize.String()), - batchSize: viper.GetInt(config.BatchSize.String()), - window: viper.GetInt(config.Window.String()), - initlr: viper.GetFloat64(config.Initlr.String()), - toLower: viper.GetBool(config.ToLower.String()), - verbose: viper.GetBool(config.Verbose.String()), - saveVectorType: saveVectorType, - - solver: solver, - xmax: viper.GetInt(config.Xmax.String()), - alpha: viper.GetFloat64(config.Alpha.String()), - }, nil -} - -// Dimension sets dimension of word vector. -func (gb *GloveBuilder) Dimension(dimension int) *GloveBuilder { - gb.dimension = dimension - return gb -} - -// Iteration sets number of iteration. -func (gb *GloveBuilder) Iteration(iter int) *GloveBuilder { - gb.iteration = iter - return gb -} - -// MinCount sets min count. -func (gb *GloveBuilder) MinCount(minCount int) *GloveBuilder { - gb.minCount = minCount - return gb -} - -// ThreadSize sets number of goroutine. -func (gb *GloveBuilder) ThreadSize(threadSize int) *GloveBuilder { - gb.threadSize = threadSize - return gb -} - -// BatchSize sets batch size to to preprocess/train. -func (gb *GloveBuilder) BatchSize(batchSize int) *GloveBuilder { - gb.batchSize = batchSize - return gb -} - -// Window sets context window size. -func (gb *GloveBuilder) Window(window int) *GloveBuilder { - gb.window = window - return gb -} - -// Initlr sets initial learning rate. -func (gb *GloveBuilder) Initlr(initlr float64) *GloveBuilder { - gb.initlr = initlr - return gb -} - -// ToLower is whether converts the words in corpus to lowercase or not. -func (gb *GloveBuilder) ToLower() *GloveBuilder { - gb.toLower = true - return gb -} - -// Verbose sets verbose mode. -func (gb *GloveBuilder) Verbose() *GloveBuilder { - gb.verbose = true - return gb -} - -func (gb *GloveBuilder) SaveVectorType(typ model.SaveVectorType) *GloveBuilder { - gb.saveVectorType = typ - return gb -} - -// Solver sets solver. -func (gb *GloveBuilder) Solver(typ glove.SolverType) *GloveBuilder { - gb.solver = typ - return gb -} - -// Xmax sets x-max. -func (gb *GloveBuilder) Xmax(xmax int) *GloveBuilder { - gb.xmax = xmax - return gb -} - -// Alpha sets alpha. -func (gb *GloveBuilder) Alpha(alpha float64) *GloveBuilder { - gb.alpha = alpha - return gb -} - -// Build creates model.Model interface. -func (gb *GloveBuilder) Build() (model.Model, error) { - o := &model.Option{ - Dimension: gb.dimension, - Iteration: gb.iteration, - MinCount: gb.minCount, - ThreadSize: gb.threadSize, - BatchSize: gb.batchSize, - Window: gb.window, - Initlr: gb.initlr, - ToLower: gb.toLower, - Verbose: gb.verbose, - SaveVectorType: gb.saveVectorType, - } - - var solver glove.Solver - switch gb.solver { - case glove.SGD: - solver = glove.NewSgd(gb.dimension, gb.initlr) - case glove.ADAGRAD: - solver = glove.NewAdaGrad(gb.dimension, gb.initlr) - default: - return nil, errors.Errorf("Invalid solver: %s not in sgd|adagrad", gb.solver) - } - - g := &glove.GloveOption{ - Solver: solver, - Xmax: gb.xmax, - Alpha: gb.alpha, - } - - return glove.NewGlove(o, g), nil -} diff --git a/pkg/builder/glove_test.go b/pkg/builder/glove_test.go deleted file mode 100644 index 0fb114a..0000000 --- a/pkg/builder/glove_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package builder - -import ( - "testing" - - "github.com/ynqa/wego/pkg/model/glove" -) - -func TestGloveDimension(t *testing.T) { - b := &GloveBuilder{} - - expectedDimension := 100 - b.Dimension(expectedDimension) - - if b.dimension != expectedDimension { - t.Errorf("Expected builder.dimension=%v: %v", expectedDimension, b.dimension) - } -} - -func TestGloveIteration(t *testing.T) { - b := &GloveBuilder{} - - expectedIteration := 50 - b.Iteration(expectedIteration) - - if b.iteration != expectedIteration { - t.Errorf("Expected builder.iteration=%v: %v", expectedIteration, b.iteration) - } -} - -func TestGloveMinCount(t *testing.T) { - b := &GloveBuilder{} - - expectedMinCount := 10 - b.MinCount(expectedMinCount) - - if b.minCount != expectedMinCount { - t.Errorf("Expected builder.minCount=%v: %v", expectedMinCount, b.minCount) - } -} - -func TestGloveThreadSize(t *testing.T) { - b := &GloveBuilder{} - - expectedThreadSize := 8 - b.ThreadSize(expectedThreadSize) - - if b.threadSize != expectedThreadSize { - t.Errorf("Expected builder.threadSize=%v: %v", expectedThreadSize, b.threadSize) - } -} - -func TestGloveWindow(t *testing.T) { - b := &GloveBuilder{} - - expectedWindow := 10 - b.Window(expectedWindow) - - if b.window != expectedWindow { - t.Errorf("Expected builder.window=%v: %v", expectedWindow, b.window) - } -} - -func TestGloveInitlr(t *testing.T) { - b := &GloveBuilder{} - - expectedInitlr := 0.001 - b.Initlr(expectedInitlr) - - if b.initlr != expectedInitlr { - t.Errorf("Expected builder.initlr=%v: %v", expectedInitlr, b.initlr) - } -} - -func TestGloveToLower(t *testing.T) { - b := &GloveBuilder{} - - b.ToLower() - - if !b.toLower { - t.Errorf("Expected builder.lower=true: %v", b.toLower) - } -} - -func TestGloveVerbose(t *testing.T) { - b := &GloveBuilder{} - - b.Verbose() - - if !b.verbose { - t.Errorf("Expected builder.verbose=true: %v", b.verbose) - } -} - -func TestGloveSolver(t *testing.T) { - b := &GloveBuilder{} - - expectedSolver := glove.ADAGRAD - b.Solver(expectedSolver) - - if b.solver != expectedSolver { - t.Errorf("Expected builder.solver=%v: %v", expectedSolver, b.solver) - } -} - -func TestGloveXmax(t *testing.T) { - b := &GloveBuilder{} - - expectedXmax := 10 - b.Xmax(expectedXmax) - - if b.xmax != expectedXmax { - t.Errorf("Expected builder.xmax=%v: %v", expectedXmax, b.xmax) - } -} - -func TestGloveAlpha(t *testing.T) { - b := &GloveBuilder{} - - exoectedAlpha := 0.1 - b.Alpha(exoectedAlpha) - - if b.alpha != exoectedAlpha { - t.Errorf("Expected builder.alpha=%v: %v", exoectedAlpha, b.alpha) - } -} - -func TestGloveInvalidSolverBuild(t *testing.T) { - b := &GloveBuilder{} - - b.Solver(glove.SolverType(10)) - - if _, err := b.Build(); err == nil { - t.Errorf("Expected to fail building with invalid solver except for sgd|adagrad: %v", b.solver) - } -} diff --git a/pkg/builder/lexvec.go b/pkg/builder/lexvec.go deleted file mode 100644 index d48c9c8..0000000 --- a/pkg/builder/lexvec.go +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package builder - -import ( - "github.com/pkg/errors" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" - "github.com/ynqa/wego/pkg/model/lexvec" -) - -// LexvecBuilder manages the members to build Model interface. -type LexvecBuilder struct { - // common configs. - dimension int - iteration int - minCount int - threadSize int - batchSize int - window int - initlr float64 - toLower bool - verbose bool - saveVectorType model.SaveVectorType - - // lexvec configs. - negativeSampleSize int - subsampleThreshold float64 - theta float64 - smooth float64 - relationType corpus.RelationType -} - -// NewLexvecBuilder creates *LexvecBuilder. -func NewLexvecBuilder() *LexvecBuilder { - return &LexvecBuilder{ - dimension: config.DefaultDimension, - iteration: config.DefaultIteration, - minCount: config.DefaultMinCount, - threadSize: config.DefaultThreadSize, - batchSize: config.DefaultBatchSize, - window: config.DefaultWindow, - initlr: config.DefaultInitlr, - toLower: config.DefaultToLower, - verbose: config.DefaultVerbose, - saveVectorType: config.DefaultSaveVectorType, - - negativeSampleSize: config.DefaultNegativeSampleSize, - subsampleThreshold: config.DefaultSubsampleThreshold, - theta: config.DefaultTheta, - smooth: config.DefaultSmooth, - relationType: config.DefaultRelationType, - } -} - -// NewLexvecBuilderFromViper creates *LexvecBuilder from viper. -func NewLexvecBuilderFromViper() (*LexvecBuilder, error) { - var saveVectorType model.SaveVectorType - saveVectorTypeStr := viper.GetString(config.SaveVectorType.String()) - switch saveVectorTypeStr { - case model.NORMAL.String(): - saveVectorType = model.NORMAL - case model.ADD.String(): - saveVectorType = model.ADD - default: - return nil, errors.Errorf("Invalid save vector type=%s", saveVectorTypeStr) - } - - var relationType corpus.RelationType - relationTypeStr := viper.GetString(config.RelationType.String()) - switch relationTypeStr { - case corpus.PPMI.String(): - relationType = corpus.PPMI - case corpus.PMI.String(): - relationType = corpus.PMI - case corpus.CO.String(): - relationType = corpus.CO - case corpus.LOGCO.String(): - relationType = corpus.LOGCO - } - - return &LexvecBuilder{ - dimension: viper.GetInt(config.Dimension.String()), - iteration: viper.GetInt(config.Iteration.String()), - minCount: viper.GetInt(config.MinCount.String()), - threadSize: viper.GetInt(config.ThreadSize.String()), - batchSize: viper.GetInt(config.BatchSize.String()), - window: viper.GetInt(config.Window.String()), - initlr: viper.GetFloat64(config.Initlr.String()), - toLower: viper.GetBool(config.ToLower.String()), - verbose: viper.GetBool(config.Verbose.String()), - saveVectorType: saveVectorType, - - subsampleThreshold: viper.GetFloat64(config.SubsampleThreshold.String()), - negativeSampleSize: viper.GetInt(config.NegativeSampleSize.String()), - smooth: viper.GetFloat64(config.Smooth.String()), - relationType: relationType, - }, nil -} - -// Dimension sets dimension of word vector. -func (lb *LexvecBuilder) Dimension(dimension int) *LexvecBuilder { - lb.dimension = dimension - return lb -} - -// Iteration sets number of iteration. -func (lb *LexvecBuilder) Iteration(iter int) *LexvecBuilder { - lb.iteration = iter - return lb -} - -// MinCount sets min count. -func (lb *LexvecBuilder) MinCount(minCount int) *LexvecBuilder { - lb.minCount = minCount - return lb -} - -// ThreadSize sets number of goroutine. -func (lb *LexvecBuilder) ThreadSize(threadSize int) *LexvecBuilder { - lb.threadSize = threadSize - return lb -} - -// BatchSize sets batch size to preprocess/train. -func (lb *LexvecBuilder) BatchSize(batchSize int) *LexvecBuilder { - lb.batchSize = batchSize - return lb -} - -// Window sets context window size. -func (lb *LexvecBuilder) Window(window int) *LexvecBuilder { - lb.window = window - return lb -} - -// Initlr sets initial learning rate. -func (lb *LexvecBuilder) Initlr(initlr float64) *LexvecBuilder { - lb.initlr = initlr - return lb -} - -// ToLower is whether converts the words in corpus to lowercase or not. -func (lb *LexvecBuilder) ToLower() *LexvecBuilder { - lb.toLower = true - return lb -} - -// Verbose sets verbose mode. -func (lb *LexvecBuilder) Verbose() *LexvecBuilder { - lb.verbose = true - return lb -} - -func (lb *LexvecBuilder) SaveVectorType(typ model.SaveVectorType) *LexvecBuilder { - lb.saveVectorType = typ - return lb -} - -// NegativeSampleSize sets number of samples as negative. -func (lb *LexvecBuilder) NegativeSampleSize(size int) *LexvecBuilder { - lb.negativeSampleSize = size - return lb -} - -// SubSampleThreshold sets threshold for subsampling. -func (lb *LexvecBuilder) SubSampleThreshold(threshold float64) *LexvecBuilder { - lb.subsampleThreshold = threshold - return lb -} - -func (lb *LexvecBuilder) Theta(theta float64) *LexvecBuilder { - lb.theta = theta - return lb -} - -func (lb *LexvecBuilder) Smooth(smooth float64) *LexvecBuilder { - lb.smooth = smooth - return lb -} - -func (lb *LexvecBuilder) RelationType(typ corpus.RelationType) *LexvecBuilder { - lb.relationType = typ - return lb -} - -// Build creates Lexvec model. -func (lb *LexvecBuilder) Build() (model.Model, error) { - o := &model.Option{ - Dimension: lb.dimension, - Iteration: lb.iteration, - MinCount: lb.minCount, - ThreadSize: lb.threadSize, - BatchSize: lb.batchSize, - Window: lb.window, - Initlr: lb.initlr, - ToLower: lb.toLower, - Verbose: lb.verbose, - SaveVectorType: lb.saveVectorType, - } - - l := &lexvec.LexvecOption{ - NegativeSampleSize: lb.negativeSampleSize, - SubSampleThreshold: lb.subsampleThreshold, - Theta: lb.theta, - Smooth: lb.smooth, - RelationType: lb.relationType, - } - - return lexvec.NewLexvec(o, l), nil -} diff --git a/pkg/builder/word2vec.go b/pkg/builder/word2vec.go deleted file mode 100644 index a6b5fe6..0000000 --- a/pkg/builder/word2vec.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package builder - -import ( - "github.com/pkg/errors" - "github.com/spf13/viper" - - "github.com/ynqa/wego/pkg/config" - "github.com/ynqa/wego/pkg/model" - "github.com/ynqa/wego/pkg/model/word2vec" -) - -// Word2vecBuilder manages the members to build Model interface. -type Word2vecBuilder struct { - // common configs. - dimension int - iteration int - minCount int - threadSize int - batchSize int - window int - initlr float64 - toLower bool - verbose bool - saveVectorType model.SaveVectorType - - // word2vec configs. - model word2vec.ModelType - optimizer word2vec.OptimizerType - maxDepth int - negativeSampleSize int - subsampleThreshold float64 - theta float64 -} - -// NewWord2vecBuilder creates *Word2vecBuilder. -func NewWord2vecBuilder() *Word2vecBuilder { - return &Word2vecBuilder{ - dimension: config.DefaultDimension, - iteration: config.DefaultIteration, - minCount: config.DefaultMinCount, - threadSize: config.DefaultThreadSize, - batchSize: config.DefaultBatchSize, - window: config.DefaultWindow, - initlr: config.DefaultInitlr, - toLower: config.DefaultToLower, - verbose: config.DefaultVerbose, - saveVectorType: config.DefaultSaveVectorType, - - model: config.DefaultModel, - optimizer: config.DefaultOptimizer, - maxDepth: config.DefaultMaxDepth, - negativeSampleSize: config.DefaultNegativeSampleSize, - subsampleThreshold: config.DefaultSubsampleThreshold, - theta: config.DefaultTheta, - } -} - -// NewWord2vecBuilderFromViper creates *Word2vecBuilder from viper. -func NewWord2vecBuilderFromViper() (*Word2vecBuilder, error) { - var saveVectorType model.SaveVectorType - saveVectorTypeStr := viper.GetString(config.SaveVectorType.String()) - switch saveVectorTypeStr { - case model.NORMAL.String(): - saveVectorType = model.NORMAL - case model.ADD.String(): - saveVectorType = model.ADD - default: - return nil, errors.Errorf("Invalid save vector type=%s", saveVectorTypeStr) - } - - var model word2vec.ModelType - modelTypeStr := viper.GetString(config.Model.String()) - switch modelTypeStr { - case word2vec.CBOW.String(): - model = word2vec.CBOW - case word2vec.SKIP_GRAM.String(): - model = word2vec.SKIP_GRAM - default: - return nil, errors.Errorf("Invalid model type=%s", modelTypeStr) - } - - var optimizer word2vec.OptimizerType - optimizerTypeStr := viper.GetString(config.Optimizer.String()) - switch optimizerTypeStr { - case word2vec.NEGATIVE_SAMPLING.String(): - optimizer = word2vec.NEGATIVE_SAMPLING - case word2vec.HIERARCHICAL_SOFTMAX.String(): - optimizer = word2vec.HIERARCHICAL_SOFTMAX - default: - return nil, errors.Errorf("Invalid optimizer type=%s", optimizerTypeStr) - } - - return &Word2vecBuilder{ - dimension: viper.GetInt(config.Dimension.String()), - iteration: viper.GetInt(config.Iteration.String()), - minCount: viper.GetInt(config.MinCount.String()), - threadSize: viper.GetInt(config.ThreadSize.String()), - batchSize: viper.GetInt(config.BatchSize.String()), - window: viper.GetInt(config.Window.String()), - initlr: viper.GetFloat64(config.Initlr.String()), - toLower: viper.GetBool(config.ToLower.String()), - verbose: viper.GetBool(config.Verbose.String()), - saveVectorType: saveVectorType, - - model: model, - optimizer: optimizer, - maxDepth: viper.GetInt(config.MaxDepth.String()), - negativeSampleSize: viper.GetInt(config.NegativeSampleSize.String()), - subsampleThreshold: viper.GetFloat64(config.SubsampleThreshold.String()), - theta: viper.GetFloat64(config.Theta.String()), - }, nil -} - -// Dimension sets dimension of word vector. -func (wb *Word2vecBuilder) Dimension(dimension int) *Word2vecBuilder { - wb.dimension = dimension - return wb -} - -// Iteration sets number of iteration. -func (wb *Word2vecBuilder) Iteration(iter int) *Word2vecBuilder { - wb.iteration = iter - return wb -} - -// MinCount sets min count. -func (wb *Word2vecBuilder) MinCount(minCount int) *Word2vecBuilder { - wb.minCount = minCount - return wb -} - -// ThreadSize sets number of goroutine. -func (wb *Word2vecBuilder) ThreadSize(threadSize int) *Word2vecBuilder { - wb.threadSize = threadSize - return wb -} - -// BatchSize sets batch size to to preprocess/train. -func (wb *Word2vecBuilder) BatchSize(batchSize int) *Word2vecBuilder { - wb.batchSize = batchSize - return wb -} - -// Window sets context window size. -func (wb *Word2vecBuilder) Window(window int) *Word2vecBuilder { - wb.window = window - return wb -} - -// Initlr sets initial learning rate. -func (wb *Word2vecBuilder) Initlr(initlr float64) *Word2vecBuilder { - wb.initlr = initlr - return wb -} - -// ToLower is whether converts the words in corpus to lowercase or not. -func (wb *Word2vecBuilder) ToLower() *Word2vecBuilder { - wb.toLower = true - return wb -} - -// Verbose sets verbose mode. -func (wb *Word2vecBuilder) Verbose() *Word2vecBuilder { - wb.verbose = true - return wb -} - -func (wb *Word2vecBuilder) SaveVectorType(typ model.SaveVectorType) *Word2vecBuilder { - wb.saveVectorType = typ - return wb -} - -// Model sets model of Word2vec. One of: cbow|skip-gram -func (wb *Word2vecBuilder) Model(typ word2vec.ModelType) *Word2vecBuilder { - wb.model = typ - return wb -} - -// Optimizer sets optimizer of Word2vec. One of: hs|ns -func (wb *Word2vecBuilder) Optimizer(typ word2vec.OptimizerType) *Word2vecBuilder { - wb.optimizer = typ - return wb -} - -// MaxDepth sets number of times to track huffman tree. -func (wb *Word2vecBuilder) MaxDepth(maxDepth int) *Word2vecBuilder { - wb.maxDepth = maxDepth - return wb -} - -// NegativeSampleSize sets number of samples as negative. -func (wb *Word2vecBuilder) NegativeSampleSize(size int) *Word2vecBuilder { - wb.negativeSampleSize = size - return wb -} - -// SubSampleThreshold sets threshold for subsampling. -func (wb *Word2vecBuilder) SubSampleThreshold(threshold float64) *Word2vecBuilder { - wb.subsampleThreshold = threshold - return wb -} - -// Theta sets lower limit of learning rate (lr >= initlr * theta). -func (wb *Word2vecBuilder) Theta(theta float64) *Word2vecBuilder { - wb.theta = theta - return wb -} - -// Build creates model.Model interface. -func (wb *Word2vecBuilder) Build() (model.Model, error) { - if wb.optimizer == word2vec.HIERARCHICAL_SOFTMAX && wb.saveVectorType == model.ADD { - return nil, errors.Errorf("Invalid pair of optimizer=%s and save vector type=%s", wb.optimizer, wb.saveVectorType) - } - - o := &model.Option{ - Dimension: wb.dimension, - Iteration: wb.iteration, - MinCount: wb.minCount, - ThreadSize: wb.threadSize, - BatchSize: wb.batchSize, - Window: wb.window, - Initlr: wb.initlr, - ToLower: wb.toLower, - Verbose: wb.verbose, - SaveVectorType: wb.saveVectorType, - } - - var opt word2vec.Optimizer - switch wb.optimizer { - case word2vec.HIERARCHICAL_SOFTMAX: - opt = word2vec.NewHierarchicalSoftmax(wb.maxDepth) - case word2vec.NEGATIVE_SAMPLING: - opt = word2vec.NewNegativeSampling(wb.negativeSampleSize) - default: - return nil, errors.Errorf("Invalid optimizer: %s not in hs|ns", wb.optimizer) - } - - var mod word2vec.Model - switch wb.model { - case word2vec.CBOW: - mod = word2vec.NewCbow(wb.dimension, wb.window, wb.threadSize) - case word2vec.SKIP_GRAM: - mod = word2vec.NewSkipGram(wb.dimension, wb.window, wb.threadSize) - default: - return nil, errors.Errorf("Invalid model: %s not in cbow|skip-gram", wb.model) - } - - w := &word2vec.Word2vecOption{ - Mod: mod, - Opt: opt, - SubsampleThreshold: wb.subsampleThreshold, - Theta: wb.theta, - } - - return word2vec.NewWord2vec(o, w), nil -} diff --git a/pkg/builder/word2vec_test.go b/pkg/builder/word2vec_test.go deleted file mode 100644 index 1efa3a3..0000000 --- a/pkg/builder/word2vec_test.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package builder - -import ( - "github.com/ynqa/wego/pkg/model/word2vec" - "testing" -) - -func TestWord2vecDimension(t *testing.T) { - b := &Word2vecBuilder{} - - expectedDimension := 100 - b.Dimension(expectedDimension) - - if b.dimension != expectedDimension { - t.Errorf("Expected builder.dimension=%v: %v", expectedDimension, b.dimension) - } -} - -func TestWord2vecIteration(t *testing.T) { - b := &Word2vecBuilder{} - - expectedIteration := 50 - b.Iteration(expectedIteration) - - if b.iteration != expectedIteration { - t.Errorf("Expected builder.iteration=%v: %v", expectedIteration, b.iteration) - } -} - -func TestWord2vecMinCount(t *testing.T) { - b := &Word2vecBuilder{} - - expectedMinCount := 10 - b.MinCount(expectedMinCount) - - if b.minCount != expectedMinCount { - t.Errorf("Expected builder.minCount=%v: %v", expectedMinCount, b.minCount) - } -} - -func TestWord2vecThreadSize(t *testing.T) { - b := &Word2vecBuilder{} - - expectedThreadSize := 8 - b.ThreadSize(expectedThreadSize) - - if b.threadSize != expectedThreadSize { - t.Errorf("Expected builder.threadSize=%v: %v", expectedThreadSize, b.threadSize) - } -} - -func TestWord2vecWindow(t *testing.T) { - b := &Word2vecBuilder{} - - expectedWindow := 10 - b.Window(expectedWindow) - - if b.window != expectedWindow { - t.Errorf("Expected builder.window=%v: %v", expectedWindow, b.window) - } -} - -func TestWord2vecInitlr(t *testing.T) { - b := &Word2vecBuilder{} - - expectedInitlr := 0.001 - b.Initlr(expectedInitlr) - - if b.initlr != expectedInitlr { - t.Errorf("Expected builder.initlr=%v: %v", expectedInitlr, b.initlr) - } -} - -func TestWord2vecToLower(t *testing.T) { - b := &Word2vecBuilder{} - - b.ToLower() - - if !b.toLower { - t.Errorf("Expected builder.lower=true: %v", b.toLower) - } -} - -func TestWord2vecVerbose(t *testing.T) { - b := &Word2vecBuilder{} - - b.Verbose() - - if !b.verbose { - t.Errorf("Expected builder.verbose=true: %v", b.verbose) - } -} - -func TestWord2vecModel(t *testing.T) { - b := &Word2vecBuilder{} - - expectedModel := word2vec.CBOW - b.Model(expectedModel) - - if b.model != expectedModel { - t.Errorf("Expected builder.model=%v: %v", expectedModel, b.model) - } -} - -func TestWord2vecOptimizer(t *testing.T) { - b := &Word2vecBuilder{} - - expectedOptimizer := word2vec.NEGATIVE_SAMPLING - b.Optimizer(expectedOptimizer) - - if b.optimizer != expectedOptimizer { - t.Errorf("Expected builder.optimizer=%v: %v", expectedOptimizer, b.optimizer) - } -} - -func TestWord2vecBatchSize(t *testing.T) { - b := &Word2vecBuilder{} - - expectedBatchSize := 2048 - b.BatchSize(expectedBatchSize) - - if b.batchSize != expectedBatchSize { - t.Errorf("Expected builder.batchSize=%v: %v", expectedBatchSize, b.batchSize) - } -} - -func TestWord2vecMaxDepth(t *testing.T) { - b := &Word2vecBuilder{} - - expectedMaxDepth := 40 - b.MaxDepth(expectedMaxDepth) - - if b.maxDepth != expectedMaxDepth { - t.Errorf("Expected builder.maxDepth=%v: %v", expectedMaxDepth, b.maxDepth) - } -} - -func TestWord2vecNegativeSampleSize(t *testing.T) { - b := &Word2vecBuilder{} - - expectedNegativeSampleSize := 20 - b.NegativeSampleSize(expectedNegativeSampleSize) - - if b.negativeSampleSize != expectedNegativeSampleSize { - t.Errorf("Expected builder.negativeSampleSize=%v: %v", expectedNegativeSampleSize, b.negativeSampleSize) - } -} - -func TestWord2vecSubSampleThreshold(t *testing.T) { - b := &Word2vecBuilder{} - - expectedSubSampleThreshold := 0.001 - b.SubSampleThreshold(expectedSubSampleThreshold) - - if b.subsampleThreshold != expectedSubSampleThreshold { - t.Errorf("Expected builder.subsampleThreshold=%v: %v", expectedSubSampleThreshold, b.subsampleThreshold) - } -} - -func TestWord2vecTheta(t *testing.T) { - b := &Word2vecBuilder{} - - expectedTheta := 1.0e-5 - b.Theta(expectedTheta) - - if b.theta != expectedTheta { - t.Errorf("Expected builder.theta=%v: %v", expectedTheta, b.theta) - } -} - -func TestWord2vecInvalidModelBuild(t *testing.T) { - b := &Word2vecBuilder{} - - b.Model(word2vec.ModelType(10)) - - if _, err := b.Build(); err == nil { - t.Errorf("Expected to fail building with invalid model except for skip-gram|cbow: %v", b.model) - } -} - -func TestWord2vecInvalidOptimizerBuild(t *testing.T) { - b := &Word2vecBuilder{} - - b.Optimizer(word2vec.OptimizerType(10)) - - if _, err := b.Build(); err == nil { - t.Errorf("Expected to fail building with invalid optimizer except for ns|hs: %v", b.optimizer) - } -} diff --git a/pkg/model/util_test.go b/pkg/clock/clock.go similarity index 73% rename from pkg/model/util_test.go rename to pkg/clock/clock.go index 889e538..37e7991 100644 --- a/pkg/model/util_test.go +++ b/pkg/clock/clock.go @@ -12,16 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -package model +package clock import ( - "testing" + "time" ) -func TestNextRandom(t *testing.T) { - // TODO: Fuzzy Test - r := NextRandom(5) - if !(0 <= r && r < 5) { - t.Errorf("Extected range between 0 < nextRandom(x) < 5: %v", r) +type Clock struct { + start, last time.Time +} + +func New() *Clock { + n := time.Now() + return &Clock{ + start: n, + last: n, } } + +func (c *Clock) AllElapsed() time.Duration { + return time.Now().Sub(c.start) +} diff --git a/pkg/co/co.go b/pkg/co/co.go deleted file mode 100644 index 1c125c5..0000000 --- a/pkg/co/co.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package co - -// The data structure for co-occurrence is referred from: -// https://blog.chewxy.com/2017/07/12/21-bits-english/ - -// EncodeBigram creates id between two words. -func EncodeBigram(l1, l2 uint64) uint64 { - if l1 < l2 { - return encode(l1, l2) - } else { - return encode(l2, l1) - } -} - -func encode(l1, l2 uint64) uint64 { - return l1 | (l2 << 32) -} - -// DecodeBigram reverts pair id to two word ids. -func DecodeBigram(pid uint64) (uint64, uint64) { - f := pid >> 32 - return pid - (f << 32), f -} diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 0a4d5c9..0000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "runtime" - - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" - "github.com/ynqa/wego/pkg/model/glove" - "github.com/ynqa/wego/pkg/model/word2vec" -) - -// Config is enum of the common config. -type Config int - -// The list of Config. -const ( - InputFile Config = iota - OutputFile - Dimension - Iteration - MinCount - ThreadSize - BatchSize - Window - Initlr - Prof - ToLower - Verbose - SaveVectorType - // Word2Vec - Model - Optimizer - MaxDepth - NegativeSampleSize - SubsampleThreshold - Theta - // GloVe - Solver - Xmax - Alpha - // Lexvec - RelationType - Smooth - // Search - Rank -) - -// The defaults of Config. -var ( - DefaultInputFile string = "example/input.txt" - DefaultOutputFile string = "example/word_vectors.txt" - DefaultDimension int = 10 - DefaultIteration int = 15 - DefaultMinCount int = 5 - DefaultThreadSize int = runtime.NumCPU() - DefaultBatchSize int = 10000 - DefaultWindow int = 5 - DefaultInitlr float64 = 0.025 - DefaultProf bool = false - DefaultToLower bool = false - DefaultVerbose bool = false - DefaultSaveVectorType model.SaveVectorType = model.NORMAL - // Word2Vec - DefaultModel word2vec.ModelType = word2vec.CBOW - DefaultOptimizer word2vec.OptimizerType = word2vec.NEGATIVE_SAMPLING - DefaultMaxDepth int = 0 - DefaultNegativeSampleSize int = 5 - DefaultSubsampleThreshold float64 = 1.0e-3 - DefaultTheta float64 = 1.0e-4 - // GloVe - DefaultSolver glove.SolverType = glove.SGD - DefaultXmax int = 100 - DefaultAlpha float64 = 0.75 - // Lexvex - DefaultRelationType corpus.RelationType = corpus.PPMI - DefaultSmooth float64 = 0.75 - // Search - DefaultRank int = 10 -) - -func (c Config) String() string { - switch c { - case InputFile: - return "inputFile" - case OutputFile: - return "outputFile" - case Dimension: - return "dimension" - case Iteration: - return "iter" - case MinCount: - return "min-count" - case ThreadSize: - return "thread" - case BatchSize: - return "batchSize" - case Window: - return "window" - case Initlr: - return "initlr" - case Prof: - return "prof" - case ToLower: - return "lower" - case Verbose: - return "verbose" - case SaveVectorType: - return "save-vec" - // Word2Vec - case Model: - return "model" - case Optimizer: - return "optimizer" - case MaxDepth: - return "maxDepth" - case NegativeSampleSize: - return "sample" - case SubsampleThreshold: - return "threshold" - case Theta: - return "theta" - // GloVe - case Solver: - return "solver" - case Xmax: - return "xmax" - case Alpha: - return "alpha" - // Lexvec - case RelationType: - return "rel" - case Smooth: - return "smooth" - // Search - case Rank: - return "rank" - default: - return "unknown" - } -} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go deleted file mode 100644 index 145c92d..0000000 --- a/pkg/config/config_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "testing" -) - -func TestInvalidConfigString(t *testing.T) { - var Fake Config = Config(1024) - - if Fake.String() != "unknown" { - t.Errorf("Fake should be not registered in Config: %v", Fake.String()) - } -} - -func TestConfigString(t *testing.T) { - testCases := []struct { - input Config - expected string - }{ - { - input: InputFile, - expected: "inputFile", - }, - { - input: OutputFile, - expected: "outputFile", - }, - { - input: Dimension, - expected: "dimension", - }, - { - input: Iteration, - expected: "iter", - }, - { - input: MinCount, - expected: "min-count", - }, - { - input: ThreadSize, - expected: "thread", - }, - { - input: Window, - expected: "window", - }, - { - input: Initlr, - expected: "initlr", - }, - { - input: Prof, - expected: "prof", - }, - { - input: ToLower, - expected: "lower", - }, - { - input: Verbose, - expected: "verbose", - }, - { - input: SaveVectorType, - expected: "save-vec", - }, - // Word2ec - { - input: Model, - expected: "model", - }, - { - input: Optimizer, - expected: "optimizer", - }, - { - input: BatchSize, - expected: "batchSize", - }, - { - input: MaxDepth, - expected: "maxDepth", - }, - { - input: NegativeSampleSize, - expected: "sample", - }, - { - input: SubsampleThreshold, - expected: "threshold", - }, - { - input: Theta, - expected: "theta", - }, - // GloVe - { - input: Solver, - expected: "solver", - }, - { - input: Xmax, - expected: "xmax", - }, - { - input: Alpha, - expected: "alpha", - }, - // Lexvec - { - input: RelationType, - expected: "rel", - }, - // Search - { - input: Rank, - expected: "rank", - }, - } - - for _, testCase := range testCases { - actual := testCase.input.String() - if actual != testCase.expected { - t.Errorf("Config: %v with String() should be %v, but get %v", testCase.input, testCase.expected, actual) - } - } -} diff --git a/pkg/corpus/core.go b/pkg/corpus/core.go deleted file mode 100644 index 2325e16..0000000 --- a/pkg/corpus/core.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package corpus - -import ( - "bufio" - "fmt" - "io" - "strings" - - "github.com/chewxy/lingo/corpus" - "github.com/pkg/errors" - - "github.com/ynqa/wego/pkg/timer" -) - -type core struct { - *corpus.Corpus - // TODO: more efficient data structure, such as radix tree (trie). - Document []int -} - -func newCore() *core { - c, _ := corpus.Construct() - return &core{ - Corpus: c, - Document: make([]int, 0), - } -} - -func (c *core) Parse(f io.Reader, toLower bool, minCount int, batchSize int, verbose bool) error { - fullDoc := make([]int, 0) - scanner := bufio.NewScanner(f) - scanner.Split(bufio.ScanWords) - - var t *timer.Timer - if verbose { - t = timer.NewTimer() - } - var i int - for scanner.Scan() { - word := scanner.Text() - if toLower { - word = strings.ToLower(word) - } - // TODO: delete words less than minCount in Corpus. - c.Add(word) - wordID, _ := c.Id(word) - fullDoc = append(fullDoc, wordID) - if verbose && i%batchSize == 0 { - fmt.Printf("Read %d words %v\r", i, t.AllElapsed()) - } - i++ - } - if err := scanner.Err(); err != nil && err != io.EOF { - return errors.Wrap(err, "Unable to complete scanning") - } - if verbose { - fmt.Printf("Read %d words %v\r\n", i, t.AllElapsed()) - } - for _, d := range fullDoc { - if c.IDFreq(d) > minCount { - c.Document = append(c.Document, d) - } - } - if verbose { - fmt.Printf("Filter words less than minCount=%d > documentSize=%d\n", minCount, len(c.Document)) - } - return nil -} diff --git a/pkg/corpus/corpus.go b/pkg/corpus/corpus.go new file mode 100644 index 0000000..17f2022 --- /dev/null +++ b/pkg/corpus/corpus.go @@ -0,0 +1,106 @@ +package corpus + +import ( + "bufio" + "fmt" + "io" + "strings" + + "github.com/pkg/errors" + + "github.com/ynqa/wego/pkg/clock" + "github.com/ynqa/wego/pkg/corpus/dictionary" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/verbose" +) + +type Corpus struct { + opts Options + dic *dictionary.Dictionary + pair *pairwise.Pairwise + maxLen int + + doc []int + + verbose *verbose.Verbose +} + +func New( + opts Options, + verbose *verbose.Verbose, +) *Corpus { + return &Corpus{ + opts: opts, + dic: dictionary.New(), + + doc: make([]int, 0), + + verbose: verbose, + } +} + +func (c *Corpus) Doc() []int { + return c.doc +} + +func (c *Corpus) Dictionary() *dictionary.Dictionary { + return c.dic +} + +func (c *Corpus) Pairwise() *pairwise.Pairwise { + return c.pair +} + +func (c *Corpus) Len() int { + return c.maxLen +} + +func (c *Corpus) Build(r io.Reader) error { + scanner := bufio.NewScanner(r) + scanner.Split(bufio.ScanWords) + + cnt, clk := 0, clock.New() + for scanner.Scan() { + word := scanner.Text() + if c.opts.ToLower { + word = strings.ToLower(word) + } + + c.dic.Add(word) + id, _ := c.dic.ID(word) + c.doc = append(c.doc, id) + c.maxLen++ + + c.verbose.Do(func() { + if cnt%100000 == 0 { + fmt.Printf("read %d words %v\r", cnt, clk.AllElapsed()) + } + cnt++ + }) + } + if err := scanner.Err(); err != nil && err != io.EOF { + return errors.Wrap(err, "failed to scan") + } + + c.verbose.Do(func() { + fmt.Printf("read %d words %v\r\n", cnt, clk.AllElapsed()) + }) + return nil +} + +func (c *Corpus) BuildWithPairwise(r io.Reader, opts pairwise.Options, window int) error { + c.pair = pairwise.New(opts) + + if err := c.Build(r); err != nil { + return err + } + + for i := 0; i < len(c.doc); i++ { + for j := i + 1; j < len(c.doc) && j <= i+window; j++ { + if err := c.pair.Add(c.doc[i], c.doc[j]); err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/corpus/count_model.go b/pkg/corpus/count_model.go deleted file mode 100644 index b9cb42e..0000000 --- a/pkg/corpus/count_model.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package corpus - -import ( - "fmt" - "math" - "math/rand" - - "github.com/pkg/errors" - "gopkg.in/cheggaaa/pb.v1" - - "github.com/ynqa/wego/pkg/co" -) - -// CountModelCorpus stores corpus and co-occurrence values between words. -type CountModelCorpus struct { - *core -} - -// Pair stores co-occurrence information. -type Pair struct { - // L1 and L2 store index number for two co-occurrence words. - L1, L2 int - // F stores the measures of co-occurrence, such as PMI. - F float64 - // Coefficient stores a coefficient for weighted matrix factorization. - Coefficient float64 -} - -// PairMap stores co-occurrences. -type PairMap map[uint64]float64 - -// RelationType is a list of types for strength relations between co-occurrence words. -type RelationType int - -const ( - PPMI RelationType = iota - PMI - CO - LOGCO -) - -// String describes relation type name. -func (r RelationType) String() string { - switch r { - case PPMI: - return "ppmi" - case PMI: - return "pmi" - case CO: - return "co" - case LOGCO: - return "logco" - default: - return "unknown" - } -} - -func (c *CountModelCorpus) relationValue(typ RelationType, l1, l2 int, co, logTotalFreq, smooth float64) (float64, error) { - switch typ { - case PPMI: - if co == 0 { - return 0, nil - } - // TODO: avoid log for l1, l2 every time - ppmi := math.Log(co) - math.Log(float64(c.IDFreq(l1))) - math.Log(math.Pow(float64(c.IDFreq(l2)), smooth)) + logTotalFreq - if ppmi < 0 { - ppmi = 0 - } - return ppmi, nil - case PMI: - if co == 0 { - return 1, nil - } - pmi := math.Log(co) - math.Log(float64(c.IDFreq(l1))) - math.Log(math.Pow(float64(c.IDFreq(l2)), smooth)) + logTotalFreq - return pmi, nil - case CO: - return co, nil - case LOGCO: - return math.Log(co), nil - default: - return 0, errors.Errorf("Invalid measure type") - } -} - -// CountType is a list of types to count co-occurences. -type CountType int - -const ( - INCREMENT CountType = iota - // DISTANCE weights values for co-occurrence times. - DISTANCE -) - -func countValue(typ CountType, left, right int) (float64, error) { - switch typ { - case INCREMENT: - return 1., nil - case DISTANCE: - div := left - right - if div == 0 { - return 0, errors.Errorf("Divide by zero on counting co-occurrence") - } - return 1. / math.Abs(float64(div)), nil - default: - return 0, errors.Errorf("Invalid count type") - } -} - -// NewCountModelCorpus creates *CountModelCorpus. -func NewCountModelCorpus() *CountModelCorpus { - return &CountModelCorpus{ - core: newCore(), - } -} - -func (c *CountModelCorpus) cooccurrence(window int, typ CountType, verbose bool) (PairMap, error) { - documentSize := len(c.Document) - - var progress *pb.ProgressBar - if verbose { - fmt.Println("Scan corpus for cooccurrences") - progress = pb.New(documentSize).SetWidth(80) - defer progress.Finish() - progress.Start() - } - - cooccurrence := make(map[uint64]float64) - for i := 0; i < documentSize; i++ { - for j := i + 1; j <= i+window; j++ { - if j >= documentSize { - continue - } - f, err := countValue(typ, i, j) - if err != nil { - return nil, errors.Wrap(err, "Failed to count co-occurrence between words") - } - cooccurrence[co.EncodeBigram(uint64(c.Document[i]), uint64(c.Document[j]))] += f - } - if verbose { - progress.Increment() - } - } - return cooccurrence, nil -} - -func (c *CountModelCorpus) PairsIntoLexvec(window int, relationType RelationType, smooth float64, verbose bool) (PairMap, error) { - cooccurrence, err := c.cooccurrence(window, INCREMENT, verbose) - if err != nil { - return nil, errors.Wrap(err, "Failed to create Pairs for Lexvec") - } - cooccurrenceSize := len(cooccurrence) - - var progress *pb.ProgressBar - if verbose { - fmt.Println("Scan cooccurrences for pairs") - progress = pb.New(cooccurrenceSize).SetWidth(80) - defer progress.Finish() - progress.Start() - } - - logTotalFreq := math.Log(math.Pow(float64(c.TotalFreq()), smooth)) - for p, f := range cooccurrence { - ul1, ul2 := co.DecodeBigram(p) - v, err := c.relationValue(relationType, int(ul1), int(ul2), f, logTotalFreq, smooth) - if err != nil { - return nil, errors.Wrap(err, "Failed to calculate relation value") - } - cooccurrence[p] = v - if verbose { - progress.Increment() - } - } - return cooccurrence, nil -} - -func (c *CountModelCorpus) PairsIntoGlove(window int, xmax int, alpha float64, verbose bool) ([]Pair, error) { - cooccurrence, err := c.cooccurrence(window, DISTANCE, verbose) - if err != nil { - return nil, errors.Wrap(err, "Failed to create Pairs for GloVe") - } - pairSize := len(cooccurrence) - pairs := make([]Pair, pairSize) - shuffle := rand.Perm(pairSize) - - var progress *pb.ProgressBar - if verbose { - fmt.Println("Scan cooccurrences for pairs") - progress = pb.New(pairSize).SetWidth(80) - defer progress.Finish() - progress.Start() - } - - var i int - for p, f := range cooccurrence { - coefficient := 1.0 - if f < float64(xmax) { - coefficient = math.Pow(f/float64(xmax), alpha) - } - - ul1, ul2 := co.DecodeBigram(p) - pairs[shuffle[i]] = Pair{ - L1: int(ul1), - L2: int(ul2), - F: math.Log(f), - Coefficient: coefficient, - } - i++ - if verbose { - progress.Increment() - } - } - return pairs, nil -} diff --git a/pkg/corpus/dictionary/dictionary.go b/pkg/corpus/dictionary/dictionary.go new file mode 100644 index 0000000..c3ebec9 --- /dev/null +++ b/pkg/corpus/dictionary/dictionary.go @@ -0,0 +1,67 @@ +package dictionary + +// inspired by +// - https://github.com/chewxy/lingo/blob/master/corpus/corpus.go +// - https://github.com/RaRe-Technologies/gensim/blob/3.8.1/gensim/corpora/dictionary.py + +type Dictionary struct { + word2id map[string]int + id2word []string + + cfs []int + + maxid int +} + +func New() *Dictionary { + return &Dictionary{ + word2id: make(map[string]int), + id2word: make([]string, 0), + + cfs: make([]int, 0), + } +} + +func (d *Dictionary) Len() int { + return d.maxid +} + +func (d *Dictionary) ID(word string) (int, bool) { + id, ok := d.word2id[word] + return id, ok +} + +func (d *Dictionary) WordFreq(word string) int { + id, ok := d.word2id[word] + if !ok { + return 0 + } + return d.cfs[id] +} + +func (d *Dictionary) Word(id int) (string, bool) { + if id >= d.maxid { + return "", false + } + return d.id2word[id], true +} + +func (d *Dictionary) IDFreq(id int) int { + if id >= d.maxid { + return 0 + } + return d.cfs[id] +} + +func (d *Dictionary) Add(words ...string) { + for _, word := range words { + if id, ok := d.word2id[word]; ok { + d.cfs[id]++ + } else { + d.word2id[word] = d.maxid + d.id2word = append(d.id2word, word) + d.cfs = append(d.cfs, 1) + d.maxid++ + } + } +} diff --git a/pkg/corpus/dictionary/huffman.go b/pkg/corpus/dictionary/huffman.go new file mode 100644 index 0000000..37ce52d --- /dev/null +++ b/pkg/corpus/dictionary/huffman.go @@ -0,0 +1,43 @@ +package dictionary + +import ( + "sort" + + "github.com/ynqa/wego/pkg/corpus/dictionary/node" +) + +func (d *Dictionary) HuffnamTree(dim int) []*node.Node { + nodes := make([]*node.Node, d.maxid) + set := make([]*node.Node, d.maxid) + for i := 0; i < d.maxid; i++ { + n := &node.Node{ + Val: d.IDFreq(i), + } + nodes[i] = n + set[i] = n + } + + sort.SliceStable(nodes, func(i, j int) bool { + return nodes[i].Val < nodes[j].Val + }) + for len(nodes) > 1 { + left, right := nodes[0], nodes[1] + merged := &node.Node{ + Val: left.Val + right.Val, + Vector: make([]float64, dim), + } + left.Code, right.Code = 0, 1 + left.Parent, right.Parent = merged, merged + + nodes = nodes[2:] + idx := sort.Search(len(nodes), func(i int) bool { + return nodes[i].Val >= merged.Val + }) + + nodes = append(nodes, &node.Node{}) + copy(nodes[idx+1:], nodes[idx:]) + nodes[idx] = merged + } + + return set +} diff --git a/pkg/corpus/dictionary/node/node.go b/pkg/corpus/dictionary/node/node.go new file mode 100644 index 0000000..3d3c328 --- /dev/null +++ b/pkg/corpus/dictionary/node/node.go @@ -0,0 +1,29 @@ +package node + +type Node struct { + cache []*Node + Parent *Node + Val int + + Code int + Vector []float64 +} + +func (n *Node) GetPath(depth int) []*Node { + if n.cache == nil { + re := func(nodes []*Node) { + for i, j := 0, len(nodes)-1; i < j; i, j = i+1, j-1 { + nodes[i], nodes[j] = nodes[j], nodes[i] + } + } + n.cache = make([]*Node, 0) + for p := n; p != nil; p = p.Parent { + n.cache = append(n.cache, p) + } + re(n.cache) + } + if depth > len(n.cache) { + depth = len(n.cache) + } + return n.cache[:depth] +} diff --git a/pkg/corpus/options.go b/pkg/corpus/options.go new file mode 100644 index 0000000..18fbe05 --- /dev/null +++ b/pkg/corpus/options.go @@ -0,0 +1,23 @@ +package corpus + +import ( + "github.com/spf13/cobra" +) + +const ( + defaultToLower = false +) + +type Options struct { + ToLower bool +} + +func DefaultOptions() Options { + return Options{ + ToLower: defaultToLower, + } +} + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().BoolVar(&opts.ToLower, "lower", defaultToLower, "whether the words on corpus convert to lowercase or not") +} diff --git a/pkg/corpus/pairwise/encode/encode.go b/pkg/corpus/pairwise/encode/encode.go new file mode 100644 index 0000000..32abd19 --- /dev/null +++ b/pkg/corpus/pairwise/encode/encode.go @@ -0,0 +1,23 @@ +package encode + +// data structure for co-occurrence mapping: +// - https://blog.chewxy.com/2017/07/12/21-bits-english/ + +// EncodeBigram creates id between two words. +func EncodeBigram(l1, l2 uint64) uint64 { + if l1 < l2 { + return encode(l1, l2) + } else { + return encode(l2, l1) + } +} + +func encode(l1, l2 uint64) uint64 { + return l1 | (l2 << 32) +} + +// DecodeBigram reverts pair id to two word ids. +func DecodeBigram(pid uint64) (uint64, uint64) { + f := pid >> 32 + return pid - (f << 32), f +} diff --git a/pkg/corpus/pairwise/options.go b/pkg/corpus/pairwise/options.go new file mode 100644 index 0000000..96a3b18 --- /dev/null +++ b/pkg/corpus/pairwise/options.go @@ -0,0 +1,54 @@ +package pairwise + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +func invalidCountTypeError(typ CountType) error { + return errors.Errorf("invalid relation type: %s not in %s|%s", typ, Increment, Distance) +} + +type CountType string + +const ( + Increment CountType = "inc" + Distance CountType = "dis" + defaultCountType = Increment +) + +func (t *CountType) String() string { + if *t == CountType("") { + *t = defaultCountType + } + return string(*t) +} + +func (t *CountType) Set(name string) error { + typ := CountType(name) + if typ == Increment || typ == Distance { + *t = typ + return nil + } + return invalidCountTypeError(typ) +} + +func (t *CountType) Type() string { + return t.String() +} + +type Options struct { + CountType CountType +} + +func DefaultOptions() Options { + return Options{ + CountType: defaultCountType, + } +} + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().Var(&opts.CountType, "cnt", fmt.Sprintf("count type for co-occurrence words. One of %s|%s", Increment, Distance)) +} diff --git a/pkg/corpus/pairwise/pairwise.go b/pkg/corpus/pairwise/pairwise.go new file mode 100644 index 0000000..e9fa23a --- /dev/null +++ b/pkg/corpus/pairwise/pairwise.go @@ -0,0 +1,46 @@ +package pairwise + +import ( + "math" + + "github.com/pkg/errors" + + "github.com/ynqa/wego/pkg/corpus/pairwise/encode" +) + +type Pairwise struct { + opts Options + + colloc map[uint64]float64 +} + +func New(opts Options) *Pairwise { + return &Pairwise{ + opts: opts, + + colloc: make(map[uint64]float64), + } +} + +func (p *Pairwise) Colloc() map[uint64]float64 { + return p.colloc +} + +func (p *Pairwise) Add(left, right int) error { + enc := encode.EncodeBigram(uint64(left), uint64(right)) + var val float64 + switch p.opts.CountType { + case Increment: + val = 1 + case Distance: + div := left - right + if div == 0 { + return errors.Errorf("Divide by zero on counting co-occurrence") + } + val = 1. / math.Abs(float64(div)) + default: + return invalidCountTypeError(p.opts.CountType) + } + p.colloc[enc] += val + return nil +} diff --git a/pkg/corpus/word2vec.go b/pkg/corpus/word2vec.go deleted file mode 100644 index 453b12f..0000000 --- a/pkg/corpus/word2vec.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package corpus - -import ( - "github.com/ynqa/wego/pkg/node" -) - -// Word2vecCorpus stores corpus. -type Word2vecCorpus struct { - *core -} - -// NewWord2vecCorpus creates *Word2vecCorpus. -func NewWord2vecCorpus() *Word2vecCorpus { - return &Word2vecCorpus{ - core: newCore(), - } -} - -// HuffmanTree builds word nodes map. -func (wc *Word2vecCorpus) HuffmanTree(dimension int) (map[int]*node.Node, error) { - ns := make(node.Nodes, 0, wc.Size()) - nm := make(map[int]*node.Node) - for i := 0; i < wc.Size(); i++ { - n := new(node.Node) - n.Value = wc.IDFreq(i) - nm[i] = n - ns = append(ns, n) - } - if err := ns.Build(dimension); err != nil { - return nil, err - } - return nm, nil -} diff --git a/pkg/corpus/word2vec_test.go b/pkg/corpus/word2vec_test.go deleted file mode 100644 index 8c0d0c9..0000000 --- a/pkg/corpus/word2vec_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package corpus - -import ( - "bytes" - "strconv" - "testing" - - "github.com/ynqa/wego/pkg/node" -) - -func TestGetPath(t *testing.T) { - c := NewWord2vecCorpus() - c.Parse(FakeSeeker, true, 0, 0, false) - huffmanTree, err := c.HuffmanTree(5) - - if err != nil { - t.Errorf(err.Error()) - } - - if len(huffmanTree) != 3 { - t.Errorf("Expected len=3: %d", len(huffmanTree)) - } - - testCases := []struct { - word string - expected string - }{ - {"a", "00"}, - {"b", "01"}, - {"c", "1"}, - } - - for _, testCase := range testCases { - wordID, _ := c.Id(testCase.word) - actual := codes(huffmanTree[wordID].GetPath()) - if actual != testCase.expected { - t.Errorf("Expected codes: %v, but got %v in %v", - testCase.expected, actual, testCase.word) - } - } -} - -func codes(nodes node.Nodes) string { - c := bytes.NewBuffer(make([]byte, 0)) - for _, v := range nodes { - c.WriteString(strconv.Itoa(v.Code)) - } - return c.String()[1:] -} diff --git a/pkg/item/item.go b/pkg/item/item.go new file mode 100644 index 0000000..cd4fcb0 --- /dev/null +++ b/pkg/item/item.go @@ -0,0 +1,88 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package item + +import ( + "bufio" + "io" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +type Item struct { + Word string + Dim int + Vector []float64 +} + +func (i Item) Validate() error { + if i.Word == "" { + return errors.New("Word is empty") + } else if i.Dim == 0 || len(i.Vector) == 0 { + return errors.Errorf("Dim of %s is zero", i.Word) + } else if i.Dim != len(i.Vector) { + return errors.Errorf("Dim and length of Vector must be same, Dim=%d, len(Vec)=%d", i.Dim, len(i.Vector)) + } + return nil +} + +type ItemOp func(Item) error + +func Parse(r io.Reader, op ItemOp) error { + s := bufio.NewScanner(r) + for s.Scan() { + line := s.Text() + if strings.HasPrefix(line, " ") { + continue + } + item, err := ParseLine(line) + if err != nil { + return err + } + if err := op(item); err != nil { + return err + } + } + if err := s.Err(); err != nil && err != io.EOF { + return errors.Wrapf(err, "failed to scan") + } + return nil +} + +func ParseLine(line string) (Item, error) { + slice := strings.Fields(line) + if len(slice) < 2 { + return Item{}, errors.New("Must be over 2 lenghth for word and vector elems") + } + word := slice[0] + elems := slice[1:] + dim := len(elems) + + vec := make([]float64, dim) + for k, elem := range elems { + val, err := strconv.ParseFloat(elem, 64) + if err != nil { + return Item{}, err + } + vec[k] = val + } + return Item{ + Word: word, + Dim: dim, + Vector: vec, + }, nil +} diff --git a/pkg/item/item_test.go b/pkg/item/item_test.go new file mode 100644 index 0000000..d3c9909 --- /dev/null +++ b/pkg/item/item_test.go @@ -0,0 +1,68 @@ +// Copyright © 2019 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package item + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "io/ioutil" + "reflect" + "testing" +) + +func TestParse(t *testing.T) { + testNumVector := 4 + testVectorStr := `apple 1 1 1 1 1 +banana 1 1 1 1 1 +chocolate 0 0 0 0 0 +dragon -1 -1 -1 -1 -1` + + f := ioutil.NopCloser(bytes.NewReader([]byte(testVectorStr))) + defer f.Close() + + items := make([]Item, 0) + op := func(item Item) error { + items = append(items, item) + return nil + } + + assert.NoError(t, Parse(f, op)) + assert.Equal(t, testNumVector, len(items)) +} + +func TestParseLine(t *testing.T) { + testCases := []struct { + name string + line string + expected Item + }{ + { + name: "parse line into Item", + line: "apple 1 1 1 1 1", + expected: Item{ + Word: "apple", + Dim: 5, + Vector: []float64{1, 1, 1, 1, 1}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + item, _ := ParseLine(tc.line) + assert.Truef(t, reflect.DeepEqual(tc.expected, item), "Must be equal %v and %v", tc.expected, item) + }) + } +} diff --git a/pkg/model/glove/adagrad.go b/pkg/model/glove/adagrad.go deleted file mode 100644 index a68d7a0..0000000 --- a/pkg/model/glove/adagrad.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package glove - -import ( - "math" -) - -// AdaGrad behaviors as one of Glove solver. -type AdaGrad struct { - dimension int - initlr float64 - gradsq []float64 -} - -// NewAdaGrad creates *AdaGrad. -func NewAdaGrad(dimension int, initlr float64) *AdaGrad { - return &AdaGrad{ - dimension: dimension, - initlr: initlr, - } -} - -func (a *AdaGrad) initialize(vectorSize int) { - a.gradsq = make([]float64, vectorSize) - for i := 0; i < vectorSize; i++ { - a.gradsq[i] = 1. - } -} - -func (a *AdaGrad) trainOne(l1, l2 int, f, coefficient float64, vector []float64) float64 { - var diff, cost float64 - for i := 0; i < a.dimension; i++ { - diff += vector[l1+i] * vector[l2+i] - } - diff += vector[l1+a.dimension] + vector[l2+a.dimension] - f - fdiff := diff * coefficient - cost = 0.5 * fdiff * diff - fdiff *= a.initlr - for i := 0; i < a.dimension; i++ { - temp1 := fdiff * vector[l2+i] - temp2 := fdiff * vector[l1+i] - a.gradsq[l1+i] += temp1 * temp1 - a.gradsq[l2+i] += temp2 * temp2 - - temp1 /= math.Sqrt(a.gradsq[l1+i]) - temp2 /= math.Sqrt(a.gradsq[l2+i]) - vector[l1+i] -= temp1 - vector[l2+i] -= temp2 - } - vector[l1+a.dimension] -= fdiff / math.Sqrt(a.gradsq[l1+a.dimension]) - vector[l2+a.dimension] -= fdiff / math.Sqrt(a.gradsq[l2+a.dimension]) - fdiff *= fdiff - a.gradsq[l1+a.dimension] += fdiff - a.gradsq[l2+a.dimension] += fdiff - return cost -} diff --git a/pkg/model/glove/adagrad_test.go b/pkg/model/glove/adagrad_test.go deleted file mode 100644 index 29ed625..0000000 --- a/pkg/model/glove/adagrad_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package glove - -import ( - "testing" -) - -func TestNewAdaGrad(t *testing.T) { - expectDimension := 10 - expectInitlr := 0.01 - solver := NewAdaGrad(expectDimension, expectInitlr) - - if solver.gradsq != nil { - t.Error("AdaGrad: gradsq is initialized before calling initialize") - } - - if solver.dimension != expectDimension { - t.Errorf("AdaGrad: dimension=%v: %v", - expectDimension, solver.dimension) - } - - if solver.initlr != expectInitlr { - t.Errorf("AdaGrad: initLearningRate=%v: %v", - expectInitlr, solver.initlr) - } -} - -func TestAdaGradInit(t *testing.T) { - dimension := 10 - initlr := 0.01 - solver := NewAdaGrad(dimension, initlr) - - expectedVectorSize := 100 - solver.initialize(expectedVectorSize) - - if len(solver.gradsq) != expectedVectorSize { - t.Errorf("AdaGrad: after init, len(gradsq)=%v: %v", expectedVectorSize, len(solver.gradsq)) - } -} - -func TestAdaGradCallBack(t *testing.T) { - dimension := 10 - initlr := 0.01 - solver := NewAdaGrad(dimension, initlr) - - before := solver.initlr - after := solver.initlr - - if before != after { - t.Errorf("AdaGrad: without changing after callback: %v -> %v", - before, after) - } -} diff --git a/pkg/model/glove/glove.go b/pkg/model/glove/glove.go index 4baf512..5e4d8e6 100644 --- a/pkg/model/glove/glove.go +++ b/pkg/model/glove/glove.go @@ -1,200 +1,210 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package glove import ( + "bufio" "bytes" + "context" "fmt" "io" "math/rand" "sync" - "github.com/pkg/errors" - "gopkg.in/cheggaaa/pb.v1" + "golang.org/x/sync/semaphore" + "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/corpus/pairwise" "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/modelutil" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" + "github.com/ynqa/wego/pkg/model/modelutil/save" + "github.com/ynqa/wego/pkg/model/subsample" + "github.com/ynqa/wego/pkg/verbose" ) -type GloveOption struct { - Solver Solver - Xmax int - Alpha float64 -} +type glove struct { + opts Options -// Glove stores the configs for Glove models. -type Glove struct { - *model.Option - *GloveOption - *corpus.CountModelCorpus + corpus *corpus.Corpus - // word pairs. - pairs []corpus.Pair + param *matrix.Matrix + subsampler *subsample.Subsampler + solver solver - // words' vector. - vector []float64 + verbose *verbose.Verbose +} - // manage data range per thread. - indexPerThread []int +func New(opts ...ModelOption) (model.Model, error) { + options := Options{ + CorpusOptions: corpus.DefaultOptions(), + PairwiseOptions: pairwise.DefaultOptions(), + ModelOptions: model.DefaultOptions(), - // progress bar. - progress *pb.ProgressBar -} + Alpha: defaultAlpha, + SolverType: defaultSolverType, + Xmax: defaultXmax, + } -// NewGlove creates *Glove. -func NewGlove(option *model.Option, gloveOption *GloveOption) *Glove { - return &Glove{ - Option: option, - GloveOption: gloveOption, + for _, fn := range opts { + fn(&options) } + + return NewForOptions(options) } -func (g *Glove) initialize() (err error) { - // Build pairs based on co-occurrence. - g.pairs, err = g.CountModelCorpus.PairsIntoGlove(g.Window, g.Xmax, g.Alpha, g.Verbose) - if err != nil { - return errors.Wrapf(err, "Failed to initialize for GloVe") - } +func NewForOptions(opts Options) (model.Model, error) { + // TODO: validate Options + v := verbose.New(opts.ModelOptions.Verbose) + return &glove{ + opts: opts, - // Initialize word vector. - vectorSize := g.CountModelCorpus.Size() * (g.Dimension + 1) * 2 - g.vector = make([]float64, vectorSize) - for i := 0; i < vectorSize; i++ { - g.vector[i] = rand.Float64() / float64(g.Dimension) - } + corpus: corpus.New(opts.CorpusOptions, v), - // Initialize solver. - switch solver := g.Solver.(type) { - case *AdaGrad: - solver.initialize(vectorSize) - } - return nil + verbose: v, + }, nil } -// Train trains words' vector on corpus. -func (g *Glove) Train(f io.Reader) error { - c := corpus.NewCountModelCorpus() - if err := c.Parse(f, g.ToLower, g.MinCount, g.BatchSize, g.Verbose); err != nil { - return errors.Wrap(err, "Unable to generate *Glove") +func (g *glove) preTrain(r io.Reader) error { + if err := g.corpus.BuildWithPairwise( + r, + g.opts.PairwiseOptions, + g.opts.ModelOptions.Window, + ); err != nil { + return err } - g.CountModelCorpus = c - if err := g.initialize(); err != nil { - return errors.Wrap(err, "Failed to initialize") + + dic, dim := g.corpus.Dictionary(), g.opts.ModelOptions.Dim + + g.param = matrix.New( + dic.Len()*2, + (dim + 1), + func(vec []float64) { + for i := 0; i < dim+1; i++ { + vec[i] = rand.Float64() / float64(dim) + } + }, + ) + + g.subsampler = subsample.New(dic, g.opts.SubsampleThreshold) + + switch g.opts.SolverType { + case Stochastic: + g.solver = newStochastic(g.opts.ModelOptions) + case AdaGrad: + g.solver = newAdaGrad(dic, g.opts.ModelOptions) + default: + return invalidSolverTypeError(g.opts.SolverType) } - return g.train() + return nil } -func (g *Glove) train() error { - pairSize := len(g.pairs) - if pairSize <= 0 { - return errors.Errorf("No pairs for training") +func (g *glove) Train(r io.Reader) error { + if err := g.preTrain(r); err != nil { + return err } - g.indexPerThread = model.IndexPerThread(g.ThreadSize, pairSize) - - semaphore := make(chan struct{}, g.ThreadSize) - waitGroup := &sync.WaitGroup{} + items := g.preCalculateItems(g.corpus.Pairwise()) + itemSize := len(items) + indexPerThread := modelutil.IndexPerThread( + g.opts.ModelOptions.ThreadSize, + itemSize, + ) - for i := 1; i <= g.Iteration; i++ { - if g.Verbose { - fmt.Printf("Train %d-th:\n", i) - g.progress = pb.New(pairSize).SetWidth(80) - g.progress.Start() - } + for i := 0; i < g.opts.ModelOptions.Iter; i++ { + trained, clk := make(chan struct{}), clock.New() + go g.observe(trained, clk) - for j := 0; j < g.ThreadSize; j++ { - waitGroup.Add(1) - go g.trainPerThread(g.indexPerThread[j], g.indexPerThread[j+1], - semaphore, waitGroup) - } + sem := semaphore.NewWeighted(int64(g.opts.ModelOptions.ThreadSize)) + wg := &sync.WaitGroup{} - switch solver := g.Solver.(type) { - case *Sgd: - solver.postOneIter() + for i := 0; i < g.opts.ModelOptions.ThreadSize; i++ { + wg.Add(1) + s, e := indexPerThread[i], indexPerThread[i+1] + go g.trainPerThread(items[s:e], trained, sem, wg) } - waitGroup.Wait() - if g.Verbose { - g.progress.Finish() - } + wg.Wait() + close(trained) } return nil } -func (g *Glove) trainPerThread(beginIdx, endIdx int, - semaphore chan struct{}, waitGroup *sync.WaitGroup) { - +func (g *glove) trainPerThread( + items []item, + trained chan struct{}, + sem *semaphore.Weighted, + wg *sync.WaitGroup, +) error { defer func() { - <-semaphore - waitGroup.Done() + wg.Done() + sem.Release(1) }() - semaphore <- struct{}{} - for i := beginIdx; i < endIdx; i++ { - if g.Verbose { - g.progress.Increment() + if err := sem.Acquire(context.Background(), 1); err != nil { + return err + } + + dic := g.corpus.Dictionary() + for _, item := range items { + if g.subsampler.Trial(item.l1) && + g.subsampler.Trial(item.l2) && + dic.IDFreq(item.l1) > g.opts.ModelOptions.MinCount && + dic.IDFreq(item.l2) > g.opts.ModelOptions.MinCount { + g.solver.trainOne(item.l1, item.l2+dic.Len(), g.param, item.f, item.coef) + g.solver.trainOne(item.l1+dic.Len(), item.l2, g.param, item.f, item.coef) } - pair := g.pairs[i] - l1 := pair.L1 * (g.Dimension + 1) - l2 := (pair.L2 + g.CountModelCorpus.Size()) * (g.Dimension + 1) - g.Solver.trainOne(l1, l2, pair.F, pair.Coefficient, g.vector) - ll1 := (pair.L1 + g.CountModelCorpus.Size()) * (g.Dimension + 1) - ll2 := pair.L2 * (g.Dimension + 1) - g.Solver.trainOne(ll1, ll2, pair.F, pair.Coefficient, g.vector) + trained <- struct{}{} } + + return nil } -// Save saves the word vector to output writer. -func (g *Glove) Save(output io.Writer) error { - if output == nil { - return errors.New("Invalid output writer: must not be nil") +func (g *glove) observe(trained chan struct{}, clk *clock.Clock) { + var cnt int + for range trained { + g.verbose.Do(func() { + cnt++ + if cnt%g.opts.ModelOptions.BatchSize == 0 { + fmt.Printf("trained %d items %v\r", cnt, clk.AllElapsed()) + } + }) } + g.verbose.Do(func() { + fmt.Printf("trained %d items %v\r\n", cnt, clk.AllElapsed()) + }) +} - wordSize := g.CountModelCorpus.Size() - if g.Verbose { - fmt.Println("Save:") - g.progress = pb.New(wordSize).SetWidth(80) - defer g.progress.Finish() - g.progress.Start() - } +func (g *glove) Save(f io.Writer, typ save.VectorType) error { + writer := bufio.NewWriter(f) + defer writer.Flush() + + dic := g.corpus.Dictionary() var buf bytes.Buffer - for i := 0; i < wordSize; i++ { - word, _ := g.CountModelCorpus.Word(i) + clk := clock.New() + for i := 0; i < dic.Len(); i++ { + word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < g.Dimension; j++ { - l1 := i*(g.Dimension+1) + j + for j := 0; j < g.opts.ModelOptions.Dim; j++ { var v float64 - switch g.SaveVectorType { - case model.NORMAL: - v = g.vector[l1] - case model.ADD: - l2 := (i+wordSize)*(g.Dimension+1) + j - v = g.vector[l1] + g.vector[l2] + switch { + case typ == save.AggregatedVector: + v = g.param.Slice(i)[j] + g.param.Slice(i + dic.Len())[j] + case typ == save.SingleVector: + v = g.param.Slice(i)[j] default: - return errors.Errorf("Invalid save vector type=%s", g.SaveVectorType) + return save.InvalidVectorTypeError(typ) } - - fmt.Fprintf(&buf, "%v ", v) + fmt.Fprintf(&buf, "%f ", v) } fmt.Fprintln(&buf) - if g.Verbose { - g.progress.Increment() - } + g.verbose.Do(func() { + fmt.Printf("save %d words %v\r", i, clk.AllElapsed()) + }) } - - output.Write(buf.Bytes()) + writer.WriteString(fmt.Sprintf("%v", buf.String())) + g.verbose.Do(func() { + fmt.Printf("save %d words %v\r\n", dic.Len(), clk.AllElapsed()) + }) return nil } diff --git a/pkg/model/glove/item.go b/pkg/model/glove/item.go new file mode 100644 index 0000000..c00be3d --- /dev/null +++ b/pkg/model/glove/item.go @@ -0,0 +1,45 @@ +package glove + +import ( + "fmt" + "math" + + "github.com/ynqa/wego/pkg/clock" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/corpus/pairwise/encode" +) + +type item struct { + l1, l2 int + f float64 + coef float64 +} + +func (g *glove) preCalculateItems(pairwise *pairwise.Pairwise) []item { + col := pairwise.Colloc() + res, idx, clk := make([]item, len(col)), 0, clock.New() + for enc, f := range col { + u1, u2 := encode.DecodeBigram(enc) + l1, l2 := int(u1), int(u2) + coef := 1. + if f < float64(g.opts.Xmax) { + coef = math.Pow(f/float64(g.opts.Xmax), g.opts.Alpha) + } + res[idx] = item{ + l1: l1, + l2: l2, + f: math.Log(f), + coef: coef, + } + idx++ + g.verbose.Do(func() { + if idx%100000 == 0 { + fmt.Printf("build %d items %v\r", idx, clk.AllElapsed()) + } + }) + } + g.verbose.Do(func() { + fmt.Printf("build %d items %v\r\n", idx, clk.AllElapsed()) + }) + return res +} diff --git a/pkg/model/glove/options.go b/pkg/model/glove/options.go new file mode 100644 index 0000000..b33c323 --- /dev/null +++ b/pkg/model/glove/options.go @@ -0,0 +1,170 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package glove + +import ( + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/model" +) + +func invalidSolverTypeError(typ SolverType) error { + return errors.Errorf("invalid solver: %s not in %s|%s", typ, Stochastic, AdaGrad) +} + +type SolverType string + +const ( + Stochastic SolverType = "sgd" + AdaGrad SolverType = "adagrad" + defaultSolverType = Stochastic +) + +func (t *SolverType) String() string { + if *t == SolverType("") { + *t = defaultSolverType + } + return string(*t) +} + +func (t *SolverType) Set(name string) error { + typ := SolverType(name) + if typ == Stochastic || typ == AdaGrad { + *t = typ + return nil + } + return invalidSolverTypeError(typ) +} + +func (t *SolverType) Type() string { + return t.String() +} + +const ( + defaultAlpha = 0.75 + defaultSubsampleThreshold = 1.0e-3 + defaultXmax = 100 +) + +type Options struct { + CorpusOptions corpus.Options + PairwiseOptions pairwise.Options + ModelOptions model.Options + + Alpha float64 + SolverType SolverType + SubsampleThreshold float64 + Xmax int +} + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().Float64Var(&opts.Alpha, "alpha", defaultAlpha, "exponent of weighting function") + cmd.Flags().Var(&opts.SolverType, "solver", "solver for GloVe objective. One of: sgd|adagrad") + cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") + cmd.Flags().IntVar(&opts.Xmax, "xmax", defaultXmax, "specifying cutoff in weighting function") +} + +type ModelOption func(*Options) + +// corpus options +func ToLower() ModelOption { + return ModelOption(func(opts *Options) { + opts.CorpusOptions.ToLower = true + }) +} + +// pairwise options +func WithCountType(typ pairwise.CountType) ModelOption { + return ModelOption(func(opts *Options) { + opts.PairwiseOptions.CountType = typ + }) +} + +// model options +func WithBatchSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.BatchSize = v + }) +} + +func WithDimension(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Dim = v + }) +} + +func WithInitLearningRate(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Initlr = v + }) +} + +func WithIteration(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Iter = v + }) +} + +func WithMinCount(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.MinCount = v + }) +} + +func WithThreadSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.ThreadSize = v + }) +} + +func WithWindow(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Window = v + }) +} + +func Verbose() ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Verbose = true + }) +} + +// for glove options +func WithAlpha(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.Alpha = v + }) +} + +func WithSolver(typ SolverType) ModelOption { + return ModelOption(func(opts *Options) { + opts.SolverType = typ + }) +} + +func WithSubsampleThreshold(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.SubsampleThreshold = v + }) +} + +func WithXmax(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.Xmax = v + }) +} diff --git a/pkg/model/glove/sgd.go b/pkg/model/glove/sgd.go deleted file mode 100644 index de98308..0000000 --- a/pkg/model/glove/sgd.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package glove - -// Sgd is stochastic gradient descent that behaviors as one of GloVe solver. -type Sgd struct { - dimension int - currentlr float64 - shrinkage float64 -} - -// NewSgd creates *Sgd. -func NewSgd(dimension int, initlr float64) *Sgd { - return &Sgd{ - dimension: dimension, - currentlr: initlr, - shrinkage: 0.9, - } -} - -func (s *Sgd) trainOne(l1, l2 int, f, coefficient float64, vector []float64) float64 { - var diff, cost float64 - for i := 0; i < s.dimension; i++ { - diff += vector[l1+i] * vector[l2+i] - } - diff += vector[l1+s.dimension] + vector[l2+s.dimension] - f - fdiff := diff * coefficient - cost = 0.5 * fdiff * diff - fdiff *= s.currentlr - for i := 0; i < s.dimension; i++ { - temp1 := fdiff * vector[l2+i] - temp2 := fdiff * vector[l1+i] - vector[l1+i] -= temp1 - vector[l2+i] -= temp2 - } - vector[l1+s.dimension] -= fdiff - vector[l2+s.dimension] -= fdiff - return cost -} - -func (s *Sgd) postOneIter() { - s.currentlr *= s.shrinkage -} diff --git a/pkg/model/glove/sgd_test.go b/pkg/model/glove/sgd_test.go deleted file mode 100644 index e02d063..0000000 --- a/pkg/model/glove/sgd_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package glove - -import ( - "testing" -) - -func TestNewSgd(t *testing.T) { - expectDimension := 10 - expectInitlr := 0.01 - solver := NewSgd(expectDimension, expectInitlr) - - if solver.dimension != expectDimension { - t.Errorf("Sgd: dimension=%v: %v", - expectDimension, solver.dimension) - } - - if solver.currentlr != expectInitlr { - t.Errorf("Sgd: currentlr=%v: %v", - expectInitlr, solver.currentlr) - } -} - -func TestSgdCallBack(t *testing.T) { - dimension := 10 - initlr := 0.01 - solver := NewSgd(dimension, initlr) - - before := solver.currentlr - solver.postOneIter() - after := solver.currentlr - - if before < after { - t.Errorf("Sgd: currentlr is smaller than after postOneIter: %v -> %v", - before, after) - } -} diff --git a/pkg/model/glove/solver.go b/pkg/model/glove/solver.go index 4281ab0..2f81202 100644 --- a/pkg/model/glove/solver.go +++ b/pkg/model/glove/solver.go @@ -1,38 +1,86 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package glove -// Solver is the interface for training with GloVe. -type Solver interface { - trainOne(l1, l2 int, f, coefficient float64, vector []float64) (cost float64) +import ( + "math" + + "github.com/ynqa/wego/pkg/corpus/dictionary" + "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" +) + +type solver interface { + trainOne(l1, l2 int, param *matrix.Matrix, f, coef float64) } -type SolverType int +type stochastic struct { + initlr float64 +} -const ( - SGD SolverType = iota - ADAGRAD -) +func newStochastic(opts model.Options) solver { + return &stochastic{ + initlr: opts.Initlr, + } +} + +func (sol *stochastic) trainOne(l1, l2 int, param *matrix.Matrix, f, coef float64) { + v1, v2 := param.Slice(l1), param.Slice(l2) + dim, diff := len(v1)-1, 0. + for i := 0; i < dim; i++ { + diff += v1[i] * v2[i] + } + diff += v1[dim] + v2[dim] - f + diff *= coef * sol.initlr + for i := 0; i < dim; i++ { + t1, t2 := diff*v2[i], diff*v1[i] + v1[i] -= t1 + v2[i] -= t2 + } + v1[dim] -= diff + v2[dim] -= diff +} + +type adaGrad struct { + initlr float64 + gradsq *matrix.Matrix +} -func (t SolverType) String() string { - switch t { - case SGD: - return "sgd" - case ADAGRAD: - return "adagrad" - default: - return "unknown" +func newAdaGrad(dic *dictionary.Dictionary, opts model.Options) solver { + dimAndBias := opts.Dim + 1 + return &adaGrad{ + initlr: opts.Initlr, + gradsq: matrix.New( + dic.Len()*2, + dimAndBias, + func(vec []float64) { + for i := 0; i < dimAndBias; i++ { + vec[i] = 1. + } + }, + ), + } +} + +func (sol *adaGrad) trainOne(l1, l2 int, param *matrix.Matrix, f, coef float64) { + v1, v2 := param.Slice(l1), param.Slice(l2) + g1, g2 := sol.gradsq.Slice(l1), sol.gradsq.Slice(l2) + dim, diff := len(v1)-1, 0. + for i := 0; i < dim; i++ { + diff += v1[i] * v2[i] + } + diff += v1[dim] + v2[dim] - f + diff *= coef * sol.initlr + for i := 0; i < dim; i++ { + t1, t2 := diff*v2[i], diff*v1[i] + g1[i] += t1 * t1 + g2[i] += t2 * t2 + t1 /= math.Sqrt(g1[i]) + t2 /= math.Sqrt(g2[i]) + v1[i] -= t1 + v2[i] -= t2 } + v1[dim] -= diff / math.Sqrt(g1[dim]) + v2[dim] -= diff / math.Sqrt(g2[dim]) + diff *= diff + g1[dim] += diff + g2[dim] += diff } diff --git a/pkg/model/lexvec/item.go b/pkg/model/lexvec/item.go new file mode 100644 index 0000000..2a1271c --- /dev/null +++ b/pkg/model/lexvec/item.go @@ -0,0 +1,72 @@ +package lexvec + +import ( + "fmt" + "math" + + "github.com/pkg/errors" + "github.com/ynqa/wego/pkg/clock" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/corpus/pairwise/encode" +) + +func (l *lexvec) preCalculateItems(pairwise *pairwise.Pairwise) (map[uint64]float64, error) { + col := pairwise.Colloc() + res, idx, clk := make(map[uint64]float64), 0, clock.New() + logTotalFreq := math.Log(math.Pow(float64(l.corpus.Len()), l.opts.Smooth)) + for enc, f := range col { + u1, u2 := encode.DecodeBigram(enc) + l1, l2 := int(u1), int(u2) + v, err := l.calculateRelation( + l.opts.RelationType, + l1, l2, + f, logTotalFreq, + ) + if err != nil { + return nil, err + } + res[enc] = v + idx++ + l.verbose.Do(func() { + if idx%100000 == 0 { + fmt.Printf("build %d items %v\r", idx, clk.AllElapsed()) + } + }) + } + l.verbose.Do(func() { + fmt.Printf("build %d items %v\r\n", idx, clk.AllElapsed()) + }) + return res, nil +} + +func (l *lexvec) calculateRelation( + typ RelationType, + l1, l2 int, + co, logTotalFreq float64, +) (float64, error) { + dic := l.corpus.Dictionary() + switch typ { + case PPMI: + if co == 0 { + return 0, nil + } + // TODO: avoid log for l1, l2 every time + ppmi := math.Log(co) - math.Log(float64(dic.IDFreq(l1))) - math.Log(math.Pow(float64(dic.IDFreq(l2)), l.opts.Smooth)) + logTotalFreq + if ppmi < 0 { + ppmi = 0 + } + return ppmi, nil + case PMI: + if co == 0 { + return 1, nil + } + pmi := math.Log(co) - math.Log(float64(dic.IDFreq(l1))) - math.Log(math.Pow(float64(dic.IDFreq(l2)), l.opts.Smooth)) + logTotalFreq + return pmi, nil + case Collocation: + return co, nil + case LogCollocation: + return math.Log(co), nil + default: + return 0, errors.Errorf("invalid measure type") + } +} diff --git a/pkg/model/lexvec/lexvec.go b/pkg/model/lexvec/lexvec.go index fa318da..b76feb9 100644 --- a/pkg/model/lexvec/lexvec.go +++ b/pkg/model/lexvec/lexvec.go @@ -1,262 +1,248 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package lexvec import ( + "bufio" "bytes" + "context" "fmt" "io" - "math" "math/rand" "sync" - "github.com/pkg/errors" - "gopkg.in/cheggaaa/pb.v1" + "golang.org/x/sync/semaphore" - "github.com/ynqa/wego/pkg/co" + "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/corpus/pairwise" + "github.com/ynqa/wego/pkg/corpus/pairwise/encode" "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/modelutil" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" + "github.com/ynqa/wego/pkg/model/modelutil/save" + "github.com/ynqa/wego/pkg/model/subsample" + "github.com/ynqa/wego/pkg/verbose" ) -type LexvecOption struct { - NegativeSampleSize int - SubSampleThreshold float64 - Theta float64 - Smooth float64 - RelationType corpus.RelationType -} +type lexvec struct { + opts Options -// Lexvec stores the configs for Lexvec models. -type Lexvec struct { - *model.Option - *LexvecOption - *corpus.CountModelCorpus + corpus *corpus.Corpus - subSamples []float64 + param *matrix.Matrix + subsampler *subsample.Subsampler + currentlr float64 - // word pairs. - pairs corpus.PairMap + verbose *verbose.Verbose +} - // words' vector. - vector []float64 +func New(opts ...ModelOption) (model.Model, error) { + options := Options{ + CorpusOptions: corpus.DefaultOptions(), + ModelOptions: model.DefaultOptions(), - // manage learning rate. - currentlr float64 - trained chan struct{} - trainedWordCount int + NegativeSampleSize: defaultNegativeSampleSize, + RelationType: defaultRelationType, + Smooth: defaultSmooth, + SubsampleThreshold: defaultSubsampleThreshold, + Theta: defaultTheta, + } - // data size per thread. - indexPerThread []int + for _, fn := range opts { + fn(&options) + } - // progress bar. - progress *pb.ProgressBar + return NewForOptions(options) } -// NewLexvec create *Lexvec. -func NewLexvec(option *model.Option, lexvecOption *LexvecOption) *Lexvec { - return &Lexvec{ - Option: option, - LexvecOption: lexvecOption, - - currentlr: option.Initlr, - trained: make(chan struct{}), - } -} +func NewForOptions(opts Options) (model.Model, error) { + // TODO: validate Options + v := verbose.New(opts.ModelOptions.Verbose) + return &lexvec{ + opts: opts, -func (l *Lexvec) initialize() (err error) { - // Build pairs based on co-occurrence. - l.pairs, err = l.CountModelCorpus.PairsIntoLexvec(l.Window, l.RelationType, l.Smooth, l.Verbose) + corpus: corpus.New(opts.CorpusOptions, v), - // Store subsample before training. - l.subSamples = make([]float64, l.Corpus.Size()) - for i := 0; i < l.Corpus.Size(); i++ { - z := 1. - math.Sqrt(l.SubSampleThreshold/float64(l.IDFreq(i))) - if z < 0 { - z = 0 - } - l.subSamples[i] = z - } + currentlr: opts.ModelOptions.Initlr, - // Initialize word vector. - vectorSize := l.Corpus.Size() * l.Dimension * 2 - l.vector = make([]float64, vectorSize) - for i := 0; i < vectorSize; i++ { - l.vector[i] = (rand.Float64() - 0.5) / float64(l.Dimension) - } - return nil + verbose: v, + }, nil } -// Train trains words' vector on corpus. -func (l *Lexvec) Train(f io.Reader) error { - c := corpus.NewCountModelCorpus() - if err := c.Parse(f, l.ToLower, l.MinCount, l.BatchSize, l.Verbose); err != nil { - return errors.Wrap(err, "Failed to parse corpus") +func (l *lexvec) preTrain(r io.Reader) error { + if err := l.corpus.BuildWithPairwise( + r, + pairwise.Options{ + CountType: pairwise.Increment, + }, + l.opts.ModelOptions.Window, + ); err != nil { + return err } - l.CountModelCorpus = c - if err := l.initialize(); err != nil { - return errors.Wrap(err, "Failed to initialize") - } - return l.train() -} -func (l *Lexvec) train() error { - document := l.Document - documentSize := len(document) - if documentSize <= 0 { - return errors.New("No words for training") - } + dic, dim := l.corpus.Dictionary(), l.opts.ModelOptions.Dim - l.indexPerThread = model.IndexPerThread(l.ThreadSize, documentSize) + l.param = matrix.New( + dic.Len()*2, + dim, + func(vec []float64) { + for i := 0; i < dim; i++ { + vec[i] = (rand.Float64() - 0.5) / float64(dim) + } + }, + ) - for i := 1; i <= l.Iteration; i++ { - if l.Verbose { - fmt.Printf("Train %d-th:\n", i) - l.progress = pb.New(documentSize).SetWidth(80) - l.progress.Start() - } - go l.observeLearningRate(i) + l.subsampler = subsample.New(dic, l.opts.SubsampleThreshold) + return nil +} - semaphore := make(chan struct{}, l.ThreadSize) - waitGroup := &sync.WaitGroup{} +func (l *lexvec) Train(r io.Reader) error { + if err := l.preTrain(r); err != nil { + return err + } - for j := 0; j < l.ThreadSize; j++ { - waitGroup.Add(1) - go l.trainPerThread(document[l.indexPerThread[j]:l.indexPerThread[j+1]], semaphore, waitGroup) + items, err := l.preCalculateItems(l.corpus.Pairwise()) + if err != nil { + return err + } + doc := l.corpus.Doc() + indexPerThread := modelutil.IndexPerThread( + l.opts.ModelOptions.ThreadSize, + len(doc), + ) + + for i := 1; i <= l.opts.ModelOptions.Iter; i++ { + trained, clk := make(chan struct{}), clock.New() + go l.observe(trained, clk) + + sem := semaphore.NewWeighted(int64(l.opts.ModelOptions.ThreadSize)) + wg := &sync.WaitGroup{} + + for i := 0; i < l.opts.ModelOptions.ThreadSize; i++ { + wg.Add(1) + s, e := indexPerThread[i], indexPerThread[i+1] + go l.trainPerThread(doc[s:e], items, trained, sem, wg) } - waitGroup.Wait() - if l.Verbose { - l.progress.Finish() - } + wg.Wait() + close(trained) } return nil } -func (l *Lexvec) trainPerThread(document []int, semaphore chan struct{}, waitGroup *sync.WaitGroup) { +func (l *lexvec) trainPerThread( + doc []int, + items map[uint64]float64, + trained chan struct{}, + sem *semaphore.Weighted, + wg *sync.WaitGroup, +) error { defer func() { - waitGroup.Done() - <-semaphore + wg.Done() + sem.Release(1) }() - semaphore <- struct{}{} - for idx, wordID := range document { - if l.Verbose { - l.progress.Increment() - } + if err := sem.Acquire(context.Background(), 1); err != nil { + return err + } - bernoulliTrial := rand.Float64() - p := l.subSamples[wordID] - if p < bernoulliTrial { - continue + dic := l.corpus.Dictionary() + for pos, id := range doc { + if l.subsampler.Trial(id) && dic.IDFreq(id) > l.opts.ModelOptions.MinCount { + l.trainOne(doc, pos, items) } - l.scan(document, idx, l.vector, l.currentlr) - l.trained <- struct{}{} + trained <- struct{}{} } + + return nil } -func (l *Lexvec) scan(document []int, wordIndex int, wordVector []float64, lr float64) { - word := document[wordIndex] - l1 := word * l.Dimension - shrinkage := model.NextRandom(l.Window) - for a := shrinkage; a < l.Window*2+1-shrinkage; a++ { - if a == l.Window { +func (l *lexvec) trainOne(doc []int, pos int, items map[uint64]float64) { + dic := l.corpus.Dictionary() + del := modelutil.NextRandom(l.opts.ModelOptions.Window) + for a := del; a < l.opts.ModelOptions.Window*2+1-del; a++ { + if a == l.opts.ModelOptions.Window { continue } - c := wordIndex - l.Window + a - if c < 0 || c >= len(document) { + c := pos - l.opts.ModelOptions.Window + a + if c < 0 || c >= len(doc) { continue } - context := document[c] - l2 := context * l.Dimension - encoded := co.EncodeBigram(uint64(word), uint64(context)) - l.trainOne(l1, l2, l.pairs[encoded]) - for n := 0; n < l.NegativeSampleSize; n++ { - sample := model.NextRandom(l.CountModelCorpus.Size()) - encoded := co.EncodeBigram(uint64(word), uint64(sample)) - l2 := (sample + l.CountModelCorpus.Size()) * l.Dimension - l.trainOne(l1, l2, l.pairs[encoded]) + enc := encode.EncodeBigram(uint64(doc[pos]), uint64(doc[c])) + l.update(doc[pos], doc[c], items[enc]) + for n := 0; n < l.opts.NegativeSampleSize; n++ { + sample := modelutil.NextRandom(dic.Len()) + enc := encode.EncodeBigram(uint64(doc[pos]), uint64(sample)) + l.update(doc[pos], sample+dic.Len(), items[enc]) } } } -func (l *Lexvec) trainOne(l1, l2 int, f float64) { +func (l *lexvec) update(l1, l2 int, f float64) { var diff float64 - for i := 0; i < l.Dimension; i++ { - diff += l.vector[l1+i] * l.vector[l2+i] + for i := 0; i < l.opts.ModelOptions.Dim; i++ { + diff += l.param.Slice(l1)[i] * l.param.Slice(l2)[i] } diff = (diff - f) * l.currentlr - for i := 0; i < l.Dimension; i++ { - t1 := diff * l.vector[l2+i] - t2 := diff * l.vector[l1+i] - l.vector[l1+i] -= t1 - l.vector[l2+i] -= t2 + for i := 0; i < l.opts.ModelOptions.Dim; i++ { + t1 := diff * l.param.Slice(l2)[i] + t2 := diff * l.param.Slice(l1)[i] + l.param.Slice(l1)[i] -= t1 + l.param.Slice(l2)[i] -= t2 } } -func (l *Lexvec) observeLearningRate(iteration int) { - for range l.trained { - l.trainedWordCount++ - if l.trainedWordCount%l.BatchSize == 0 { - l.currentlr = l.Initlr * - (1. - float64(l.trainedWordCount)/ - (float64(l.Corpus.TotalFreq())-float64(iteration))) - if l.currentlr < l.Initlr*l.Theta { - l.currentlr = l.Initlr * l.Theta +func (l *lexvec) observe(trained chan struct{}, clk *clock.Clock) { + var cnt int + for range trained { + cnt++ + if cnt%l.opts.ModelOptions.BatchSize == 0 { + lower := l.opts.ModelOptions.Initlr * l.opts.Theta + if l.currentlr < lower { + l.currentlr = lower + } else { + l.currentlr = l.opts.ModelOptions.Initlr * (1.0 - float64(cnt)/float64(l.corpus.Len())) } + l.verbose.Do(func() { + fmt.Printf("trained %d words %v\r", cnt, clk.AllElapsed()) + }) } } + l.verbose.Do(func() { + fmt.Printf("trained %d words %v\r\n", cnt, clk.AllElapsed()) + }) } -// Save saves the word vector to output writer. -func (l *Lexvec) Save(output io.Writer) error { - if output == nil { - return errors.New("Invalid output writer: must not be nil") - } +func (l *lexvec) Save(f io.Writer, typ save.VectorType) error { + writer := bufio.NewWriter(f) + defer writer.Flush() - wordSize := l.CountModelCorpus.Size() - if l.Verbose { - fmt.Println("Save:") - l.progress = pb.New(wordSize).SetWidth(80) - defer l.progress.Finish() - l.progress.Start() - } + dic := l.corpus.Dictionary() var buf bytes.Buffer - for i := 0; i < wordSize; i++ { - word, _ := l.CountModelCorpus.Word(i) + clk := clock.New() + for i := 0; i < dic.Len(); i++ { + word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < l.Dimension; j++ { - l1 := i*l.Dimension + j + for j := 0; j < l.opts.ModelOptions.Dim; j++ { var v float64 - switch l.SaveVectorType { - case model.NORMAL: - v = l.vector[l1] - case model.ADD: - l2 := (i+wordSize)*l.Dimension + j - v = l.vector[l1] + l.vector[l2] + switch { + case typ == save.AggregatedVector: + v = l.param.Slice(i)[j] + l.param.Slice(i)[j] + case typ == save.SingleVector: + v = l.param.Slice(i)[j] default: - return errors.Errorf("Invalid save vector type=%s", l.SaveVectorType) + return save.InvalidVectorTypeError(typ) } - fmt.Fprintf(&buf, "%v ", v) + fmt.Fprintf(&buf, "%f ", v) } fmt.Fprintln(&buf) - if l.Verbose { - l.progress.Increment() - } + l.verbose.Do(func() { + fmt.Printf("save %d words %v\r", i, clk.AllElapsed()) + }) } - - output.Write(buf.Bytes()) + writer.WriteString(fmt.Sprintf("%v", buf.String())) + l.verbose.Do(func() { + fmt.Printf("save %d words %v\r\n", dic.Len(), clk.AllElapsed()) + }) return nil } diff --git a/pkg/model/lexvec/options.go b/pkg/model/lexvec/options.go new file mode 100644 index 0000000..17b6ea5 --- /dev/null +++ b/pkg/model/lexvec/options.go @@ -0,0 +1,174 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lexvec + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/model" +) + +func invalidRelationTypeError(typ RelationType) error { + return errors.Errorf("invalid relation type: %s not in %s|%s|%s|%s", typ, PPMI, PMI, Collocation, LogCollocation) +} + +type RelationType string + +const ( + PPMI RelationType = "ppmi" + PMI RelationType = "pmi" + Collocation RelationType = "co" + LogCollocation RelationType = "logco" + defaultRelationType = PPMI +) + +func (t *RelationType) String() string { + if *t == RelationType("") { + *t = defaultRelationType + } + return string(*t) +} + +func (t *RelationType) Set(name string) error { + typ := RelationType(name) + if typ == PPMI || typ == PMI || typ == Collocation || typ == LogCollocation { + *t = typ + return nil + } + return invalidRelationTypeError(typ) +} + +func (t *RelationType) Type() string { + return t.String() +} + +const ( + defaultNegativeSampleSize = 5 + defaultSmooth = 0.75 + defaultSubsampleThreshold = 1.0e-3 + defaultTheta = 1.0e-4 +) + +type Options struct { + CorpusOptions corpus.Options + ModelOptions model.Options + + NegativeSampleSize int + RelationType RelationType + Smooth float64 + SubsampleThreshold float64 + Theta float64 +} + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().IntVar(&opts.NegativeSampleSize, "sample", defaultNegativeSampleSize, "negative sample size") + cmd.Flags().Var(&opts.RelationType, "rel", fmt.Sprintf("relation type for co-occurrence words. One of %s|%s|%s|%s", PPMI, PMI, Collocation, LogCollocation)) + cmd.Flags().Float64Var(&opts.Smooth, "smooth", defaultSmooth, "smoothing value for co-occurence value") + cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") + cmd.Flags().Float64Var(&opts.Theta, "theta", defaultTheta, "lower limit of learning rate (lr >= initlr * theta)") +} + +type ModelOption func(*Options) + +// corpus options +func ToLower() ModelOption { + return ModelOption(func(opts *Options) { + opts.CorpusOptions.ToLower = true + }) +} + +// model options +func WithBatchSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.BatchSize = v + }) +} + +func WithDimension(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Dim = v + }) +} + +func WithInitLearningRate(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Initlr = v + }) +} + +func WithIteration(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Iter = v + }) +} + +func WithMinCount(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.MinCount = v + }) +} + +func WithThreadSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.ThreadSize = v + }) +} + +func WithWindow(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Window = v + }) +} + +func Verbose() ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Verbose = true + }) +} + +// for lexvec options +func WithNegativeSampleSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.NegativeSampleSize = v + }) +} + +func WithRelation(typ RelationType) ModelOption { + return ModelOption(func(opts *Options) { + opts.RelationType = typ + }) +} + +func WithSmooth(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.Smooth = v + }) +} + +func WithSubsampleThreshold(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.SubsampleThreshold = v + }) +} + +func WithTheta(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.Theta = v + }) +} diff --git a/pkg/model/model.go b/pkg/model/model.go index 03df6da..027b6c5 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -16,31 +16,60 @@ package model import ( "io" + "runtime" + + "github.com/spf13/cobra" + "github.com/ynqa/wego/pkg/model/modelutil/save" ) -// Model is the interface that has Train, Save. type Model interface { - Train(f io.Reader) error - Save(o io.Writer) error + Train(io.Reader) error + Save(io.Writer, save.VectorType) error } -// SaveVectorType is a list of types to save model. -type SaveVectorType int - -const ( - // NORMAL saves word vectors only. - NORMAL SaveVectorType = iota - // ADD add word to context vectors, and save them. - ADD +var ( + defaultBatchSize = 100000 + defaultDim = 10 + defaultInitlr = 0.025 + defaultIter = 15 + defaultMinCount = 5 + defaultThreadSize = runtime.NumCPU() + defaultWindow = 5 + defaultVerbose = false ) -func (t SaveVectorType) String() string { - switch t { - case NORMAL: - return "normal" - case ADD: - return "add" - default: - return "unknown" +// Options stores common options for each model. +type Options struct { + BatchSize int + Dim int + Initlr float64 + Iter int + MinCount int + ThreadSize int + Window int + Verbose bool +} + +func DefaultOptions() Options { + return Options{ + BatchSize: defaultBatchSize, + Dim: defaultDim, + Initlr: defaultInitlr, + Iter: defaultIter, + MinCount: defaultMinCount, + ThreadSize: defaultThreadSize, + Window: defaultWindow, + Verbose: defaultVerbose, } } + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().IntVar(&opts.BatchSize, "batch", defaultBatchSize, "batch size to train") + cmd.Flags().IntVarP(&opts.Dim, "dim", "d", defaultDim, "dimension for word vector") + cmd.Flags().Float64Var(&opts.Initlr, "initlr", defaultInitlr, "initial learning rate") + cmd.Flags().IntVar(&opts.Iter, "iter", defaultIter, "number of iteration") + cmd.Flags().IntVar(&opts.MinCount, "min-count", defaultMinCount, "lower limit to filter rare words") + cmd.Flags().IntVar(&opts.ThreadSize, "thread", defaultThreadSize, "number of goroutine") + cmd.Flags().IntVarP(&opts.Window, "window", "w", defaultWindow, "context window size") + cmd.Flags().BoolVar(&opts.Verbose, "verbose", defaultVerbose, "verbose mode") +} diff --git a/pkg/model/modelutil/matrix/matrix.go b/pkg/model/modelutil/matrix/matrix.go new file mode 100644 index 0000000..53382a8 --- /dev/null +++ b/pkg/model/modelutil/matrix/matrix.go @@ -0,0 +1,36 @@ +package matrix + +type Matrix struct { + array []float64 + row int + col int +} + +func New(row, col int, fn func([]float64)) *Matrix { + mat := &Matrix{ + array: make([]float64, row*col), + row: row, + col: col, + } + for i := 0; i < row; i++ { + fn(mat.Slice(i)) + } + return mat +} + +func (m *Matrix) startIndex(id int) int { + return id * m.col +} + +func (m *Matrix) Row() int { + return m.row +} + +func (m *Matrix) Col() int { + return m.col +} + +func (m *Matrix) Slice(id int) []float64 { + start := m.startIndex(id) + return m.array[start : start+m.col] +} diff --git a/pkg/model/modelutil/modelutil.go b/pkg/model/modelutil/modelutil.go new file mode 100644 index 0000000..9f1defc --- /dev/null +++ b/pkg/model/modelutil/modelutil.go @@ -0,0 +1,27 @@ +package modelutil + +import ( + "math" +) + +var ( + next uint64 = 1 +) + +// NextRandom is linear congruential generator (rand.Intn). +func NextRandom(value int) int { + next = next*uint64(25214903917) + 11 + return int(next % uint64(value)) +} + +// IndexPerThread creates interval of indices per thread. +func IndexPerThread(threadSize, dataSize int) []int { + indexPerThread := make([]int, threadSize+1) + indexPerThread[0] = 0 + indexPerThread[threadSize] = dataSize + for i := 1; i < threadSize; i++ { + indexPerThread[i] = indexPerThread[i-1] + + int(math.Trunc(float64((dataSize+i)/threadSize))) + } + return indexPerThread +} diff --git a/pkg/model/modelutil/save/save.go b/pkg/model/modelutil/save/save.go new file mode 100644 index 0000000..25c33e1 --- /dev/null +++ b/pkg/model/modelutil/save/save.go @@ -0,0 +1,37 @@ +package save + +import ( + "github.com/pkg/errors" +) + +func InvalidVectorTypeError(typ VectorType) error { + return errors.Errorf("invalid vector type: %s not in %s|%s", typ, SingleVector, AggregatedVector) +} + +type VectorType string + +const ( + SingleVector VectorType = "single" + AggregatedVector VectorType = "agg" + defaultSaveVectorType = SingleVector +) + +func (t *VectorType) String() string { + if *t == VectorType("") { + *t = defaultSaveVectorType + } + return string(*t) +} + +func (t *VectorType) Set(name string) error { + typ := VectorType(name) + if typ == SingleVector || typ == AggregatedVector { + *t = typ + return nil + } + return InvalidVectorTypeError(typ) +} + +func (t *VectorType) Type() string { + return t.String() +} diff --git a/pkg/model/option.go b/pkg/model/option.go deleted file mode 100644 index 90df71a..0000000 --- a/pkg/model/option.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package model - -// Option stores common options for each model. -type Option struct { - Dimension int - Iteration int - MinCount int - ThreadSize int - BatchSize int - Window int - Initlr float64 - ToLower bool - Verbose bool - SaveVectorType SaveVectorType -} diff --git a/pkg/model/subsample/subsample.go b/pkg/model/subsample/subsample.go new file mode 100644 index 0000000..cd58e6d --- /dev/null +++ b/pkg/model/subsample/subsample.go @@ -0,0 +1,38 @@ +package subsample + +import ( + "math" + "math/rand" + + "github.com/ynqa/wego/pkg/corpus/dictionary" +) + +type Subsampler struct { + samples []float64 +} + +func New( + dic *dictionary.Dictionary, + threshold float64, +) *Subsampler { + samples := make([]float64, dic.Len()) + for i := 0; i < dic.Len(); i++ { + z := 1. - math.Sqrt(threshold/float64(dic.IDFreq(i))) + if z < 0 { + z = 0 + } + samples[i] = z + } + return &Subsampler{ + samples: samples, + } +} + +func (s *Subsampler) Trial(id int) bool { + bernoulliTrial := rand.Float64() + var ok bool + if s.samples[id] > bernoulliTrial { + ok = true + } + return ok +} diff --git a/pkg/model/util.go b/pkg/model/util.go deleted file mode 100644 index 4b55566..0000000 --- a/pkg/model/util.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package model - -import ( - "math" -) - -// IndexPerThread creates interval of indices per thread. -func IndexPerThread(threadSize, dataSize int) []int { - indexPerThread := make([]int, threadSize+1) - indexPerThread[0] = 0 - indexPerThread[threadSize] = dataSize - for i := 1; i < threadSize; i++ { - indexPerThread[i] = indexPerThread[i-1] + - int(math.Trunc(float64((dataSize+i)/threadSize))) - } - return indexPerThread -} - -var next uint64 = 1 - -// NextRandom is linear congruential generator like rand.Intn(window) -func NextRandom(value int) int { - next = next*uint64(25214903917) + 11 - return int(next % uint64(value)) -} diff --git a/pkg/model/word2vec/cbow.go b/pkg/model/word2vec/cbow.go deleted file mode 100644 index 2b3626d..0000000 --- a/pkg/model/word2vec/cbow.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "github.com/ynqa/wego/pkg/model" -) - -// Cbow behaviors as one of Word2vec solver. -type Cbow struct { - sums, pools chan []float64 - - dimension int - window int -} - -// NewCbow creates *Cbow -func NewCbow(dimension, window, threadSize int) *Cbow { - pools := make(chan []float64, threadSize) - sums := make(chan []float64, threadSize) - for i := 0; i < threadSize; i++ { - pools <- make([]float64, dimension) - sums <- make([]float64, dimension) - } - return &Cbow{ - sums: sums, - pools: pools, - - dimension: dimension, - window: window, - } -} - -func (c *Cbow) trainOne(document []int, wordIndex int, wordVector []float64, lr float64, optimizer Optimizer) { - sum := <-c.sums - pool := <-c.pools - word := document[wordIndex] - for i := 0; i < c.dimension; i++ { - sum[i] = 0.0 - pool[i] = 0.0 - } - c.dowith(document, wordIndex, sum, pool, wordVector, c.initSum) - optimizer.update(word, lr, sum, pool) - c.dowith(document, wordIndex, sum, pool, wordVector, c.updateContext) - c.sums <- sum - c.pools <- pool -} - -func (c *Cbow) dowith(document []int, wordIndex int, sum, pool, wordVector []float64, - opr func(context int, sum, pool, wordVector []float64)) { - - shrinkage := model.NextRandom(c.window) - for a := shrinkage; a < c.window*2+1-shrinkage; a++ { - if a != c.window { - c := wordIndex - c.window + a - if c < 0 || c >= len(document) { - continue - } - context := document[c] - opr(context, sum, pool, wordVector) - } - } -} - -func (c *Cbow) initSum(context int, sum, pool, wordVector []float64) { - for i := 0; i < c.dimension; i++ { - sum[i] += wordVector[context*c.dimension+i] - } -} - -func (c *Cbow) updateContext(context int, sum, pool, wordVector []float64) { - for i := 0; i < c.dimension; i++ { - wordVector[context*c.dimension+i] += pool[i] - } -} diff --git a/pkg/model/word2vec/hs.go b/pkg/model/word2vec/hs.go deleted file mode 100644 index d153494..0000000 --- a/pkg/model/word2vec/hs.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/node" - - "github.com/pkg/errors" -) - -// HierarchicalSoftmax is a piece of Word2Vec optimizer. -type HierarchicalSoftmax struct { - *SigmoidTable - nodeMap map[int]*node.Node - maxDepth int - - dimension int - vocabulary int -} - -// NewHierarchicalSoftmax creates *HierarchicalSoftmax. -func NewHierarchicalSoftmax(maxDepth int) *HierarchicalSoftmax { - return &HierarchicalSoftmax{ - SigmoidTable: newSigmoidTable(), - maxDepth: maxDepth, - } -} - -func (hs *HierarchicalSoftmax) initialize(cps *corpus.Word2vecCorpus, dimension int) error { - nodeMap, err := cps.HuffmanTree(dimension) - if err != nil { - return errors.Wrap(err, "Failed to initialize of *HierarchicalSoftmax") - } - hs.nodeMap = nodeMap - hs.dimension = dimension - hs.vocabulary = cps.Size() - return nil -} - -func (hs *HierarchicalSoftmax) update(word int, lr float64, vector, poolVector []float64) { - path := hs.nodeMap[word].GetPath() - for p := 0; p < len(path)-1; p++ { - relayPoint := path[p] - childCode := path[p+1].Code - hs.gradUpd(childCode, lr, relayPoint.Vector, vector, poolVector) - if hs.maxDepth > 0 && p >= hs.maxDepth { - break - } - } -} - -func (hs *HierarchicalSoftmax) gradUpd(childCode int, lr float64, relayPointVec, vector, poolVector []float64) { - var inner float64 - for i := 0; i < hs.dimension; i++ { - inner += vector[i] * relayPointVec[i] - } - if inner <= -hs.maxExp || inner >= hs.maxExp { - return - } - g := (1.0 - float64(childCode) - hs.sigmoid(inner)) * lr - for i := 0; i < hs.dimension; i++ { - poolVector[i] += g * relayPointVec[i] - relayPointVec[i] += g * vector[i] - } -} diff --git a/pkg/model/word2vec/hs_test.go b/pkg/model/word2vec/hs_test.go deleted file mode 100644 index 6c23640..0000000 --- a/pkg/model/word2vec/hs_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "testing" - - "github.com/ynqa/wego/pkg/corpus" -) - -func TestNewHierarchicalSoftmax(t *testing.T) { - maxDepth := 10 - hs := NewHierarchicalSoftmax(maxDepth) - - if hs.nodeMap != nil { - t.Error("HierarchicalSoftmax: Initializing without building huffman tree") - } -} - -func TestHSInit(t *testing.T) { - maxDepth := 10 - hs := NewHierarchicalSoftmax(maxDepth) - - dimension := 10 - c := corpus.NewWord2vecCorpus() - c.Parse(corpus.FakeSeeker, true, 0, 0, false) - hs.initialize(c, dimension) - - expectedNodeMapSize := c.Size() - if len(hs.nodeMap) != expectedNodeMapSize { - t.Errorf("HierarchicalSoftmax: Init returns nodeMap with length=%v: %v", - expectedNodeMapSize, len(hs.nodeMap)) - } -} diff --git a/pkg/model/word2vec/model.go b/pkg/model/word2vec/model.go index e30aa21..8b230cc 100644 --- a/pkg/model/word2vec/model.go +++ b/pkg/model/word2vec/model.go @@ -1,38 +1,134 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package word2vec -// Model is the interface to train a word vector. -type Model interface { - trainOne(document []int, wordIndex int, wordVector []float64, lr float64, optimizer Optimizer) +import ( + "github.com/ynqa/wego/pkg/model/modelutil" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" +) + +type mod interface { + trainOne( + doc []int, + pos int, + lr float64, + param *matrix.Matrix, + optimizer optimizer, + ) } -type ModelType int +type skipGram struct { + ch chan []float64 + window int +} -const ( - CBOW ModelType = iota - SKIP_GRAM -) +func newSkipGram(opts Options) mod { + ch := make(chan []float64, opts.ModelOptions.ThreadSize) + for i := 0; i < opts.ModelOptions.ThreadSize; i++ { + ch <- make([]float64, opts.ModelOptions.Dim) + } + return &skipGram{ + ch: ch, + window: opts.ModelOptions.Window, + } +} + +func (mod *skipGram) trainOne( + doc []int, + pos int, + lr float64, + param *matrix.Matrix, + optimizer optimizer, +) { + tmp := <-mod.ch + defer func() { + mod.ch <- tmp + }() + del := modelutil.NextRandom(mod.window) + for a := del; a < mod.window*2+1-del; a++ { + if a == mod.window { + continue + } + c := pos - mod.window + a + if c < 0 || c >= len(doc) { + continue + } + for i := 0; i < len(tmp); i++ { + tmp[i] = 0 + } + ctxID := doc[c] + ctx := param.Slice(ctxID) + optimizer.optim(doc[pos], lr, ctx, tmp) + for i := 0; i < len(ctx); i++ { + ctx[i] += tmp[i] + } + } +} + +type cbow struct { + ch chan []float64 + window int +} + +func newCbow(opts Options) mod { + ch := make(chan []float64, opts.ModelOptions.ThreadSize*2) + for i := 0; i < opts.ModelOptions.ThreadSize; i++ { + ch <- make([]float64, opts.ModelOptions.Dim) + } + return &cbow{ + ch: ch, + window: opts.ModelOptions.Window, + } +} + +func (mod *cbow) trainOne( + doc []int, + pos int, + lr float64, + param *matrix.Matrix, + optimizer optimizer, +) { + agg, tmp := <-mod.ch, <-mod.ch + defer func() { + mod.ch <- agg + mod.ch <- tmp + }() + for i := 0; i < len(agg); i++ { + agg[i], tmp[i] = 0, 0 + } + mod.dowith(doc, pos, param, agg, tmp, mod.aggregate) + optimizer.optim(doc[pos], lr, agg, tmp) + mod.dowith(doc, pos, param, agg, tmp, mod.update) +} + +func (mod *cbow) dowith( + doc []int, + pos int, + param *matrix.Matrix, + agg, tmp []float64, + fn func(ctx, agg, tmp []float64), +) { + del := modelutil.NextRandom(mod.window) + for a := del; a < mod.window*2+1-del; a++ { + if a == mod.window { + continue + } + c := pos - mod.window + a + if c < 0 || c >= len(doc) { + continue + } + ctxID := doc[c] + ctx := param.Slice(ctxID) + fn(ctx, agg, tmp) + } +} + +func (c *cbow) aggregate(ctx, agg, _ []float64) { + for i := 0; i < len(ctx); i++ { + agg[i] += ctx[i] + } +} -func (t ModelType) String() string { - switch t { - case CBOW: - return "cbow" - case SKIP_GRAM: - return "skip-gram" - default: - return "unknown" +func (c *cbow) update(ctx, _, tmp []float64) { + for i := 0; i < len(ctx); i++ { + ctx[i] += tmp[i] } } diff --git a/pkg/model/word2vec/ns.go b/pkg/model/word2vec/ns.go deleted file mode 100644 index b97d15b..0000000 --- a/pkg/model/word2vec/ns.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "github.com/ynqa/wego/pkg/corpus" - "github.com/ynqa/wego/pkg/model" -) - -// NegativeSampling is a piece of Word2Vec optimizer. -type NegativeSampling struct { - *SigmoidTable - ContextVector []float64 - sampleSize int - - dimension int - vocabulary int -} - -// NewNegativeSampling creates *NegativeSampling. -func NewNegativeSampling(sampleSize int) *NegativeSampling { - return &NegativeSampling{ - SigmoidTable: newSigmoidTable(), - sampleSize: sampleSize, - } -} - -func (ns *NegativeSampling) initialize(cps *corpus.Word2vecCorpus, dimension int) error { - ns.vocabulary = cps.Size() - ns.dimension = dimension - ns.ContextVector = make([]float64, ns.vocabulary*ns.dimension) - return nil -} - -func (ns *NegativeSampling) update(word int, lr float64, vector, poolVector []float64) { - var label int - var sample int - var sampleVector []float64 - for n := -1; n < ns.sampleSize; n++ { - if n == -1 { - label = 1 - sampleVector = ns.ContextVector[word*ns.dimension : word*ns.dimension+ns.dimension] - } else { - label = 0 - sample = model.NextRandom(ns.vocabulary) - sampleVector = ns.ContextVector[sample*ns.dimension : sample*ns.dimension+ns.dimension] - if word == sample { - continue - } - } - ns.gradUpd(label, lr, sampleVector, vector, poolVector) - var index int - if n == -1 { - index = word - } else { - index = sample - } - for i := 0; i < ns.dimension; i++ { - ns.ContextVector[index*ns.dimension+i] = sampleVector[i] - } - } -} - -func (ns *NegativeSampling) gradUpd(label int, lr float64, sampledVector, vector, poolVector []float64) { - var inner float64 - for i := 0; i < ns.dimension; i++ { - inner += sampledVector[i] * vector[i] - } - var g float64 - if inner <= -ns.maxExp { - g = (float64(label - 0)) * lr - } else if inner >= ns.maxExp { - g = (float64(label - 1)) * lr - } else { - g = (float64(label) - ns.sigmoid(inner)) * lr - } - for i := 0; i < ns.dimension; i++ { - poolVector[i] += g * sampledVector[i] - sampledVector[i] += g * vector[i] - } -} diff --git a/pkg/model/word2vec/ns_test.go b/pkg/model/word2vec/ns_test.go deleted file mode 100644 index 029c2be..0000000 --- a/pkg/model/word2vec/ns_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "testing" - - "github.com/ynqa/wego/pkg/corpus" -) - -func TestNewNegativeSampling(t *testing.T) { - sampleSize := 10 - ns := NewNegativeSampling(sampleSize) - - if ns.ContextVector != nil { - t.Error("NegativeSampling: Initializing without building negative vactors") - } -} - -func TestInitialize(t *testing.T) { - sampleSize := 10 - ns := NewNegativeSampling(sampleSize) - - dimension := 10 - c := corpus.NewWord2vecCorpus() - c.Parse(corpus.FakeSeeker, true, 0, 0, false) - ns.initialize(c, dimension) - - expectedVectorSize := c.Size() * dimension - if len(ns.ContextVector) != expectedVectorSize { - t.Errorf("NegativeSampling: Init returns negativeTensor with length=%v: %v", - expectedVectorSize, len(ns.ContextVector)) - } -} diff --git a/pkg/model/word2vec/opt.go b/pkg/model/word2vec/opt.go deleted file mode 100644 index 8fd5354..0000000 --- a/pkg/model/word2vec/opt.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "github.com/ynqa/wego/pkg/corpus" -) - -// Optimizer is the interface to initialize after scanning corpus once, and update the word vector. -type Optimizer interface { - initialize(cps *corpus.Word2vecCorpus, dimension int) error - update(word int, lr float64, vector, poolVector []float64) -} - -type OptimizerType int - -const ( - NEGATIVE_SAMPLING OptimizerType = iota - HIERARCHICAL_SOFTMAX -) - -func (t OptimizerType) String() string { - switch t { - case NEGATIVE_SAMPLING: - return "ns" - case HIERARCHICAL_SOFTMAX: - return "hs" - default: - return "unknown" - } -} diff --git a/pkg/model/word2vec/optimizer.go b/pkg/model/word2vec/optimizer.go new file mode 100644 index 0000000..65488ab --- /dev/null +++ b/pkg/model/word2vec/optimizer.go @@ -0,0 +1,116 @@ +package word2vec + +import ( + "math/rand" + + "github.com/ynqa/wego/pkg/corpus/dictionary" + "github.com/ynqa/wego/pkg/corpus/dictionary/node" + "github.com/ynqa/wego/pkg/model/modelutil" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" +) + +type optimizer interface { + optim(id int, lr float64, ctx, tmp []float64) +} + +type negativeSampling struct { + ctx *matrix.Matrix + sigtable *sigmoidTable + sampleSize int +} + +func newNegativeSampling(dic *dictionary.Dictionary, opts Options) optimizer { + dim := opts.ModelOptions.Dim + return &negativeSampling{ + ctx: matrix.New( + dic.Len(), + dim, + func(vec []float64) { + for i := 0; i < dim; i++ { + vec[i] = (rand.Float64() - 0.5) / float64(dim) + } + }, + ), + sigtable: newSigmoidTable(), + sampleSize: opts.NegativeSampleSize, + } +} + +func (opt *negativeSampling) optim( + id int, + lr float64, + ctx, tmp []float64, +) { + var ( + label int + picked int + ) + dim := len(ctx) + for n := -1; n < opt.sampleSize; n++ { + if n == -1 { + label = 1 + picked = id + } else { + label = 0 + picked = modelutil.NextRandom(opt.ctx.Row()) + if id == picked { + continue + } + } + rnd := opt.ctx.Slice(picked) + var inner float64 + for i := 0; i < dim; i++ { + inner += rnd[i] * ctx[i] + } + var g float64 + if inner <= -opt.sigtable.maxExp { + g = (float64(label - 0)) * lr + } else if inner >= opt.sigtable.maxExp { + g = (float64(label - 1)) * lr + } else { + g = (float64(label) - opt.sigtable.sigmoid(inner)) * lr + } + for i := 0; i < dim; i++ { + tmp[i] += g * rnd[i] + rnd[i] += g * ctx[i] + } + } +} + +type hierarchicalSoftmax struct { + sigtable *sigmoidTable + nodeset []*node.Node + maxDepth int +} + +func newHierarchicalSoftmax(dic *dictionary.Dictionary, opts Options) optimizer { + return &hierarchicalSoftmax{ + sigtable: newSigmoidTable(), + nodeset: dic.HuffnamTree(opts.ModelOptions.Dim), + maxDepth: opts.MaxDepth, + } +} + +func (opt *hierarchicalSoftmax) optim( + id int, + lr float64, + ctx, tmp []float64, +) { + path := opt.nodeset[id].GetPath(opt.maxDepth) + for i := 0; i < len(path)-1; i++ { + p := path[i] + childCode := path[i+1].Code + var inner float64 + for j := 0; j < len(p.Vector); j++ { + inner += ctx[j] * p.Vector[j] + } + if inner <= -opt.sigtable.maxExp || inner >= opt.sigtable.maxExp { + return + } + g := (1.0 - float64(childCode) - opt.sigtable.sigmoid(inner)) * lr + for j := 0; j < len(p.Vector); j++ { + tmp[j] += g * p.Vector[j] + p.Vector[j] += g * ctx[j] + } + } +} diff --git a/pkg/model/word2vec/options.go b/pkg/model/word2vec/options.go new file mode 100644 index 0000000..727b872 --- /dev/null +++ b/pkg/model/word2vec/options.go @@ -0,0 +1,211 @@ +// Copyright © 2017 Makoto Ito +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package word2vec + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/pkg/corpus" + "github.com/ynqa/wego/pkg/model" +) + +func invalidModelTypeError(typ ModelType) error { + return errors.Errorf("invalid model: %s not in %s|%s", typ, Cbow, SkipGram) +} +func invalidOptimizerTypeError(typ OptimizerType) error { + return errors.Errorf("invalid optimizer: %s not in %s|%s", typ, NegativeSampling, HierarchicalSoftmax) +} + +type ModelType string + +const ( + Cbow ModelType = "cbow" + SkipGram ModelType = "skipgram" + defaultModelType = Cbow +) + +func (t *ModelType) String() string { + if *t == ModelType("") { + *t = defaultModelType + } + return string(*t) +} + +func (t *ModelType) Set(name string) error { + typ := ModelType(name) + if typ == SkipGram || typ == Cbow { + *t = typ + return nil + } + return invalidModelTypeError(typ) +} + +func (t *ModelType) Type() string { + return t.String() +} + +type OptimizerType string + +const ( + NegativeSampling OptimizerType = "ns" + HierarchicalSoftmax OptimizerType = "hs" + defaultOptimizerType = NegativeSampling +) + +func (t *OptimizerType) String() string { + if *t == OptimizerType("") { + *t = defaultOptimizerType + } + return string(*t) +} + +func (t *OptimizerType) Set(name string) error { + typ := OptimizerType(name) + if typ == NegativeSampling || typ == HierarchicalSoftmax { + *t = typ + return nil + } + return invalidOptimizerTypeError(typ) +} + +func (t *OptimizerType) Type() string { + return t.String() +} + +const ( + defaultMaxDepth = 100 + defaultNegativeSampleSize = 5 + defaultSubsampleThreshold = 1.0e-3 + defaultTheta = 1.0e-4 +) + +type Options struct { + CorpusOptions corpus.Options + ModelOptions model.Options + + MaxDepth int + ModelType ModelType + NegativeSampleSize int + OptimizerType OptimizerType + SubsampleThreshold float64 + Theta float64 +} + +func LoadForCmd(cmd *cobra.Command, opts *Options) { + cmd.Flags().IntVar(&opts.MaxDepth, "maxDepth", defaultMaxDepth, "times to track huffman tree, max-depth=0 means to track full path from root to word (for hierarchical softmax only)") + cmd.Flags().Var(&opts.ModelType, "model", fmt.Sprintf("which model does it use? one of: %s|%s", Cbow, SkipGram)) + cmd.Flags().IntVar(&opts.NegativeSampleSize, "sample", defaultNegativeSampleSize, "negative sample size(for negative sampling only)") + cmd.Flags().Var(&opts.OptimizerType, "optimizer", fmt.Sprintf("which optimizer does it use? one of: %s|%s", HierarchicalSoftmax, NegativeSampling)) + cmd.Flags().Float64Var(&opts.SubsampleThreshold, "threshold", defaultSubsampleThreshold, "threshold for subsampling") + cmd.Flags().Float64Var(&opts.Theta, "theta", defaultTheta, "lower limit of learning rate (lr >= initlr * theta)") +} + +type ModelOption func(*Options) + +// corpus options +func ToLower() ModelOption { + return ModelOption(func(opts *Options) { + opts.CorpusOptions.ToLower = true + }) +} + +// model options +func WithBatchSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.BatchSize = v + }) +} + +func WithDimension(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Dim = v + }) +} + +func WithInitLearningRate(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Initlr = v + }) +} + +func WithIteration(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Iter = v + }) +} + +func WithMinCount(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.MinCount = v + }) +} + +func WithThreadSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.ThreadSize = v + }) +} + +func WithWindow(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Window = v + }) +} + +func Verbose() ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelOptions.Verbose = true + }) +} + +// word2vec options +func WithMaxDepth(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.MaxDepth = v + }) +} + +func WithModel(typ ModelType) ModelOption { + return ModelOption(func(opts *Options) { + opts.ModelType = typ + }) +} + +func WithNegativeSampleSize(v int) ModelOption { + return ModelOption(func(opts *Options) { + opts.NegativeSampleSize = v + }) +} + +func WithOptimizer(typ OptimizerType) ModelOption { + return ModelOption(func(opts *Options) { + opts.OptimizerType = typ + }) +} + +func WithSubsampleThreshold(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.SubsampleThreshold = v + }) +} + +func WithTheta(v float64) ModelOption { + return ModelOption(func(opts *Options) { + opts.Theta = v + }) +} diff --git a/pkg/model/word2vec/sigmoid_table.go b/pkg/model/word2vec/sigmoid_table.go index 3962354..9574861 100644 --- a/pkg/model/word2vec/sigmoid_table.go +++ b/pkg/model/word2vec/sigmoid_table.go @@ -18,22 +18,18 @@ import ( "math" ) -// SigmoidTable stores the values of sigmoid function. -type SigmoidTable struct { +type sigmoidTable struct { expTable []float64 expTableSize int maxExp float64 cache float64 } -// newSigmoidTable creates sigmoid table, which acquires the sigmoid value f(x) from: -func newSigmoidTable() *SigmoidTable { - s := new(SigmoidTable) +func newSigmoidTable() *sigmoidTable { + s := new(sigmoidTable) s.expTableSize = 1000 s.maxExp = 6.0 - s.cache = float64(s.expTableSize) / s.maxExp / 2.0 - s.expTable = make([]float64, s.expTableSize) for i := 0; i < s.expTableSize; i++ { expval := math.Exp((float64(i)/float64(s.expTableSize)*2. - 1.) * s.maxExp) @@ -44,6 +40,6 @@ func newSigmoidTable() *SigmoidTable { // sigmoid returns: f(x) = (x + max_exp) * (exp_table_size / max_exp / 2) // If you set x to over |max_exp|, it raises index out of range error. -func (s *SigmoidTable) sigmoid(x float64) float64 { +func (s *sigmoidTable) sigmoid(x float64) float64 { return s.expTable[int((x+s.maxExp)*s.cache)] } diff --git a/pkg/model/word2vec/sigmoid_table_test.go b/pkg/model/word2vec/sigmoid_table_test.go index 72f5c0d..b75820a 100644 --- a/pkg/model/word2vec/sigmoid_table_test.go +++ b/pkg/model/word2vec/sigmoid_table_test.go @@ -18,11 +18,10 @@ import ( "testing" ) -func TestSigmoidOverLength(t *testing.T) { +func TestSigmoid(t *testing.T) { table := newSigmoidTable() - // TODO: fuzzy testing. f := table.sigmoid(3) if !(f >= 0 || f <= 1) { - t.Errorf("Extected range between 0 < Sigmoid(x) < 1: %v", f) + t.Errorf("Expected range is 0 < sigmoid(x) < 1, but got %v", f) } } diff --git a/pkg/model/word2vec/skipgram.go b/pkg/model/word2vec/skipgram.go deleted file mode 100644 index 0a44b46..0000000 --- a/pkg/model/word2vec/skipgram.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package word2vec - -import ( - "github.com/ynqa/wego/pkg/model" -) - -// SkipGram behaviors as one of Word2vec solver. -type SkipGram struct { - pools chan []float64 - - dimension int - window int -} - -// NewSkipGram creates *SkipGram -func NewSkipGram(dimension, window, threadSize int) *SkipGram { - pools := make(chan []float64, threadSize) - for i := 0; i < threadSize; i++ { - pools <- make([]float64, dimension) - } - return &SkipGram{ - pools: pools, - - dimension: dimension, - window: window, - } -} - -func (s *SkipGram) trainOne(document []int, wordIndex int, wordVector []float64, lr float64, optimizer Optimizer) { - pool := <-s.pools - word := document[wordIndex] - shrinkage := model.NextRandom(s.window) - for a := shrinkage; a < s.window*2+1-shrinkage; a++ { - if a == s.window { - continue - } - c := wordIndex - s.window + a - if c < 0 || c >= len(document) { - continue - } - context := document[c] - for i := 0; i < s.dimension; i++ { - pool[i] = 0.0 - } - optimizer.update(word, lr, wordVector[context*s.dimension:context*s.dimension+s.dimension], pool) - for i := 0; i < s.dimension; i++ { - wordVector[context*s.dimension+i] += pool[i] - } - } - s.pools <- pool -} diff --git a/pkg/model/word2vec/word2vec.go b/pkg/model/word2vec/word2vec.go index 65e36a5..64d3667 100644 --- a/pkg/model/word2vec/word2vec.go +++ b/pkg/model/word2vec/word2vec.go @@ -1,221 +1,232 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package word2vec import ( + "bufio" "bytes" + "context" "fmt" "io" - "math" "math/rand" "sync" - "github.com/pkg/errors" - "gopkg.in/cheggaaa/pb.v1" + "golang.org/x/sync/semaphore" + "github.com/ynqa/wego/pkg/clock" "github.com/ynqa/wego/pkg/corpus" "github.com/ynqa/wego/pkg/model" + "github.com/ynqa/wego/pkg/model/modelutil" + "github.com/ynqa/wego/pkg/model/modelutil/matrix" + "github.com/ynqa/wego/pkg/model/modelutil/save" + "github.com/ynqa/wego/pkg/model/subsample" + "github.com/ynqa/wego/pkg/verbose" ) -type Word2vecOption struct { - Mod Model - Opt Optimizer - SubsampleThreshold float64 - Theta float64 -} +type word2vec struct { + opts Options -// Word2vec stores the configs for Word2vec models. -type Word2vec struct { - *model.Option - *Word2vecOption - *corpus.Word2vecCorpus + corpus *corpus.Corpus - subSamples []float64 + param *matrix.Matrix + subsampler *subsample.Subsampler + currentlr float64 + mod mod + optimizer optimizer - // words' vector. - vector []float64 + verbose *verbose.Verbose +} - // manage learning rate. - currentlr float64 - trained chan struct{} - trainedWordCount int +func New(opts ...ModelOption) (model.Model, error) { + options := Options{ + CorpusOptions: corpus.DefaultOptions(), + ModelOptions: model.DefaultOptions(), + + MaxDepth: defaultMaxDepth, + ModelType: defaultModelType, + NegativeSampleSize: defaultNegativeSampleSize, + OptimizerType: defaultOptimizerType, + SubsampleThreshold: defaultSubsampleThreshold, + Theta: defaultTheta, + } - // manage data range per thread. - indexPerThread []int + for _, fn := range opts { + fn(&options) + } - // progress bar. - progress *pb.ProgressBar + return NewForOptions(options) } -// NewWord2vec creates *Word2Vec. -func NewWord2vec(option *model.Option, word2vecOption *Word2vecOption) *Word2vec { - return &Word2vec{ - Option: option, - Word2vecOption: word2vecOption, +func NewForOptions(opts Options) (model.Model, error) { + // TODO: validate Options + v := verbose.New(opts.ModelOptions.Verbose) + return &word2vec{ + opts: opts, - currentlr: option.Initlr, - trained: make(chan struct{}), - } -} + corpus: corpus.New(opts.CorpusOptions, v), -func (w *Word2vec) initialize() error { - // Store subsumple before training. - w.subSamples = make([]float64, w.Word2vecCorpus.Size()) - for i := 0; i < w.Word2vecCorpus.Size(); i++ { - z := float64(w.Word2vecCorpus.IDFreq(i)) / float64(w.Word2vecCorpus.TotalFreq()) - w.subSamples[i] = (math.Sqrt(z/w.SubsampleThreshold) + 1.0) * - w.SubsampleThreshold / z - } + currentlr: opts.ModelOptions.Initlr, + + verbose: v, + }, nil +} - // Initialize word vector. - vectorSize := w.Word2vecCorpus.Size() * w.Dimension - w.vector = make([]float64, vectorSize) - for i := 0; i < vectorSize; i++ { - w.vector[i] = (rand.Float64() - 0.5) / float64(w.Dimension) +func (w *word2vec) preTrain(r io.Reader) error { + if err := w.corpus.Build(r); err != nil { + return err } - // Initialize optimizer. - return w.Opt.initialize(w.Word2vecCorpus, w.Dimension) -} + dic, dim := w.corpus.Dictionary(), w.opts.ModelOptions.Dim -// Train trains words' vector on corpus. -func (w *Word2vec) Train(f io.Reader) error { - c := corpus.NewWord2vecCorpus() - if err := c.Parse(f, w.ToLower, w.MinCount, w.BatchSize, w.Verbose); err != nil { - return errors.Wrap(err, "Failed to parse corpus") + w.param = matrix.New( + dic.Len(), + dim, + func(vec []float64) { + for i := 0; i < dim; i++ { + vec[i] = (rand.Float64() - 0.5) / float64(dim) + } + }, + ) + + w.subsampler = subsample.New(dic, w.opts.SubsampleThreshold) + + switch w.opts.ModelType { + case SkipGram: + w.mod = newSkipGram(w.opts) + case Cbow: + w.mod = newCbow(w.opts) + default: + return invalidModelTypeError(w.opts.ModelType) } - w.Word2vecCorpus = c - if err := w.initialize(); err != nil { - return errors.Wrap(err, "Failed to initialize") + + switch w.opts.OptimizerType { + case NegativeSampling: + w.optimizer = newNegativeSampling( + w.corpus.Dictionary(), + w.opts, + ) + case HierarchicalSoftmax: + w.optimizer = newHierarchicalSoftmax( + w.corpus.Dictionary(), + w.opts, + ) + default: + return invalidOptimizerTypeError(w.opts.OptimizerType) } - return w.train() + return nil } -func (w *Word2vec) train() error { - document := w.Document - documentSize := len(document) - if documentSize <= 0 { - return errors.New("No words for training") +func (w *word2vec) Train(r io.Reader) error { + if err := w.preTrain(r); err != nil { + return err } - w.indexPerThread = model.IndexPerThread(w.ThreadSize, documentSize) + doc := w.corpus.Doc() + indexPerThread := modelutil.IndexPerThread( + w.opts.ModelOptions.ThreadSize, + len(doc), + ) - for i := 1; i <= w.Iteration; i++ { - if w.Verbose { - fmt.Printf("Train %d-th:\n", i) - w.progress = pb.New(documentSize).SetWidth(80) - w.progress.Start() - } - go w.observeLearningRate() + for i := 1; i <= w.opts.ModelOptions.Iter; i++ { + trained, clk := make(chan struct{}), clock.New() + go w.observe(trained, clk) - semaphore := make(chan struct{}, w.ThreadSize) - waitGroup := &sync.WaitGroup{} + sem := semaphore.NewWeighted(int64(w.opts.ModelOptions.ThreadSize)) + wg := &sync.WaitGroup{} - for j := 0; j < w.ThreadSize; j++ { - waitGroup.Add(1) - go w.trainPerThread(document[w.indexPerThread[j]:w.indexPerThread[j+1]], w.Mod.trainOne, - semaphore, waitGroup) - } - waitGroup.Wait() - if w.Verbose { - w.progress.Finish() + for i := 0; i < w.opts.ModelOptions.ThreadSize; i++ { + wg.Add(1) + s, e := indexPerThread[i], indexPerThread[i+1] + go w.trainPerThread(doc[s:e], trained, sem, wg) } + + wg.Wait() + close(trained) } return nil } -func (w *Word2vec) trainPerThread(document []int, - trainOne func(document []int, wordIndex int, wordVector []float64, lr float64, optimizer Optimizer), - semaphore chan struct{}, waitGroup *sync.WaitGroup) { - +func (w *word2vec) trainPerThread( + doc []int, + trained chan struct{}, + sem *semaphore.Weighted, + wg *sync.WaitGroup, +) error { defer func() { - <-semaphore - waitGroup.Done() + wg.Done() + sem.Release(1) }() - semaphore <- struct{}{} - for idx, wordID := range document { - if w.Verbose { - w.progress.Increment() - } + if err := sem.Acquire(context.Background(), 1); err != nil { + return err + } - bernoulliTrial := rand.Float64() - p := w.subSamples[wordID] - if p < bernoulliTrial { - continue + dic := w.corpus.Dictionary() + for pos, id := range doc { + if w.subsampler.Trial(id) && dic.IDFreq(id) > w.opts.ModelOptions.MinCount { + w.mod.trainOne(doc, pos, w.currentlr, w.param, w.optimizer) } - trainOne(document, idx, w.vector, w.currentlr, w.Opt) - w.trained <- struct{}{} + trained <- struct{}{} } + + return nil } -func (w *Word2vec) observeLearningRate() { - for range w.trained { - w.trainedWordCount++ - if w.trainedWordCount%w.BatchSize == 0 { - w.currentlr = w.Initlr * (1.0 - float64(w.trainedWordCount)/float64(w.TotalFreq())) - if w.currentlr < w.Initlr*w.Theta { - w.currentlr = w.Initlr * w.Theta +func (w *word2vec) observe(trained chan struct{}, clk *clock.Clock) { + var cnt int + for range trained { + cnt++ + if cnt%w.opts.ModelOptions.BatchSize == 0 { + lower := w.opts.ModelOptions.Initlr * w.opts.Theta + if w.currentlr < lower { + w.currentlr = lower + } else { + w.currentlr = w.opts.ModelOptions.Initlr * (1.0 - float64(cnt)/float64(w.corpus.Len())) } + w.verbose.Do(func() { + fmt.Printf("trained %d words %v\r", cnt, clk.AllElapsed()) + }) } } + w.verbose.Do(func() { + fmt.Printf("trained %d words %v\r\n", cnt, clk.AllElapsed()) + }) } -// Save saves the word vector to output writer. -func (w *Word2vec) Save(output io.Writer) error { - if output == nil { - return errors.New("Invalid output writer: must not be nil") - } - - wordSize := w.Size() - if w.Verbose { - fmt.Println("Save:") - w.progress = pb.New(wordSize).SetWidth(80) - defer w.progress.Finish() - w.progress.Start() - } +func (w *word2vec) Save(f io.Writer, typ save.VectorType) error { + writer := bufio.NewWriter(f) + defer writer.Flush() - var contextVector []float64 - switch opt := w.Opt.(type) { - case *NegativeSampling: - contextVector = opt.ContextVector + dic := w.corpus.Dictionary() + var ctx *matrix.Matrix + ng, ok := w.optimizer.(*negativeSampling) + if ok { + ctx = ng.ctx } var buf bytes.Buffer - for i := 0; i < wordSize; i++ { - word, _ := w.Word(i) + clk := clock.New() + for i := 0; i < dic.Len(); i++ { + word, _ := dic.Word(i) fmt.Fprintf(&buf, "%v ", word) - for j := 0; j < w.Dimension; j++ { + for j := 0; j < w.opts.ModelOptions.Dim; j++ { var v float64 - l := i*w.Dimension + j switch { - case w.SaveVectorType == model.ADD && len(contextVector) != 0: - v = w.vector[l] + contextVector[l] - case w.SaveVectorType == model.NORMAL: - v = w.vector[l] + case typ == save.AggregatedVector && ctx.Row() > i: + v = w.param.Slice(i)[j] + ctx.Slice(i)[j] + case typ == save.SingleVector: + v = w.param.Slice(i)[j] default: - return errors.Errorf("Invalid case to save vector type=%s", w.SaveVectorType) + return save.InvalidVectorTypeError(typ) } fmt.Fprintf(&buf, "%f ", v) } fmt.Fprintln(&buf) - if w.Verbose { - w.progress.Increment() - } + w.verbose.Do(func() { + fmt.Printf("save %d words %v\r", i, clk.AllElapsed()) + }) } - - output.Write(buf.Bytes()) + writer.WriteString(fmt.Sprintf("%v", buf.String())) + w.verbose.Do(func() { + fmt.Printf("save %d words %v\r\n", dic.Len(), clk.AllElapsed()) + }) return nil } diff --git a/pkg/node/node.go b/pkg/node/node.go deleted file mode 100644 index e71dd81..0000000 --- a/pkg/node/node.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package node - -import ( - "errors" - "sort" -) - -// Node stores the node with vector in huffman tree. -type Node struct { - parent *Node - cachePath Nodes - Code int - Value int - Vector []float64 -} - -// GetPath returns the nodes from root to word on huffman tree. -func (n *Node) GetPath() Nodes { - // Reverse - re := func(n Nodes) { - for i, j := 0, len(n)-1; i < j; i, j = i+1, j-1 { - n[i], n[j] = n[j], n[i] - } - } - - trace := func() Nodes { - nodes := make(Nodes, 0) - nodes = append(nodes, n) - for parent := n.parent; parent != nil; parent = parent.parent { - nodes = append(nodes, parent) - } - re(nodes) - return nodes - } - - path := n.cachePath - if path == nil { - path = trace() - } - - n.cachePath = path - return path -} - -// Nodes is the list of Node. -type Nodes []*Node - -func (n *Nodes) Len() int { return len(*n) } -func (n *Nodes) Less(i, j int) bool { return (*n)[i].Value < (*n)[j].Value } -func (n *Nodes) Swap(i, j int) { (*n)[i], (*n)[j] = (*n)[j], (*n)[i] } - -// Build builds huffman tree based on word frequencies. -func (n *Nodes) Build(dimension int) error { - if len(*n) == 0 { - return errors.New("The length of Nodes is 0") - } - sort.Sort(n) - - for len(*n) > 1 { - // Pop - left, right := (*n)[0], (*n)[1] - *n = (*n)[2:] - - parentValue := left.Value + right.Value - parent := &Node{ - Value: parentValue, - Vector: make([]float64, dimension), - } - left.parent = parent - left.Code = 0 - right.parent = parent - right.Code = 1 - - idx := sort.Search(len(*n), func(i int) bool { return (*n)[i].Value >= parentValue }) - - // Insert - *n = append(*n, &Node{}) - copy((*n)[idx+1:], (*n)[idx:]) - (*n)[idx] = parent - } - return nil -} diff --git a/pkg/repl/README.md b/pkg/repl/README.md deleted file mode 100644 index 6f19020..0000000 --- a/pkg/repl/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# REPL - -Similarity search between word vectors with REPL mode. - -## Usage - -``` -Search similar words with REPL mode - -Usage: - wego repl [flags] - -Examples: - wego repl -i example/word_vectors.txt - >> apple + banana - ... - -Flags: - -h, --help help for repl - -i, --inputFile string input file path for trained word vector (default "example/word_vectors.txt") - -r, --rank int how many the most similar words will be displayed (default 10) -``` - -## Example - -Now, it is able to use `+`, `-` for arithmetic operations. - -``` -$ go run wego.go repl -i example/word_vectors_sg.txt ->> a + b - RANK | WORD | SIMILARITY -+------+---------+------------+ - 1 | phi | 0.907975 - 2 | q | 0.904593 - 3 | mathbf | 0.903066 - 4 | cdot | 0.902205 - 5 | b | 0.901952 - 6 | becomes | 0.900346 - 7 | int | 0.898680 - 8 | z | 0.897895 - 9 | named | 0.896480 - 10 | v | 0.895456 -``` diff --git a/pkg/search/README.md b/pkg/search/README.md deleted file mode 100644 index fb20e13..0000000 --- a/pkg/search/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Search - -Similarity search between word vectors. - -## Usage - -``` -Search similar words - -Usage: - wego search [flags] - -Examples: - wego search -i example/word_vectors.txt microsoft - -Flags: - -h, --help help for search - -i, --inputFile string input file path for trained word vector (default "example/input.txt") - -r, --rank int how many the most similar words will be displayed (default 10) -``` - -## Example - -``` -$ go run wego.go search -i example/word_vectors_sg.txt microsoft - RANK | WORD | SIMILARITY -+------+------------+------------+ - 1 | apple | 0.994008 - 2 | operating | 0.992855 - 3 | versions | 0.992800 - 4 | ibm | 0.992232 - 5 | os | 0.989174 - 6 | computers | 0.988998 - 7 | machines | 0.988804 - 8 | dvd | 0.988732 - 9 | cd | 0.988259 - 10 | compatible | 0.988200 -``` diff --git a/pkg/search/describer.go b/pkg/search/describer.go deleted file mode 100644 index 9e8cc0f..0000000 --- a/pkg/search/describer.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "fmt" - "os" - - "github.com/olekukonko/tablewriter" -) - -// Describe shows the similar words list for target word. -func Describe(neighbors Neighbors) error { - table := make([][]string, len(neighbors)) - for k := 0; k < len(neighbors); k++ { - table[k] = []string{ - fmt.Sprintf("%d", k+1), - neighbors[k].word, - fmt.Sprintf("%f", neighbors[k].similarity), - } - } - - writer := tablewriter.NewWriter(os.Stdout) - writer.SetHeader([]string{"Rank", "Word", "Similarity"}) - writer.SetBorder(false) - writer.AppendBulk(table) - writer.Render() - return nil -} diff --git a/pkg/search/describer_test.go b/pkg/search/describer_test.go deleted file mode 100644 index 0d238c1..0000000 --- a/pkg/search/describer_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "testing" -) - -func TestDescribe(t *testing.T) { - neighbors := Neighbors{ - Neighbor{ - word: "a", - similarity: 0.95, - }, - Neighbor{ - word: "b", - similarity: 0.9, - }, - Neighbor{ - word: "c", - similarity: 0.85, - }, - Neighbor{ - word: "d", - similarity: 0.8, - }, - } - if err := Describe(neighbors); err != nil { - t.Errorf("Failed to describe neighbors=%v: %s", neighbors, err.Error()) - } -} diff --git a/pkg/search/neighbor.go b/pkg/search/neighbor.go deleted file mode 100644 index 97c7bc9..0000000 --- a/pkg/search/neighbor.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -// Neighbor stores the word with cosine similarity value on the target. -type Neighbor struct { - word string - similarity float64 -} - -// Neighbors is the list of Sim. -type Neighbors []Neighbor - -func (n Neighbors) Len() int { return len(n) } -func (n Neighbors) Less(i, j int) bool { return n[i].similarity < n[j].similarity } -func (n Neighbors) Swap(i, j int) { n[i], n[j] = n[j], n[i] } diff --git a/pkg/search/neighbor_test.go b/pkg/search/neighbor_test.go deleted file mode 100644 index b23b71a..0000000 --- a/pkg/search/neighbor_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "testing" -) - -func NewFakeNeighbors() Neighbors { - ns := make(Neighbors, 0) - ns = append(ns, Neighbor{ - word: "Cupcake", - similarity: 0.1, - }) - ns = append(ns, Neighbor{ - word: "Donut", - similarity: 0.2, - }) - ns = append(ns, Neighbor{ - word: "Eclair", - similarity: 0.3, - }) - ns = append(ns, Neighbor{ - word: "Froyo", - similarity: 0.4, - }) - ns = append(ns, Neighbor{ - word: "Gingerbread", - similarity: 0.5, - }) - return ns -} - -func TestNeighborsLen(t *testing.T) { - neighbors := NewFakeNeighbors() - - if neighbors.Len() != 5 { - t.Errorf("Expected len=5: %v", neighbors.Len()) - } -} - -func TestNeighborsLess(t *testing.T) { - neighbors := NewFakeNeighbors() - - if !neighbors.Less(0, 3) { - t.Errorf("Expected less(0, 3)=true: neighbors[0].similarity %v vs. neighbors[3].similarity %v", - neighbors[0].similarity, neighbors[3].similarity) - } -} - -func TestNeighborsSwap(t *testing.T) { - neighbors := NewFakeNeighbors() - n0 := neighbors[0] - n3 := neighbors[3] - - neighbors.Swap(0, 3) - - if neighbors[0] != n3 { - t.Errorf("Expected to equal %v to %v", neighbors[0], n3) - } - - if neighbors[3] != n0 { - t.Errorf("Expected to equal %v to %v", neighbors[3], n0) - } -} diff --git a/pkg/search/parser.go b/pkg/search/parser.go deleted file mode 100644 index c440e36..0000000 --- a/pkg/search/parser.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "bufio" - "io" - "strconv" - "strings" - - "github.com/pkg/errors" -) - -type StoreFunc func(string, []float64, int) - -func ParseAll(f io.Reader, store StoreFunc) error { - s := bufio.NewScanner(f) - for s.Scan() { - line := s.Text() - if strings.HasPrefix(line, " ") { - continue - } - word, vec, dim, err := parse(line) - if err != nil { - return err - } - store(word, vec, dim) - } - if err := s.Err(); err != nil && err != io.EOF { - return errors.Wrapf(err, "Failed to scan %v", f) - } - return nil -} - -func parse(line string) (string, []float64, int, error) { - sep := strings.Fields(line) - word := sep[0] - elems := sep[1:] - dim := len(elems) - vec := make([]float64, dim) - for k, elem := range elems { - val, err := strconv.ParseFloat(elem, 64) - if err != nil { - return "", nil, 0, err - } - vec[k] = val - } - return word, vec, dim, nil -} diff --git a/pkg/search/parser_test.go b/pkg/search/parser_test.go deleted file mode 100644 index 7a3ffc7..0000000 --- a/pkg/search/parser_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "bytes" - "io/ioutil" - "testing" -) - -func TestParseAll(t *testing.T) { - f := ioutil.NopCloser(bytes.NewReader([]byte(testVectorStr))) - defer f.Close() - - vectors := make(map[string][]float64) - storeFunc := func(word string, vec []float64, dim int) { - vectors[word] = vec - } - if err := ParseAll(f, storeFunc); err != nil { - t.Errorf("Failed to parse vector file: %s", err.Error()) - } - - if len(vectors) != testNumVector { - t.Errorf("Expected vector len=%d, but got %d", testNumVector, len(vectors)) - } -} - -func TestParse(t *testing.T) { - f := ioutil.NopCloser(bytes.NewReader([]byte(testVectorStr))) - defer f.Close() - - word, vec, _, err := parse("apple 1 1 1 1 1") - if err != nil { - t.Errorf("Failed to parse a vector str: %s", err.Error()) - } - if word != "apple" { - t.Errorf("Expected word=apple, but got %s", word) - } - if len(vec) != 5 { - t.Errorf("Expected vector len=5, but got %d", len(vec)) - } -} diff --git a/pkg/repl/op.go b/pkg/search/repl/op.go similarity index 51% rename from pkg/repl/op.go rename to pkg/search/repl/op.go index d2ab6bc..89d09a9 100644 --- a/pkg/repl/op.go +++ b/pkg/search/repl/op.go @@ -14,27 +14,39 @@ package repl +import ( + "github.com/pkg/errors" +) + type Operator func(float64, float64) float64 -func elementWise(v1, v2 []float64, op Operator) []float64 { +func elementWise(v1, v2 []float64, op Operator) ([]float64, error) { + if len(v1) != len(v2) { + return nil, errors.Errorf("Both lengths of vector must be the same, got %d and %d", len(v1), len(v2)) + } + v := make([]float64, len(v1)) for i := 0; i < len(v1); i++ { - v1[i] = op(v1[i], v2[i]) + v[i] = op(v1[i], v2[i]) } - return v1 + return v, nil } -func addOp(x, y float64) float64 { - return x + y +func add(v1, v2 []float64) ([]float64, error) { + return elementWise(v1, v2, addOp()) } -func Add(v1, v2 []float64) []float64 { - return elementWise(v1, v2, addOp) +func addOp() Operator { + return Operator(func(x, y float64) float64 { + return x + y + }) } -func subOp(x, y float64) float64 { - return x - y +func sub(v1, v2 []float64) ([]float64, error) { + return elementWise(v1, v2, subOp()) } -func Sub(v1, v2 []float64) []float64 { - return elementWise(v1, v2, subOp) +func subOp() Operator { + return Operator(func(x, y float64) float64 { + return x - y + }) } diff --git a/pkg/repl/repl.go b/pkg/search/repl/repl.go similarity index 60% rename from pkg/repl/repl.go rename to pkg/search/repl/repl.go index 8d211ec..62d4f2e 100644 --- a/pkg/repl/repl.go +++ b/pkg/search/repl/repl.go @@ -23,36 +23,46 @@ import ( "github.com/peterh/liner" "github.com/pkg/errors" - "github.com/ynqa/wego/pkg/search" ) +var ( + vector []float64 +) + +type searchParam struct { + dim int + k int +} + type Repl struct { + *liner.State searcher *search.Searcher - line *liner.State - vector []float64 - k int + param *searchParam } -func NewRepl(f io.Reader, k int) (*Repl, error) { - searcher, err := search.NewSearcher(f) +func New(r io.Reader, k int) (*Repl, error) { + searcher, err := search.NewForVectorFile(r) if err != nil { return nil, err } - line := liner.NewLiner() - vector := make([]float64, searcher.Dimension) + if searcher.Items.Empty() { + return nil, errors.New("Number of items for searcher must be over 0") + } return &Repl{ + State: liner.NewLiner(), searcher: searcher, - line: line, - vector: vector, - k: k, + param: &searchParam{ + dim: len(searcher.Items), + k: k, + }, }, nil } func (r *Repl) Run() error { - defer r.line.Close() + defer r.Close() for { - l, err := r.line.Prompt(">> ") + l, err := r.Prompt(">> ") if err != nil { fmt.Println("error: ", err) } @@ -71,7 +81,7 @@ func (r *Repl) Run() error { func (r *Repl) eval(l string) error { defer func() { - r.vector = make([]float64, r.searcher.Dimension) + vector = make([]float64, r.param.dim) }() expr, err := parser.ParseExpr(l) @@ -82,24 +92,23 @@ func (r *Repl) eval(l string) error { var neighbors search.Neighbors switch e := expr.(type) { case *ast.Ident: - neighbors, err = r.searcher.SearchWithQuery(e.String(), r.k) + neighbors, err = r.searcher.InternalSearch(e.String(), r.param.k) if err != nil { fmt.Printf("failed to search with word=%s\n", e.String()) } - return search.Describe(neighbors) case *ast.BinaryExpr: if err := r.evalExpr(expr); err != nil { return err } - neighbors, err := r.searcher.Search(r.vector, r.k) + neighbors, err = r.searcher.Search(vector, r.param.k) if err != nil { - fmt.Printf("failed to search with vector=%v\n", r.vector) + fmt.Printf("failed to search with vector=%v\n", vector) } - return search.Describe(neighbors) default: return errors.Errorf("invalid type %v", e) } - + neighbors.Describe() + return nil } func (r *Repl) evalExpr(expr ast.Expr) error { @@ -114,55 +123,42 @@ func (r *Repl) evalExpr(expr ast.Expr) error { } func (r *Repl) evalBinaryExpr(expr *ast.BinaryExpr) error { - if err := r.evalExpr(expr.X); err != nil { + xi, err := r.evalAsItem(expr.X) + if err != nil { return err } - - if err := r.evalExpr(expr.Y); err != nil { - return err + yi, err := r.evalAsItem(expr.Y) + if err != nil { + return nil } + vector, err = arithmetic(xi.Vector, expr.Op, yi.Vector) + return err +} - x, ok := expr.X.(*ast.Ident) - if ok && isZeros(r.vector) { - xv, ok := r.searcher.Vectors[x.String()] - if !ok { - return errors.Errorf("not found word=%s in vector map", x.String()) - } - copy(r.vector, xv) +func (r *Repl) evalAsItem(expr ast.Expr) (search.Item, error) { + if err := r.evalExpr(expr); err != nil { + return search.Item{}, err } - - y, ok := expr.Y.(*ast.Ident) + v, ok := expr.(*ast.Ident) if !ok { - return errors.Errorf("failed to parse %v", expr.Y) + return search.Item{}, errors.Errorf("failed to parse %v", expr) } - - yv, ok := r.searcher.Vectors[y.String()] + vi, ok := r.searcher.Items.Find(v.String()) if !ok { - return errors.Errorf("not found word=%s in vector map", y.String()) + return search.Item{}, errors.Errorf("not found word=%s in vector map", v.String()) + } else if err := vi.Validate(); err != nil { + return search.Item{}, err } - - var err error - r.vector, err = arithmetic(r.vector, expr.Op, yv) - return err + return vi, nil } func arithmetic(v1 []float64, op token.Token, v2 []float64) ([]float64, error) { switch op { case token.ADD: - v1 = Add(v1, v2) + return add(v1, v2) case token.SUB: - v1 = Sub(v1, v2) + return sub(v1, v2) default: return nil, errors.Errorf("invalid operator %v", op.String()) } - return v1, nil -} - -func isZeros(vec []float64) bool { - for _, v := range vec { - if v != 0. { - return false - } - } - return true } diff --git a/pkg/search/search.go b/pkg/search/search.go index 8c51bad..72c7184 100644 --- a/pkg/search/search.go +++ b/pkg/search/search.go @@ -15,86 +15,200 @@ package search import ( + "fmt" "io" + "math" + "os" "sort" + "github.com/olekukonko/tablewriter" "github.com/pkg/errors" + + "github.com/ynqa/wego/pkg/item" ) -// Searcher stores the elements for cosine similarity. +// Neighbor stores the word with cosine similarity value on the target. +type Neighbor struct { + Word string + Rank uint + Similarity float64 +} + +type Neighbors []Neighbor + +func (neighbors Neighbors) Describe() { + table := make([][]string, len(neighbors)) + for i, n := range neighbors { + table[i] = []string{ + fmt.Sprintf("%d", n.Rank), + n.Word, + fmt.Sprintf("%f", n.Similarity), + } + } + + writer := tablewriter.NewWriter(os.Stdout) + writer.SetHeader([]string{"Rank", "Word", "Similarity"}) + writer.SetBorder(false) + writer.AppendBulk(table) + writer.Render() +} + +type Item struct { + item.Item + Norm float64 +} + +type Items []Item + +func (items Items) Empty() bool { + return len(items) == 0 +} + +func (items Items) Find(word string) (Item, bool) { + for _, item := range items { + if word == item.Word { + return item, true + } + } + return Item{}, false +} + type Searcher struct { - Vectors map[string][]float64 - Dimension int + Items Items } -// NewSearcher creates *Searcher -func NewSearcher(f io.Reader) (*Searcher, error) { - vectors := make(map[string][]float64) - var d int - storeFunc := func(word string, vec []float64, dim int) { - if d == 0 { - d = dim - } else if d != dim { - return +func New(items ...item.Item) (*Searcher, error) { + elems := make(Items, len(items)) + wholeDim := 0 + for i, item := range items { + if err := item.Validate(); err != nil { + return nil, err } - vectors[word] = vec + if i != 0 && wholeDim != item.Dim { + return nil, errors.Errorf("whole of word Dim for searcher must be same, maybe %d but got %d", wholeDim, item.Dim) + } + elems[i] = Item{ + Item: item, + Norm: norm(item.Vector), + } + wholeDim = item.Dim } - if err := ParseAll(f, storeFunc); err != nil { - return nil, errors.Wrap(err, "Failed to parse") + return &Searcher{ + Items: elems, + }, nil +} + +func NewForVectorFile(r io.Reader) (*Searcher, error) { + var ( + elems Items + i int + wholeDim int + ) + err := item.Parse(r, item.ItemOp(func(item item.Item) error { + if err := item.Validate(); err != nil { + return err + } + if i != 0 && wholeDim != item.Dim { + return errors.Errorf("whole of dim for searcher must be same, maybe %d but got %d", wholeDim, item.Dim) + } + elems = append(elems, Item{ + Item: item, + Norm: norm(item.Vector), + }) + i++ + wholeDim = item.Dim + return nil + })) + if err != nil { + return nil, err } return &Searcher{ - Vectors: vectors, - Dimension: d, + Items: elems, }, nil } -// Search searches similar words for query word and returns top-k nearest neighbors with similarity. -func (s *Searcher) SearchWithQuery(query string, k int) (Neighbors, error) { - queryVec, ok := s.Vectors[query] - if !ok { - return nil, errors.Errorf("%s is not found in vector map", query) +func (s *Searcher) InternalSearch(word string, k int) (Neighbors, error) { + var q Item + for _, item := range s.Items { + if item.Word == word { + q = item + break + } + } + if q.Word == "" { + return nil, errors.Errorf("%s is not found in searcher", word) } - queryNorm := norm(queryVec) - if k > len(s.Vectors) { - k = len(s.Vectors) - 1 + neighbors, err := s.search(q, k, word) + if err != nil { + return nil, err } - neighbors := make(Neighbors, k) - for word, vec := range s.Vectors { - if word == query { - continue + idx := -1 + for i, neighbor := range neighbors { + if neighbor.Word == word { + idx = i + break } - n := norm(vec) - neighbors = append(neighbors, Neighbor{ - word: word, - similarity: cosine(queryVec, vec, queryNorm, n), - }) } - sort.Sort(sort.Reverse(neighbors)) - - return neighbors[:k], nil - + if idx >= 0 { + neighbors = append(neighbors[:idx], neighbors[idx+1:]...) + } + return neighbors, nil } -func (s *Searcher) Search(queryVec []float64, k int) (Neighbors, error) { - queryNorm := norm(queryVec) +func (s *Searcher) Search(query []float64, k int) (Neighbors, error) { + return s.search(Item{ + Item: item.Item{ + Vector: query, + }, + Norm: norm(query), + }, k) +} - if k > len(s.Vectors) { - k = len(s.Vectors) +func (s *Searcher) search(query Item, k int, ignoreWord ...string) (Neighbors, error) { + neighbors := make(Neighbors, len(s.Items)) + for i, item := range s.Items { + var ignore bool + for _, w := range ignoreWord { + ignore = ignore || item.Word == w + } + if !ignore { + neighbors[i] = Neighbor{ + Word: item.Word, + Similarity: cosine(query.Vector, item.Vector, query.Norm, item.Norm), + } + } } - neighbors := make(Neighbors, k) - for word, vec := range s.Vectors { - n := norm(vec) - neighbors = append(neighbors, Neighbor{ - word: word, - similarity: cosine(queryVec, vec, queryNorm, n), - }) + sort.SliceStable(neighbors, func(i, j int) bool { + return neighbors[i].Similarity > neighbors[j].Similarity + }) + for i := range neighbors { + neighbors[i].Rank = uint(i) + 1 } + if k > len(s.Items) { + k = len(s.Items) + } + return neighbors[:k], nil +} - sort.Sort(sort.Reverse(neighbors)) +func norm(vec []float64) float64 { + var n float64 + for _, v := range vec { + n += math.Pow(v, 2) + } + return math.Sqrt(n) +} - return neighbors[:k], nil +func cosine(v1, v2 []float64, n1, n2 float64) float64 { + if n1 == 0 || n2 == 0 { + return 0 + } + var dot float64 + for i := range v1 { + dot += v1[i] * v2[i] + } + return dot / n1 / n2 } diff --git a/pkg/search/search_test.go b/pkg/search/search_test.go index a74671a..c2e9f09 100644 --- a/pkg/search/search_test.go +++ b/pkg/search/search_test.go @@ -17,61 +17,197 @@ package search import ( "bytes" "io/ioutil" + "reflect" "testing" -) -func TestSearchWithQuery(t *testing.T) { - f := ioutil.NopCloser(bytes.NewReader([]byte(testVectorStr))) - defer f.Close() + "github.com/stretchr/testify/assert" + + "github.com/ynqa/wego/pkg/item" +) - searcher, err := NewSearcher(f) - if err != nil { - t.Errorf("Failed to create searcher: %s", err.Error()) +func TestNewForVectorFile(t *testing.T) { + testCases := []struct { + name string + contents string + itemSize int + }{ + { + name: "read vector file", + contents: `apple 1 1 1 1 1 + banana 1 1 1 1 1 + chocolate 0 0 0 0 0 + dragon -1 -1 -1 -1 -1`, + itemSize: 4, + }, } - neighbors, err := searcher.SearchWithQuery("banana", 20) - if err != nil { - t.Errorf("Failed to search with word=banana, rank=20: %s", err.Error()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := ioutil.NopCloser(bytes.NewReader([]byte(tc.contents))) + defer r.Close() + + s, _ := NewForVectorFile(r) + assert.Equal(t, tc.itemSize, len(s.Items)) + }) } +} - if len(searcher.Vectors) != testNumVector { - t.Errorf("Expected searcher.Vectors len=%d, but got %d", testNumVector, len(searcher.Vectors)) +func TestInternalSearch(t *testing.T) { + type args struct { + word string + k int } - if len(neighbors) != testNumVector-1 { - t.Errorf("Expected neighbors len=%d, but got %d", testNumVector-1, len(neighbors)) + testCases := []struct { + name string + items []item.Item + args args + expect Neighbors + }{ + { + name: "internal search", + items: []item.Item{ + { + Word: "apple", + Dim: 5, + Vector: []float64{1, 1, 1, 1, 1}, + }, + { + Word: "banana", + Dim: 5, + Vector: []float64{1, 1, 1, 1, 1}, + }, + { + Word: "chocolate", + Dim: 5, + Vector: []float64{0, 0, 0, 0, 0}, + }, + { + Word: "dragon", + Dim: 5, + Vector: []float64{-1, -1, -1, -1, -1}, + }, + }, + args: args{ + word: "apple", + k: 1, + }, + expect: Neighbors{ + { + Word: "banana", + Rank: 1, + Similarity: 1., + }, + }, + }, } - // NOTE: temporarily comment out the following lines - // if neighbors[0].word != "apple" { - // t.Errorf("Expected the most near word is `apple` for `banana`, but got neighbors=%v", neighbors) - // } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s, _ := New(tc.items...) + neighbors, _ := s.InternalSearch(tc.args.word, tc.args.k) + assert.Truef(t, reflect.DeepEqual(neighbors, tc.expect), "Must be equal %v and %v", neighbors, tc.expect) + }) + } } func TestSearch(t *testing.T) { - f := ioutil.NopCloser(bytes.NewReader([]byte(testVectorStr))) - defer f.Close() + type args struct { + query []float64 + k int + } + + testCases := []struct { + name string + items []item.Item + args args + expect Neighbors + }{ + { + name: "internal search", + items: []item.Item{ + { + Word: "apple", + Dim: 5, + Vector: []float64{1, 1, 1, 1, 1}, + }, + { + Word: "banana", + Dim: 5, + Vector: []float64{1, 1, 1, 1, 1}, + }, + { + Word: "chocolate", + Dim: 5, + Vector: []float64{0, 0, 0, 0, 0}, + }, + { + Word: "dragon", + Dim: 5, + Vector: []float64{-1, -1, -1, -1, -1}, + }, + }, + args: args{ + query: []float64{-1, -1, -1, -1, -1}, + k: 1, + }, + expect: Neighbors{ + { + Word: "dragon", + Rank: 1, + Similarity: 1., + }, + }, + }, + } - searcher, err := NewSearcher(f) - if err != nil { - t.Errorf("Failed to create searcher: %s", err.Error()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s, _ := New(tc.items...) + neighbors, _ := s.Search(tc.args.query, tc.args.k) + assert.Truef(t, reflect.DeepEqual(tc.expect, neighbors), "Must be equal %v and %v", tc.expect, neighbors) + }) } +} - neighbors, err := searcher.Search(dragonVector, 20) - if err != nil { - t.Errorf("Failed to search with vector=%v, rank=20: %s", dragonVector, err.Error()) +func TestNorm(t *testing.T) { + testCases := []struct { + name string + vec []float64 + expect float64 + }{ + { + name: "norm", + vec: []float64{1, 1, 1, 1, 0, 0}, + expect: 2., + }, } - if len(searcher.Vectors) != testNumVector { - t.Errorf("Expected searcher.Vectors len=%d, but got %d", testNumVector, len(searcher.Vectors)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, norm(tc.vec)) + }) } +} - if len(neighbors) != testNumVector { - t.Errorf("Expected neighbors len=%d, but got %d", testNumVector, len(neighbors)) +func TestCosine(t *testing.T) { + testCases := []struct { + name string + v1 []float64 + v2 []float64 + expect float64 + }{ + { + name: "cosine", + v1: []float64{1, 1, 1, 1, 0, 0}, + v2: []float64{1, 1, 0, 0, 1, 1}, + expect: 0.5, + }, } - // NOTE: temporarily comment out the following lines - // if neighbors[0].word != "dragon" { - // t.Errorf("Expected the most near word is vector=%v for `dragon`, but got neighbors=%v", dragonVector, neighbors) - // } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, cosine(tc.v1, tc.v2, norm(tc.v1), norm(tc.v2))) + }) + } } diff --git a/pkg/search/testing.go b/pkg/search/testing.go deleted file mode 100644 index ef58433..0000000 --- a/pkg/search/testing.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -const ( - testNumVector = 4 - testVectorStr = `apple 1 1 1 1 1 -banana 1 1 1 1 1 -chocolate 0 0 0 0 0 -dragon -1 -1 -1 -1 -1` -) - -var dragonVector = []float64{-1, -1, -1, -1, -1} diff --git a/pkg/search/util.go b/pkg/search/util.go deleted file mode 100644 index 596e140..0000000 --- a/pkg/search/util.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "math" -) - -func norm(vec []float64) float64 { - var n float64 - for _, v := range vec { - n += math.Pow(v, 2) - } - return math.Sqrt(n) -} - -func cosine(v1, v2 []float64, n1, n2 float64) float64 { - var dot float64 - for i := range v1 { - dot += v1[i] * v2[i] - } - return dot / n1 / n2 -} diff --git a/pkg/search/util_test.go b/pkg/search/util_test.go deleted file mode 100644 index 0d1a902..0000000 --- a/pkg/search/util_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package search - -import ( - "testing" -) - -var ( - v1 = []float64{1, 1, 1, 1, 0, 0} - v2 = []float64{1, 1, 0, 0, 1, 1} -) - -func TestNorm(t *testing.T) { - n1 := norm(v1) - if n1 != 2. { - t.Errorf("Expect norm of v1=%v is 2, but got %f", v1, n1) - } -} - -func TestCosine(t *testing.T) { - n1 := norm(v1) - n2 := norm(v2) - sim := cosine(v1, v2, n1, n2) - if sim != 0.5 { - t.Errorf("Expect sim(v1, v2)=0.5, but got %f", sim) - } -} diff --git a/pkg/timer/timer.go b/pkg/timer/timer.go deleted file mode 100644 index 1d1247d..0000000 --- a/pkg/timer/timer.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright © 2019 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package timer - -import ( - "fmt" - "time" -) - -const ( - ms = 1000 - s = 60 - m = 60 - h = 24 -) - -type Timer struct { - start, last time.Time -} - -func NewTimer() *Timer { - n := time.Now() - return &Timer{ - start: n, - last: n, - } -} - -func (t *Timer) AllElapsed() string { - d := time.Now().Sub(t.start) - return durationFmt(d) -} - -func (t *Timer) Elapsed() string { - n := time.Now() - d := n.Sub(t.last) - t.last = n - return durationFmt(d) -} - -func durationFmt(d time.Duration) string { - second := int(d.Seconds()) % s - minute := int(d.Minutes()) % m - hour := int(d.Hours()) % h - day := int(d / (h * time.Hour)) - millisecond := int(d/time.Millisecond) - (second * ms) - (minute * ms * s) - (hour * ms * s * m) - (day * ms * s * m * h) - switch { - case day > 0: - return fmt.Sprintf("%dd%dh%dm%ds%dms", day, hour, minute, second, millisecond) - case hour > 0: - return fmt.Sprintf("%dh%dm%ds%dms", hour, minute, second, millisecond) - case minute > 0: - return fmt.Sprintf("%dm%ds%dms", minute, second, millisecond) - case second > 0: - return fmt.Sprintf("%ds%dms", second, millisecond) - case millisecond > 0: - return fmt.Sprintf("%dms", millisecond) - default: - return fmt.Sprint("0ms") - } -} diff --git a/pkg/validate/validate.go b/pkg/validate/validate.go deleted file mode 100644 index 09b84e2..0000000 --- a/pkg/validate/validate.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validate - -import ( - "os" -) - -// FileExists validates whether the file path exists or not. -func FileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} diff --git a/pkg/validate/validate_test.go b/pkg/validate/validate_test.go deleted file mode 100644 index 1aa84b1..0000000 --- a/pkg/validate/validate_test.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright © 2017 Makoto Ito -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validate - -import ( - "testing" -) - -func TestFileExists(t *testing.T) { - if FileExists("fake.go") { - t.Error("fake.go is not existed") - } -} diff --git a/pkg/verbose/verbose.go b/pkg/verbose/verbose.go new file mode 100644 index 0000000..3d91fcc --- /dev/null +++ b/pkg/verbose/verbose.go @@ -0,0 +1,17 @@ +package verbose + +type Verbose struct { + flag bool +} + +func New(flag bool) *Verbose { + return &Verbose{ + flag: flag, + } +} + +func (v *Verbose) Do(fn func()) { + if v.flag { + fn() + } +} diff --git a/scripts/demo.sh b/scripts/demo.sh index 1b18841..71799f6 100755 --- a/scripts/demo.sh +++ b/scripts/demo.sh @@ -1,4 +1,4 @@ -#!/bin/sh -e +#!/bin/bash -e if [ ! -e text8 ]; then echo "Download text8 corpus" @@ -13,5 +13,5 @@ if [ ! -e text8 ]; then rm text8.zip fi -go run wego.go word2vec --verbose --inputFile text8 --model skip-gram --optimizer ns --outputFile example/word_vectors_sg.txt -d 100 && \ - echo "Save trained vectors to example/word_vectors_sg.txt" +go run wego.go word2vec --verbose --inputFile text8 --model skipgram --optimizer ns --outputFile examples/word_vectors.txt -d 100 && \ + echo "Save trained vectors to examples/word_vectors.txt" diff --git a/scripts/e2e.sh b/scripts/e2e.sh index c9c0175..6f9c9c8 100755 --- a/scripts/e2e.sh +++ b/scripts/e2e.sh @@ -1,4 +1,4 @@ -#!/bin/sh -e +#!/bin/bash -e e2e=$(basename $0) @@ -21,11 +21,11 @@ usage() { } function build() { - make build + go build } function clean_examples() { - make clean-example + rm -rf *.txt } function get_corpus() { @@ -44,56 +44,56 @@ function get_corpus() { } function train_word2vec() { - echo "train: skip-gram with ns" - ./wego word2vec -i text8 -o example/word2vec_sg_ns.txt \ - --model skip-gram --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-vec add --thread 20 - echo "train: skip-gram with hs" - ./wego word2vec -i text8 -o example/word2vec_sg_hs.txt \ - --model skip-gram --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 + echo "train: skipgram with ns" + ./wego word2vec -i text8 -o word2vec_sg_ns.txt \ + --model skipgram --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-vec agg --thread 20 --batch 100000 + echo "train: skipgram with hs" + ./wego word2vec -i text8 -o word2vec_sg_hs.txt \ + --model skipgram --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 --batch 100000 echo "train: cbow with ns" - ./wego word2vec -i text8 -o example/word2vec_cbow_ns.txt \ - --model cbow --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-vec add --thread 20 + ./wego word2vec -i text8 -o word2vec_cbow_ns.txt \ + --model cbow --optimizer ns -d 100 -w 5 --verbose --iter 3 --min-count 5 --save-vec agg --thread 20 --batch 100000 echo "train: cbow with hs" - ./wego word2vec -i text8 -o example/word2vec_cbow_hs.txt \ - --model cbow --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 + ./wego word2vec -i text8 -o word2vec_cbow_hs.txt \ + --model cbow --optimizer hs -d 100 -w 5 --verbose --iter 3 --min-count 5 --thread 20 --batch 100000 } function train_glove() { echo "train: glove with sgd" - ./wego glove -d 50 -i text8 -o example/glove_sgd.txt \ - --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver sgd --save-vec add --verbose + ./wego glove -d 50 -i text8 -o glove_sgd.txt \ + --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver sgd --save-vec agg --verbose echo "train: glove with adagrad" - ./wego glove -d 50 -i text8 -o example/glove_adagrad.txt \ - --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver adagrad --save-vec add --verbose + ./wego glove -d 50 -i text8 -o glove_adagrad.txt \ + --iter 10 --thread 12 --initlr 0.05 --min-count 5 -w 15 --solver adagrad --save-vec agg --verbose } function train_lexvec() { echo "train: lexvec" - ./wego lexvec -d 50 -i text8 -o example/lexvec.txt \ - --iter 3 --thread 12 --initlr 0.05 --min-count 5 -w 5 --rel ppmi --save-vec add --verbose + ./wego lexvec -d 50 -i text8 -o lexvec.txt \ + --iter 3 --thread 12 --initlr 0.05 --min-count 5 -w 5 --rel ppmi --save-vec agg --verbose } function search_word2vec() { - echo "similarity search: skip-gram with ns" - ./wego search -i example/word2vec_sg_ns.txt microsoft - echo "similarity search: skip-gram with hs" - ./wego search -i example/word2vec_sg_hs.txt microsoft + echo "similarity search: skipgram with ns" + ./wego search -i word2vec_sg_ns.txt microsoft + echo "similarity search: skipgram with hs" + ./wego search -i word2vec_sg_hs.txt microsoft echo "similarity search: cbow with ns" - ./wego search -i example/word2vec_cbow_ns.txt microsoft + ./wego search -i word2vec_cbow_ns.txt microsoft echo "similarity search: cbow with hs" - ./wego search -i example/word2vec_cbow_hs.txt microsoft + ./wego search -i word2vec_cbow_hs.txt microsoft } function search_glove() { echo "similarity search: glove with sgd" - ./wego search -i example/glove_sgd.txt microsoft + ./wego search -i glove_sgd.txt microsoft echo "similarity search: glove with adagrad" - ./wego search -i example/glove_adagrad.txt microsoft + ./wego search -i glove_adagrad.txt microsoft } function search_lexvec() { echo "similarity search: lexvec" - ./wego search -i example/lexvec.txt microsoft + ./wego search -i lexvec.txt microsoft } for OPT in "$@"; do diff --git a/wego.go b/wego.go index f906cb0..9c61708 100644 --- a/wego.go +++ b/wego.go @@ -17,11 +17,43 @@ package main import ( "os" - "github.com/ynqa/wego/cmd" + "github.com/pkg/errors" + "github.com/spf13/cobra" + + "github.com/ynqa/wego/cmd/model/glove" + "github.com/ynqa/wego/cmd/model/lexvec" + "github.com/ynqa/wego/cmd/model/word2vec" + "github.com/ynqa/wego/cmd/search" + "github.com/ynqa/wego/cmd/search/repl" ) func main() { - if err := cmd.RootCmd.Execute(); err != nil { + word2vec := word2vec.New() + glove := glove.New() + lexvec := lexvec.New() + search := search.New() + repl := repl.New() + + cmd := &cobra.Command{ + Use: "wego", + Short: "tools for embedding words into vector space", + RunE: func(cmd *cobra.Command, args []string) error { + return errors.Errorf("Set sub-command. One of %s|%s|%s|%s|%s", + word2vec.Name(), + glove.Name(), + lexvec.Name(), + search.Name(), + repl.Name(), + ) + }, + } + cmd.AddCommand(word2vec) + cmd.AddCommand(glove) + cmd.AddCommand(lexvec) + cmd.AddCommand(search) + cmd.AddCommand(repl) + + if err := cmd.Execute(); err != nil { os.Exit(1) } }