From c1a3a23fc3457e979e1a943a27c6212c684a1aa5 Mon Sep 17 00:00:00 2001 From: Gary Linscott Date: Mon, 12 Mar 2018 20:55:12 -0700 Subject: [PATCH] Match support Support matches on the client, and add support for automatically promoting new networks to best when they finish with more wins than losses versus the previous. --- go/src/client/README.md | 5 + go/src/client/http/http.go | 46 +++- go/src/client/main.go | 256 +++++++++++++++------ go/src/server/db/models.go | 6 + go/src/server/main.go | 331 ++++++++++++++++++++++++--- go/src/server/main_test.go | 188 ++++++++++++++- go/src/server/templates/base.tmpl | 3 + go/src/server/templates/match.tmpl | 28 +++ go/src/server/templates/matches.tmpl | 30 +++ 9 files changed, 783 insertions(+), 110 deletions(-) create mode 100755 go/src/server/templates/match.tmpl create mode 100755 go/src/server/templates/matches.tmpl diff --git a/go/src/client/README.md b/go/src/client/README.md index 0d030a9c3..ee64c9c3c 100644 --- a/go/src/client/README.md +++ b/go/src/client/README.md @@ -8,6 +8,11 @@ export GOPATH=/Users/gary/go:/Users/gary/Development/leela-chess/go ``` Here, I've set my system install of go as the first entry, and then the leela-chess/go directory as the second. +Pre-reqs: +``` +go get -u github.com/notnil/chess +``` + Then you just need to `go build`, and it should produce a `client` executable. # Running diff --git a/go/src/client/http/http.go b/go/src/client/http/http.go index 9f32c9ee5..63357ca90 100644 --- a/go/src/client/http/http.go +++ b/go/src/client/http/http.go @@ -11,20 +11,33 @@ import ( "log" "mime/multipart" "net/http" + "net/url" "os" "path/filepath" + "strconv" + "strings" ) -func postJson(httpClient *http.Client, uri string, target interface{}) error { - r, err := httpClient.Post(uri, "application/json", bytes.NewBuffer([]byte{})) +func postParams(httpClient *http.Client, uri string, data map[string]string, target interface{}) error { + var encoded string + if data != nil { + values := url.Values{} + for key, val := range data { + values.Set(key, val) + } + encoded = values.Encode() + } + r, err := httpClient.Post(uri, "application/x-www-form-urlencoded", strings.NewReader(encoded)) if err != nil { return err } defer r.Body.Close() b, _ := ioutil.ReadAll(r.Body) - err = json.Unmarshal(b, target) - if err != nil { - log.Printf("Bad JSON from %s -- %s\n", uri, string(b)) + if target != nil { + err = json.Unmarshal(b, target) + if err != nil { + log.Printf("Bad JSON from %s -- %s\n", uri, string(b)) + } } return err } @@ -62,15 +75,19 @@ func BuildUploadRequest(uri string, params map[string]string, paramName, path st } type NextGameResponse struct { - Type string - TrainingId uint - NetworkId uint - Sha string + Type string + TrainingId uint + NetworkId uint + Sha string + CandidateSha string + Params string + Flip bool + MatchGameId uint } -func NextGame(httpClient *http.Client, hostname string) (NextGameResponse, error) { +func NextGame(httpClient *http.Client, hostname string, params map[string]string) (NextGameResponse, error) { resp := NextGameResponse{} - err := postJson(httpClient, hostname+"/next_game", &resp) + err := postParams(httpClient, hostname+"/next_game", params, &resp) if len(resp.Sha) == 0 { return resp, errors.New("Server gave back empty SHA") @@ -79,6 +96,13 @@ func NextGame(httpClient *http.Client, hostname string) (NextGameResponse, error return resp, err } +func UploadMatchResult(httpClient *http.Client, hostname string, match_game_id uint, result int, pgn string, params map[string]string) error { + params["match_game_id"] = strconv.Itoa(int(match_game_id)) + params["result"] = strconv.Itoa(result) + params["pgn"] = pgn + return postParams(httpClient, hostname+"/match_result", params, nil) +} + func DownloadNetwork(httpClient *http.Client, hostname string, networkPath string, sha string) error { uri := hostname + fmt.Sprintf("/get_network?sha=%s", sha) r, err := httpClient.Get(uri) diff --git a/go/src/client/main.go b/go/src/client/main.go index 91064578f..3fc19cb24 100644 --- a/go/src/client/main.go +++ b/go/src/client/main.go @@ -4,8 +4,10 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "flag" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -14,15 +16,18 @@ import ( "path" "path/filepath" "strconv" + "strings" "time" "client/http" + "github.com/notnil/chess" ) var HOSTNAME = flag.String("hostname", "http://162.217.248.187", "Address of the server") var USER = flag.String("user", "", "Username") var PASSWORD = flag.String("password", "", "Password") -var GPU = flag.Int("gpu", 0, "ID of the OpenCL device to use") +var GPU = flag.Int("gpu", -1, "ID of the OpenCL device to use (-1 for default, or no GPU)") +var DEBUG = flag.Bool("debug", false, "Enable debug mode to see verbose output") type Settings struct { User string @@ -66,15 +71,19 @@ func readSettings(path string) (string, string) { return settings.User, settings.Pass } -func uploadGame(httpClient *http.Client, path string, pgn string, nextGame client.NextGameResponse) error { - extraParams := map[string]string{ - "user": *USER, - "password": *PASSWORD, - "version": "1", - "training_id": strconv.Itoa(int(nextGame.TrainingId)), - "network_id": strconv.Itoa(int(nextGame.NetworkId)), - "pgn": pgn, +func getExtraParams() map[string]string { + return map[string]string{ + "user": *USER, + "password": *PASSWORD, + "version": "2", } +} + +func uploadGame(httpClient *http.Client, path string, pgn string, nextGame client.NextGameResponse) error { + extraParams := getExtraParams() + extraParams["training_id"] = strconv.Itoa(int(nextGame.TrainingId)) + extraParams["network_id"] = strconv.Itoa(int(nextGame.NetworkId)) + extraParams["pgn"] = pgn request, err := client.BuildUploadRequest(*HOSTNAME+"/upload_game", extraParams, "file", path) if err != nil { return err @@ -96,51 +105,47 @@ func uploadGame(httpClient *http.Client, path string, pgn string, nextGame clien return nil } -/* -func playMatch() { - p1 := exec.Command("lczero") - p1_in, _ := p1.StdinPipe() - p1_out, _ := p1.StdoutPipe() - p1.Start() - p1.Write("...") +type CmdWrapper struct { + Cmd *exec.Cmd + Pgn string + Input io.WriteCloser + BestMove chan string } -*/ -func train(networkPath string) (string, string) { - // pid is intended for use in multi-threaded training - pid := os.Getpid() +func (c *CmdWrapper) openInput() { + var err error + c.Input, err = c.Cmd.StdinPipe() + if err != nil { + log.Fatal(err) + } +} +func (c *CmdWrapper) launch(networkPath string, args []string, input bool) { + c.BestMove = make(chan string) + weights := fmt.Sprintf("--weights=%s", networkPath) dir, _ := os.Getwd() - train_dir := path.Join(dir, fmt.Sprintf("data-%v", pid)) - if _, err := os.Stat(train_dir); err == nil { - files, err := ioutil.ReadDir(train_dir) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Cleanup training files:\n") - for _, f := range files { - fmt.Printf("%s/%s\n", train_dir, f.Name()) - } - err = os.RemoveAll(train_dir) - if err != nil { - log.Fatal(err) - } + c.Cmd = exec.Command(path.Join(dir, "lczero"), weights, "-t1") + c.Cmd.Args = append(c.Cmd.Args, args...) + if *GPU != -1 { + c.Cmd.Args = append(c.Cmd.Args, fmt.Sprintf("--gpu=%v", *GPU)) } + if !*DEBUG { + c.Cmd.Args = append(c.Cmd.Args, "--quiet") + } + fmt.Printf("Args: %v\n", c.Cmd.Args) - num_games := 1 - gpu_id := fmt.Sprintf("--gpu=%v", *GPU) - train_cmd := fmt.Sprintf("--start=train %v %v", pid, num_games) - weights := fmt.Sprintf("--weights=%s", networkPath) - // cmd := exec.Command(path.Join(dir, "lczero"), weights, "--randomize", "-n", "-t1", "-p20", "--noponder", "--quiet", train_cmd) - cmd := exec.Command(path.Join(dir, "lczero"), weights, gpu_id, "--randomize", "-n", "-t1", "--quiet", train_cmd) + stdout, err := c.Cmd.StdoutPipe() + if err != nil { + log.Fatal(err) + } - stdout, err := cmd.StdoutPipe() + stderr, err := c.Cmd.StderrPipe() if err != nil { log.Fatal(err) } - stdoutScanner := bufio.NewScanner(stdout) - pgn := "" + go func() { + stdoutScanner := bufio.NewScanner(stdout) reading_pgn := false for stdoutScanner.Scan() { line := stdoutScanner.Text() @@ -150,44 +155,145 @@ func train(networkPath string) (string, string) { } else if line == "END" { reading_pgn = false } else if reading_pgn { - pgn += line + "\n" + c.Pgn += line + "\n" + } else if strings.HasPrefix(line, "bestmove ") { + c.BestMove <- strings.Split(line, " ")[1] } } }() - stderr, err := cmd.StderrPipe() - if err != nil { - log.Fatal(err) - } - stderrScanner := bufio.NewScanner(stderr) go func() { + stderrScanner := bufio.NewScanner(stderr) for stderrScanner.Scan() { fmt.Printf("%s\n", stderrScanner.Text()) } }() - err = cmd.Start() + if input { + c.openInput() + } + + err = c.Cmd.Start() if err != nil { log.Fatal(err) } +} + +func playMatch(baselinePath string, candidatePath string, params []string, flip bool) (int, string) { + baseline := CmdWrapper{} + baseline.launch(baselinePath, params, true) + defer baseline.Input.Close() + + candidate := CmdWrapper{} + candidate.launch(candidatePath, params, true) + defer candidate.Input.Close() + + p1 := &candidate + p2 := &baseline + + if flip { + p2, p1 = p1, p2 + } + + io.WriteString(baseline.Input, "uci\n") + io.WriteString(candidate.Input, "uci\n") + + // Play a game using UCI + var result int + game := chess.NewGame(chess.UseNotation(chess.LongAlgebraicNotation{})) + move_history := "" + for { + moves := game.ValidMoves() + if len(moves) == 0 { + if game.Outcome() == chess.WhiteWon { + result = 1 + } else if game.Outcome() == chess.BlackWon { + result = -1 + } else { + result = 0 + } - err = cmd.Wait() + // Always report the result relative to the candidate engine (which defaults to white, unless flip = true) + if flip { + result = -result + } + break + } + + var p *CmdWrapper + if game.Position().Turn() == chess.White { + p = p1 + } else { + p = p2 + } + io.WriteString(p.Input, "position startpos"+move_history+"\n") + io.WriteString(p.Input, "go\n") + + best_move := <-p.BestMove + err := game.MoveStr(best_move) + if err != nil { + log.Println("Error decoding: " + best_move + " for game:\n" + game.String()) + log.Fatal(err) + } + if len(move_history) == 0 { + move_history = " moves" + } + move_history += " " + best_move + } + + chess.UseNotation(chess.AlgebraicNotation{})(game) + return result, game.String() +} + +func train(networkPath string, params []string) (string, string) { + // pid is intended for use in multi-threaded training + pid := os.Getpid() + + dir, _ := os.Getwd() + train_dir := path.Join(dir, fmt.Sprintf("data-%v", pid)) + if _, err := os.Stat(train_dir); err == nil { + files, err := ioutil.ReadDir(train_dir) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Cleanup training files:\n") + for _, f := range files { + fmt.Printf("%s/%s\n", train_dir, f.Name()) + } + err = os.RemoveAll(train_dir) + if err != nil { + log.Fatal(err) + } + } + + num_games := 1 + train_cmd := fmt.Sprintf("--start=train %v %v", pid, num_games) + params = append(params, train_cmd) + + c := CmdWrapper{} + c.launch(networkPath, params, false) + + err := c.Cmd.Wait() if err != nil { log.Fatal(err) } - return path.Join(train_dir, "training.0.gz"), pgn + return path.Join(train_dir, "training.0.gz"), c.Pgn } -func getNetwork(httpClient *http.Client, sha string) (string, error) { +func getNetwork(httpClient *http.Client, sha string, clearOld bool) (string, error) { // Sha already exists? path := filepath.Join("networks", sha) - if _, err := os.Stat(path); err == nil { - return path, nil + if stat, err := os.Stat(path); err == nil { + if stat.Size() != 0 { + return path, nil + } } - // Clean out any old networks - os.RemoveAll("networks") + if clearOld { + // Clean out any old networks + os.RemoveAll("networks") + } os.MkdirAll("networks", os.ModePerm) fmt.Printf("Downloading network...\n") @@ -199,18 +305,40 @@ func getNetwork(httpClient *http.Client, sha string) (string, error) { return path, nil } -func nextGame(httpClient *http.Client, hostname string) error { - nextGame, err := client.NextGame(httpClient, *HOSTNAME) +func nextGame(httpClient *http.Client) error { + nextGame, err := client.NextGame(httpClient, *HOSTNAME, getExtraParams()) if err != nil { return err } - networkPath, err := getNetwork(httpClient, nextGame.Sha) + var params []string + err = json.Unmarshal([]byte(nextGame.Params), ¶ms) if err != nil { return err } - trainFile, pgn := train(networkPath) - uploadGame(httpClient, trainFile, pgn, nextGame) - return nil + + if nextGame.Type == "match" { + networkPath, err := getNetwork(httpClient, nextGame.Sha, false) + if err != nil { + return err + } + candidatePath, err := getNetwork(httpClient, nextGame.CandidateSha, false) + if err != nil { + return err + } + result, pgn := playMatch(networkPath, candidatePath, params, nextGame.Flip) + client.UploadMatchResult(httpClient, *HOSTNAME, nextGame.MatchGameId, result, pgn, getExtraParams()) + return nil + } else if nextGame.Type == "train" { + networkPath, err := getNetwork(httpClient, nextGame.Sha, true) + if err != nil { + return err + } + trainFile, pgn := train(networkPath, params) + uploadGame(httpClient, trainFile, pgn, nextGame) + return nil + } + + return errors.New("Unknown game type: " + nextGame.Type) } func main() { @@ -229,7 +357,7 @@ func main() { httpClient := &http.Client{} for { - err := nextGame(httpClient, *HOSTNAME) + err := nextGame(httpClient) if err != nil { log.Print(err) log.Print("Sleeping for 30 seconds...") diff --git a/go/src/server/db/models.go b/go/src/server/db/models.go index 9fe95672f..2663a5a99 100644 --- a/go/src/server/db/models.go +++ b/go/src/server/db/models.go @@ -45,12 +45,15 @@ type Match struct { gorm.Model TrainingRunID uint + Parameters string Candidate Network CandidateID uint CurrentBest Network CurrentBestID uint + GamesCreated int + Wins int Losses int Draws int @@ -70,6 +73,9 @@ type MatchGame struct { Version uint Pgn string + Result int + Done bool + Flip bool } type TrainingGame struct { diff --git a/go/src/server/main.go b/go/src/server/main.go index f96edef32..021dfe221 100644 --- a/go/src/server/main.go +++ b/go/src/server/main.go @@ -20,25 +20,91 @@ import ( "github.com/gin-gonic/gin" ) +func checkUser(c *gin.Context) (*db.User, error) { + if len(c.PostForm("user")) == 0 { + return nil, errors.New("No user supplied") + } + + user := &db.User{ + Password: c.PostForm("password"), + } + err := db.GetDB().Where(db.User{Username: c.PostForm("user")}).FirstOrCreate(&user).Error + if err != nil { + return nil, err + } + + // Ensure passwords match + if user.Password != c.PostForm("password") { + return nil, errors.New("Incorrect password") + } + + return user, nil +} + func nextGame(c *gin.Context) { + user, err := checkUser(c) + training_run := db.TrainingRun{ Active: true, } - // TODO(gary): Need to set some sort of priority system here. - err := db.GetDB().Preload("BestNetwork").Where(&training_run).First(&training_run).Error + // TODO(gary): Only really supports one training run right now... + err = db.GetDB().Where(&training_run).First(&training_run).Error if err != nil { log.Println(err) c.String(http.StatusBadRequest, "Invalid training run") return } - // TODO: Check for active matches. + network := db.Network{} + err = db.GetDB().Where("id = ?", training_run.BestNetworkID).First(&network).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + if user != nil { + var match []db.Match + err = db.GetDB().Preload("Candidate").Where("done=false").Limit(1).Find(&match).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + if len(match) > 0 { + // Return this match + match_game := db.MatchGame{ + UserID: user.ID, + MatchID: match[0].ID, + } + err = db.GetDB().Create(&match_game).Error + // Note, this could cause an imbalance of white/black games for a particular match, + // but it's good enough for now. + flip := (match_game.ID & 1) == 1 + db.GetDB().Model(&match_game).Update("flip", flip) + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + result := gin.H{ + "type": "match", + "matchGameId": match_game.ID, + "sha": network.Sha, + "candidateSha": match[0].Candidate.Sha, + "params": match[0].Parameters, + "flip": flip, + } + c.JSON(http.StatusOK, result) + return + } + } result := gin.H{ "type": "train", "trainingId": training_run.ID, - "networkId": training_run.BestNetwork.ID, - "sha": training_run.BestNetwork.Sha, + "networkId": training_run.BestNetworkID, + "sha": network.Sha, "params": training_run.TrainParameters, } c.JSON(http.StatusOK, result) @@ -68,14 +134,9 @@ func computeSha(http_file *multipart.FileHeader) (string, error) { return sha, nil } -func getTrainingRun(training_id string) (*db.TrainingRun, error) { - id, err := strconv.ParseUint(training_id, 10, 32) - if err != nil { - return nil, err - } - +func getTrainingRun(training_id uint) (*db.TrainingRun, error) { var training_run db.TrainingRun - err = db.GetDB().Where("id = ?", id).First(&training_run).Error + err := db.GetDB().Where("id = ?", training_id).First(&training_run).Error if err != nil { return nil, err } @@ -115,6 +176,9 @@ func uploadNetwork(c *gin.Context) { } // Create new network + // TODO(gary): Just hardcoding this for now. + var training_run_id uint = 1 + network.TrainingRunID = training_run_id layers, err := strconv.ParseInt(c.PostForm("layers"), 10, 32) network.Layers = int(layers) filters, err := strconv.ParseInt(c.PostForm("filters"), 10, 32) @@ -141,18 +205,26 @@ func uploadNetwork(c *gin.Context) { return } - // Set the best network of this training_run - training_run, err := getTrainingRun(c.PostForm("training_id")) + // Create a match to see if this network is better + training_run, err := getTrainingRun(training_run_id) if err != nil { log.Println(err) - c.String(http.StatusBadRequest, "Invalid training run") + c.String(500, "Internal error") return } - training_run.BestNetwork = network - err = db.GetDB().Save(training_run).Error + + match := db.Match{ + TrainingRunID: training_run_id, + CandidateID: network.ID, + CurrentBestID: training_run.BestNetworkID, + Done: false, + GameCap: 400, + Parameters: `["--noise"]`, + } + err = db.GetDB().Create(&match).Error if err != nil { log.Println(err) - c.String(500, "Failed to update best training_run") + c.String(500, "Internal error") return } @@ -160,25 +232,23 @@ func uploadNetwork(c *gin.Context) { } func uploadGame(c *gin.Context) { - var user db.User - user.Password = c.PostForm("password") - err := db.GetDB().Where(db.User{Username: c.PostForm("user")}).FirstOrCreate(&user).Error + user, err := checkUser(c) if err != nil { log.Println(err) - c.String(http.StatusBadRequest, "Invalid user") + c.String(http.StatusBadRequest, err.Error()) return } - // Ensure passwords match - if user.Password != c.PostForm("password") { - c.String(http.StatusBadRequest, "Incorrect password") - return + training_id, err := strconv.ParseUint(c.PostForm("training_id"), 10, 32) + if err != nil { + log.Println(err) + c.String(http.StatusBadRequest, "Invalid training_id") } - training_run, err := getTrainingRun(c.PostForm("training_id")) + training_run, err := getTrainingRun(uint(training_id)) if err != nil { log.Println(err) - c.String(http.StatusBadRequest, "Invalid training run") + c.String(500, "Internal error") return } @@ -258,6 +328,123 @@ func getNetwork(c *gin.Context) { c.File(network.Path) } +func setBestNetwork(training_id uint, network_id uint) error { + // Set the best network of this training_run + training_run, err := getTrainingRun(training_id) + if err != nil { + return err + } + err = db.GetDB().Model(&training_run).Update("best_network_id", network_id).Error + if err != nil { + return err + } + return nil +} + +func checkMatchFinished(match_id uint) error { + // Now check to see if match is finished + var match db.Match + err := db.GetDB().Where("id = ?", match_id).First(&match).Error + if err != nil { + return err + } + + // Already done? Just return + if match.Done { + return nil + } + + if match.Wins+match.Losses+match.Draws >= match.GameCap { + err = db.GetDB().Model(&match).Update("done", true).Error + if err != nil { + return err + } + // Update to our new best network + // TODO(SPRT) + if match.Wins > match.Losses { + err = setBestNetwork(match.TrainingRunID, match.CandidateID) + if err != nil { + return err + } + } + } + + return nil +} + +func matchResult(c *gin.Context) { + user, err := checkUser(c) + if err != nil { + log.Println(err) + c.String(http.StatusBadRequest, err.Error()) + return + } + + match_game_id, err := strconv.ParseUint(c.PostForm("match_game_id"), 10, 32) + if err != nil { + log.Println(err) + c.String(http.StatusBadRequest, "Invalid match_game_id") + return + } + + var match_game db.MatchGame + err = db.GetDB().Where("id = ?", match_game_id).First(&match_game).Error + if err != nil { + log.Println(err) + c.String(http.StatusBadRequest, "Invalid match_game") + return + } + + result, err := strconv.ParseInt(c.PostForm("result"), 10, 32) + if err != nil { + log.Println(err) + c.String(http.StatusBadRequest, "Unable to parse result") + return + } + + good_result := result == 0 || result == -1 || result == 1 + if !good_result { + c.String(http.StatusBadRequest, "Bad result") + return + } + + err = db.GetDB().Model(&match_game).Updates(db.MatchGame{ + Result: int(result), + Done: true, + Pgn: c.PostForm("pgn"), + }).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + col := "" + if result == 0 { + col = "draws" + } else if result == 1 { + col = "wins" + } else { + col = "losses" + } + // Atomic update of game count + err = db.GetDB().Exec(fmt.Sprintf("UPDATE matches SET %s = %s + 1 WHERE id = ?", col, col), match_game.MatchID).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + err = checkMatchFinished(match_game.MatchID) + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + c.String(http.StatusOK, fmt.Sprintf("Match game %d successfuly uploaded from user=%s.", match_game.ID, user.Username)) +} + func getActiveUsers() ([]gin.H, error) { rows, err := db.GetDB().Raw(`SELECT username, training_games.version, training_games.created_at, c.count FROM users LEFT JOIN training_games @@ -425,6 +612,29 @@ func game(c *gin.Context) { }) } +func viewMatchGame(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + game := db.MatchGame{ + ID: uint64(id), + } + err = db.GetDB().Where(&game).First(&game).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + c.HTML(http.StatusOK, "game", gin.H{ + "pgn": game.Pgn, + }) +} + func getNetworkCounts(networks []db.Network) map[uint]uint64 { counts := make(map[uint]uint64) for _, network := range networks { @@ -434,7 +644,7 @@ func getNetworkCounts(networks []db.Network) map[uint]uint64 { } func viewNetworks(c *gin.Context) { - // TODO(gary): Whole things needs to take training_run into account... + // TODO(gary): Whole thing needs to take training_run into account... var networks []db.Network err := db.GetDB().Order("id desc").Find(&networks).Error if err != nil { @@ -506,6 +716,63 @@ func viewStats(c *gin.Context) { }) } +func viewMatches(c *gin.Context) { + var matches []db.Match + err := db.GetDB().Order("id desc").Find(&matches).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + json := []gin.H{} + for _, match := range matches { + json = append(json, gin.H{ + "id": match.ID, + "current_id": match.CurrentBestID, + "candidate_id": match.CandidateID, + "score": fmt.Sprintf("%d - %d - %d", match.Wins, match.Losses, match.Draws), + "done": match.Done, + }) + } + + c.HTML(http.StatusOK, "matches", gin.H{ + "matches": json, + }) +} + +func viewMatch(c *gin.Context) { + match := db.Match{} + err := db.GetDB().Where("id = ?", c.Param("id")).First(&match).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + games := []db.MatchGame{} + err = db.GetDB().Where(&db.MatchGame{MatchID: match.ID}).Find(&games).Error + if err != nil { + log.Println(err) + c.String(500, "Internal error") + return + } + + gamesJson := []gin.H{} + for _, game := range games { + gamesJson = append(gamesJson, gin.H{ + "id": game.ID, + "created_at": game.CreatedAt.String(), + "result": game.Result, + "done": game.Done, + }) + } + + c.HTML(http.StatusOK, "match", gin.H{ + "games": gamesJson, + }) +} + func createTemplates() multitemplate.Render { r := multitemplate.New() r.AddFromFiles("index", "templates/base.tmpl", "templates/index.tmpl") @@ -514,6 +781,8 @@ func createTemplates() multitemplate.Render { r.AddFromFiles("networks", "templates/base.tmpl", "templates/networks.tmpl") r.AddFromFiles("training_runs", "templates/base.tmpl", "templates/training_runs.tmpl") r.AddFromFiles("stats", "templates/base.tmpl", "templates/stats.tmpl") + r.AddFromFiles("match", "templates/base.tmpl", "templates/match.tmpl") + r.AddFromFiles("matches", "templates/base.tmpl", "templates/matches.tmpl") return r } @@ -532,9 +801,13 @@ func setupRouter() *gin.Engine { router.GET("/networks", viewNetworks) router.GET("/stats", viewStats) router.GET("/training_runs", viewTrainingRuns) + router.GET("/match/:id", viewMatch) + router.GET("/matches", viewMatches) + router.GET("/match_game/:id", viewMatchGame) router.POST("/next_game", nextGame) router.POST("/upload_game", uploadGame) router.POST("/upload_network", uploadNetwork) + router.POST("/match_result", matchResult) return router } diff --git a/go/src/server/main_test.go b/go/src/server/main_test.go index f3ae670e7..966acbd33 100644 --- a/go/src/server/main_test.go +++ b/go/src/server/main_test.go @@ -10,8 +10,10 @@ import ( "log" "net/http" "net/http/httptest" + "net/url" "os" "server/db" + "strings" "testing" "github.com/gin-gonic/gin" @@ -48,16 +50,21 @@ func (s *StoreSuite) SetupTest() { } db.SetupDB() - network := db.Network{Sha: "abcd", Path: "/tmp/network"} + network := db.Network{Sha: "abcd", Path: "/tmp/network", TrainingRunID: 1} if err := db.GetDB().Create(&network).Error; err != nil { log.Fatal(err) } - training_run := db.TrainingRun{Description: "Testing", BestNetwork: network, Active: true} + training_run := db.TrainingRun{Description: "Testing", BestNetworkID: network.ID, Active: true} if err := db.GetDB().Create(&training_run).Error; err != nil { log.Fatal(err) } + user := db.User{Username: "defaut", Password: "1234"} + if err := db.GetDB().Create(&user).Error; err != nil { + log.Fatal(err) + } + s.w = httptest.NewRecorder() } @@ -71,14 +78,87 @@ func TestStoreSuite(t *testing.T) { suite.Run(t, s) } -func (s *StoreSuite) TestNextGame() { +func postParams(params map[string]string) *strings.Reader { + data := url.Values{} + for key, val := range params { + data.Set(key, val) + } + return strings.NewReader(data.Encode()) +} + +func initMatch(matchDone bool) { + candidate_network := db.Network{Sha: "efgh", Path: "/tmp/network2"} + if err := db.GetDB().Create(&candidate_network).Error; err != nil { + log.Fatal(err) + } + + match := db.Match{ + TrainingRunID: 1, + Parameters: `["--visits 10"]`, + CandidateID: candidate_network.ID, + CurrentBestID: 1, + Done: matchDone, + GameCap: 6, + } + if err := db.GetDB().Create(&match).Error; err != nil { + log.Fatal(err) + } +} + +// For backwards compatibility in short term. +func (s *StoreSuite) TestNextGameNoUser() { + req, _ := http.NewRequest("POST", "/next_game", nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") +} + +// Make sure old users don't get match games +func (s *StoreSuite) TestNextGameNoUserMatch() { + initMatch(false) + req, _ := http.NewRequest("POST", "/next_game", nil) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") +} + +func (s *StoreSuite) TestNextGameUserNoMatch() { + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") s.router.ServeHTTP(s.w, req) assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") } +func (s *StoreSuite) TestNextGameUserMatch() { + initMatch(false) + + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + assert.JSONEqf(s.T(), `{"params":"[\"--visits 10\"]","type":"match","matchGameId":1,"sha":"abcd","candidateSha":"efgh","flip":true}`, s.w.Body.String(), "Body incorrect") +} + +func (s *StoreSuite) TestNextGameUserMatchDone() { + initMatch(true) + + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + // Shouldn't get a match back + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") +} + func (s *StoreSuite) TestUploadGameNewUser() { extraParams := map[string]string{ "user": "foo", @@ -148,13 +228,14 @@ func uploadTestNetwork(s *StoreSuite, contentString string, networkId int) { s.router.ServeHTTP(s.w, req) assert.Equal(s.T(), 400, s.w.Code, s.w.Body.String()) - // Now we should be able to query for this network + // /next_game shouldn't return new network now, since it hasn't passed yet. s.w = httptest.NewRecorder() - sha := sha256.Sum256(content) req, _ = http.NewRequest("POST", "/next_game", nil) s.router.ServeHTTP(s.w, req) assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) - assert.JSONEqf(s.T(), fmt.Sprintf(`{"params":"", "type":"train","trainingId":1,"networkId":%d,"sha":"%x"}`, networkId, sha), s.w.Body.String(), "Body incorrect") + assert.JSONEqf(s.T(), `{"params":"", "type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") + + sha := sha256.Sum256(content) // And let's download it now. s.w = httptest.NewRecorder() @@ -176,5 +257,100 @@ func uploadTestNetwork(s *StoreSuite, contentString string, networkId int) { func (s *StoreSuite) TestUploadNetwork() { uploadTestNetwork(s, "this_is_a_network", 2) + + // We should get a match game. + s.w = httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + sha := sha256.Sum256([]byte("this_is_a_network")) + assert.JSONEqf(s.T(), fmt.Sprintf(`{"params":"","type":"match","matchGameId":1,"sha":"abcd","candidateSha":"%x","flip":true}`, sha), s.w.Body.String(), "Body incorrect") + uploadTestNetwork(s, "network2", 3) } + +func testMatchResult(s *StoreSuite, promote bool) { + initMatch(false) + + for i := 0; i < 6; i++ { + // get the match game + s.w = httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + match_game_id := fmt.Sprintf("%d", i+1) + flip := (i & 1) == 0 + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + assert.JSONEqf(s.T(), fmt.Sprintf(`{"params":"[\"--visits 10\"]","type":"match","matchGameId":%s,"sha":"abcd","candidateSha":"efgh","flip":%t}`, match_game_id, flip), s.w.Body.String(), "Body incorrect") + + // Now, post a result from the match + s.w = httptest.NewRecorder() + + result := -1 + if promote { + result = 1 + } + + req, _ = http.NewRequest("POST", "/match_result", postParams(map[string]string{ + "user": "default", + "password": "1234", + "version": "2", + "match_game_id": match_game_id, + "result": fmt.Sprintf("%d", result), + "pgn": "asdf", + })) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + + // Check that the match game is present now. + match_game := db.MatchGame{} + err := db.GetDB().Where("id = ?", 1).First(&match_game).Error + if err != nil { + log.Fatal(err) + } + + assert.Equal(s.T(), result, match_game.Result) + assert.Equal(s.T(), "asdf", match_game.Pgn) + assert.Equal(s.T(), true, match_game.Done) + + // And now that the match is updated. + match := db.Match{} + err = db.GetDB().Where("id = ?", 1).First(&match).Error + if err != nil { + log.Fatal(err) + } + } + + // Match should be done now + match := db.Match{} + err := db.GetDB().Where("id = ?", 1).First(&match).Error + if err != nil { + log.Fatal(err) + } + assert.Equal(s.T(), true, match.Done) + + // And, now, we shouldn't get a match game back + s.w = httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/next_game", postParams(map[string]string{"user": "default", "password": "1234", "version": "2"})) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + s.router.ServeHTTP(s.w, req) + + assert.Equal(s.T(), 200, s.w.Code, s.w.Body.String()) + if promote { + assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":2,"sha":"efgh"}`, s.w.Body.String(), "Body incorrect") + } else { + assert.JSONEqf(s.T(), `{"params":"","type":"train","trainingId":1,"networkId":1,"sha":"abcd"}`, s.w.Body.String(), "Body incorrect") + } +} + +func (s *StoreSuite) TestPostMatchResultFailed() { + testMatchResult(s, false) +} + +func (s *StoreSuite) TestPostMatchResultSuccess() { + testMatchResult(s, true) +} diff --git a/go/src/server/templates/base.tmpl b/go/src/server/templates/base.tmpl index 7c4d75891..8fcdeaa70 100755 --- a/go/src/server/templates/base.tmpl +++ b/go/src/server/templates/base.tmpl @@ -51,6 +51,9 @@ + diff --git a/go/src/server/templates/match.tmpl b/go/src/server/templates/match.tmpl new file mode 100755 index 000000000..d5084648a --- /dev/null +++ b/go/src/server/templates/match.tmpl @@ -0,0 +1,28 @@ +{{define "content"}} +

Match

+
+ + + + + + + + + + + {{range .games}} + + + + + + + {{end}} + +
Game IdResultFinishedTime
{{.id}}{{.result}}{{.done}}{{.created_at}}
+
+{{end}} + +{{define "scripts"}} +{{end}} diff --git a/go/src/server/templates/matches.tmpl b/go/src/server/templates/matches.tmpl new file mode 100755 index 000000000..eb2743794 --- /dev/null +++ b/go/src/server/templates/matches.tmpl @@ -0,0 +1,30 @@ +{{define "content"}} +

Matches

+
+ + + + + + + + + + + + {{range .matches}} + + + + + + + + {{end}} + +
IdCurrent IDCandidate IDScoreDone
{{.id}}{{.current_id}}{{.candidate_id}}{{.score}}{{.done}}
+
+{{end}} + +{{define "scripts"}} +{{end}}