Skip to content

Commit

Permalink
Merge pull request #44 from ynqa/trainapi
Browse files Browse the repository at this point in the history
Change train apis
  • Loading branch information
ynqa authored Feb 16, 2019
2 parents 6cf4004 + aa2700e commit 4e8bdcb
Show file tree
Hide file tree
Showing 19 changed files with 161 additions and 162 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ It's also able to train word vectors using wego APIs. Examples are as follows.
package main

import (
"os"

"github.com/ynqa/wego/builder"
"github.com/ynqa/wego/model/word2vec"
)

func main() {
b := builder.NewWord2vecBuilder()

b.InputFile("text8").
Dimension(10).
b.Dimension(10).
Window(5).
Model(word2vec.CBOW).
Optimizer(word2vec.NEGATIVE_SAMPLING).
Expand All @@ -103,8 +104,10 @@ func main() {
// Failed to build word2vec.
}

input, _ := os.Open("text8")

// Start to Train.
if err = m.Train(); err != nil {
if err = m.Train(input); err != nil {
// Failed to train by word2vec.
}

Expand Down
27 changes: 1 addition & 26 deletions builder/glove.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
package builder

import (
"os"

"github.com/pkg/errors"
"github.com/spf13/viper"

"github.com/ynqa/wego/config"
"github.com/ynqa/wego/model"
"github.com/ynqa/wego/model/glove"
"github.com/ynqa/wego/validate"
)

// GloveBuilder manages the members to build Model interface.
type GloveBuilder struct {
// input file path.
inputFile string

// common configs.
dimension int
iteration int
Expand All @@ -52,8 +46,6 @@ type GloveBuilder struct {
// NewGloveBuilder creates *GloveBuilder
func NewGloveBuilder() *GloveBuilder {
return &GloveBuilder{
inputFile: config.DefaultInputFile,

dimension: config.DefaultDimension,
iteration: config.DefaultIteration,
minCount: config.DefaultMinCount,
Expand Down Expand Up @@ -95,8 +87,6 @@ func NewGloveBuilderFromViper() (*GloveBuilder, error) {
return nil, errors.Errorf("Invalid solver type=%s", solverTypeStr)
}
return &GloveBuilder{
inputFile: viper.GetString(config.InputFile.String()),

dimension: viper.GetInt(config.Dimension.String()),
iteration: viper.GetInt(config.Iteration.String()),
minCount: viper.GetInt(config.MinCount.String()),
Expand All @@ -114,12 +104,6 @@ func NewGloveBuilderFromViper() (*GloveBuilder, error) {
}, nil
}

// InputFile sets input file string.
func (gb *GloveBuilder) InputFile(inputFile string) *GloveBuilder {
gb.inputFile = inputFile
return gb
}

// Dimension sets dimension of word vector.
func (gb *GloveBuilder) Dimension(dimension int) *GloveBuilder {
gb.dimension = dimension
Expand Down Expand Up @@ -199,15 +183,6 @@ func (gb *GloveBuilder) Alpha(alpha float64) *GloveBuilder {

// Build creates model.Model interface.
func (gb *GloveBuilder) Build() (model.Model, error) {
if !validate.FileExists(gb.inputFile) {
return nil, errors.Errorf("Not such a file %s", gb.inputFile)
}

input, err := os.Open(gb.inputFile)
if err != nil {
return nil, err
}

o := &model.Option{
Dimension: gb.dimension,
Iteration: gb.iteration,
Expand Down Expand Up @@ -237,5 +212,5 @@ func (gb *GloveBuilder) Build() (model.Model, error) {
Alpha: gb.alpha,
}

return glove.NewGlove(input, o, g)
return glove.NewGlove(o, g), nil
}
11 changes: 0 additions & 11 deletions builder/glove_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@ import (
"github.com/ynqa/wego/model/glove"
)

func TestGloveInputFile(t *testing.T) {
b := &GloveBuilder{}

expectedInputFile := "inputfile"
b.InputFile(expectedInputFile)

if b.inputFile != expectedInputFile {
t.Errorf("Expected builder.inputFile=%v: %v", expectedInputFile, b.inputFile)
}
}

func TestGloveDimension(t *testing.T) {
b := &GloveBuilder{}

Expand Down
27 changes: 1 addition & 26 deletions builder/lexvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,17 @@
package builder

import (
"os"

"github.com/pkg/errors"
"github.com/spf13/viper"

"github.com/ynqa/wego/config"
"github.com/ynqa/wego/corpus"
"github.com/ynqa/wego/model"
"github.com/ynqa/wego/model/lexvec"
"github.com/ynqa/wego/validate"
)

// LexvecBuilder manages the members to build Model interface.
type LexvecBuilder struct {
// input file path.
inputFile string

// common configs.
dimension int
iteration int
Expand All @@ -55,8 +49,6 @@ type LexvecBuilder struct {
// NewLexvecBuilder creates *LexvecBuilder.
func NewLexvecBuilder() *LexvecBuilder {
return &LexvecBuilder{
inputFile: config.DefaultInputFile,

dimension: config.DefaultDimension,
iteration: config.DefaultIteration,
minCount: config.DefaultMinCount,
Expand Down Expand Up @@ -103,8 +95,6 @@ func NewLexvecBuilderFromViper() (*LexvecBuilder, error) {
}

return &LexvecBuilder{
inputFile: viper.GetString(config.InputFile.String()),

dimension: viper.GetInt(config.Dimension.String()),
iteration: viper.GetInt(config.Iteration.String()),
minCount: viper.GetInt(config.MinCount.String()),
Expand All @@ -123,12 +113,6 @@ func NewLexvecBuilderFromViper() (*LexvecBuilder, error) {
}, nil
}

// InputFile sets input file string.
func (lb *LexvecBuilder) InputFile(inputFile string) *LexvecBuilder {
lb.inputFile = inputFile
return lb
}

// Dimension sets dimension of word vector.
func (lb *LexvecBuilder) Dimension(dimension int) *LexvecBuilder {
lb.dimension = dimension
Expand Down Expand Up @@ -217,15 +201,6 @@ func (lb *LexvecBuilder) RelationType(typ corpus.RelationType) *LexvecBuilder {

// Build creates Lexvec model.
func (lb *LexvecBuilder) Build() (model.Model, error) {
if !validate.FileExists(lb.inputFile) {
return nil, errors.Errorf("Not such a file %s", lb.inputFile)
}

input, err := os.Open(lb.inputFile)
if err != nil {
return nil, err
}

o := &model.Option{
Dimension: lb.dimension,
Iteration: lb.iteration,
Expand All @@ -247,5 +222,5 @@ func (lb *LexvecBuilder) Build() (model.Model, error) {
RelationType: lb.relationType,
}

return lexvec.NewLexvec(input, o, l)
return lexvec.NewLexvec(o, l), nil
}
27 changes: 1 addition & 26 deletions builder/word2vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
package builder

import (
"os"

"github.com/pkg/errors"
"github.com/spf13/viper"

"github.com/ynqa/wego/config"
"github.com/ynqa/wego/model"
"github.com/ynqa/wego/model/word2vec"
"github.com/ynqa/wego/validate"
)

// Word2vecBuilder manages the members to build Model interface.
type Word2vecBuilder struct {
// input file path.
inputFile string

// common configs.
dimension int
iteration int
Expand All @@ -55,8 +49,6 @@ type Word2vecBuilder struct {
// NewWord2vecBuilder creates *Word2vecBuilder.
func NewWord2vecBuilder() *Word2vecBuilder {
return &Word2vecBuilder{
inputFile: config.DefaultInputFile,

dimension: config.DefaultDimension,
iteration: config.DefaultIteration,
minCount: config.DefaultMinCount,
Expand Down Expand Up @@ -113,8 +105,6 @@ func NewWord2vecBuilderFromViper() (*Word2vecBuilder, error) {
}

return &Word2vecBuilder{
inputFile: viper.GetString(config.InputFile.String()),

dimension: viper.GetInt(config.Dimension.String()),
iteration: viper.GetInt(config.Iteration.String()),
minCount: viper.GetInt(config.MinCount.String()),
Expand All @@ -135,12 +125,6 @@ func NewWord2vecBuilderFromViper() (*Word2vecBuilder, error) {
}, nil
}

// InputFile sets input file string.
func (wb *Word2vecBuilder) InputFile(inputFile string) *Word2vecBuilder {
wb.inputFile = inputFile
return wb
}

// Dimension sets dimension of word vector.
func (wb *Word2vecBuilder) Dimension(dimension int) *Word2vecBuilder {
wb.dimension = dimension
Expand Down Expand Up @@ -238,19 +222,10 @@ func (wb *Word2vecBuilder) Theta(theta float64) *Word2vecBuilder {

// Build creates model.Model interface.
func (wb *Word2vecBuilder) Build() (model.Model, error) {
if !validate.FileExists(wb.inputFile) {
return nil, errors.Errorf("Not such a file %s", wb.inputFile)
}

if wb.optimizer == word2vec.HIERARCHICAL_SOFTMAX && wb.saveVectorType == model.ADD {
return nil, errors.Errorf("Invalid pair of optimizer=%s and save vector type=%s", wb.optimizer, wb.saveVectorType)
}

input, err := os.Open(wb.inputFile)
if err != nil {
return nil, err
}

o := &model.Option{
Dimension: wb.dimension,
Iteration: wb.iteration,
Expand Down Expand Up @@ -291,5 +266,5 @@ func (wb *Word2vecBuilder) Build() (model.Model, error) {
Theta: wb.theta,
}

return word2vec.NewWord2vec(input, o, w)
return word2vec.NewWord2vec(o, w), nil
}
11 changes: 0 additions & 11 deletions builder/word2vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,6 @@ import (
"testing"
)

func TestWord2vecInputFile(t *testing.T) {
b := &Word2vecBuilder{}

expectedInputFile := "inputfile"
b.InputFile(expectedInputFile)

if b.inputFile != expectedInputFile {
t.Errorf("Expected builder.inputFile=%v: %v", expectedInputFile, b.inputFile)
}
}

func TestWord2vecDimension(t *testing.T) {
b := &Word2vecBuilder{}

Expand Down
10 changes: 9 additions & 1 deletion cmd/glove.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ func runGlove() error {
if err != nil {
return err
}
if err := mod.Train(); err != nil {
inputFile := viper.GetString(config.InputFile.String())
if !validate.FileExists(inputFile) {
return errors.Errorf("Not such a file %s", inputFile)
}
input, err := os.Open(inputFile)
if err != nil {
return err
}
if err := mod.Train(input); err != nil {
return err
}
return mod.Save(outputFile)
Expand Down
10 changes: 9 additions & 1 deletion cmd/lexvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,15 @@ func runLexvec() error {
if err != nil {
return err
}
if err := mod.Train(); err != nil {
inputFile := viper.GetString(config.InputFile.String())
if !validate.FileExists(inputFile) {
return errors.Errorf("Not such a file %s", inputFile)
}
input, err := os.Open(inputFile)
if err != nil {
return err
}
if err := mod.Train(input); err != nil {
return err
}
return mod.Save(outputFile)
Expand Down
10 changes: 9 additions & 1 deletion cmd/word2vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,15 @@ func runWord2vec() error {
if err != nil {
return err
}
if err := mod.Train(); err != nil {
inputFile := viper.GetString(config.InputFile.String())
if !validate.FileExists(inputFile) {
return errors.Errorf("Not such a file %s", inputFile)
}
input, err := os.Open(inputFile)
if err != nil {
return err
}
if err := mod.Train(input); err != nil {
return err
}
return mod.Save(outputFile)
Expand Down
2 changes: 1 addition & 1 deletion corpus/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func newCore() *core {
}
}

func (c *core) Parse(f io.ReadCloser, toLower bool, minCount int, batchSize int, verbose bool) error {
func (c *core) Parse(f io.Reader, toLower bool, minCount int, batchSize int, verbose bool) error {
fullDoc := make([]int, 0)
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanWords)
Expand Down
9 changes: 6 additions & 3 deletions example/glove/glove.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
package main

import (
"os"

"github.com/ynqa/wego/builder"
"github.com/ynqa/wego/model/glove"
)

func main() {
b := builder.NewGloveBuilder()

b.InputFile("text8").
Dimension(10).
b.Dimension(10).
Window(5).
Solver(glove.SGD).
Verbose()
Expand All @@ -33,8 +34,10 @@ func main() {
// Failed to build word2vec.
}

input, _ := os.Open("text8")

// Start to Train.
if err = m.Train(); err != nil {
if err = m.Train(input); err != nil {
// Failed to train by word2vec.
}

Expand Down
Loading

0 comments on commit 4e8bdcb

Please sign in to comment.