Skip to content

Commit

Permalink
Merge pull request #58 from ynqa/fix-race-condition-of-cbow
Browse files Browse the repository at this point in the history
Fix race condition
  • Loading branch information
ynqa authored Jul 28, 2022
2 parents dc41a6e + 2dcde9e commit 4ce56c0
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions pkg/model/word2vec/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,23 @@ func (mod *skipGram) trainOne(
}
}

type cbowToken struct {
agg []float64
tmp []float64
}

type cbow struct {
ch chan []float64
ch chan cbowToken
window int
}

func newCbow(opts Options) mod {
ch := make(chan []float64, opts.Goroutines*2)
ch := make(chan cbowToken, opts.Goroutines)
for i := 0; i < opts.Goroutines; i++ {
ch <- make([]float64, opts.Dim)
ch <- cbowToken{
agg: make([]float64, opts.Dim),
tmp: make([]float64, opts.Dim),
}
}
return &cbow{
ch: ch,
Expand All @@ -100,10 +108,11 @@ func (mod *cbow) trainOne(
param *matrix.Matrix,
optimizer optimizer,
) {
agg, tmp := <-mod.ch, <-mod.ch
token := <-mod.ch
agg, tmp := token.agg, token.tmp
defer func() {
mod.ch <- agg
mod.ch <- tmp
token := cbowToken{agg, tmp}
mod.ch <- token
}()
for i := 0; i < len(agg); i++ {
agg[i], tmp[i] = 0, 0
Expand Down

0 comments on commit 4ce56c0

Please sign in to comment.