diff --git a/pkg/model/word2vec/model.go b/pkg/model/word2vec/model.go index e0e29c6..818d9ef 100644 --- a/pkg/model/word2vec/model.go +++ b/pkg/model/word2vec/model.go @@ -77,15 +77,23 @@ func (mod *skipGram) trainOne( } } +type cbowToken struct { + agg []float64 + tmp []float64 +} + type cbow struct { - ch chan []float64 + ch chan cbowToken window int } func newCbow(opts Options) mod { - ch := make(chan []float64, opts.Goroutines*2) + ch := make(chan cbowToken, opts.Goroutines) for i := 0; i < opts.Goroutines; i++ { - ch <- make([]float64, opts.Dim) + ch <- cbowToken{ + agg: make([]float64, opts.Dim), + tmp: make([]float64, opts.Dim), + } } return &cbow{ ch: ch, @@ -100,10 +108,11 @@ func (mod *cbow) trainOne( param *matrix.Matrix, optimizer optimizer, ) { - agg, tmp := <-mod.ch, <-mod.ch + token := <-mod.ch + agg, tmp := token.agg, token.tmp defer func() { - mod.ch <- agg - mod.ch <- tmp + token := cbowToken{agg, tmp} + mod.ch <- token }() for i := 0; i < len(agg); i++ { agg[i], tmp[i] = 0, 0