diff --git a/.gitignore b/.gitignore index 0333209..e8750bc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,16 +5,9 @@ *.o # executables -castro -castro-* -chex -chex-* -moy -moy-* -pentagod -pentagod-* -trex -trex-* +morat-* +test +test-* # valgrind output callgrind.out.* diff --git a/Makefile b/Makefile index 1964445..376b31b 100644 --- a/Makefile +++ b/Makefile @@ -25,24 +25,74 @@ else endif -all: castro chex moy trex pentagod +all: morat-havannah morat-hex morat-pentago morat-rex morat-y -castro: \ - havannah/castro.o \ +test: \ + lib/test.o \ + lib/fileio.o \ + lib/move_test.o \ + lib/outcome.o \ + lib/outcome_test.o \ + lib/sgf_test.o \ + lib/string.o \ + lib/string_test.o \ + lib/zobrist.o \ havannah/agentmcts.o \ havannah/agentmctsthread.o \ + havannah/agentmcts_test.o \ havannah/agentpns.o \ + havannah/agentpns_test.o \ + havannah/board.o \ + havannah/board_test.o \ + hex/agentmcts.o \ + hex/agentmctsthread.o \ + hex/agentmcts_test.o \ + hex/agentpns.o \ + hex/agentpns_test.o \ + hex/board.o \ + hex/board_test.o \ + pentago/agentmcts.o \ + pentago/agentmctsthread.o \ + pentago/agentmcts_test.o \ + pentago/agentpns.o \ + pentago/agentpns_test.o \ + pentago/board.o \ + rex/agentmcts.o \ + rex/agentmctsthread.o \ + rex/agentmcts_test.o \ + rex/agentpns.o \ + rex/agentpns_test.o \ + rex/board.o \ + rex/board_test.o \ + y/agentmcts.o \ + y/agentmctsthread.o \ + y/agentmcts_test.o \ + y/agentpns.o \ + y/agentpns_test.o \ + y/board.o \ + y/board_test.o \ + $(ALARM) + $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) + ./test + +morat-havannah: \ + havannah/main.o \ + havannah/agentmcts.o \ + havannah/agentmctsthread.o \ + havannah/agentpns.o \ + havannah/board.o \ havannah/gtpgeneral.o \ havannah/gtpagent.o \ lib/fileio.o \ lib/gtpcommon.o \ + lib/outcome.o \ lib/string.o \ lib/zobrist.o \ $(ALARM) $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) -pentagod: \ - pentago/pentagod.o \ +morat-pentago: \ + pentago/main.o \ pentago/agentab.o \ pentago/agentmcts.o \ pentago/agentmctsthread.o \ @@ -50,58 +100,64 @@ pentagod: \ pentago/board.o \ pentago/gtpgeneral.o \ pentago/gtpagent.o \ - pentago/move.o \ pentago/moveiterator.o \ lib/fileio.o \ lib/gtpcommon.o \ + lib/outcome.o \ lib/string.o \ $(ALARM) $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) -moy: \ - y/moy.o \ +morat-y: \ + y/main.o \ y/agentmcts.o \ y/agentmctsthread.o \ y/agentpns.o \ + y/board.o \ y/gtpagent.o \ y/gtpgeneral.o \ lib/fileio.o \ lib/gtpcommon.o \ + lib/outcome.o \ lib/string.o \ lib/zobrist.o \ $(ALARM) $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) -chex: \ - hex/chex.o \ +morat-hex: \ + hex/main.o \ hex/agentmcts.o \ hex/agentmctsthread.o \ hex/agentpns.o \ + hex/board.o \ hex/gtpagent.o \ hex/gtpgeneral.o \ lib/fileio.o \ lib/gtpcommon.o \ + lib/outcome.o \ lib/string.o \ lib/zobrist.o \ $(ALARM) $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) - -trex: \ - rex/trex.o \ + +morat-rex: \ + rex/main.o \ rex/agentmcts.o \ rex/agentmctsthread.o \ rex/agentpns.o \ + rex/board.o \ rex/gtpagent.o \ rex/gtpgeneral.o \ lib/fileio.o \ lib/gtpcommon.o \ + lib/outcome.o \ lib/string.o \ lib/zobrist.o \ $(ALARM) $(CXX) $(LDFLAGS) -o $@ $^ $(LOADLIBES) $(LDLIBS) clean: - rm -f */*.o castro moy pentagod chex trex .Makefile + rm -f */*.o test morat-havannah morat-hex morat-pentago morat-rex morat-y .Makefile fresh: clean all diff --git a/README.md b/README.md index 108691d..e4587fa 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,30 @@ Morat is a game playing framework, along with implementations of several games. It includes some general purpose libraries (alarm, time, thread, random), and some game specific libraries (compacting tree, gtp, time controls). -So far it supports 3 games: +It specializes in 2-player, perfect information, zero sum, deterministic games, especially placement games (where RAVE works). + +So far it supports 5 games: * [Havannah](https://en.wikipedia.org/wiki/Havannah) +* [Hex](https://en.wikipedia.org/wiki/Hex_%28board_game%29) +* Rex - Reverse Hex (ie try to force the opponent to connect their edges). * [Y](https://en.wikipedia.org/wiki/Y_%28game%29) * [Pentago](https://en.wikipedia.org/wiki/Pentago) +Potential games: +* [Gomoku](https://en.wikipedia.org/wiki/Gomoku) or more generally [M,n,k](https://en.wikipedia.org/wiki/M,n,k-game) +* [Star](https://en.wikipedia.org/wiki/Star_%28board_game%29) or [*Star](https://en.wikipedia.org/wiki/*Star) +* [Domineering or Cram](https://en.wikipedia.org/wiki/Domineering) +* [Dots and Boxes](https://en.wikipedia.org/wiki/Dots_and_Boxes) + So far it supports 3 algorithms: * [MCTS: Monte-Carlo Tree Search](https://en.wikipedia.org/wiki/Monte-Carlo_tree_search) * [PNS: Proof Number Search](https://en.wikipedia.org/wiki/Proof-number_search) * [AB: Alpha-Beta](https://en.wikipedia.org/wiki/Alpha%E2%80%93beta_pruning) +Potential algorithms: +* [Probability Search](http://www.lamsade.dauphine.fr/~cazenave/papers/probabilitySearch.pdf) +* Conspiracy Number Search + The goal is to make the algorithms game independent, and make it easier to implement new games with strong players. There is quite a bit of work left to make this a reality, so the current work is just to make the game code more similar and then move the code into common libraries. The primary interface is [GTP (Go Text Protocol)](https://en.wikipedia.org/wiki/Go_Text_Protocol), which can be used from: @@ -30,9 +44,13 @@ The primary interface is [GTP (Go Text Protocol)](https://en.wikipedia.org/wiki/ * Check out the code from github * Run ```make``` to compile the code * Run: - * ```./castro``` for Havannah - * ```./moy``` for Y - * ```./pentagod``` for pentago + * ```./morat-havannah``` for Havannah + * ```./morat-hex``` for Hex + * ```./morat-rex``` for Reverse Hex + * ```./morat-y``` for Y + * ```./morat-pentago``` for pentago + +Run ```make test``` to run the test suite. Current test coverage is pretty bad. If you make any changes to the code and want to update the dependencies, just ```make clean```, or ```rm .Makefile```. diff --git a/havannah/agent.h b/havannah/agent.h index 6adecd2..941bb7a 100644 --- a/havannah/agent.h +++ b/havannah/agent.h @@ -3,11 +3,19 @@ //Interface for the various agents: players and solvers +#include "../lib/outcome.h" +#include "../lib/sgf.h" #include "../lib/types.h" #include "board.h" + +namespace Morat { +namespace Havannah { + class Agent { +protected: + typedef std::vector vecmove; public: Agent() { } virtual ~Agent() { } @@ -19,51 +27,57 @@ class Agent { virtual void set_memlimit(uint64_t lim) = 0; // in bytes virtual void clear_mem() = 0; - virtual vector get_pv() const = 0; - string move_stats() const { return move_stats(vector()); } - virtual string move_stats(const vector moves) const = 0; + virtual vecmove get_pv() const = 0; + std::string move_stats() const { return move_stats(vecmove()); } + virtual std::string move_stats(const vecmove moves) const = 0; virtual double gamelen() const = 0; virtual void timedout(){ timeout = true; } + virtual void gen_sgf(SGFPrinter & sgf, int limit) const = 0; + virtual void load_sgf(SGFParser & sgf) = 0; + protected: volatile bool timeout; Board rootboard; - static int solve1ply(const Board & board, unsigned int & nodes) { - int outcome = -3; - int turn = board.toplay(); + static Outcome solve1ply(const Board & board, unsigned int & nodes) { + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; } return outcome; } - static int solve2ply(const Board & board, unsigned int & nodes) { + static Outcome solve2ply(const Board & board, unsigned int & nodes) { int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); + Side op = ~turn; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; - if(board.test_win(*move, opponent) > 0) + if(board.test_outcome(*move, op) == +op) losses++; } if(losses >= 2) - return opponent; + return (Outcome)op; return outcome; } - }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentab.cpp b/havannah/agentab.cpp index 2c66bce..787563e 100644 --- a/havannah/agentab.cpp +++ b/havannah/agentab.cpp @@ -6,6 +6,10 @@ #include "agentab.h" + +namespace Morat { +namespace Havannah { + void AgentAB::search(double time, uint64_t maxiters, int verbose) { reset(); if(rootboard.won() >= 0) @@ -41,8 +45,8 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { if(verbose){ logerr("Finished: " + to_str(nodes_seen) + " nodes in " + to_str(time_used*1000, 0) + " msec: " + to_str((uint64_t)((double)nodes_seen/time_used)) + " Nodes/s\n"); - vector pv = get_pv(); - string pvstr; + vecmove pv = get_pv(); + std::string pvstr; for(auto m : pv) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); @@ -56,11 +60,11 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int depth) { nodes_seen++; - int won = board.won(); - if(won >= 0){ - if(won == 0) + Outcome won = board.won(); + if(won >= Outcome::DRAW){ + if(won == Outcome::DRAW) return SCORE_DRAW; - if(won == board.toplay()) + if(won == +board.toplay()) return SCORE_WIN; return SCORE_LOSS; } @@ -81,8 +85,8 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if(TT && (node = tt_get(board)) && node->depth >= depth){ switch(node->flag){ case VALID: return node->score; - case LBOUND: alpha = max(alpha, node->score); break; - case UBOUND: beta = min(beta, node->score); break; + case LBOUND: alpha = std::max(alpha, node->score); break; + case UBOUND: beta = std::min(beta, node->score); break; default: assert(false && "Unknown flag!"); } if(alpha >= beta) @@ -125,11 +129,11 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d return score; } -string AgentAB::move_stats(vector moves) const { - string s = ""; +std::string AgentAB::move_stats(vecmove moves) const { + std::string s = ""; Board b = rootboard; - for(vector::iterator m = moves.begin(); m != moves.end(); ++m) + for(vecmove::iterator m = moves.begin(); m != moves.end(); ++m) b.move(*m); for(MoveIterator move(b); !move.done(); ++move){ @@ -162,8 +166,8 @@ Move AgentAB::return_move(const Board & board, int verbose) const { return best; } -vector AgentAB::get_pv() const { - vector pv; +std::vector AgentAB::get_pv() const { + vecmove pv; Board b = rootboard; int i = 20; @@ -197,3 +201,6 @@ AgentAB::Node * AgentAB::tt_get(uint64_t h) const { void AgentAB::tt_set(const Node & n) { *(tt(n.hash)) = n; } + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentab.h b/havannah/agentab.h index 646043f..c7411aa 100644 --- a/havannah/agentab.h +++ b/havannah/agentab.h @@ -7,6 +7,10 @@ #include "agent.h" + +namespace Morat { +namespace Havannah { + class AgentAB : public Agent { static const int16_t SCORE_WIN = 32767; static const int16_t SCORE_LOSS = -32767; @@ -30,7 +34,7 @@ class AgentAB : public Agent { Node(uint64_t h = ~0ull, int16_t s = 0, Move b = M_UNKNOWN, int8_t d = 0, int8_t f = 0) : //. int8_t o = -3 hash(h), score(s), bestmove(b), depth(d), flag(f), padding(0xDEAD) { } //, outcome(o) - string to_s() const { + std::string to_s() const { return "score " + to_str(score) + ", depth " + to_str((int)depth) + ", flag " + to_str((int)flag) + @@ -93,8 +97,16 @@ class AgentAB : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(rootboard, verbose); } double gamelen() const { return rootboard.movesremain(); } - vector get_pv() const; - string move_stats(vector moves) const; + vecmove get_pv() const; + std::string move_stats(vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + log("gen_sgf not supported in the ab agent."); + } + + void load_sgf(SGFParser & sgf) { + log("load_sgf not supported in the ab agent."); + } private: int16_t negamax(const Board & board, int16_t alpha, int16_t beta, int depth); @@ -105,3 +117,6 @@ class AgentAB : public Agent { Node * tt_get(const Board & b) const ; void tt_set(const Node & n) ; }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentmcts.cpp b/havannah/agentmcts.cpp index 3cd5ae0..e5d4aa6 100644 --- a/havannah/agentmcts.cpp +++ b/havannah/agentmcts.cpp @@ -10,12 +10,45 @@ #include "agentmcts.h" #include "board.h" + +namespace Morat { +namespace Havannah { + const float AgentMCTS::min_rave = 0.1; +std::string AgentMCTS::Node::to_s() const { + return "AgentMCTS::Node" + ", move " + move.to_s() + + ", exp " + exp.to_s() + + ", rave " + rave.to_s() + + ", know " + to_str(know) + + ", outcome " + to_str((int)outcome.to_i()) + + ", depth " + to_str((int)proofdepth) + + ", best " + bestmove.to_s() + + ", children " + to_str(children.num()); +} + +bool AgentMCTS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 9){ + move = Move(dict["move"]); + exp = ExpPair(dict["exp"]); + rave = ExpPair(dict["rave"]); + know = from_str(dict["know"]); + outcome = Outcome(from_str(dict["outcome"])); + proofdepth = from_str(dict["depth"]); + bestmove = Move(dict["best"]); + // ignore children + return true; + } + return false; +} + void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) + if(rootboard.won() >= Outcome::DRAW || (time <= 0 && max_runs == 0)) return; Time starttime; @@ -78,30 +111,23 @@ void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ } } - if(root.outcome != -3){ - logerr("Solved as a "); - if( root.outcome == 0) logerr("draw\n"); - else if(root.outcome == 3) logerr("draw by simultaneous win\n"); - else if(root.outcome == toplay) logerr("win\n"); - else if(root.outcome == 3-toplay) logerr("loss\n"); - else if(root.outcome == -toplay) logerr("win or draw\n"); - else if(root.outcome == toplay-3) logerr("loss or draw\n"); - } + if(root.outcome != Outcome::UNKNOWN) + logerr("Solved as a " + root.outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } pool.reset(); runs = 0; - if(ponder && root.outcome < 0) + if(ponder && root.outcome < Outcome::DRAW) pool.resume(); } @@ -219,8 +245,8 @@ void AgentMCTS::move(const Move & m){ rootboard.move(m); root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; + if(rootboard.won() < Outcome::DRAW) + root.outcome = Outcome::UNKNOWN; if(ponder) pool.resume(); @@ -233,16 +259,16 @@ double AgentMCTS::gamelen() const { return len.avg(); } -vector AgentMCTS::get_pv() const { - vector pv; +std::vector AgentMCTS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -251,8 +277,8 @@ vector AgentMCTS::get_pv() const { return pv; } -string AgentMCTS::move_stats(vector moves) const { - string s = ""; +std::string AgentMCTS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -273,8 +299,8 @@ string AgentMCTS::move_stats(vector moves) const { return s; } -Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { - if(node->outcome >= 0) +Move AgentMCTS::return_move(const Node * node, Side toplay, int verbose) const { + if(node->outcome >= Outcome::DRAW) return node->bestmove; double val, maxval = -1000000000000.0; //1 trillion @@ -284,10 +310,10 @@ Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss + if(child->outcome >= Outcome::DRAW){ + if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win + else if(child->outcome == Outcome::DRAW) val = -400000000000.0 + child->exp.num(); //longest tie + else val = -800000000000.0 + child->exp.num(); //longest loss }else{ //not proven if(msrave == -1) //num simulations val = child->exp.num(); @@ -315,13 +341,13 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ Node * child = node->children.begin(), * end = node->children.end(); - int toplay = board.toplay(); + Side toplay = board.toplay(); for( ; child != end; child++){ if(child->children.num() == 0) continue; - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved + if( (node->outcome >= Outcome::DRAW && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == Outcome::DRAW)) || //parent is solved, only keep the proof tree, plus heavy draws + (node->outcome < Outcome::DRAW && child->exp.num() > (child->outcome >= Outcome::DRAW ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved board.set(child->move); garbage_collect(board, child); board.unset(child->move); @@ -339,29 +365,16 @@ AgentMCTS::Node * AgentMCTS::find_child(const Node * node, const Move & move) co return NULL; } -void AgentMCTS::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; +void AgentMCTS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.exp.num() >= limit && (side != node.outcome || child.outcome == node.outcome)){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); } } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); } void AgentMCTS::create_children_simple(const Board & board, Node * node){ @@ -386,64 +399,25 @@ void AgentMCTS::create_children_simple(const Board & board, Node * node){ PLUS(nodes, node->children.num()); } -//reads the format from gen_hgf. -void AgentMCTS::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); +void AgentMCTS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, & node); - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); } } - - eat_char(fd, ')'); - - return; } + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentmcts.h b/havannah/agentmcts.h index 5751ad8..21b5538 100644 --- a/havannah/agentmcts.h +++ b/havannah/agentmcts.h @@ -11,6 +11,12 @@ #include "../lib/depthstats.h" #include "../lib/exppair.h" #include "../lib/log.h" +#include "../lib/move.h" +#include "../lib/movelist.h" +#include "../lib/policy_bridge.h" +#include "../lib/policy_instantwin.h" +#include "../lib/policy_lastgoodreply.h" +#include "../lib/policy_random.h" #include "../lib/thread.h" #include "../lib/time.h" #include "../lib/types.h" @@ -19,14 +25,11 @@ #include "agent.h" #include "board.h" #include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" +namespace Morat { +namespace Havannah { + class AgentMCTS : public Agent{ public: @@ -35,7 +38,7 @@ class AgentMCTS : public Agent{ ExpPair rave; ExpPair exp; int16_t know; - int8_t outcome; + Outcome outcome; uint8_t proofdepth; Move move; Move bestmove; //if outcome is set, then bestmove is the way to get there @@ -44,8 +47,8 @@ class AgentMCTS : public Agent{ //seems to need padding to multiples of 8 bytes or it segfaults? //don't forget to update the copy constructor/operator - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } + Node() : know(0), outcome(Outcome::UNKNOWN), proofdepth(0), move(M_NONE) { } + Node(const Move & m, Outcome o = Outcome::UNKNOWN) : know(0), outcome(o), proofdepth(0), move(m) { } Node(const Node & n) { *this = n; } Node & operator = (const Node & n){ if(this != & n){ //don't copy to self @@ -68,18 +71,8 @@ class AgentMCTS : public Agent{ children.swap(n.children); } - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); unsigned int size() const { unsigned int num = children.num(); @@ -142,16 +135,16 @@ class AgentMCTS : public Agent{ class AgentThread : public AgentThreadBase { mutable XORShift_float unitrand; - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; + LastGoodReply last_good_reply; + RandomPolicy random_policy; + ProtectBridge protect_bridge; + InstantWin instant_wins; bool use_rave; //whether to use rave for this simulation bool use_explore; //whether to use exploration for this simulation LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; + MoveList movelist; int stage; //which of the four MCTS stages is it on public: @@ -184,11 +177,11 @@ class AgentMCTS : public Agent{ void walk_tree(Board & board, Node * node, int depth); bool create_children(const Board & board, Node * node); void add_knowledge(const Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); + Node * choose_move(const Node * node, Side toplay, int remain) const; + void update_rave(const Node * node, Side toplay); bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - int rollout(Board & board, Move move, int depth); + Outcome rollout(Board & board, Move move, int depth); Move rollout_choose_move(Board & board, const Move & prev); Move rollout_pattern(const Board & board, const Move & move); }; @@ -269,12 +262,12 @@ class AgentMCTS : public Agent{ Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; bool done() { //solved or finished runs - return (rootboard.won() >= 0 || root.outcome >= 0 || (maxruns > 0 && runs >= maxruns)); + return (rootboard.won() >= Outcome::DRAW || root.outcome >= Outcome::DRAW || (maxruns > 0 && runs >= maxruns)); } bool need_gc() { @@ -300,16 +293,28 @@ class AgentMCTS : public Agent{ gclimit = (int)(gclimit*0.9); //slowly decay to a minimum of 5 } + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0) + limit = root.exp.num()/1000; + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } protected: void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - bool do_backup(Node * node, Node * backup, int toplay); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + bool do_backup(Node * node, Node * backup, Side toplay); + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const ; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentmcts_test.cpp b/havannah/agentmcts_test.cpp new file mode 100644 index 0000000..a634369 --- /dev/null +++ b/havannah/agentmcts_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentmcts.h" + + +using namespace Morat; +using namespace Havannah; + +TEST_CASE("Havannah::AgentMCTS::Node::to_s/from_s", "[havannah][agentmcts]") { + AgentMCTS::Node n(Move("a1")); + auto s = n.to_s(); + AgentMCTS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/havannah/agentmctsthread.cpp b/havannah/agentmctsthread.cpp index 71c1940..5611c6c 100644 --- a/havannah/agentmctsthread.cpp +++ b/havannah/agentmctsthread.cpp @@ -6,6 +6,10 @@ #include "agentmcts.h" + +namespace Morat { +namespace Havannah { + void AgentMCTS::AgentThread::iterate(){ INCR(agent->runs); if(agent->profile){ @@ -19,7 +23,7 @@ void AgentMCTS::AgentThread::iterate(){ use_rave = (unitrand() < agent->userave); use_explore = (unitrand() < agent->useexplore); walk_tree(copy, & agent->root, 0); - agent->root.exp.addv(movelist.getexp(3-agent->rootboard.toplay())); + agent->root.exp.addv(movelist.getexp(~agent->rootboard.toplay())); if(agent->profile){ times[0] += timestamps[1] - timestamps[0]; @@ -30,16 +34,16 @@ void AgentMCTS::AgentThread::iterate(){ } void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); + Side toplay = board.toplay(); - if(!node->children.empty() && node->outcome < 0){ + if(!node->children.empty() && node->outcome < Outcome::DRAW){ //choose a child and recurse Node * child; do{ int remain = board.movesremain(); child = choose_move(node, toplay, remain); - if(child->outcome < 0){ + if(child->outcome < Outcome::DRAW){ movelist.addtree(child->move, toplay); if(!board.move(child->move)){ @@ -71,10 +75,10 @@ void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ timestamps[1] = Time(); } - int won = (agent->minimax ? node->outcome : board.won()); + Outcome won = (agent->minimax ? node->outcome : board.won()); //if it's not already decided - if(won < 0){ + if(won < Outcome::DRAW){ //create children if valid if(node->exp.num() >= agent->visitexpand+1 && create_children(board, node)){ walk_tree(board, node, depth); @@ -119,13 +123,13 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ return false; if(agent->dists || agent->detectdraw){ - dists.run(&board, (agent->dists > 0), (agent->detectdraw ? 0 : board.toplay())); + dists.run(&board, (agent->dists > 0), (agent->detectdraw ? Side::NONE : board.toplay())); if(agent->detectdraw){ -// assert(node->outcome == -3); +// assert(node->outcome < Outcome::DRAW); node->outcome = dists.isdraw(); //could be winnable by only one side - if(node->outcome == 0){ //proven draw, neither side can influence the outcome + if(node->outcome == Outcome::DRAW){ //proven draw, neither side can influence the outcome node->bestmove = *(board.moveit()); //just choose the first move since all are equal at this point node->children.unlock(); return true; @@ -136,6 +140,8 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ CompactTree::Children temp; temp.alloc(board.movesremain(), agent->ctmem); + Side toplay = board.toplay(); + Side opponent = ~toplay; int losses = 0; Node * child = temp.begin(), @@ -147,14 +153,14 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ *child = Node(*move); if(agent->minimax){ - child->outcome = board.test_win(*move); + child->outcome = board.test_outcome(*move); - if(agent->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ + if(agent->minimax >= 2 && board.test_outcome(*move, opponent) == +opponent){ losses++; loss = child; } - if(child->outcome == board.toplay()){ //proven win from here, don't need children + if(child->outcome == +toplay){ //proven win from here, don't need children node->outcome = child->outcome; node->proofdepth = 1; node->bestmove = *move; @@ -182,7 +188,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ macro.exp.addwins(agent->visitexpand); *(temp.begin()) = macro; }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - board.toplay(); + node->outcome = +opponent; node->proofdepth = 2; node->bestmove = loss->move; node->children.unlock(); @@ -191,7 +197,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ } if(agent->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); + std::sort(temp.begin(), temp.end(), sort_node_know); PLUS(agent->nodes, temp.num()); node->children.swap(temp); @@ -200,7 +206,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ return true; } -AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int toplay, int remain) const { +AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, Side toplay, int remain) const { float val, maxval = -1000000000; float logvisits = log(node->exp.num()); int dynwidenlim = (agent->dynwiden > 0 ? (int)(logvisits/agent->logdynwiden)+2 : Board::max_vecsize); @@ -215,11 +221,11 @@ AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int top * end = node->children.end(); for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ + if(child->outcome >= Outcome::DRAW){ if(child->outcome == toplay) //return a win immediately return child; - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse + val = (child->outcome == Outcome::DRAW ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse }else{ val = child->value(raveval, agent->knowledge, agent->fpurgency); if(explore > 0) @@ -248,80 +254,80 @@ backup in this order: 0 lose return true if fully solved, false if it's unknown or partially unknown */ -bool AgentMCTS::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread +bool AgentMCTS::do_backup(Node * node, Node * backup, Side toplay){ + Outcome node_outcome = node->outcome; + if(node_outcome >= Outcome::DRAW) //already proven, probably by a different thread return true; - if(backup->outcome == -3) //nothing proven by this child, so no chance + if(backup->outcome == Outcome::UNKNOWN) //nothing proven by this child, so no chance return false; uint8_t proofdepth = backup->proofdepth; if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; + uint64_t sims = 0, bestsims = 0, outcome = 0, best_outcome = 0; backup = NULL; Node * child = node->children.begin(), * end = node->children.end(); for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions + Outcome child_outcome = child->outcome; //save a copy to avoid race conditions if(proofdepth < child->proofdepth+1) proofdepth = child->proofdepth+1; //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss + if(child_outcome == Outcome::UNKNOWN){ // win/draw/loss outcome = 3; - }else if(childoutcome == toplay){ //win + }else if(child_outcome == toplay){ //win backup = child; outcome = 6; proofdepth = child->proofdepth+1; break; - }else if(childoutcome == 3-toplay){ //loss + }else if(child_outcome == ~toplay){ //loss outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss + }else if(child_outcome == Outcome::DRAW){ //draw + if(node_outcome == -toplay) //draw/loss, ie I can't win outcome = 4; else outcome = 2; - }else if(childoutcome == -toplay){ //win/draw + }else if(child_outcome == -~toplay){ //win/draw, ie opponent can't win outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss + }else if(child_outcome == -toplay){ //draw/loss, ie I can't win outcome = 1; }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); + logerr("child_outcome == " + child_outcome.to_s() + "\n"); assert(false && "How'd I get here? All outcomes should be tested above"); } sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; + if(best_outcome < outcome){ //better outcome is always preferable + best_outcome = outcome; bestsims = sims; backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ + }else if(best_outcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ //find long losses or easy wins/draws bestsims = sims; backup = child; } } - if(bestoutcome == 3) //no win, but found an unknown + if(best_outcome == 3) //no win, but found an unknown return false; } - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ + if(node->outcome.cas(node_outcome, backup->outcome)){ node->bestmove = backup->move; node->proofdepth = proofdepth; }else //if it was in a race, try again, might promote a partial solve to full solve return do_backup(node, backup, toplay); - return (node->outcome >= 0); + return (node->outcome >= Outcome::DRAW); } //update the rave score of all children that were played -void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ +void AgentMCTS::AgentThread::update_rave(const Node * node, Side toplay){ Node * child = node->children.begin(), * childend = node->children.end(); @@ -332,7 +338,7 @@ void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Node * child){ if(agent->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); + int dist = board.dist(node->move, child->move); if(dist < 4) child->know += agent->localreply * (4 - dist); } @@ -354,24 +360,24 @@ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Nod child->know += agent->bridge; if(agent->dists) - child->know += abs(agent->dists) * max(0, board.get_size_d() - dists.get(child->move, board.toplay())); + child->know += abs(agent->dists) * std::max(0, board.get_size_d() - dists.get(child->move, board.toplay())); } //test whether this move is a forced reply to the opponent probing your virtual connections bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) + if(board.dist(move, test) != 1) return false; bool equals = false; int state = 0; - int piece = 3 - board.get(move); + Side piece = ~board.get(move); for(int i = 0; i < 8; i++){ Move cur = move + neighbours[i % 6]; bool on = board.onboard(cur); - int v = 0; + Side v = Side::NONE; if(on) v = board.get(cur); @@ -382,7 +388,7 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //else state = 0; }else if(state == 1){ if(on){ - if(v == 0){ + if(v == Side::NONE){ state = 2; equals = (test == cur); }else if(v != piece) @@ -407,8 +413,8 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //play a random game starting from a board state, and return the results of who won -int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ - int won; +Outcome AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ + Outcome won; if(agent->instantwin) instant_wins.rollout_start(board, agent->instantwin); @@ -423,8 +429,8 @@ int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ board.perm_rings = agent->ringperm; - while((won = board.won()) < 0){ - int turn = board.toplay(); + while((won = board.won()) < Outcome::DRAW){ + Side turn = board.toplay(); board.check_rings = (depth < checkdepth); @@ -438,8 +444,8 @@ int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ gamelen.add(depth); - if(won > 0) - wintypes[won-1][(int)board.getwintype()].add(depth); + if(won > Outcome::DRAW) + wintypes[won.to_i() - 1][(int)board.getwintype()].add(depth); //update the last good reply table if(agent->lastgoodreply) @@ -473,3 +479,6 @@ Move AgentMCTS::AgentThread::rollout_choose_move(Board & board, const Move & pre return random_policy.choose_move(board, prev); } + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentpns.cpp b/havannah/agentpns.cpp index ec270ff..973e80e 100644 --- a/havannah/agentpns.cpp +++ b/havannah/agentpns.cpp @@ -5,6 +5,40 @@ #include "agentpns.h" + +namespace Morat { +namespace Havannah { + +std::string AgentPNS::Node::to_s() const { + return "AgentPNS::Node" + ", move " + move.to_s() + + ", phi " + to_str(phi) + + ", delta " + to_str(delta) + + ", work " + to_str(work) + + ", children " + to_str(children.num()); +} + +bool AgentPNS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 6){ + move = Move(dict["move"]); + phi = from_str(dict["phi"]); + delta = from_str(dict["delta"]); + work = from_str(dict["work"]); + // ignore children + return true; + } + return false; +} + +void AgentPNS::test() { + Node n(Move("a1")); + auto s = n.to_s(); + Node k; + assert(k.from_s(s)); +} + void AgentPNS::search(double time, uint64_t maxiters, int verbose){ max_nodes_seen = maxiters; @@ -32,27 +66,20 @@ void AgentPNS::search(double time, uint64_t maxiters, int verbose){ logerr("Tree depth: " + treelen.to_s() + "\n"); } - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); logerr("Root: " + root.to_s() + "\n"); - int outcome = root.to_outcome(3-toplay); - if(outcome != -3){ - logerr("Solved as a "); - if( outcome == 0) logerr("draw\n"); - else if(outcome == 3) logerr("draw by simultaneous win\n"); - else if(outcome == toplay) logerr("win\n"); - else if(outcome == 3-toplay) logerr("loss\n"); - else if(outcome == -toplay) logerr("win or draw\n"); - else if(outcome == toplay-3) logerr("loss or draw\n"); - } + Outcome outcome = root.to_outcome(~toplay); + if(outcome != Outcome::UNKNOWN) + logerr("Solved as a " + outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } } @@ -83,8 +110,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin unsigned int i = 0; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - unsigned int pd = 1; - int outcome; + unsigned int pd; + Outcome outcome; if(agent->ab){ Board next = board; @@ -94,10 +121,10 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin outcome = (agent->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); }else{ pd = 1; - outcome = board.test_win(*move); + outcome = board.test_outcome(*move); } - if(agent->lbdist && outcome < 0) + if(agent->lbdist && outcome != Outcome::UNKNOWN) pd = dists.get(*move); temp[i] = Node(*move).outcome(outcome, board.toplay(), agent->ties, pd); @@ -132,8 +159,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin } } - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); + tpc = std::min(INF32/2, (td + child->phi - node->delta)); + tdc = std::min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); }else{ tpc = tdc = 0; for(auto & i : node->children) @@ -198,16 +225,16 @@ double AgentPNS::gamelen() const { return rootboard.movesremain(); } -vector AgentPNS::get_pv() const { - vector pv; +std::vector AgentPNS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -216,8 +243,8 @@ vector AgentPNS::get_pv() const { return pv; } -string AgentPNS::move_stats(vector moves) const { - string s = ""; +std::string AgentPNS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -238,7 +265,7 @@ string AgentPNS::move_stats(vector moves) const { return s; } -Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { +Move AgentPNS::return_move(const Node * node, Side toplay, int verbose) const { double val, maxval = -1000000000000.0; //1 trillion Node * ret = NULL, @@ -246,11 +273,11 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - int outcome = child->to_outcome(toplay); - if(outcome >= 0){ - if(outcome == toplay) val = 800000000000.0 - (double)child->work; //shortest win - else if(outcome == 0) val = -400000000000.0 + (double)child->work; //longest tie - else val = -800000000000.0 + (double)child->work; //longest loss + Outcome outcome = child->to_outcome(toplay); + if(outcome >= Outcome::DRAW){ + if( outcome == +toplay) val = 800000000000.0 - (double)child->work; //shortest win + else if(outcome == Outcome::DRAW) val = -400000000000.0 + (double)child->work; //longest tie + else val = -800000000000.0 + (double)child->work; //longest loss }else{ //not proven val = child->work; } @@ -290,3 +317,51 @@ void AgentPNS::garbage_collect(Node * node){ } } } + +void AgentPNS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + node->children.alloc(board.movesremain(), ctmem); + unsigned int i = 0; + for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ + Outcome outcome = board.test_outcome(*move); + node->children[i] = Node(*move).outcome(outcome, board.toplay(), ties, 1); + i++; + } + PLUS(nodes, i); + node->children.shrink(i); //if symmetry, there may be extra moves to ignore +} + +void AgentPNS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.work >= limit && (side != node.to_outcome(~side) || child.to_outcome(side) == node.to_outcome(~side))){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentPNS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, &node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentpns.h b/havannah/agentpns.h index ad33042..b1734b7 100644 --- a/havannah/agentpns.h +++ b/havannah/agentpns.h @@ -3,15 +3,21 @@ //A multi-threaded, tree based, proof number search solver. +#include + #include "../lib/agentpool.h" #include "../lib/compacttree.h" #include "../lib/depthstats.h" #include "../lib/log.h" +#include "../lib/string.h" #include "agent.h" #include "lbdist.h" +namespace Morat { +namespace Havannah { + class AgentPNS : public Agent { static const uint32_t LOSS = (1<<30)-1; static const uint32_t DRAW = (1<<30)-2; @@ -51,33 +57,33 @@ class AgentPNS : public Agent { assert(children.empty()); } - Node & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); + Node & abval(int ab_outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && (ab_outcome == 1 || ab_outcome == -1)) + ab_outcome = (toplay == assign ? 2 : -2); - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } + if( ab_outcome == 0) { phi = value; delta = value; } + else if(ab_outcome == 2) { phi = LOSS; delta = 0; } + else if(ab_outcome == -2) { phi = 0; delta = LOSS; } + else /*(ab_outcome 1||-1)*/ { phi = 0; delta = DRAW; } return *this; } - Node & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; + Node & outcome(Outcome outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && outcome == Outcome::DRAW) + outcome = +assign; - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } + if( outcome == Outcome::UNKNOWN) { phi = value; delta = value; } + else if(outcome == +toplay) { phi = LOSS; delta = 0; } + else if(outcome == +~toplay) { phi = 0; delta = LOSS; } + else /*(outcome == Outcome::DRAW)*/ { phi = 0; delta = DRAW; } return *this; } - int to_outcome(int toplay) const { - if(phi == LOSS) return toplay; - if(delta == LOSS) return 3 - toplay; - if(delta == DRAW) return 0; - return -3; + Outcome to_outcome(Side toplay) const { + if(phi == LOSS) return +toplay; + if(delta == LOSS) return +~toplay; + if(delta == DRAW) return Outcome::DRAW; + return Outcome::UNKNOWN; } bool terminal(){ return (phi == 0 || delta == 0); } @@ -98,15 +104,8 @@ class AgentPNS : public Agent { return num; } - string to_s() const { - return "Node: move " + move.to_s() + - ", phi " + to_str(phi) + - ", delta " + to_str(delta) + - ", work " + to_str(work) + -// ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + -// ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); void swap_tree(Node & n){ children.swap(n.children); @@ -162,7 +161,7 @@ class AgentPNS : public Agent { int ab; // how deep of an alpha-beta search to run at each leaf node bool df; // go depth first? float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 + Side ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 bool lbdist; int numthreads; @@ -172,7 +171,7 @@ class AgentPNS : public Agent { ab = 2; df = true; epsilon = 0.25; - ties = 0; + ties = Side::NONE; lbdist = false; numthreads = 1; pool.set_num_threads(numthreads); @@ -228,7 +227,7 @@ class AgentPNS : public Agent { root.swap_tree(child); if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr(std::string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); assert(nodes == root.size()); @@ -280,12 +279,36 @@ class AgentPNS : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0){ + limit = 0; + //TODO: Set the root.work properly + for(auto & child : root.children) + limit += child.work; + limit /= 1000; + } + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } + + static void test(); private: //remove all the nodes with little work to free up some memory void garbage_collect(Node * node); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/agentpns_test.cpp b/havannah/agentpns_test.cpp new file mode 100644 index 0000000..d840267 --- /dev/null +++ b/havannah/agentpns_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentpns.h" + + +using namespace Morat; +using namespace Havannah; + +TEST_CASE("Havannah::AgentPNS::Node::to_s/from_s", "[havannah][agentpns]") { + AgentPNS::Node n(Move("a1")); + auto s = n.to_s(); + AgentPNS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/havannah/board.cpp b/havannah/board.cpp new file mode 100644 index 0000000..34a5981 --- /dev/null +++ b/havannah/board.cpp @@ -0,0 +1,276 @@ + +#include "board.h" + +namespace Morat { +namespace Havannah { + +std::string Board::Cell::to_s(int i) const { + return "Cell " + to_str(i) +": " + "piece: " + to_str(piece.to_i())+ + ", size: " + to_str((int)size) + + ", parent: " + to_str((int)parent) + + ", corner: " + to_str((int)corner) + "/" + to_str(numcorners()) + + ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + + ", perm: " + to_str((int)perm) + + ", pattern: " + to_str((int)pattern); +} + +std::string empty(Move m) { return "."; } + +std::string Board::to_s(bool color) const { + return to_s(color, empty); +} +std::string Board::to_s(bool color, std::function func) const { + using std::string; + string white = "O", + black = "@", + coord = "", + reset = ""; + if(color){ + string esc = "\033"; + reset = esc + "[0m"; + coord = esc + "[1;37m"; + white = esc + "[1;33m" + "@"; //yellow + black = esc + "[1;34m" + "@"; //blue + } + + string s; + s += string(size + 3, ' '); + for(int i = 0; i < size; i++) + s += " " + coord + to_str(i+1); + s += "\n"; + + for(int y = 0; y < size_d; y++){ + s += string(abs(sizem1 - y) + 2, ' '); + s += coord + char('A' + y); + int end = lineend(y); + for(int x = linestart(y); x < end; x++){ + s += (last == Move(x, y) ? coord + "[" : + last == Move(x-1, y) ? coord + "]" : " "); + Side p = get(x, y); + if( p == Side::NONE) s += reset + func(Move(x, y)); + else if(p == Side::P1) s += white; + else if(p == Side::P2) s += black; + else s += "?"; + } + s += (last == Move(end-1, y) ? coord + "]" : " "); + if(y < sizem1) + s += coord + to_str(size + y + 1); + s += '\n'; + } + + s += reset; + return s; +} + + +int Board::iscorner(int x, int y) const { + if(!onboard(x,y)) + return -1; + + int m = sizem1, e = size_d-1; + + if(x == 0 && y == 0) return 0; + if(x == m && y == 0) return 1; + if(x == e && y == m) return 2; + if(x == e && y == e) return 3; + if(x == m && y == e) return 4; + if(x == 0 && y == m) return 5; + + return -1; +} + +int Board::isedge(int x, int y) const { + if(!onboard(x,y)) + return -1; + + int m = sizem1, e = size_d-1; + + if(y == 0 && x != 0 && x != m) return 0; + if(x-y == m && x != m && x != e) return 1; + if(x == e && y != m && y != e) return 2; + if(y == e && x != e && x != m) return 3; + if(y-x == m && x != m && x != 0) return 4; + if(x == 0 && y != m && y != 0) return 5; + + return -1; +} + +// do a depth first search for a ring +bool Board::checkring_df(const MoveValid & pos, const Side turn) const { + const Cell * start = cell(pos); + start->mark = 1; + bool success = false; + for(int i = 0; i < 4; i++){ //4 instead of 6 since any ring must have its first endpoint in the first 4 + MoveValid loc = nb_begin(pos)[i]; + + if(!loc.onboard()) + continue; + + const Cell * g = cell(loc); + + if(turn != g->piece) + continue; + + g->mark = 1; + success = followring(loc, i, turn, (perm_rings - g->perm)); + g->mark = 0; + + if(success) + break; + } + start->mark = 0; + return success; +} +// only take the 3 directions that are valid in a ring +// the backwards directions are either invalid or not part of the shortest loop +bool Board::followring(const MoveValid & cur, const int & dir, const Side & turn, const int & permsneeded) const { + for(int i = 5; i <= 7; i++){ + int nd = (dir + i) % 6; + MoveValid next = nb_begin(cur)[nd]; + + if(!next.onboard()) + continue; + + const Cell * g = cell(next); + + if(g->mark) + return (permsneeded <= 0); + + if(turn != g->piece) + continue; + + g->mark = 1; + bool success = followring(next, nd, turn, (permsneeded - g->perm)); + g->mark = 0; + + if(success) + return true; + } + return false; +} + +// do an O(1) ring check +// must be done before placing the stone and joining it with the neighbouring groups +bool Board::checkring_o1(const MoveValid & pos, const Side turn) const { + static const unsigned char ringdata[64][10] = { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000000 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000001 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000010 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000011 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000100 + {1, 3, 5, 0, 0, 0, 0, 0, 0, 0}, //000101 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000110 + {3,16,10, 9, 0, 0, 0, 0, 0, 0}, //000111 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //001000 + {1, 2, 5, 0, 0, 0, 0, 0, 0, 0}, //001001 + {1, 2, 4, 0, 0, 0, 0, 0, 0, 0}, //001010 + {1, 2, 4, 0, 0, 0, 0, 0, 0, 0}, //001011 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //001100 + {1, 2, 5, 0, 0, 0, 0, 0, 0, 0}, //001101 + {3,15, 9, 8, 0, 0, 0, 0, 0, 0}, //001110 + {4,16,10, 9,15, 8, 9, 0, 0, 0}, //001111 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //010000 + {1, 1, 5, 0, 0, 0, 0, 0, 0, 0}, //010001 + {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //010010 + {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //010011 + {1, 1, 3, 0, 0, 0, 0, 0, 0, 0}, //010100 + {2, 1, 3, 5, 0, 0, 0, 0, 0, 0}, //010101 + {1, 1, 3, 0, 0, 0, 0, 0, 0, 0}, //010110 + {7,16,10, 9, 1, 3, 0, 0, 0, 0}, //010111 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //011000 + {1, 1, 5, 0, 0, 0, 0, 0, 0, 0}, //011001 + {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //011010 + {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //011011 + {3,14, 8, 7, 0, 0, 0, 0, 0, 0}, //011100 + {7,14, 8, 7, 1, 5, 0, 0, 0, 0}, //011101 + {4,15, 9, 8,14, 7, 8, 0, 0, 0}, //011110 + {5,16,10, 9,15, 8, 9,14, 8, 7}, //011111 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //100000 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //100001 + {1, 0, 4, 0, 0, 0, 0, 0, 0, 0}, //100010 + {3,17,11,10, 0, 0, 0, 0, 0, 0}, //100011 + {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100100 + {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100101 + {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100110 + {4,17,11,10,16, 9,10, 0, 0, 0}, //100111 + {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101000 + {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101001 + {2, 0, 2, 4, 0, 0, 0, 0, 0, 0}, //101010 + {7,17,11,10, 0, 2, 0, 0, 0, 0}, //101011 + {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101100 + {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101101 + {7,15, 9, 8, 0, 2, 0, 0, 0, 0}, //101110 + {5,17,11,10,16, 9,10,15, 9, 8}, //101111 + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //110000 + {3,12, 6,11, 0, 0, 0, 0, 0, 0}, //110001 + {1, 0, 4, 0, 0, 0, 0, 0, 0, 0}, //110010 + {4,12, 6,11,17,10,11, 0, 0, 0}, //110011 + {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //110100 + {7,12, 6,11, 0, 3, 0, 0, 0, 0}, //110101 + {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //110110 + {5,12, 6,11,17,10,11,16,10, 9}, //110111 + {3,13, 7, 6, 0, 0, 0, 0, 0, 0}, //111000 + {4,13, 7, 6,12,11, 6, 0, 0, 0}, //111001 + {7,13, 7, 6, 0, 4, 0, 0, 0, 0}, //111010 + {5,13, 7, 6,12,11, 6,17,11,10}, //111011 + {4,14, 8, 7,13, 6, 7, 0, 0, 0}, //111100 + {5,14, 8, 7,13, 6, 7,12, 6,11}, //111101 + {5,15, 9, 8,14, 7, 8,13, 7, 6}, //111110 + {6, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //111111 + }; + + int bitpattern = 0; + const MoveValid * s = nb_begin(pos); + for(const MoveValid * i = s, *e = nb_end(i); i < e; i++){ + bitpattern <<= 1; + if(i->onboard() && turn == get(i->xy)) + bitpattern |= 1; + } + + const unsigned char * d = ringdata[bitpattern]; + + switch(d[0]){ + case 0: //no ring (000000, 000001, 000011) + return false; + + case 1: //simple case (000101, 001101, 001011, 011011) + return (find_group(s[d[1]]) == find_group(s[d[2]])); + + case 2:{ //3 non-neighbours (010101) + int a = find_group(s[d[1]]), b = find_group(s[d[2]]), c = find_group(s[d[3]]); + return (a == b || a == c || b == c); + } + + case 7: //case 1 and 3 (010111) + if(find_group(s[d[4]]) == find_group(s[d[5]])) + return true; + //fall through + + case 3: // 3 neighbours (000111) + return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn); + + case 4: // 4 neighbours (001111) + return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn) || + checkring_back(s[d[4]], s[d[5]], s[d[6]], turn); + + case 5: // 5 neighbours (011111) + return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn) || + checkring_back(s[d[4]], s[d[5]], s[d[6]], turn) || + checkring_back(s[d[7]], s[d[8]], s[d[9]], turn); + + case 6: // 6 neighbours (111111) + return true; //a ring around this position? how'd that happen + + default: + return false; + } +} +//checks for 3 more stones, a should be the corner +bool Board::checkring_back(const MoveValid & a, const MoveValid & b, const MoveValid & c, Side turn) const { + return (a.onboard() && get(a) == turn && get(b) == turn && get(c) == turn); +} + + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/board.h b/havannah/board.h index e901bb2..ff84adb 100644 --- a/havannah/board.h +++ b/havannah/board.h @@ -4,40 +4,42 @@ #include #include #include +#include +#include #include #include #include "../lib/bitcount.h" #include "../lib/hashset.h" +#include "../lib/move.h" +#include "../lib/outcome.h" #include "../lib/string.h" #include "../lib/types.h" #include "../lib/zobrist.h" -#include "move.h" - -using namespace std; +namespace Morat { +namespace Havannah { /* * the board is represented as a flattened 2d array of the form: * 1 2 3 - * A 0 1 2 0 1 0 1 - * B 3 4 5 => 3 4 5 => 3 4 5 - * C 6 7 8 7 8 7 8 - * This follows the H-Gui convention, not the 'standard' convention + * A 0 1 2 0 1 0 1 + * B 3 4 5 <=> 3 4 5 <=> 3 4 5 + * C 6 7 8 7 8 7 8 */ /* neighbours are laid out in this pattern: - * 6 12 7 - * 17 0 1 13 - * 11 5 X 2 8 - * 16 4 3 14 - * 10 15 9 + * 12 6 13 12 6 13 + * 11 0 1 7 11 0 1 7 + * 17 5 X 2 14 <=> 17 5 X 2 14 + * 10 4 3 8 10 4 3 8 + * 16 9 15 16 9 15 */ const MoveScore neighbours[18] = { MoveScore(-1,-1, 3), MoveScore(0,-1, 3), MoveScore(1, 0, 3), MoveScore(1, 1, 3), MoveScore( 0, 1, 3), MoveScore(-1, 0, 3), //direct neighbours, clockwise - MoveScore(-2,-2, 1), MoveScore(0,-2, 1), MoveScore(2, 0, 1), MoveScore(2, 2, 1), MoveScore( 0, 2, 1), MoveScore(-2, 0, 1), //corners of ring 2, easy to block MoveScore(-1,-2, 2), MoveScore(1,-1, 2), MoveScore(2, 1, 2), MoveScore(1, 2, 2), MoveScore(-1, 1, 2), MoveScore(-2,-1, 2), //sides of ring 2, virtual connections + MoveScore(-2,-2, 1), MoveScore(0,-2, 1), MoveScore(2, 0, 1), MoveScore(2, 2, 1), MoveScore( 0, 2, 1), MoveScore(-2, 0, 1), //corners of ring 2, easy to block }; static MoveValid * staticneighbourlist[11] = {NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL}; //one per boardsize @@ -46,16 +48,18 @@ static MoveValid * staticneighbourlist[11] = {NULL,NULL,NULL,NULL,NULL,NULL,NULL class Board{ public: + static constexpr const char * const name = "havannah"; static const int default_size = 8; static const int min_size = 3; static const int max_size = 10; static const int max_vecsize = 19*19; + static const int num_win_types = 3; static const int pattern_cells = 18; typedef uint64_t Pattern; struct Cell { - uint8_t piece; //who controls this cell, 0 for none, 1,2 for players + Side piece; //who controls this cell, 0 for none, 1,2 for players uint8_t size; //size of this group of cells mutable uint16_t parent; //parent for this group of cells uint8_t corner; //which corners are this group connected to @@ -64,23 +68,14 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen uint8_t perm; //is this a permanent piece or a randomly placed piece? Pattern pattern; //the pattern of pieces for neighbours, but from their perspective. Rotate 180 for my perpective - Cell() : piece(0), size(0), parent(0), corner(0), edge(0), mark(0), perm(0), pattern(0) { } - Cell(unsigned int p, unsigned int a, unsigned int s, unsigned int c, unsigned int e, Pattern t) : + Cell() : piece(Side::NONE), size(0), parent(0), corner(0), edge(0), mark(0), perm(0), pattern(0) { } + Cell(Side p, unsigned int a, unsigned int s, unsigned int c, unsigned int e, Pattern t) : piece(p), size(s), parent(a), corner(c), edge(e), mark(0), perm(0), pattern(t) { } int numcorners() const { return BitsSetTable256[corner]; } int numedges() const { return BitsSetTable256[edge]; } - string to_s(int i) const { - return "Cell " + to_str(i) +": " - "piece: " + to_str((int)piece)+ - ", size: " + to_str((int)size) + - ", parent: " + to_str((int)parent) + - ", corner: " + to_str((int)corner) + "/" + to_str(numcorners()) + - ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + - ", perm: " + to_str((int)perm) + - ", pattern: " + to_str((int)pattern); - } + std::string to_s(int i) const; }; class MoveIterator { //only returns valid moves... @@ -91,7 +86,7 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen HashSet hashes; public: MoveIterator(const Board & b, bool Unique) : board(b), lineend(0), move(Move(M_SWAP), -1), unique(Unique) { - if(board.outcome >= 0){ + if(board.outcome >= Outcome::DRAW){ move = MoveValid(0, board.get_size_d(), -1); //already done } else { if(unique) @@ -143,11 +138,11 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen short nummoves; short unique_depth; //update and test rotations/symmetry with less than this many pieces on the board Move last; - char toPlay; - char outcome; //-3 = unknown, 0 = tie, 1,2 = player win + Side toPlay; + Outcome outcome; char wintype; //0 no win, 1 = edge, 2 = corner, 3 = ring - vector cells; + std::vector cells; Zobrist<12> hash; const MoveValid * neighbourlist; @@ -166,8 +161,8 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen last = M_NONE; nummoves = 0; unique_depth = 5; - toPlay = 1; - outcome = -3; + toPlay = Side::P1; + outcome = Outcome::UNKNOWN; wintype = 0; check_rings = true; perm_rings = 0; @@ -185,11 +180,16 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen p |= j; j <<= 2; } - cells[posxy] = Cell(0, posxy, 1, (1 << iscorner(x, y)), (1 << isedge(x, y)), pattern_reverse(p)); + Side s = (onboard(x, y) ? Side::NONE : Side::UNDEF); + cells[posxy] = Cell(s, posxy, 1, (1 << iscorner(x, y)), (1 << isedge(x, y)), pattern_reverse(p)); } } } +/* ~Board(){ + printf("~Board"); + } +*/ int memsize() const { return sizeof(Board) + sizeof(Cell)*vecsize(); } int get_size_d() const { return size_d; } @@ -199,7 +199,7 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen int numcells() const { return num_cells; } int num_moves() const { return nummoves; } - int movesremain() const { return (won() >= 0 ? 0 : num_cells - nummoves); } + int movesremain() const { return (won() >= Outcome::DRAW ? 0 : num_cells - nummoves); } int xy(int x, int y) const { return y*size_d + x; } int xy(const Move & m) const { return m.y*size_d + m.x; } @@ -210,6 +210,10 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen MoveValid yx(int i) const { return MoveValid(i % size, i / size, i); } + int dist(const Move & a, const Move & b) const { + return (abs(a.x - b.x) + abs(a.y - b.y) + abs((a.x - a.y) - (b.x - b.y)) )/2; + } + const Cell * cell(int i) const { return & cells[i]; } const Cell * cell(int x, int y) const { return cell(xy(x,y)); } const Cell * cell(const Move & m) const { return cell(xy(m)); } @@ -217,18 +221,18 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen //assumes valid x,y - int get(int i) const { return cells[i].piece; } - int get(int x, int y) const { return get(xy(x, y)); } - int get(const Move & m) const { return get(xy(m)); } - int get(const MoveValid & m) const { return get(m.xy); } + Side get(int i) const { return cells[i].piece; } + Side get(int x, int y) const { return get(xy(x, y)); } + Side get(const Move & m) const { return get(xy(m)); } + Side get(const MoveValid & m) const { return get(m.xy); } - int geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : 0); } + Side geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : Side::UNDEF); } - int local(const Move & m, char turn) const { return local(xy(m), turn); } - int local(int i, char turn) const { + int local(const Move & m, Side turn) const { return local(xy(m), turn); } + int local(int i, Side turn) const { Pattern p = pattern(i); Pattern x = ((p & 0xAAAAAAAAAull) >> 1) ^ (p & 0x555555555ull); // p1 is now when p1 or p2 but not both (ie off the board) - p = x & (turn == 1 ? p : p >> 1); // now just the selected player + p = x & (turn == Side::P1 ? p : p >> 1); // now just the selected player return (p & 0x000000FFF ? 3 : 0) | (p & 0x000FFF000 ? 2 : 0) | (p & 0xFFF000000 ? 1 : 0); @@ -244,13 +248,14 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen bool onboard(const MoveValid & m) const { return m.onboard(); } //assumes x, y are in bounds and the game isn't already finished - bool valid_move_fast(int x, int y) const { return !get(x,y); } - bool valid_move_fast(const Move & m) const { return !get(m); } - bool valid_move_fast(const MoveValid & m) const { return !get(m.xy); } + bool valid_move_fast(int i) const { return get(i) == Side::NONE; } + bool valid_move_fast(int x, int y) const { return valid_move_fast(xy(x, y)); } + bool valid_move_fast(const Move & m) const { return valid_move_fast(xy(m)); } + bool valid_move_fast(const MoveValid & m) const { return valid_move_fast(m.xy); } //checks array bounds too - bool valid_move(int x, int y) const { return (outcome == -3 && onboard(x, y) && !get(x, y)); } - bool valid_move(const Move & m) const { return (outcome == -3 && onboard(m) && !get(m)); } - bool valid_move(const MoveValid & m) const { return (outcome == -3 && m.onboard() && !get(m)); } + bool valid_move(int x, int y) const { return (outcome < Outcome::DRAW && onboard(x, y) && valid_move_fast(x, y)); } + bool valid_move(const Move & m) const { return (outcome < Outcome::DRAW && onboard(m) && valid_move_fast(m)); } + bool valid_move(const MoveValid & m) const { return (outcome < Outcome::DRAW && m.onboard() && valid_move_fast(m)); } //iterator through neighbours of a position const MoveValid * nb_begin(int x, int y) const { return nb_begin(xy(x, y)); } @@ -264,39 +269,7 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen const MoveValid * nb_end_small_hood(const MoveValid * m) const { return m + 12; } const MoveValid * nb_end_big_hood(const MoveValid * m) const { return m + 18; } - int iscorner(int x, int y) const { - if(!onboard(x,y)) - return -1; - - int m = sizem1, e = size_d-1; - - if(x == 0 && y == 0) return 0; - if(x == m && y == 0) return 1; - if(x == e && y == m) return 2; - if(x == e && y == e) return 3; - if(x == m && y == e) return 4; - if(x == 0 && y == m) return 5; - - return -1; - } - - int isedge(int x, int y) const { - if(!onboard(x,y)) - return -1; - - int m = sizem1, e = size_d-1; - - if(y == 0 && x != 0 && x != m) return 0; - if(x-y == m && x != m && x != e) return 1; - if(x == e && y != m && y != e) return 2; - if(y == e && x != e && x != m) return 3; - if(y-x == m && x != m && x != 0) return 4; - if(x == 0 && y != m && y != 0) return 5; - - return -1; - } - - MoveValid * get_neighbour_list(){ + MoveValid * get_neighbour_list() { if(!staticneighbourlist[(int)size]){ MoveValid * list = new MoveValid[vecsize()*18]; MoveValid * a = list; @@ -318,96 +291,26 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return staticneighbourlist[(int)size]; } - int linestart(int y) const { return (y < size ? 0 : y - sizem1); } int lineend(int y) const { return (y < size ? size + y : size_d); } int linelen(int y) const { return size_d - abs(sizem1 - y); } - string to_s(bool color) const { - string white = "O", - black = "@", - empty = ".", - coord = "", - reset = ""; - if(color){ - string esc = "\033"; - reset = esc + "[0m"; - coord = esc + "[1;37m"; - empty = reset + "."; - white = esc + "[1;33m" + "@"; //yellow - black = esc + "[1;34m" + "@"; //blue - } + std::string to_s(bool color) const; + std::string to_s(bool color, std::function func) const; - string s; - s += string(size + 3, ' '); - for(int i = 0; i < size; i++) - s += " " + coord + to_str(i+1); - s += "\n"; - - for(int y = 0; y < size_d; y++){ - s += string(abs(sizem1 - y) + 2, ' '); - s += coord + char('A' + y); - int end = lineend(y); - for(int x = linestart(y); x < end; x++){ - s += (last == Move(x, y) ? coord + "[" : - last == Move(x-1, y) ? coord + "]" : " "); - int p = get(x, y); - if(p == 0) s += empty; - if(p == 1) s += white; - if(p == 2) s += black; - if(p >= 3) s += "?"; - } - s += (last == Move(end-1, y) ? coord + "]" : " "); - if(y < sizem1) - s += coord + to_str(size + y + 1); - s += '\n'; - } - - s += reset; - return s; - } + friend std::ostream& operator<< (std::ostream &out, const Board & b) { return out << b.to_s(true); } void print(bool color = true) const { printf("%s", to_s(color).c_str()); } - string boardstr() const { - string white, black; - for(int y = 0; y < size_d; y++){ - for(int x = linestart(y); x < lineend(y); x++){ - int p = get(x, y); - if(p == 1) white += Move(x, y).to_s(); - if(p == 2) black += Move(x, y).to_s(); - } - } - return white + ";" + black; - } - - string won_str() const { - switch(outcome){ - case -3: return "none"; - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - } - return "unknown"; - } - - char won() const { + Outcome won() const { return outcome; } - int win() const{ // 0 for draw or unknown, 1 for win, -1 for loss - if(outcome <= 0) - return 0; - return (outcome == toplay() ? 1 : -1); - } - char getwintype() const { return wintype; } - char toplay() const { + Side toplay() const { return toPlay; } @@ -415,22 +318,22 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return MoveIterator(*this, (unique ? nummoves <= unique_depth : false)); } - void set(const Move & m, bool perm = true){ + void set(const Move & m, bool perm = true) { last = m; Cell * cell = & cells[xy(m)]; cell->piece = toPlay; cell->perm = perm; nummoves++; update_hash(m, toPlay); //depends on nummoves - toPlay = 3 - toPlay; + toPlay = ~toPlay; } - void unset(const Move & m){ //break win checks, but is a poor mans undo if all you care about is the hash - toPlay = 3 - toPlay; + void unset(const Move & m) { //break win checks, but is a poor mans undo if all you care about is the hash + toPlay = ~toPlay; update_hash(m, toPlay); nummoves--; Cell * cell = & cells[xy(m)]; - cell->piece = 0; + cell->piece = Side::NONE; cell->perm = 0; } @@ -460,7 +363,7 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return true; if(cells[i].size < cells[j].size) //force i's subtree to be bigger - swap(i, j); + std::swap(i, j); cells[j].parent = i; cells[i].size += cells[j].size; @@ -471,7 +374,7 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen } Cell test_cell(const Move & pos) const { - char turn = toplay(); + Side turn = toplay(); int posxy = xy(pos); Cell testcell = cells[find_group(pos)]; @@ -499,8 +402,8 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen //check if a position is encirclable by a given player //false if it or one of its neighbours are the opponent's and connected to an edge or corner - bool encirclable(const Move pos, int player) const { - int otherplayer = 3-player; + bool encirclable(const Move pos, Side player) const { + Side otherplayer = ~player; int posxy = xy(pos); const Cell * g = & cells[find_group(posxy)]; @@ -511,193 +414,18 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen if(!i->onboard()) return false; - const Cell * g = & cells[find_group(i->xy)]; + const Cell * g = cell(find_group(i->xy)); if(g->piece == otherplayer && (g->edge || g->corner)) return false; } return true; } - // do a depth first search for a ring - bool checkring_df(const Move & pos, const int turn) const { - const Cell * start = & cells[xy(pos)]; - start->mark = 1; - bool success = false; - for(int i = 0; i < 4; i++){ //4 instead of 6 since any ring must have its first endpoint in the first 4 - Move loc = pos + neighbours[i]; - - if(!onboard(loc)) - continue; - - const Cell * g = & cells[xy(loc)]; - - if(turn != g->piece) - continue; - - g->mark = 1; - success = followring(loc, i, turn, (perm_rings - g->perm)); - g->mark = 0; - - if(success) - break; - } - start->mark = 0; - return success; - } - // only take the 3 directions that are valid in a ring - // the backwards directions are either invalid or not part of the shortest loop - bool followring(const Move & cur, const int & dir, const int & turn, const int & permsneeded) const { - for(int i = 5; i <= 7; i++){ - int nd = (dir + i) % 6; - Move next = cur + neighbours[nd]; - - if(!onboard(next)) - continue; - - const Cell * g = & cells[xy(next)]; - - if(g->mark) - return (permsneeded <= 0); - - if(turn != g->piece) - continue; - - g->mark = 1; - bool success = followring(next, nd, turn, (permsneeded - g->perm)); - g->mark = 0; - - if(success) - return true; - } - return false; - } - - // do an O(1) ring check - // must be done before placing the stone and joining it with the neighbouring groups - bool checkring_o1(const Move & pos, const int turn) const { - static const unsigned char ringdata[64][10] = { - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000000 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000001 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000010 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000011 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000100 - {1, 3, 5, 0, 0, 0, 0, 0, 0, 0}, //000101 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //000110 - {3,10,16,15, 0, 0, 0, 0, 0, 0}, //000111 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //001000 - {1, 2, 5, 0, 0, 0, 0, 0, 0, 0}, //001001 - {1, 2, 4, 0, 0, 0, 0, 0, 0, 0}, //001010 - {1, 2, 4, 0, 0, 0, 0, 0, 0, 0}, //001011 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //001100 - {1, 2, 5, 0, 0, 0, 0, 0, 0, 0}, //001101 - {3, 9,15,14, 0, 0, 0, 0, 0, 0}, //001110 - {4,10,16,15, 9,14,15, 0, 0, 0}, //001111 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //010000 - {1, 1, 5, 0, 0, 0, 0, 0, 0, 0}, //010001 - {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //010010 - {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //010011 - {1, 1, 3, 0, 0, 0, 0, 0, 0, 0}, //010100 - {2, 1, 3, 5, 0, 0, 0, 0, 0, 0}, //010101 - {1, 1, 3, 0, 0, 0, 0, 0, 0, 0}, //010110 - {7,10,16,15, 1, 3, 0, 0, 0, 0}, //010111 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //011000 - {1, 1, 5, 0, 0, 0, 0, 0, 0, 0}, //011001 - {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //011010 - {1, 1, 4, 0, 0, 0, 0, 0, 0, 0}, //011011 - {3, 8,14,13, 0, 0, 0, 0, 0, 0}, //011100 - {7, 8,14,13, 1, 5, 0, 0, 0, 0}, //011101 - {4, 9,15,14, 8,13,14, 0, 0, 0}, //011110 - {5,10,16,15, 9,14,15, 8,14,13}, //011111 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //100000 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //100001 - {1, 0, 4, 0, 0, 0, 0, 0, 0, 0}, //100010 - {3,11,17,16, 0, 0, 0, 0, 0, 0}, //100011 - {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100100 - {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100101 - {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //100110 - {4,11,17,16,10,15,16, 0, 0, 0}, //100111 - {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101000 - {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101001 - {2, 0, 2, 4, 0, 0, 0, 0, 0, 0}, //101010 - {7,11,17,16, 0, 2, 0, 0, 0, 0}, //101011 - {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101100 - {1, 0, 2, 0, 0, 0, 0, 0, 0, 0}, //101101 - {7, 9,15,14, 0, 2, 0, 0, 0, 0}, //101110 - {5,11,17,16,10,15,16, 9,15,14}, //101111 - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //110000 - {3, 6,12,17, 0, 0, 0, 0, 0, 0}, //110001 - {1, 0, 4, 0, 0, 0, 0, 0, 0, 0}, //110010 - {4, 6,12,17,11,16,17, 0, 0, 0}, //110011 - {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //110100 - {7, 6,12,17, 0, 3, 0, 0, 0, 0}, //110101 - {1, 0, 3, 0, 0, 0, 0, 0, 0, 0}, //110110 - {5, 6,12,17,11,16,17,10,16,15}, //110111 - {3, 7,13,12, 0, 0, 0, 0, 0, 0}, //111000 - {4, 7,13,12, 6,17,12, 0, 0, 0}, //111001 - {7, 7,13,12, 0, 4, 0, 0, 0, 0}, //111010 - {5, 7,13,12, 6,17,12,11,17,16}, //111011 - {4, 8,14,13, 7,12,13, 0, 0, 0}, //111100 - {5, 8,14,13, 7,12,13, 6,12,17}, //111101 - {5, 9,15,14, 8,13,14, 7,13,12}, //111110 - {6, 0, 0, 0, 0, 0, 0, 0, 0, 0}, //111111 - }; - - int bitpattern = 0; - const MoveValid * s = nb_begin(pos); - for(const MoveValid * i = s, *e = nb_end(i); i < e; i++){ - bitpattern <<= 1; - if(i->onboard() && turn == get(i->xy)) - bitpattern |= 1; - } - - const unsigned char * d = ringdata[bitpattern]; - - switch(d[0]){ - case 0: //no ring (000000, 000001, 000011) - return false; - - case 1: //simple case (000101, 001101, 001011, 011011) - return (find_group(s[d[1]]) == find_group(s[d[2]])); - - case 2:{ //3 non-neighbours (010101) - int a = find_group(s[d[1]]), b = find_group(s[d[2]]), c = find_group(s[d[3]]); - return (a == b || a == c || b == c); - } - - case 7: //case 1 and 3 (010111) - if(find_group(s[d[4]]) == find_group(s[d[5]])) - return true; - //fall through - - case 3: // 3 neighbours (000111) - return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn); - - case 4: // 4 neighbours (001111) - return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn) || - checkring_back(s[d[4]], s[d[5]], s[d[6]], turn); - - case 5: // 5 neighbours (011111) - return checkring_back(s[d[1]], s[d[2]], s[d[3]], turn) || - checkring_back(s[d[4]], s[d[5]], s[d[6]], turn) || - checkring_back(s[d[7]], s[d[8]], s[d[9]], turn); - - case 6: // 6 neighbours (111111) - return true; //a ring around this position? how'd that happen - - default: - return false; - } - } - //checks for 3 more stones, a should be the corner - bool checkring_back(const MoveValid & a, const MoveValid & b, const MoveValid & c, int turn) const { - return (a.onboard() && get(a) == turn && get(b) == turn && get(c) == turn); - } - hash_t gethash() const { return (nummoves > unique_depth ? hash.get(0) : hash.get()); } - string hashstr() const { + std::string hashstr() const { static const char hexlookup[] = "0123456789abcdef"; char buf[19] = "0x"; hash_t val = gethash(); @@ -709,7 +437,8 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return (char *)buf; } - void update_hash(const Move & pos, int turn){ + void update_hash(const Move & pos, Side side) { + int turn = side.to_i(); if(nummoves > unique_depth){ //simple update, no rotations/symmetry hash.update(0, 3*xy(pos) + turn); return; @@ -741,7 +470,8 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return test_hash(pos, toplay()); } - hash_t test_hash(const Move & pos, int turn) const { + hash_t test_hash(const Move & pos, Side side) const { + int turn = side.to_i(); if(nummoves >= unique_depth) //simple test, no rotations/symmetry return hash.test(0, 3*xy(pos) + turn); @@ -750,17 +480,17 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen z = y - x; hash_t m = hash.test(0, 3*xyc( x, y) + turn); - m = min(m, hash.test(1, 3*xyc( y, z) + turn)); - m = min(m, hash.test(2, 3*xyc( z, -x) + turn)); - m = min(m, hash.test(3, 3*xyc(-x, -y) + turn)); - m = min(m, hash.test(4, 3*xyc(-y, -z) + turn)); - m = min(m, hash.test(5, 3*xyc(-z, x) + turn)); - m = min(m, hash.test(6, 3*xyc( y, x) + turn)); - m = min(m, hash.test(7, 3*xyc( z, y) + turn)); - m = min(m, hash.test(8, 3*xyc(-x, z) + turn)); - m = min(m, hash.test(9, 3*xyc(-y, -x) + turn)); - m = min(m, hash.test(10, 3*xyc(-z, -y) + turn)); - m = min(m, hash.test(11, 3*xyc( x, -z) + turn)); + m = std::min(m, hash.test(1, 3*xyc( y, z) + turn)); + m = std::min(m, hash.test(2, 3*xyc( z, -x) + turn)); + m = std::min(m, hash.test(3, 3*xyc(-x, -y) + turn)); + m = std::min(m, hash.test(4, 3*xyc(-y, -z) + turn)); + m = std::min(m, hash.test(5, 3*xyc(-z, x) + turn)); + m = std::min(m, hash.test(6, 3*xyc( y, x) + turn)); + m = std::min(m, hash.test(7, 3*xyc( z, y) + turn)); + m = std::min(m, hash.test(8, 3*xyc(-x, z) + turn)); + m = std::min(m, hash.test(9, 3*xyc(-y, -x) + turn)); + m = std::min(m, hash.test(10, 3*xyc(-z, -y) + turn)); + m = std::min(m, hash.test(11, 3*xyc( x, -z) + turn)); return m; } @@ -792,13 +522,13 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen return (((p & 0x03F03F03Full) << 6) | ((p & 0xFC0FC0FC0ull) >> 6)); } - static Pattern pattern_invert(Pattern p){ //switch players + static Pattern pattern_invert(Pattern p) { //switch players return ((p & 0xAAAAAAAAAull) >> 1) | ((p & 0x555555555ull) << 1); } - static Pattern pattern_rotate(Pattern p){ + static Pattern pattern_rotate(Pattern p) { return (((p & 0x003003003ull) << 10) | ((p & 0xFFCFFCFFCull) >> 2)); } - static Pattern pattern_mirror(Pattern p){ + static Pattern pattern_mirror(Pattern p) { // HGFEDC BA9876 543210 -> DEFGHC 6789AB 123450 return ((p & (3ull << 6)) ) | ((p & (3ull << 0)) ) | // 0,3 stay in place ((p & (3ull << 10)) >> 8) | ((p & (3ull << 2)) << 8) | // 1,5 swap @@ -810,36 +540,36 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen ((p & (3ull << 34)) >> 8) | ((p & (3ull << 26)) << 8) | // H,D swap ((p & (3ull << 32)) >> 4) | ((p & (3ull << 28)) << 4); // G,E swap } - static Pattern pattern_symmetry(Pattern p){ //takes a pattern and returns the representative version + static Pattern pattern_symmetry(Pattern p) { //takes a pattern and returns the representative version Pattern m = p; //012345 - m = min(m, (p = pattern_rotate(p)));//501234 - m = min(m, (p = pattern_rotate(p)));//450123 - m = min(m, (p = pattern_rotate(p)));//345012 - m = min(m, (p = pattern_rotate(p)));//234501 - m = min(m, (p = pattern_rotate(p)));//123450 - m = min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 - m = min(m, (p = pattern_rotate(p)));//105432 - m = min(m, (p = pattern_rotate(p)));//210543 - m = min(m, (p = pattern_rotate(p)));//321054 - m = min(m, (p = pattern_rotate(p)));//432105 - m = min(m, (p = pattern_rotate(p)));//543210 + m = std::min(m, (p = pattern_rotate(p)));//501234 + m = std::min(m, (p = pattern_rotate(p)));//450123 + m = std::min(m, (p = pattern_rotate(p)));//345012 + m = std::min(m, (p = pattern_rotate(p)));//234501 + m = std::min(m, (p = pattern_rotate(p)));//123450 + m = std::min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 + m = std::min(m, (p = pattern_rotate(p)));//105432 + m = std::min(m, (p = pattern_rotate(p)));//210543 + m = std::min(m, (p = pattern_rotate(p)));//321054 + m = std::min(m, (p = pattern_rotate(p)));//432105 + m = std::min(m, (p = pattern_rotate(p)));//543210 return m; } - bool move(const Move & pos, bool checkwin = true, bool permanent = true){ + bool move(const Move & pos, bool checkwin = true, bool permanent = true) { return move(MoveValid(pos, xy(pos)), checkwin, permanent); } - bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true){ - assert(outcome < 0); + bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true) { + assert(outcome < Outcome::DRAW); if(!valid_move(pos)) return false; - char turn = toplay(); + Side turn = toplay(); set(pos, permanent); // update the nearby patterns - Pattern p = turn; + Pattern p = turn.to_i(); for(const MoveValid * i = nb_begin(pos.xy), *e = nb_end_big_hood(i); i < e; i++){ if(i->onboard()){ cells[i->xy].pattern |= p; @@ -860,35 +590,35 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen if(checkwin){ Cell * g = & cells[find_group(pos.xy)]; if(g->numedges() >= 3){ - outcome = turn; + outcome = +turn; wintype = 1; }else if(g->numcorners() >= 2){ - outcome = turn; + outcome = +turn; wintype = 2; }else if(check_rings && alreadyjoined && g->size >= 6 && checkring_df(pos, turn)){ - outcome = turn; + outcome = +turn; wintype = 3; }else if(nummoves == num_cells){ - outcome = 0; + outcome = Outcome::DRAW; } } return true; } - bool test_local(const Move & pos, char turn) const { + bool test_local(const Move & pos, Side turn) const { return test_local(MoveValid(pos, xy(pos)), turn); } + bool test_local(const MoveValid & pos, Side turn) const { return (local(pos, turn) == 3); } //test if making this move would win, but don't actually make the move - int test_win(const Move & pos, char turn = 0) const { - if(turn == 0) - turn = toplay(); - + Outcome test_outcome(const Move & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const Move & pos, Side turn) const { return test_outcome(MoveValid(pos, xy(pos)), turn); } + Outcome test_outcome(const MoveValid & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const MoveValid & pos, Side turn) const { if(test_local(pos, turn)){ - int posxy = xy(pos); - Cell testcell = cells[find_group(posxy)]; + Cell testcell = cells[find_group(pos.xy)]; int numgroups = 0; - for(const MoveValid * i = nb_begin(posxy), *e = nb_end(i); i < e; i++){ + for(const MoveValid * i = nb_begin(pos), *e = nb_end(i); i < e; i++){ if(i->onboard() && turn == get(i->xy)){ const Cell * g = & cells[find_group(i->xy)]; testcell.corner |= g->corner; @@ -900,12 +630,25 @@ mutable uint8_t mark; //when doing a ring search, has this position been seen } if(testcell.numcorners() >= 2 || testcell.numedges() >= 3 || (check_rings && numgroups >= 2 && testcell.size >= 6 && checkring_o1(pos, turn))) - return turn; + return +turn; } if(nummoves+1 == num_cells) - return 0; + return Outcome::DRAW; - return -3; + return Outcome::UNKNOWN; } + +private: + int iscorner(int x, int y) const; + int isedge(int x, int y) const; + + bool checkring_df(const MoveValid & pos, const Side turn) const; + bool checkring_o1(const MoveValid & pos, const Side turn) const; + bool followring(const MoveValid & cur, const int & dir, const Side & turn, const int & permsneeded) const; + bool checkring_back(const MoveValid & a, const MoveValid & b, const MoveValid & c, Side turn) const; + }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/board_test.cpp b/havannah/board_test.cpp new file mode 100644 index 0000000..85a37ad --- /dev/null +++ b/havannah/board_test.cpp @@ -0,0 +1,181 @@ + +#include "../lib/catch.hpp" + +#include "board.h" + + +using namespace Morat; +using namespace Havannah; + +void test_game(Board b, std::vector moves, Outcome outcome) { + REQUIRE(b.num_moves() == 0); + Side side = Side::P1; + int made = 0, remain = 37; + for(auto s : moves) { + Outcome expected = (s == moves.back() ? outcome : Outcome::UNKNOWN); + Move move(s); + CAPTURE(move); + CAPTURE(b); + REQUIRE(b.num_moves() == made); + REQUIRE(b.movesremain() == remain); + REQUIRE(b.valid_move(move)); + REQUIRE(b.toplay() == side); + REQUIRE(b.test_outcome(move) == expected); + REQUIRE(b.move(move)); + REQUIRE(b.won() == expected); + side = ~side; + made++; + remain--; + } + REQUIRE(b.num_moves() == made); + REQUIRE(b.movesremain() == (outcome == Outcome::UNKNOWN ? remain : 0)); +} + +TEST_CASE("Havannah::Board", "[havannah][board]") { + Board b(4); + + SECTION("Basics") { + REQUIRE(b.get_size() == 4); + REQUIRE(b.get_size_d() == 7); + REQUIRE(b.movesremain() == 37); + } + + SECTION("valid moves") { + std::string valid[] = {"A1", "D4", + "a1", "a2", "a3", "a4", + "b1", "b2", "b3", "b4", "b5", + "c1", "c2", "c3", "c4", "c5", "c6", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e2", "e3", "e4", "e5", "e6", "e7", + "f3", "f4", "f5", "f6", "f7", + "g4", "g5", "g6", "g7", + }; + for(auto m : valid){ + REQUIRE(b.onboard(m)); + REQUIRE(b.valid_move(m)); + } + } + + SECTION("invalid moves") { + std::string invalid[] = {"a0", "a5", "a10", "b6", "c7", "e1", "e8", "f1", "f2", "h1", "f0"}; + for(auto m : invalid){ + REQUIRE_FALSE(b.onboard(m)); + REQUIRE_FALSE(b.valid_move(m)); + } + } + + SECTION("edges") { + std::string edges[] = {"a2", "a3", "b5", "c6", "e7", "f7", "g6", "g5", "f3", "e2", "c1", "b1"}; + std::string corners[] = {"a1", "a4", "d7", "g7", "g4", "d1"}; + std::string middle[] = { + "b2", "b3", "b4", + "c2", "c3", "c4", "c5", + "d2", "d3", "d4", "d5", "d6", + "e3", "e4", "e5", "e6", + "f4", "f5", "f6", + }; + + for(auto m : edges){ + auto * c = b.cell(Move(m)); + REQUIRE(c->numedges() == 1); + REQUIRE(c->numcorners() == 0); + } + + for(auto m : corners){ + auto * c = b.cell(Move(m)); + REQUIRE(c->numedges() == 0); + REQUIRE(c->numcorners() == 1); + } + + for(auto m : middle){ + auto * c = b.cell(Move(m)); + REQUIRE(c->numedges() == 0); + REQUIRE(c->numcorners() == 0); + } + } + + SECTION("duplicate moves") { + Move m("a1"); + REQUIRE(b.valid_move(m)); + REQUIRE(b.move(m)); + REQUIRE_FALSE(b.valid_move(m)); + REQUIRE_FALSE(b.move(m)); + } + + SECTION("num moves, moves remain") { + std::string moves[] = { + "a1", "a2", "a3", "a4", + "b1", "b2", "b3", "b4", "b5", + "c1", "c2", "c3", "c4", "c5", "c6", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e2", "e3", "e4", "e5", "e6", "e7", + "f3", "f4", "f5", "f6", "f7", + "g4", "g5", "g6", "g7", + }; + int made = 0, remain = 37; + for(auto m : moves) { + REQUIRE(b.num_moves() == made); + REQUIRE(b.movesremain() == remain); + b.move(Move(m)); + made++; + remain--; + } + REQUIRE(b.num_moves() == made); + REQUIRE(b.movesremain() == remain); + } + + SECTION("move distance") { + SECTION("x") { + REQUIRE(b.dist(Move("b2"), Move("b1")) == 1); + REQUIRE(b.dist(Move("b2"), Move("b3")) == 1); + } + SECTION("y") { + REQUIRE(b.dist(Move("b2"), Move("a2")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c2")) == 1); + } + SECTION("z") { + REQUIRE(b.dist(Move("b2"), Move("a1")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c3")) == 1); + } + SECTION("farther") { + REQUIRE(b.dist(Move("b2"), Move("a3")) == 2); + REQUIRE(b.dist(Move("b2"), Move("c1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("d4")) == 2); + REQUIRE(b.dist(Move("b2"), Move("d3")) == 2); + REQUIRE(b.dist(Move("b2"), Move("e3")) == 3); + REQUIRE(b.dist(Move("b2"), Move("d1")) == 3); + } + } + + SECTION("bridge") { + test_game(b, { "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::P1); + test_game(b, {"d4", "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::P2); + } + + SECTION("fork") { + test_game(b, { "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::P1); + test_game(b, {"d4", "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::P2); + } + + SECTION("ring") { + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::P1); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::P2); + } + + SECTION("filled ring") { + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::P1); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::P2); + } + + SECTION("draw") { + test_game(b, { + "a1", "a2", "a3", "a4", + "b1", "b2", "b3", "b4", "b5", + "c1", "c2", "c3", "c4", "c5", "c6", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e2", "e3", "e4", "e5", "e6", "e7", + "f3", "f4", "f5", "f6", "f7", + "g4", "g5", "g6", "g7", + }, Outcome::DRAW); + } +} diff --git a/havannah/castro.cpp b/havannah/castro.cpp deleted file mode 100644 index 4ecb85d..0000000 --- a/havannah/castro.cpp +++ /dev/null @@ -1,61 +0,0 @@ - -#include -#include - -#include "../lib/time.h" - -#include "gtp.h" - -using namespace std; - -void die(int code, const string & str){ - printf("%s\n", str.c_str()); - exit(code); -} - -int main(int argc, char **argv){ - srand((unsigned int)(Time().to_f()*1000)); - GTP gtp; - - gtp.colorboard = isatty(fileno(stdout)); - - for(int i = 1; i < argc; i++) { - string arg = argv[i]; - if(arg == "-h" || arg == "--help"){ - die(255, "Usage:\n" - "\t-h --help Show this help\n" - "\t-v --verbose Give more output over gtp\n" - "\t-n --nocolor Don't output the board in color\n" - "\t-c --cmd Pass a gtp command from the command line\n" - "\t-f --file Run this gtp file before reading from stdin\n" -// "\t-s --server Run in server mode\n" - ); - }else if(arg == "-v" || arg == "--verbose"){ - gtp.verbose = true; - }else if(arg == "-n" || arg == "--nocolor"){ - gtp.colorboard = false; - }else if(arg == "-c" || arg == "--cmd"){ - char * ptr = argv[++i]; - if(ptr == NULL) die(255, "Missing a command"); - gtp.cmd(ptr); - }else if(arg == "-f" || arg == "--file"){ - char * ptr = argv[++i]; - if(ptr == NULL) die(255, "Missing a file to run"); - FILE * fd = fopen(ptr, "r"); - gtp.setinfile(fd); - gtp.setoutfile(NULL); - if(!gtp.run()) - return 0; - fclose(fd); -// }else if(arg == "-s" || arg == "--server"){ -// gtp.setservermode(true); - }else{ - die(255, "Unknown argument: " + arg + ", try --help"); - } - } - - gtp.setinfile(stdin); - gtp.setoutfile(stdout); - gtp.run(); - return 0; -} diff --git a/havannah/gtp.h b/havannah/gtp.h index c32bd47..42e8134 100644 --- a/havannah/gtp.h +++ b/havannah/gtp.h @@ -2,6 +2,8 @@ #pragma once #include "../lib/gtpcommon.h" +#include "../lib/history.h" +#include "../lib/move.h" #include "../lib/string.h" #include "agent.h" @@ -9,11 +11,13 @@ #include "agentmcts.h" #include "agentpns.h" #include "board.h" -#include "history.h" -#include "move.h" + + +namespace Morat { +namespace Havannah { class GTP : public GTPCommon { - History hist; + History hist; public: int verbose; @@ -35,46 +39,47 @@ class GTP : public GTPCommon { set_board(); - newcallback("name", bind(>P::gtp_name, this, _1), "Name of the program"); - newcallback("version", bind(>P::gtp_version, this, _1), "Version of the program"); - newcallback("verbose", bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); - newcallback("extended", bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); - newcallback("debug", bind(>P::gtp_debug, this, _1), "Enable debug mode"); - newcallback("colorboard", bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); - newcallback("showboard", bind(>P::gtp_print, this, _1), "Show the board"); - newcallback("print", bind(>P::gtp_print, this, _1), "Alias for showboard"); - newcallback("dists", bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); - newcallback("zobrist", bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); - newcallback("clear_board", bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); - newcallback("clear", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); - newcallback("boardsize", bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); - newcallback("size", bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); - newcallback("play", bind(>P::gtp_play, this, _1), "Place a stone: play "); - newcallback("white", bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); - newcallback("black", bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); - newcallback("undo", bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); - newcallback("time", bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); - newcallback("genmove", bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); - newcallback("solve", bind(>P::gtp_solve, this, _1), "Try to solve this position"); - -// newcallback("ab", bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); - newcallback("mcts", bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); - newcallback("pns", bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); - - newcallback("all_legal", bind(>P::gtp_all_legal, this, _1), "List all legal moves"); - newcallback("history", bind(>P::gtp_history, this, _1), "List of played moves"); - newcallback("playgame", bind(>P::gtp_playgame, this, _1), "Play a list of moves"); - newcallback("winner", bind(>P::gtp_winner, this, _1), "Check the winner of the game"); - newcallback("patterns", bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); - - newcallback("pv", bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); - newcallback("move_stats", bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); - - newcallback("params", bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); - -// newcallback("player_hgf", bind(>P::gtp_player_hgf, this, _1), "Output an hgf of the current tree"); -// newcallback("player_load_hgf", bind(>P::gtp_player_load_hgf,this, _1), "Load an hgf generated by player_hgf"); -// newcallback("player_gammas", bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); + newcallback("name", std::bind(>P::gtp_name, this, _1), "Name of the program"); + newcallback("version", std::bind(>P::gtp_version, this, _1), "Version of the program"); + newcallback("verbose", std::bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); + newcallback("extended", std::bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); + newcallback("debug", std::bind(>P::gtp_debug, this, _1), "Enable debug mode"); + newcallback("colorboard", std::bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); + newcallback("showboard", std::bind(>P::gtp_print, this, _1), "Show the board"); + newcallback("print", std::bind(>P::gtp_print, this, _1), "Alias for showboard"); + newcallback("dists", std::bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); + newcallback("zobrist", std::bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); + newcallback("clear_board", std::bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); + newcallback("clear", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("reset", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("boardsize", std::bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); + newcallback("size", std::bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); + newcallback("play", std::bind(>P::gtp_play, this, _1), "Place a stone: play "); + newcallback("white", std::bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); + newcallback("black", std::bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); + newcallback("undo", std::bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); + newcallback("time", std::bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); + newcallback("genmove", std::bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); + newcallback("solve", std::bind(>P::gtp_solve, this, _1), "Try to solve this position"); + +// newcallback("ab", std::bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); + newcallback("mcts", std::bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); + newcallback("pns", std::bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); + + newcallback("all_legal", std::bind(>P::gtp_all_legal, this, _1), "List all legal moves"); + newcallback("history", std::bind(>P::gtp_history, this, _1), "List of played moves"); + newcallback("playgame", std::bind(>P::gtp_playgame, this, _1), "Play a list of moves"); + newcallback("winner", std::bind(>P::gtp_winner, this, _1), "Check the winner of the game"); + newcallback("patterns", std::bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); + + newcallback("pv", std::bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); + newcallback("move_stats", std::bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); + + newcallback("params", std::bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + + newcallback("save_sgf", std::bind(>P::gtp_save_sgf, this, _1), "Output an sgf of the current tree"); + newcallback("load_sgf", std::bind(>P::gtp_load_sgf, this, _1), "Load an sgf generated by save_sgf"); +// newcallback("player_gammas", std::bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); } void set_board(bool clear = true){ @@ -94,7 +99,7 @@ class GTP : public GTPCommon { GTPResponse gtp_all_legal(vecstr args); GTPResponse gtp_history(vecstr args); GTPResponse gtp_patterns(vecstr args); - GTPResponse play(const string & pos, int toplay); + GTPResponse play(const std::string & pos, Side toplay); GTPResponse gtp_playgame(vecstr args); GTPResponse gtp_play(vecstr args); GTPResponse gtp_playwhite(vecstr args); @@ -124,8 +129,11 @@ class GTP : public GTPCommon { GTPResponse gtp_pns_params(vecstr args); // GTPResponse gtp_player_gammas(vecstr args); -// GTPResponse gtp_player_hgf(vecstr args); -// GTPResponse gtp_player_load_hgf(vecstr args); + GTPResponse gtp_save_sgf(vecstr args); + GTPResponse gtp_load_sgf(vecstr args); - string solve_str(int outcome) const; + std::string solve_str(int outcome) const; }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/gtpagent.cpp b/havannah/gtpagent.cpp index 4f6903c..36916a6 100644 --- a/havannah/gtpagent.cpp +++ b/havannah/gtpagent.cpp @@ -1,13 +1,12 @@ -#include +#include "gtp.h" -#include "../lib/fileio.h" -#include "gtp.h" +namespace Morat { +namespace Havannah { using namespace std; - GTPResponse GTP::gtp_move_stats(vecstr args){ vector moves; for(auto s : args) @@ -258,7 +257,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" " -m --memory Memory limit in Mb [" + to_str(pns->memlimit/(1024*1024)) + "]\n" " -t --threads How many threads to run [" + to_str(pns->numthreads) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties) + "]\n" + " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties.to_i()) + "]\n" " -d --df Use depth-first thresholds [" + to_str(pns->df) + "]\n" " -e --epsilon How big should the threshold be [" + to_str(pns->epsilon) + "]\n" " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(pns->ab) + "]\n" @@ -276,7 +275,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); pns->set_memlimit(mem*1024*1024); }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - pns->ties = from_str(args[++i]); + pns->ties = Side(from_str(args[++i])); pns->clear_mem(); }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ pns->df = from_str(args[++i]); @@ -291,3 +290,6 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ return GTPResponse(true, errs); } + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/gtpgeneral.cpp b/havannah/gtpgeneral.cpp index c532b29..e4bc1e9 100644 --- a/havannah/gtpgeneral.cpp +++ b/havannah/gtpgeneral.cpp @@ -1,7 +1,15 @@ +#include + +#include "../lib/sgf.h" + #include "gtp.h" #include "lbdist.h" + +namespace Morat { +namespace Havannah { + GTPResponse GTP::gtp_mcts(vecstr args){ delete agent; agent = new AgentMCTS(); @@ -39,7 +47,7 @@ GTPResponse GTP::gtp_boardsize(vecstr args){ if(size < Board::min_size || size > Board::max_size) return GTPResponse(false, "Size " + to_str(size) + " is out of range."); - hist = History(size); + hist = History(size); set_board(); time_control.new_game(); @@ -69,14 +77,14 @@ GTPResponse GTP::gtp_undo(vecstr args){ GTPResponse GTP::gtp_patterns(vecstr args){ bool symmetric = true; bool invert = true; - string ret; + std::string ret; const Board & board = *hist; for(Board::MoveIterator move = board.moveit(); !move.done(); ++move){ ret += move->to_s() + " "; unsigned int p = board.pattern(*move); if(symmetric) p = board.pattern_symmetry(p); - if(invert && board.toplay() == 2) + if(invert && board.toplay() == Side::P2) p = board.pattern_invert(p); ret += to_str(p); ret += "\n"; @@ -85,24 +93,24 @@ GTPResponse GTP::gtp_patterns(vecstr args){ } GTPResponse GTP::gtp_all_legal(vecstr args){ - string ret; + std::string ret; for(Board::MoveIterator move = hist->moveit(); !move.done(); ++move) ret += move->to_s() + " "; return GTPResponse(true, ret); } GTPResponse GTP::gtp_history(vecstr args){ - string ret; + std::string ret; for(auto m : hist) ret += m.to_s() + " "; return GTPResponse(true, ret); } -GTPResponse GTP::play(const string & pos, int toplay){ +GTPResponse GTP::play(const std::string & pos, Side toplay){ if(toplay != hist->toplay()) return GTPResponse(false, "It is the other player's turn!"); - if(hist->won() >= 0) + if(hist->won() >= Outcome::DRAW) return GTPResponse(false, "The game is already over."); Move m(pos); @@ -113,7 +121,7 @@ GTPResponse GTP::play(const string & pos, int toplay){ move(m); if(verbose >= 2) - logerr("Placement: " + m.to_s() + ", outcome: " + hist->won_str() + "\n" + hist->to_s(colorboard)); + logerr("Placement: " + m.to_s() + ", outcome: " + hist->won().to_s() + "\n" + hist->to_s(colorboard)); return GTPResponse(true); } @@ -131,37 +139,33 @@ GTPResponse GTP::gtp_play(vecstr args){ if(args.size() != 2) return GTPResponse(false, "Wrong number of arguments"); - char toplay = 0; switch(tolower(args[0][0])){ - case 'w': toplay = 1; break; - case 'b': toplay = 2; break; - default: - return GTPResponse(false, "Invalid player selection"); + case 'w': return play(args[1], Side::P1); + case 'b': return play(args[1], Side::P2); + default: return GTPResponse(false, "Invalid player selection"); } - - return play(args[1], toplay); } GTPResponse GTP::gtp_playwhite(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 1); + return play(args[0], Side::P1); } GTPResponse GTP::gtp_playblack(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 2); + return play(args[0], Side::P2); } GTPResponse GTP::gtp_winner(vecstr args){ - return GTPResponse(true, hist->won_str()); + return GTPResponse(true, hist->won().to_s()); } GTPResponse GTP::gtp_name(vecstr args){ - return GTPResponse(true, "Castro"); + return GTPResponse(true, std::string("morat-") + Board::name); } GTPResponse GTP::gtp_version(vecstr args){ @@ -193,7 +197,7 @@ GTPResponse GTP::gtp_extended(vecstr args){ } GTPResponse GTP::gtp_debug(vecstr args){ - string str = "\n"; + std::string str = "\n"; str += "Board size: " + to_str(hist->get_size()) + "\n"; str += "Board cells: " + to_str(hist->numcells()) + "\n"; str += "Board vec: " + to_str(hist->vecsize()) + "\n"; @@ -203,60 +207,115 @@ GTPResponse GTP::gtp_debug(vecstr args){ } GTPResponse GTP::gtp_dists(vecstr args){ + using std::string; Board board = *hist; LBDists dists(&board); - int side = 0; + Side side = Side::NONE; if(args.size() >= 1){ switch(tolower(args[0][0])){ - case 'w': side = 1; break; - case 'b': side = 2; break; + case 'w': side = Side::P1; break; + case 'b': side = Side::P2; break; default: return GTPResponse(false, "Invalid player selection"); } } - int size = board.get_size(); - int size_d = board.get_size_d(); + if(args.size() >= 2) { + return GTPResponse(true, to_str(dists.get(Move(args[1]), side))); + } - string s = "\n"; - s += string(size + 3, ' '); - for(int i = 0; i < size; i++) - s += " " + to_str(i+1); - s += "\n"; + return GTPResponse(true, "\n" + board.to_s(colorboard, bind(&LBDists::get_s, &dists, _1, side))); +} - string white = "O", black = "@"; - if(colorboard){ - string esc = "\033", reset = esc + "[0m"; - white = esc + "[1;33m" + "@" + reset; //yellow - black = esc + "[1;34m" + "@" + reset; //blue +GTPResponse GTP::gtp_zobrist(vecstr args){ + return GTPResponse(true, hist->hashstr()); +} + +GTPResponse GTP::gtp_save_sgf(vecstr args){ + int limit = -1; + if(args.size() == 0) + return GTPResponse(true, "save_sgf [work limit]"); + + std::ifstream infile(args[0].c_str()); + + if(infile) { + infile.close(); + return GTPResponse(false, "File " + args[0] + " already exists"); } - for(int y = 0; y < size_d; y++){ - s += string(abs(size-1 - y) + 2, ' '); - s += char('A' + y); - for(int x = board.linestart(y); x < board.lineend(y); x++){ - int p = board.get(x, y); - s += ' '; - if(p == 0){ - int d = (side ? dists.get(Move(x, y), side) : dists.get(Move(x, y))); - if(d < 10) - s += to_str(d); - else - s += '.'; - }else if(p == 1){ - s += white; - }else if(p == 2){ - s += black; - } - } - if(y < size-1) - s += " " + to_str(1 + size + y); - s += '\n'; + std::ofstream outfile(args[0].c_str()); + + if(!outfile) + return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); + + if(args.size() > 1) + limit = from_str(args[1]); + + SGFPrinter sgf(outfile); + sgf.game(Board::name); + sgf.program(gtp_name(vecstr()).response, gtp_version(vecstr()).response); + sgf.size(hist->get_size()); + + sgf.end_root(); + + Side s = Side::P1; + for(auto m : hist){ + sgf.move(s, m); + s = ~s; } - return GTPResponse(true, s); + + agent->gen_sgf(sgf, limit); + + sgf.end(); + outfile.close(); + return true; } -GTPResponse GTP::gtp_zobrist(vecstr args){ - return GTPResponse(true, hist->hashstr()); + +GTPResponse GTP::gtp_load_sgf(vecstr args){ + if(args.size() == 0) + return GTPResponse(true, "load_sgf "); + + std::ifstream infile(args[0].c_str()); + + if(!infile) { + return GTPResponse(false, "Error opening file " + args[0] + " for reading"); + } + + SGFParser sgf(infile); + if(sgf.game() != Board::name){ + infile.close(); + return GTPResponse(false, "File is for the wrong game: " + sgf.game()); + } + + int size = sgf.size(); + if(size != hist->get_size()){ + if(hist.len() == 0){ + hist = History(size); + set_board(); + time_control.new_game(); + }else{ + infile.close(); + return GTPResponse(false, "File has the wrong boardsize to match the existing game"); + } + } + + Side s = Side::P1; + + while(sgf.next_node()){ + Move m = sgf.move(); + move(m); // push the game forward + s = ~s; + } + + if(sgf.has_children()) + agent->load_sgf(sgf); + + assert(sgf.done_child()); + infile.close(); + return true; } + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/gtpplayer.cpp b/havannah/gtpplayer.cpp deleted file mode 100644 index 406db58..0000000 --- a/havannah/gtpplayer.cpp +++ /dev/null @@ -1,594 +0,0 @@ - - -#include - -#include "../lib/fileio.h" - -#include "gtp.h" - -using namespace std; - - -GTPResponse GTP::gtp_move_stats(vecstr args){ - string s = ""; - - Player::Node * node = &(player.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - Player::Node * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - Player::Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s(); - s += "," + to_str((child->exp.num() ? child->exp.avg() : 0.0), 4) + "," + to_str(child->exp.num()); - s += "," + to_str((child->rave.num() ? child->rave.avg() : 0.0), 4) + "," + to_str(child->rave.num()); - s += "," + to_str(child->know); - if(child->outcome >= 0) - s += "," + won_str(child->outcome); - s += "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_player_solve(vecstr args){ - double use_time = (args.size() >= 1 ? - from_str(args[0]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = M_RESIGN; - if(ret) - best = ret->move; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - DepthStats wintypes[2][4]; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 2; a++) - for(int b = 0; b < 4; b++) - wintypes[a][b] += player.threads[i]->wintypes[a][b]; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - stats += "Win Types: "; - stats += "P1: f " + to_str(wintypes[0][1].num) + ", b " + to_str(wintypes[0][2].num) + ", r " + to_str(wintypes[0][3].num) + "; "; - stats += "P2: f " + to_str(wintypes[1][1].num) + ", b " + to_str(wintypes[1][2].num) + ", r " + to_str(wintypes[1][3].num) + "\n"; - - if(verbose >= 2){ - stats += "P1 fork: " + wintypes[0][1].to_s() + "\n"; - stats += "P1 bridge: " + wintypes[0][2].to_s() + "\n"; - stats += "P1 ring: " + wintypes[0][3].to_s() + "\n"; - stats += "P2 fork: " + wintypes[1][1].to_s() + "\n"; - stats += "P2 bridge: " + wintypes[1][2].to_s() + "\n"; - stats += "P2 ring: " + wintypes[1][3].to_s() + "\n"; - } - } - - if(ret){ - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(ret->outcome >= 0){ - stats += "Solved as a "; - if(ret->outcome == toplay) stats += "win"; - else if(ret->outcome == 0) stats += "draw"; - else stats += "loss"; - stats += "\n"; - } - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - if(verbose) - logerr(stats); - - Solver s; - if(ret){ - s.outcome = (ret->outcome >= 0 ? ret->outcome : -3); - s.bestmove = ret->move; - s.maxdepth = gamelen.maxdepth; - s.nodes_seen = runs; - }else{ - s.outcome = 3-toplay; - s.bestmove = M_RESIGN; - s.maxdepth = 0; - s.nodes_seen = 0; - } - - return GTPResponse(true, solve_str(s)); -} - - -GTPResponse GTP::gtp_player_solved(vecstr args){ - string s = ""; - Player::Node * child = player.root.children.begin(), - * childend = player.root.children.end(); - int toplay = player.rootboard.toplay(); - int best = 0; - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - if(child->outcome == toplay) - return GTPResponse(true, won_str(toplay)); - else if(child->outcome < 0) - best = 2; - else if(child->outcome == 0) - best = 1; - } - if(best == 2) return GTPResponse(true, won_str(-3)); - if(best == 1) return GTPResponse(true, won_str(0)); - return GTPResponse(true, won_str(3 - toplay)); -} - -GTPResponse GTP::gtp_pv(vecstr args){ - string pvstr = ""; - vector pv = player.get_pv(); - for(unsigned int i = 0; i < pv.size(); i++) - pvstr += pv[i].to_s() + " "; - return GTPResponse(true, pvstr); -} - -GTPResponse GTP::gtp_player_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_hgf [sims limit]"); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(fd){ - fclose(fd); - return GTPResponse(false, "File " + args[0] + " already exists"); - } - - fd = fopen(args[0].c_str(), "w"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); - - unsigned int limit = 10000; - if(args.size() > 1) - limit = from_str(args[1]); - - Board board = *hist; - - - fprintf(fd, "(;FF[4]SZ[%i]\n", board.get_size()); - int p = 1; - for(auto m : hist){ - fprintf(fd, ";%c[%s]", (p == 1 ? 'W' : 'B'), m.to_s().c_str()); - p = 3-p; - } - - - Player::Node * child = player.root.children.begin(), - * end = player.root.children.end(); - - for( ; child != end; child++){ - if(child->exp.num() >= limit){ - board.set(child->move); - player.gen_hgf(board, child, limit, 1, fd); - board.unset(child->move); - } - } - - fprintf(fd, ")\n"); - - fclose(fd); - - return true; -} - -GTPResponse GTP::gtp_player_load_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_load_hgf "); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for reading failed"); - - int size; - assert(fscanf(fd, "(;FF[4]SZ[%i]", & size) > 0); - if(size != hist->get_size()){ - if(hist.len() == 0){ - hist = History(Board(size)); - set_board(); - }else{ - fclose(fd); - return GTPResponse(false, "File has the wrong boardsize to match the existing game"); - } - } - - eat_whitespace(fd); - - Board board(size); - Player::Node * node = & player.root; - vector prefix; - - char side, movestr[5]; - while(fscanf(fd, ";%c[%5[^]]]", &side, movestr) > 0){ - Move move(movestr); - - if(board.num_moves() >= (int)hist.len()){ - if(node->children.empty()) - player.create_children_simple(board, node); - - prefix.push_back(node); - node = player.find_child(node, move); - }else if(hist[board.num_moves()] != move){ - fclose(fd); - return GTPResponse(false, "The current game is deeper than this file"); - } - board.move(move); - - eat_whitespace(fd); - } - prefix.push_back(node); - - - if(fpeek(fd) != ')'){ - if(node->children.empty()) - player.create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Player::Node child; - player.load_hgf(board, & child, fd); - - Player::Node * i = player.find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_whitespace(fd); - assert(fgetc(fd) == ')'); - fclose(fd); - - while(!prefix.empty()){ - Player::Node * node = prefix.back(); - prefix.pop_back(); - - Player::Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = hist->toplay(); - if(prefix.size() % 2 == 1) - toplay = 3 - toplay; - - Player::Node * backup = child; - - node->exp.clear(); - for( ; child != end; child++){ - node->exp += child->exp.invert(); - if(child->outcome == toplay || child->exp.num() > backup->exp.num()) - backup = child; - } - player.do_backup(node, backup, toplay); - } - - return true; -} - - -GTPResponse GTP::gtp_genmove(vecstr args){ - if(player.rootboard.won() >= 0) - return GTPResponse(true, "resign"); - - double use_time = (args.size() >= 2 ? - from_str(args[1]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(args.size() >= 2) - use_time = from_str(args[1]); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - uword nodesbefore = player.nodes; - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = player.root.bestmove; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - uint64_t games = 0; - DepthStats wintypes[2][4]; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 2; a++){ - for(int b = 0; b < 4; b++){ - wintypes[a][b] += player.threads[i]->wintypes[a][b]; - games += player.threads[i]->wintypes[a][b].num; - } - } - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - stats += "Win Types: "; - stats += "W: f " + to_str(wintypes[0][1].num*100.0/games,0) + "%, b " + to_str(wintypes[0][2].num*100.0/games,0) + "%, r " + to_str(wintypes[0][3].num*100.0/games,0) + "%; "; - stats += "B: f " + to_str(wintypes[1][1].num*100.0/games,0) + "%, b " + to_str(wintypes[1][2].num*100.0/games,0) + "%, r " + to_str(wintypes[1][3].num*100.0/games,0) + "%\n"; - - if(verbose >= 2){ - stats += "W fork: " + wintypes[0][1].to_s() + "\n"; - stats += "W bridge: " + wintypes[0][2].to_s() + "\n"; - stats += "W ring: " + wintypes[0][3].to_s() + "\n"; - stats += "B fork: " + wintypes[1][1].to_s() + "\n"; - stats += "B bridge: " + wintypes[1][2].to_s() + "\n"; - stats += "B ring: " + wintypes[1][3].to_s() + "\n"; - } - } - - if(ret) - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(player.root.outcome != -3){ - stats += "Solved as a "; - if(player.root.outcome == 0) stats += "draw"; - else if(player.root.outcome == toplay) stats += "win"; - else if(player.root.outcome == 3-toplay) stats += "loss"; - else if(player.root.outcome == -toplay) stats += "win or draw"; - else if(player.root.outcome == toplay-3) stats += "loss or draw"; - stats += "\n"; - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - string extended; - if(genmoveextended){ - //move score - if(ret) extended += " " + to_str(ret->exp.avg()); - else extended += " 0"; - //outcome - extended += " " + won_str(player.root.outcome); - //work - extended += " " + to_str(runs); - //nodes - extended += " " + to_str(player.nodes - nodesbefore); - } - - move(best); - - if(verbose >= 2){ - stats += "history: "; - for(auto m : hist) - stats += m.to_s() + " "; - stats += "\n"; - stats += hist->to_s(colorboard) + "\n"; - } - - if(verbose) - logerr(stats); - - return GTPResponse(true, best.to_s() + extended); -} - -GTPResponse GTP::gtp_player_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Set player parameters, eg: player_params -e 1 -f 0 -t 2 -o 1 -p 0\n" + - "Processing:\n" + -#ifndef SINGLE_THREAD - " -t --threads Number of MCTS threads [" + to_str(player.numthreads) + "]\n" + -#endif - " -o --ponder Continue to ponder during the opponents time [" + to_str(player.ponder) + "]\n" + - " -M --maxmem Max memory in Mb to use for the tree [" + to_str(player.maxmem/(1024*1024)) + "]\n" + - " --profile Output the time used by each phase of MCTS [" + to_str(player.profile) + "]\n" + - "Final move selection:\n" + - " -E --msexplore Lower bound constant in final move selection [" + to_str(player.msexplore) + "]\n" + - " -F --msrave Rave factor, 0 for pure exp, -1 # sims, -2 # wins [" + to_str(player.msrave) + "]\n" + - "Tree traversal:\n" + - " -e --explore Exploration rate for UCT [" + to_str(player.explore) + "]\n" + - " -A --parexplore Multiply the explore rate by parents experience [" + to_str(player.parentexplore) + "]\n" + - " -f --ravefactor The rave factor: alpha = rf/(rf + visits) [" + to_str(player.ravefactor) + "]\n" + - " -d --decrrave Decrease the rave factor over time: rf += d*empty [" + to_str(player.decrrave) + "]\n" + - " -a --knowledge Use knowledge: 0.01*know/sqrt(visits+1) [" + to_str(player.knowledge) + "]\n" + - " -r --userave Use rave with this probability [0-1] [" + to_str(player.userave) + "]\n" + - " -X --useexplore Use exploration with this probability [0-1] [" + to_str(player.useexplore) + "]\n" + - " -u --fpurgency Value to assign to an unplayed move [" + to_str(player.fpurgency) + "]\n" + - " -O --rollouts Number of rollouts to run per simulation [" + to_str(player.rollouts) + "]\n" + - " -I --dynwiden Dynamic widening, consider log_wid(exp) children [" + to_str(player.dynwiden) + "]\n" + - "Tree building:\n" + - " -s --shortrave Only use moves from short rollouts for rave [" + to_str(player.shortrave) + "]\n" + - " -k --keeptree Keep the tree from the previous move [" + to_str(player.keeptree) + "]\n" + - " -m --minimax Backup the minimax proof in the UCT tree [" + to_str(player.minimax) + "]\n" + - " -T --detectdraw Detect draws once no win is possible at all [" + to_str(player.detectdraw) + "]\n" + - " -x --visitexpand Number of visits before expanding a node [" + to_str(player.visitexpand) + "]\n" + - " -P --symmetry Prune symmetric moves, good for proof, not play [" + to_str(player.prunesymmetry) + "]\n" + - " --gcsolved Garbage collect solved nodes with fewer sims than [" + to_str(player.gcsolved) + "]\n" + - "Node initialization knowledge, Give a bonus:\n" + - " -l --localreply based on the distance to the previous move [" + to_str(player.localreply) + "]\n" + - " -y --locality to stones near other stones of the same color [" + to_str(player.locality) + "]\n" + - " -c --connect to stones connected to edges/corners [" + to_str(player.connect) + "]\n" + - " -S --size based on the size of the group [" + to_str(player.size) + "]\n" + - " -b --bridge to maintaining a 2-bridge after the op probes [" + to_str(player.bridge) + "]\n" + - " -D --distance to low minimum distance to win (<0 avoid VCs) [" + to_str(player.dists) + "]\n" + - "Rollout policy:\n" + - " -h --weightrand Weight the moves according to computed gammas [" + to_str(player.weightedrandom) + "]\n" + - " -R --ringdepth Check for rings for this depth, < 0 for % moves [" + to_str(player.checkringdepth) + "]\n" + - " -G --ringperm Num stones placed before rollout to form a ring [" + to_str(player.ringperm) + "]\n" + - " -p --pattern Maintain the virtual connection pattern [" + to_str(player.rolloutpattern) + "]\n" + - " -g --goodreply Reuse the last good reply (1), remove losses (2) [" + to_str(player.lastgoodreply) + "]\n" + - " -w --instantwin Look for instant wins to this depth [" + to_str(player.instantwin) + "]\n" - ); - - string errs; - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - player.numthreads = from_str(args[++i]); - bool p = player.ponder; - player.set_ponder(false); //stop the threads while resetting them - player.reset_threads(); - player.set_ponder(p); - }else if((arg == "-o" || arg == "--ponder") && i+1 < args.size()){ - player.set_ponder(from_str(args[++i])); - }else if((arg == "--profile") && i+1 < args.size()){ - player.profile = from_str(args[++i]); - }else if((arg == "-M" || arg == "--maxmem") && i+1 < args.size()){ - player.maxmem = from_str(args[++i])*1024*1024; - }else if((arg == "-E" || arg == "--msexplore") && i+1 < args.size()){ - player.msexplore = from_str(args[++i]); - }else if((arg == "-F" || arg == "--msrave") && i+1 < args.size()){ - player.msrave = from_str(args[++i]); - }else if((arg == "-e" || arg == "--explore") && i+1 < args.size()){ - player.explore = from_str(args[++i]); - }else if((arg == "-A" || arg == "--parexplore") && i+1 < args.size()){ - player.parentexplore = from_str(args[++i]); - }else if((arg == "-f" || arg == "--ravefactor") && i+1 < args.size()){ - player.ravefactor = from_str(args[++i]); - }else if((arg == "-d" || arg == "--decrrave") && i+1 < args.size()){ - player.decrrave = from_str(args[++i]); - }else if((arg == "-a" || arg == "--knowledge") && i+1 < args.size()){ - player.knowledge = from_str(args[++i]); - }else if((arg == "-s" || arg == "--shortrave") && i+1 < args.size()){ - player.shortrave = from_str(args[++i]); - }else if((arg == "-k" || arg == "--keeptree") && i+1 < args.size()){ - player.keeptree = from_str(args[++i]); - }else if((arg == "-m" || arg == "--minimax") && i+1 < args.size()){ - player.minimax = from_str(args[++i]); - }else if((arg == "-T" || arg == "--detectdraw") && i+1 < args.size()){ - player.detectdraw = from_str(args[++i]); - }else if((arg == "-P" || arg == "--symmetry") && i+1 < args.size()){ - player.prunesymmetry = from_str(args[++i]); - }else if(( arg == "--gcsolved") && i+1 < args.size()){ - player.gcsolved = from_str(args[++i]); - }else if((arg == "-r" || arg == "--userave") && i+1 < args.size()){ - player.userave = from_str(args[++i]); - }else if((arg == "-X" || arg == "--useexplore") && i+1 < args.size()){ - player.useexplore = from_str(args[++i]); - }else if((arg == "-u" || arg == "--fpurgency") && i+1 < args.size()){ - player.fpurgency = from_str(args[++i]); - }else if((arg == "-O" || arg == "--rollouts") && i+1 < args.size()){ - player.rollouts = from_str(args[++i]); - if(player.gclimit < player.rollouts*5) - player.gclimit = player.rollouts*5; - }else if((arg == "-I" || arg == "--dynwiden") && i+1 < args.size()){ - player.dynwiden = from_str(args[++i]); - player.logdynwiden = std::log(player.dynwiden); - }else if((arg == "-x" || arg == "--visitexpand") && i+1 < args.size()){ - player.visitexpand = from_str(args[++i]); - }else if((arg == "-l" || arg == "--localreply") && i+1 < args.size()){ - player.localreply = from_str(args[++i]); - }else if((arg == "-y" || arg == "--locality") && i+1 < args.size()){ - player.locality = from_str(args[++i]); - }else if((arg == "-c" || arg == "--connect") && i+1 < args.size()){ - player.connect = from_str(args[++i]); - }else if((arg == "-S" || arg == "--size") && i+1 < args.size()){ - player.size = from_str(args[++i]); - }else if((arg == "-b" || arg == "--bridge") && i+1 < args.size()){ - player.bridge = from_str(args[++i]); - }else if((arg == "-D" || arg == "--distance") && i+1 < args.size()){ - player.dists = from_str(args[++i]); - }else if((arg == "-h" || arg == "--weightrand") && i+1 < args.size()){ - player.weightedrandom = from_str(args[++i]); - }else if((arg == "-R" || arg == "--ringdepth") && i+1 < args.size()){ - player.checkringdepth = from_str(args[++i]); - }else if((arg == "-G" || arg == "--ringperm") && i+1 < args.size()){ - player.ringperm = from_str(args[++i]); - }else if((arg == "-p" || arg == "--pattern") && i+1 < args.size()){ - player.rolloutpattern = from_str(args[++i]); - }else if((arg == "-g" || arg == "--goodreply") && i+1 < args.size()){ - player.lastgoodreply = from_str(args[++i]); - }else if((arg == "-w" || arg == "--instantwin") && i+1 < args.size()){ - player.instantwin = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - return GTPResponse(true, errs); -} - -GTPResponse GTP::gtp_player_gammas(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "Must pass the filename of a set of gammas"); - - ifstream ifs(args[0].c_str()); - - if(!ifs.good()) - return GTPResponse(false, "Failed to open file for reading"); - - Board board = *hist; - - for(int i = 0; i < 4096; i++){ - int a; - float f; - ifs >> a >> f; - - if(i != a){ - ifs.close(); - return GTPResponse(false, "Line " + to_str(i) + " doesn't match the expected value"); - } - - int s = board.pattern_symmetry(i); - if(s == i) - player.gammas[i] = exp(f); - else - player.gammas[i] = player.gammas[s]; - } - - ifs.close(); - return GTPResponse(true); -} diff --git a/havannah/gtpsolver.cpp b/havannah/gtpsolver.cpp deleted file mode 100644 index 8f15594..0000000 --- a/havannah/gtpsolver.cpp +++ /dev/null @@ -1,343 +0,0 @@ - - -#include "gtp.h" - -string GTP::solve_str(int outcome) const { - switch(outcome){ - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - default: return "unknown"; - } -} - -string GTP::solve_str(const Solver & solve){ - string ret = ""; - ret += solve_str(solve.outcome) + " "; - ret += solve.bestmove.to_s() + " "; - ret += to_str(solve.maxdepth) + " "; - ret += to_str(solve.nodes_seen); - return ret; -} - - -GTPResponse GTP::gtp_solve_ab(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverab.solve(time); - - logerr("Finished in " + to_str(solverab.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverab)); -} - -GTPResponse GTP::gtp_solve_ab_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the alpha-beta solver settings, eg: ab_params -m 100 -s 1 -d 3\n" - " -m --memory Memory limit in Mb (0 to disable the TT) [" + to_str(solverab.memlimit/(1024*1024)) + "]\n" - " -s --scout Whether to scout ahead for the true minimax value [" + to_str(solverab.scout) + "]\n" - " -d --depth Starting depth [" + to_str(solverab.startdepth) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - solverab.set_memlimit(mem); - }else if((arg == "-s" || arg == "--scout") && i+1 < args.size()){ - solverab.scout = from_str(args[++i]); - }else if((arg == "-d" || arg == "--depth") && i+1 < args.size()){ - solverab.startdepth = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_ab_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - int value; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - value = solverab.tt_get(board.test_hash(*move)); - - s += move->to_s() + "," + to_str(value) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_ab_clear(vecstr args){ - solverab.clear_mem(); - return true; -} - - - -GTPResponse GTP::gtp_solve_pns(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns.solve(time); - - logerr("Finished in " + to_str(solverpns.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns)); -} - -GTPResponse GTP::gtp_solve_pns_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns.memlimit/(1024*1024)) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(solverpns.ties) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns.set_memlimit(mem*1024*1024); - }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - solverpns.ties = from_str(args[++i]); - solverpns.clear_mem(); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns_stats(vecstr args){ - string s = ""; - - SolverPNS::PNSNode * node = &(solverpns.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns_clear(vecstr args){ - solverpns.clear_mem(); - return true; -} - - -GTPResponse GTP::gtp_solve_pns2(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns2.solve(time); - - logerr("Finished in " + to_str(solverpns2.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns2)); -} - -GTPResponse GTP::gtp_solve_pns2_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns2.memlimit/(1024*1024)) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(solverpns2.ties) + "]\n" - " -t --threads How many threads to run [" + to_str(solverpns2.numthreads) + "]\n" -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns2.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns2.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns2.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns2.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - solverpns2.numthreads = from_str(args[++i]); - solverpns2.reset_threads(); - }else if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns2.set_memlimit(mem*1024*1024); - }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - solverpns2.ties = from_str(args[++i]); - solverpns2.clear_mem(); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns2.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns2.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns2.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns2.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns2_stats(vecstr args){ - string s = ""; - - SolverPNS2::PNSNode * node = &(solverpns2.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS2::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS2::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "," + to_str(child->work) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns2_clear(vecstr args){ - solverpns2.clear_mem(); - return true; -} - - - - -GTPResponse GTP::gtp_solve_pnstt(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpnstt.solve(time); - - logerr("Finished in " + to_str(solverpnstt.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpnstt)); -} - -GTPResponse GTP::gtp_solve_pnstt_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pnstt solver settings, eg: pnstt_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpnstt.memlimit/(1024*1024)) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(solverpnstt.ties) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpnstt.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpnstt.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpnstt.ab) + "]\n" - " -c --copy Try to copy a proof to this many siblings, <0 quit early [" + to_str(solverpnstt.copyproof) + "]\n" -// " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpnstt.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpnstt.set_memlimit(mem*1024*1024); - }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - solverpnstt.ties = from_str(args[++i]); - solverpnstt.clear_mem(); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpnstt.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpnstt.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpnstt.ab = from_str(args[++i]); - }else if((arg == "-c" || arg == "--copy") && i+1 < args.size()){ - solverpnstt.copyproof = from_str(args[++i]); -// }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ -// solverpnstt.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pnstt_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - SolverPNSTT::PNSNode * child = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - child = solverpnstt.tt(board, *move); - - s += move->to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pnstt_clear(vecstr args){ - solverpnstt.clear_mem(); - return true; -} diff --git a/havannah/history.h b/havannah/history.h deleted file mode 100644 index 00ccd06..0000000 --- a/havannah/history.h +++ /dev/null @@ -1,70 +0,0 @@ - -#pragma once - -#include - -#include "../lib/string.h" - -#include "board.h" -#include "move.h" - -class History { - std::vector hist; - Board board; - -public: - - History() { } - History(const Board & b) : board(b) { } - - const Move & operator [] (int i) const { - return hist[i]; - } - - Move last() const { - if(hist.size() == 0) - return M_NONE; - - return hist.back(); - } - - const Board & operator * () const { return board; } - const Board * operator -> () const { return & board; } - - std::vector::const_iterator begin() const { return hist.begin(); } - std::vector::const_iterator end() const { return hist.end(); } - - const Board get_board() const { - Board b(board.get_size()); - for(auto m : hist) - b.move(m); - return b; - } - - int len() const { - return hist.size(); - } - - void clear() { - hist.clear(); - board = get_board(); - } - - bool undo() { - if(hist.size() <= 0) - return false; - - hist.pop_back(); - board = get_board(); - return true; - } - - bool move(const Move & m) { - if(board.valid_move(m)){ - board.move(m); - hist.push_back(m); - return true; - } - return false; - } -}; diff --git a/havannah/lbdist.h b/havannah/lbdist.h index 4e8b000..28709c7 100644 --- a/havannah/lbdist.h +++ b/havannah/lbdist.h @@ -10,19 +10,22 @@ Increase distance when crossing an opponent virtual connection? Decrease distance when crossing your own virtual connection? */ +#include "../lib/move.h" #include "board.h" -#include "move.h" + + +namespace Morat { +namespace Havannah { class LBDists { struct MoveDist { - Move pos; + MoveValid pos; int dist; int dir; MoveDist() { } - MoveDist(Move p, int d, int r) : pos(p), dist(d), dir(r) { } - MoveDist(int x, int y, int d, int r) : pos(Move(x,y)), dist(d), dir(r) { } + MoveDist(MoveValid p, int d, int r) : pos(p), dist(d), dir(r) { } }; //a specialized priority queue @@ -70,15 +73,18 @@ class LBDists { IntPQueue Q; const Board * board; - int & dist(int edge, int player, int i) { return dists[edge][player-1][i]; } - int & dist(int edge, int player, const Move & m) { return dist(edge, player, board->xy(m)); } - int & dist(int edge, int player, int x, int y) { return dist(edge, player, board->xy(x, y)); } - - void init(int x, int y, int edge, int player, int dir){ - int val = board->get(x, y); - if(val != 3 - player){ - Q.push(MoveDist(x, y, (val == 0), dir)); - dist(edge, player, x, y) = (val == 0); + int & dist(int edge, Side player, int i) { return dists[edge][player.to_i() - 1][i]; } + int & dist(int edge, Side player, const MoveValid & m) { return dist(edge, player, m.xy); } + int & dist(int edge, Side player, const Move & m) { return dist(edge, player, board->xy(m)); } + int & dist(int edge, Side player, int x, int y) { return dist(edge, player, board->xy(x, y)); } + + void init(int x, int y, int edge, Side player, int dir){ + Side val = board->get(x, y); + if(val != ~player){ + bool empty = (val == Side::NONE); + MoveValid move(x, y, board->xy(x, y)); + Q.push(MoveDist(move, empty, dir)); + dist(edge, player, move) = empty; } } @@ -87,7 +93,7 @@ class LBDists { LBDists() : board(NULL) {} LBDists(const Board * b) { run(b); } - void run(const Board * b, bool crossvcs = true, int side = 0) { + void run(const Board * b, bool crossvcs = true, Side side = Side::BOTH) { board = b; for(int i = 0; i < 12; i++) @@ -95,56 +101,54 @@ class LBDists { for(int k = 0; k < board->vecsize(); k++) dists[i][j][k] = maxdist; //far far away! + if((side & Side::P1) == Side::P1) init_player(crossvcs, Side::P1); + if((side & Side::P2) == Side::P2) init_player(crossvcs, Side::P2); + } + + void init_player(bool crossvcs, Side player){ int m = board->get_size()-1, e = board->get_size_d()-1; - int start, end; - if(side){ start = end = side; } - else { start = 1; end = 2; } - - for(int player = start; player <= end; player++){ - init(0, 0, 0, player, 3); flood(0, player, crossvcs); //corner 0 - init(m, 0, 1, player, 4); flood(1, player, crossvcs); //corner 1 - init(e, m, 2, player, 5); flood(2, player, crossvcs); //corner 2 - init(e, e, 3, player, 0); flood(3, player, crossvcs); //corner 3 - init(m, e, 4, player, 1); flood(4, player, crossvcs); //corner 4 - init(0, m, 5, player, 2); flood(5, player, crossvcs); //corner 5 - - for(int x = 1; x < m; x++) { init(x, 0, 6, player, 3+(x==1)); } flood(6, player, crossvcs); //edge 0 - for(int y = 1; y < m; y++) { init(m+y, y, 7, player, 4+(y==1)); } flood(7, player, crossvcs); //edge 1 - for(int y = m+1; y < e; y++) { init(e, y, 8, player, 5+(y==m+1)); } flood(8, player, crossvcs); //edge 2 - for(int x = m+1; x < e; x++) { init(x, e, 9, player, 0+(x==e-1)); } flood(9, player, crossvcs); //edge 3 - for(int x = 1; x < m; x++) { init(x, m+x, 10, player, 1+(x==m-1)); } flood(10, player, crossvcs); //edge 4 - for(int y = 1; y < m; y++) { init(0, y, 11, player, 2+(y==m-1)); } flood(11, player, crossvcs); //edge 5 - } + init(0, 0, 0, player, 3); flood(0, player, crossvcs); //corner 0 + init(m, 0, 1, player, 4); flood(1, player, crossvcs); //corner 1 + init(e, m, 2, player, 5); flood(2, player, crossvcs); //corner 2 + init(e, e, 3, player, 0); flood(3, player, crossvcs); //corner 3 + init(m, e, 4, player, 1); flood(4, player, crossvcs); //corner 4 + init(0, m, 5, player, 2); flood(5, player, crossvcs); //corner 5 + + for(int x = 1; x < m; x++) { init(x, 0, 6, player, 3+(x==1)); } flood(6, player, crossvcs); //edge 0 + for(int y = 1; y < m; y++) { init(m+y, y, 7, player, 4+(y==1)); } flood(7, player, crossvcs); //edge 1 + for(int y = m+1; y < e; y++) { init(e, y, 8, player, 5+(y==m+1)); } flood(8, player, crossvcs); //edge 2 + for(int x = m+1; x < e; x++) { init(x, e, 9, player, 0+(x==e-1)); } flood(9, player, crossvcs); //edge 3 + for(int x = 1; x < m; x++) { init(x, m+x, 10, player, 1+(x==m-1)); } flood(10, player, crossvcs); //edge 4 + for(int y = 1; y < m; y++) { init(0, y, 11, player, 2+(y==m-1)); } flood(11, player, crossvcs); //edge 5 } - void flood(int edge, int player, bool crossvcs){ - int otherplayer = 3 - player; + void flood(int edge, Side player, bool crossvcs){ + Side otherplayer = ~player; MoveDist cur; while(Q.pop(cur)){ for(int i = 5; i <= 7; i++){ int nd = (cur.dir + i) % 6; - MoveDist next(cur.pos + neighbours[nd], cur.dist, nd); + MoveDist next(board->nb_begin(cur.pos)[nd], cur.dist, nd); if(board->onboard(next.pos)){ - int pos = board->xy(next.pos); - int colour = board->get(pos); + Side colour = board->get(next.pos); if(colour == otherplayer) continue; - if(colour == 0){ + if(colour == Side::NONE){ if(!crossvcs && //forms a vc - board->get(cur.pos + neighbours[(nd - 1) % 6]) == otherplayer && - board->get(cur.pos + neighbours[(nd + 1) % 6]) == otherplayer) + board->get(board->nb_begin(cur.pos)[(nd - 1) % 6]) == otherplayer && + board->get(board->nb_begin(cur.pos)[(nd + 1) % 6]) == otherplayer) continue; next.dist++; } - if( dist(edge, player, pos) > next.dist){ - dist(edge, player, pos) = next.dist; + if( dist(edge, player, next.pos) > next.dist){ + dist(edge, player, next.pos) = next.dist; if(next.dist < board->get_size()) Q.push(next); } @@ -153,27 +157,34 @@ class LBDists { } } - int isdraw(){ - int outcome = 0; + Outcome isdraw(){ + Outcome outcome = Outcome::DRAW; // assume neither side can win for(int y = 0; y < board->get_size_d(); y++){ for(int x = board->linestart(y); x < board->lineend(y); x++){ - Move pos(x,y); + MoveValid pos(x, y, board->xy(x, y)); - if(board->encirclable(pos, 1) || get(pos, 1) < maxdist-5) - outcome |= 1; - if(board->encirclable(pos, 2) || get(pos, 2) < maxdist-5) - outcome |= 2; + if(board->encirclable(pos, Side::P1) || get(pos, Side::P1) < maxdist-5) + outcome |= Outcome::P1; // P1 can win + if(board->encirclable(pos, Side::P2) || get(pos, Side::P2) < maxdist-5) + outcome |= Outcome::P2; // P2 can win - if(outcome == 3) - return -3; + if(outcome == Outcome::DRAW2) // both can win + return Outcome::UNKNOWN; // so nothing is known } } - return -outcome; + return -outcome; // this isn't certainty, so negate + } + + std::string get_s(Move pos, Side side) { // for use by Board::to_s + int dist = (side == Side::NONE ? get(pos) : get(pos, side)); + return (dist < 10 ? to_str(dist) : "."); } - int get(Move pos){ return min(get(pos, 1), get(pos, 2)); } - int get(Move pos, int player){ return get(board->xy(pos), player); } - int get(int pos, int player){ + int get(Move pos){ return get(MoveValid(pos, board->xy(pos))); } + int get(MoveValid pos){ return std::min(get(pos, Side::P1), get(pos, Side::P2)); } + int get(Move pos, Side player) { return get(board->xy(pos), player); } + int get(MoveValid pos, Side player) { return get(pos.xy, player); } + int get(int pos, Side player){ int list[6]; for(int i = 0; i < 6; i++) list[i] = dist(i, player, pos); @@ -185,7 +196,7 @@ class LBDists { partialsort(list, 3); int edges = list[0] + list[1] + list[2] - 2; - return min(corners, edges); + return std::min(corners, edges); } //partially sort the list with selection sort @@ -205,3 +216,6 @@ class LBDists { } } }; + +}; // namespace Havannah +}; // namespace Morat diff --git a/havannah/main.cpp b/havannah/main.cpp new file mode 100644 index 0000000..094b467 --- /dev/null +++ b/havannah/main.cpp @@ -0,0 +1,63 @@ + +#include +#include + +#include "../lib/time.h" + +#include "gtp.h" + + +using namespace Morat; +using namespace Havannah; + +using namespace std; + +void die(int code, const string & str){ + printf("%s\n", str.c_str()); + exit(code); +} + +int main(int argc, char **argv){ + + srand(Time().in_usec()); + GTP gtp; + + gtp.colorboard = isatty(fileno(stdout)); + + for(int i = 1; i < argc; i++) { + string arg = argv[i]; + if(arg == "-h" || arg == "--help"){ + die(255, "Usage:\n" + "\t-h --help Show this help\n" + "\t-v --verbose Give more output over gtp\n" + "\t-n --nocolor Don't output the board in color\n" + "\t-c --cmd Pass a gtp command from the command line\n" + "\t-f --file Run this gtp file before reading from stdin\n" + ); + }else if(arg == "-v" || arg == "--verbose"){ + gtp.verbose = true; + }else if(arg == "-n" || arg == "--nocolor"){ + gtp.colorboard = false; + }else if(arg == "-c" || arg == "--cmd"){ + char * ptr = argv[++i]; + if(ptr == NULL) die(255, "Missing a command"); + gtp.cmd(ptr); + }else if(arg == "-f" || arg == "--file"){ + char * ptr = argv[++i]; + if(ptr == NULL) die(255, "Missing a file to run"); + FILE * fd = fopen(ptr, "r"); + gtp.setinfile(fd); + gtp.setoutfile(NULL); + if(!gtp.run()) + return 0; + fclose(fd); + }else{ + die(255, "Unknown argument: " + arg + ", try --help"); + } + } + + gtp.setinfile(stdin); + gtp.setoutfile(stdout); + gtp.run(); + return 0; +} diff --git a/havannah/move.h b/havannah/move.h deleted file mode 100644 index 84cf035..0000000 --- a/havannah/move.h +++ /dev/null @@ -1,91 +0,0 @@ - -#pragma once - -#include -#include - -#include "../lib/string.h" - -enum MoveSpecial { - M_SWAP = -1, //-1 so that adding 1 makes it into a valid move - M_RESIGN = -2, - M_NONE = -3, - M_UNKNOWN = -4, -}; - -struct Move { - int8_t y, x; - - Move(MoveSpecial a = M_UNKNOWN) : y(a), x(120) { } //big x so it will always wrap to y=0 with swap - Move(int X, int Y) : y(Y), x(X) { } - - Move(const std::string & str){ - if( str == "swap" ){ y = M_SWAP; x = 120; } - else if(str == "resign" ){ y = M_RESIGN; x = 120; } - else if(str == "none" ){ y = M_NONE; x = 120; } - else if(str == "unknown"){ y = M_UNKNOWN; x = 120; } - else{ - y = tolower(str[0]) - 'a'; - x = atoi(str.c_str() + 1) - 1; - } - } - - std::string to_s() const { - if(y == M_UNKNOWN) return "unknown"; - if(y == M_NONE) return "none"; - if(y == M_SWAP) return "swap"; - if(y == M_RESIGN) return "resign"; - - return std::string() + char(y + 'a') + to_str(x + 1); - } - - bool operator< (const Move & b) const { return (y == b.y ? x < b.x : y < b.y); } - bool operator<=(const Move & b) const { return (y == b.y ? x <= b.x : y <= b.y); } - bool operator> (const Move & b) const { return (y == b.y ? x > b.x : y > b.y); } - bool operator>=(const Move & b) const { return (y == b.y ? x >= b.x : y >= b.y); } - bool operator==(const MoveSpecial & b) const { return (y == b); } - bool operator==(const Move & b) const { return (y == b.y && x == b.x); } - bool operator!=(const Move & b) const { return (y != b.y || x != b.x); } - bool operator!=(const MoveSpecial & b) const { return (y != b); } - Move operator+ (const Move & b) const { return Move(x + b.x, y + b.y); } - Move & operator+=(const Move & b) { y += b.y; x += b.x; return *this; } - Move operator- (const Move & b) const { return Move(x - b.x, y - b.y); } - Move & operator-=(const Move & b) { y -= b.y; x -= b.x; return *this; } - - int z() const { return (x - y); } - int dist(const Move & b) const { - return (abs(x - b.x) + abs(y - b.y) + abs(z() - b.z()))/2; - } -}; - -struct MoveScore : public Move { - int16_t score; - - MoveScore() : score(0) { } - MoveScore(MoveSpecial a) : Move(a), score(0) { } - MoveScore(int X, int Y, int s) : Move(X, Y), score(s) { } - MoveScore operator+ (const Move & b) const { return MoveScore(x + b.x, y + b.y, score); } -}; - -struct MoveValid : public Move { - int16_t xy; - - MoveValid() : Move(), xy(-1) { } - MoveValid(int x, int y, int XY) : Move(x,y), xy(XY) { } - MoveValid(const Move & m, int XY) : Move(m), xy(XY) { } - bool onboard() const { return xy != -1; } -}; - -struct MovePlayer : public Move { - char player; - - MovePlayer() : Move(M_UNKNOWN), player(0) { } - MovePlayer(const Move & m, char p = 0) : Move(m), player(p) { } -}; - - -struct PairMove { - Move a, b; - PairMove(Move A = M_UNKNOWN, Move B = M_UNKNOWN) : a(A), b(B) { } - PairMove(MoveSpecial A) : a(Move(A)), b(M_UNKNOWN) { } -}; diff --git a/havannah/movelist.h b/havannah/movelist.h deleted file mode 100644 index 27c22de..0000000 --- a/havannah/movelist.h +++ /dev/null @@ -1,76 +0,0 @@ - -#pragma once - -#include "../lib/exppair.h" - -#include "board.h" -#include "move.h" - -struct MoveList { - ExpPair exp[2]; //aggregated outcomes overall - ExpPair rave[2][Board::max_vecsize]; //aggregated outcomes per move - MovePlayer moves[Board::max_vecsize]; //moves made in order - int tree; //number of moves in the tree - int rollout; //number of moves in the rollout - Board * board; //reference to rootboard for xy() - - MoveList() : tree(0), rollout(0), board(NULL) { } - - void addtree(const Move & move, char player){ - moves[tree++] = MovePlayer(move, player); - } - void addrollout(const Move & move, char player){ - moves[tree + rollout++] = MovePlayer(move, player); - } - void reset(Board * b){ - tree = 0; - rollout = 0; - board = b; - exp[0].clear(); - exp[1].clear(); - for(int i = 0; i < b->vecsize(); i++){ - rave[0][i].clear(); - rave[1][i].clear(); - } - } - void finishrollout(int won){ - exp[0].addloss(); - exp[1].addloss(); - if(won == 0){ - exp[0].addtie(); - exp[1].addtie(); - }else{ - exp[won-1].addwin(); - - for(MovePlayer * i = begin(), * e = end(); i != e; i++){ - ExpPair & r = rave[i->player-1][board->xy(*i)]; - r.addloss(); - if(i->player == won) - r.addwin(); - } - } - rollout = 0; - } - const MovePlayer * begin() const { - return moves; - } - MovePlayer * begin() { - return moves; - } - const MovePlayer * end() const { - return moves + tree + rollout; - } - MovePlayer * end() { - return moves + tree + rollout; - } - void subvlosses(int n){ - exp[0].addlosses(-n); - exp[1].addlosses(-n); - } - const ExpPair & getrave(int player, const Move & move) const { - return rave[player-1][board->xy(move)]; - } - const ExpPair & getexp(int player) const { - return exp[player-1]; - } -}; diff --git a/havannah/player.cpp b/havannah/player.cpp deleted file mode 100644 index 8276834..0000000 --- a/havannah/player.cpp +++ /dev/null @@ -1,509 +0,0 @@ - -#include -#include - -#include "../lib/alarm.h" -#include "../lib/fileio.h" -#include "../lib/string.h" -#include "../lib/time.h" - -#include "board.h" -#include "player.h" - -const float Player::min_rave = 0.1; - -void Player::PlayerThread::run(){ - while(true){ - switch(player->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_Start, Thread_Running); - CAS(player->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(player->rootboard.won() >= 0 || player->root.outcome >= 0 || (player->maxruns > 0 && player->runs >= player->maxruns)){ //solved or finished runs - if(CAS(player->threadstate, Thread_Running, Thread_Wait_End) && player->root.outcome >= 0) - logerr("Solved as " + to_str((int)player->root.outcome) + "\n"); - break; - } - if(player->ctmem.memalloced() >= player->maxmem){ //out of memory, start garbage collection - CAS(player->threadstate, Thread_Running, Thread_GC); - break; - } - - INCR(player->runs); - iterate(); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(player->gcbarrier.wait()){ - Time starttime; - logerr("Starting player GC with limit " + to_str(player->gclimit) + " ... "); - uint64_t nodesbefore = player->nodes; - Board copy = player->rootboard; - player->garbage_collect(copy, & player->root); - Time gctime; - player->ctmem.compact(1.0, 0.75); - Time compacttime; - logerr(to_str(100.0*player->nodes/nodesbefore, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(player->ctmem.meminuse() >= player->maxmem/2) - player->gclimit = (int)(player->gclimit*1.3); - else if(player->gclimit > player->rollouts*5) - player->gclimit = (int)(player->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(player->threadstate, Thread_GC, Thread_Running); - CAS(player->threadstate, Thread_GC_End, Thread_Wait_End); - } - player->gcbarrier.wait(); - break; - } - } -} - -Player::Node * Player::genmove(double time, int max_runs, bool flexible){ - time_used = 0; - int toplay = rootboard.toplay(); - - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) - return NULL; - - Time starttime; - - stop_threads(); - - if(runs) - logerr("Pondered " + to_str(runs) + " runs\n"); - - runs = 0; - maxruns = max_runs; - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); - - // if the move is forced and the time can be added to the clock, don't bother running at all - if(!flexible || root.children.num() != 1){ - //let them run! - start_threads(); - - Alarm timer; - if(time > 0) - timer(time - (Time() - starttime), std::bind(&Player::timedout, this)); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } - - if(ponder && root.outcome < 0) - start_threads(); - - time_used = Time() - starttime; - -//return the best one - return return_move(& root, toplay); -} - - - -Player::Player() { - nodes = 0; - gclimit = 5; - time_used = 0; - - profile = false; - ponder = false; -//#ifdef SINGLE_THREAD ... make sure only 1 thread - numthreads = 1; - maxmem = 1000*1024*1024; - - msrave = -2; - msexplore = 0; - - explore = 0; - parentexplore = false; - ravefactor = 500; - decrrave = 200; - knowledge = true; - userave = 1; - useexplore = 1; - fpurgency = 1; - rollouts = 1; - dynwiden = 0; - logdynwiden = (dynwiden ? std::log(dynwiden) : 0); - - shortrave = false; - keeptree = true; - minimax = 2; - detectdraw = false; - visitexpand = 1; - prunesymmetry = false; - gcsolved = 100000; - - localreply = 0; - locality = 0; - connect = 20; - size = 0; - bridge = 25; - dists = 0; - - weightedrandom = false; - checkringdepth = 1000; - ringperm = 0; - rolloutpattern = false; - lastgoodreply = false; - instantwin = 0; - - for(int i = 0; i < 4096; i++) - gammas[i] = 1; - - //no threads started until a board is set - threadstate = Thread_Wait_Start; -} -Player::~Player(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); -} -void Player::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); -} - -string Player::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void Player::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void Player::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void Player::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++){ - threads[i]->join(); - delete threads[i]; - } - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new PlayerUCT(this)); -} - -void Player::set_ponder(bool p){ - if(ponder != p){ - ponder = p; - stop_threads(); - - if(ponder) - start_threads(); - } -} - -void Player::set_board(const Board & board){ - stop_threads(); - - nodes -= root.dealloc(ctmem); - root = Node(); - root.exp.addwins(visitexpand+1); - - rootboard = board; - - reset_threads(); //needed since the threads aren't started before a board it set - - if(ponder) - start_threads(); -} -void Player::move(const Move & m){ - stop_threads(); - - uword nodesbefore = nodes; - - if(keeptree && root.children.num() > 0){ - Node child; - - for(Node * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr("Nodes before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - }else{ - nodes -= root.dealloc(ctmem); - root = Node(); - root.move = m; - } - assert(nodes == root.size()); - - rootboard.move(m); - - root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; - - if(ponder) - start_threads(); -} - -double Player::gamelen(){ - DepthStats len; - for(unsigned int i = 0; i < threads.size(); i++) - len += threads[i]->gamelen; - return len.avg(); -} - -vector Player::get_pv(){ - vector pv; - - Node * r, * n = & root; - char turn = rootboard.toplay(); - while(!n->children.empty()){ - r = return_move(n, turn); - if(!r) break; - pv.push_back(r->move); - turn = 3 - turn; - n = r; - } - - if(pv.size() == 0) - pv.push_back(Move(M_RESIGN)); - - return pv; -} - -Player::Node * Player::return_move(Node * node, int toplay) const { - double val, maxval = -1000000000000.0; //1 trillion - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss - }else{ //not proven - if(msrave == -1) //num simulations - val = child->exp.num(); - else if(msrave == -2) //num wins - val = child->exp.sum(); - else - val = child->value(msrave, 0, 0) - msexplore*sqrt(log(node->exp.num())/(child->exp.num() + 1)); - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - -//set bestmove, but don't touch outcome, if it's solved that will already be set, otherwise it shouldn't be set - if(ret){ - node->bestmove = ret->move; - }else if(node->bestmove == M_UNKNOWN){ - // TODO: Is this needed? -// SolverAB solver; -// solver.set_board(rootboard); -// solver.solve(0.1); -// node->bestmove = solver.bestmove; - } - - assert(node->bestmove != M_UNKNOWN); - - return ret; -} - -void Player::garbage_collect(Board & board, Node * node){ - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - for( ; child != end; child++){ - if(child->children.num() == 0) - continue; - - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved - board.set(child->move); - garbage_collect(board, child); - board.unset(child->move); - }else{ - nodes -= child->dealloc(ctmem); - } - } -} - -Player::Node * Player::find_child(Node * node, const Move & move){ - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - - return NULL; -} - -void Player::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; - } - } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); -} - -void Player::create_children_simple(const Board & board, Node * node){ - assert(node->children.empty()); - - node->children.alloc(board.movesremain(), ctmem); - - Node * child = node->children.begin(), - * end = node->children.end(); - Board::MoveIterator moveit = board.moveit(prunesymmetry); - int nummoves = 0; - for(; !moveit.done() && child != end; ++moveit, ++child){ - *child = Node(*moveit); - nummoves++; - } - - if(prunesymmetry) - node->children.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(moveit.done() && child == end); - - PLUS(nodes, node->children.num()); -} - -//reads the format from gen_hgf. -void Player::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); - - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_char(fd, ')'); - - return; -} diff --git a/havannah/player.h b/havannah/player.h deleted file mode 100644 index eaa9b9c..0000000 --- a/havannah/player.h +++ /dev/null @@ -1,313 +0,0 @@ - -#pragma once - -//A Monte-Carlo Tree Search based player - -#include -#include - -#include "../lib/compacttree.h" -#include "../lib/depthstats.h" -#include "../lib/exppair.h" -#include "../lib/log.h" -#include "../lib/thread.h" -#include "../lib/time.h" -#include "../lib/types.h" -#include "../lib/xorshift.h" - -#include "board.h" -#include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" - - -class Player { -public: - - struct Node { - public: - ExpPair rave; - ExpPair exp; - int16_t know; - int8_t outcome; - uint8_t proofdepth; - Move move; - Move bestmove; //if outcome is set, then bestmove is the way to get there - CompactTree::Children children; -// int padding; - //seems to need padding to multiples of 8 bytes or it segfaults? - //don't forget to update the copy constructor/operator - - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } - Node(const Node & n) { *this = n; } - Node & operator = (const Node & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - rave = n.rave; - exp = n.exp; - know = n.know; - move = n.move; - bestmove = n.bestmove; - outcome = n.outcome; - proofdepth = n.proofdepth; - //children = n.children; ignore the children, they need to be swap_tree'd in - } - return *this; - } - - void swap_tree(Node & n){ - children.swap(n.children); - } - - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str(outcome) + "/" + to_str(proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } - - unsigned int size() const { - unsigned int num = children.num(); - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - ~Node(){ - assert(children.empty()); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - - //new way, more standard way of changing over from rave scores to real scores - float value(float ravefactor, bool knowledge, float fpurgency){ - float val = fpurgency; - float expnum = exp.num(); - float ravenum = rave.num(); - - if(ravefactor <= min_rave){ - if(expnum > 0) - val = exp.avg(); - }else if(ravenum > 0 || expnum > 0){ - float alpha = ravefactor/(ravefactor + expnum); -// float alpha = sqrt(ravefactor/(ravefactor + 3.0f*expnum)); -// float alpha = ravenum/(expnum + ravenum + expnum*ravenum*ravefactor); - - val = 0; - if(ravenum > 0) val += alpha*rave.avg(); - if(expnum > 0) val += (1.0f-alpha)*exp.avg(); - } - - if(knowledge && know > 0){ - if(expnum <= 1) - val += 0.01f * know; - else if(expnum < 1000) //knowledge is only useful with little experience - val += 0.01f * know / sqrt(expnum); - } - - return val; - } - }; - - class PlayerThread { - protected: - public: - mutable XORShift_uint32 rand32; - mutable XORShift_float unitrand; - Thread thread; - Player * player; - public: - DepthStats treelen, gamelen; - DepthStats wintypes[2][4]; //player,wintype - double times[4]; //time spent in each of the stages - - PlayerThread() {} - virtual ~PlayerThread() { } - virtual void reset() { } - int join(){ return thread.join(); } - void run(); //thread runner, calls iterate on each iteration - virtual void iterate() { } //handles each iteration - }; - - class PlayerUCT : public PlayerThread { - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; - - bool use_rave; //whether to use rave for this simulation - bool use_explore; //whether to use exploration for this simulation - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; - int stage; //which of the four MCTS stages is it on - Time timestamps[4]; //timestamps for the beginning, before child creation, before rollout, after rollout - - public: - PlayerUCT(Player * p) : PlayerThread() { - player = p; - reset(); - thread(bind(&PlayerUCT::run, this)); - } - - void reset(){ - treelen.reset(); - gamelen.reset(); - - use_rave = false; - use_explore = false; - - for(int a = 0; a < 2; a++) - for(int b = 0; b < 4; b++) - wintypes[a][b].reset(); - - for(int a = 0; a < 4; a++) - times[a] = 0; - } - - private: - void iterate(); - void walk_tree(Board & board, Node * node, int depth); - bool create_children(Board & board, Node * node, int toplay); - void add_knowledge(Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); - bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - - int rollout(Board & board, Move move, int depth); - Move rollout_choose_move(Board & board, const Move & prev); - Move rollout_pattern(const Board & board, const Move & move); - }; - - -public: - - static const float min_rave; - - bool ponder; //think during opponents time? - int numthreads; //number of player threads to run - u64 maxmem; //maximum memory for the tree in bytes - bool profile; //count how long is spent in each stage of MCTS -//final move selection - float msrave; //rave factor in final move selection, -1 means use number instead of value - float msexplore; //the UCT constant in final move selection -//tree traversal - bool parentexplore; // whether to multiple exploration by the parents winrate - float explore; //greater than one favours exploration, smaller than one favours exploitation - float ravefactor; //big numbers favour rave scores, small ignore it - float decrrave; //decrease rave over time, add this value for each empty position on the board - bool knowledge; //whether to include knowledge - float userave; //what probability to use rave - float useexplore; //what probability to use UCT exploration - float fpurgency; //what value to return for a move that hasn't been played yet - int rollouts; //number of rollouts to run after the tree traversal - float dynwiden; //dynamic widening, look at first log_dynwiden(experience) number of children, 0 to disable - float logdynwiden; // = log(dynwiden), cached for performance -//tree building - bool shortrave; //only update rave values on short rollouts - bool keeptree; //reuse the tree from the previous move - int minimax; //solve the minimax tree within the uct tree - bool detectdraw; //look for draws early, slow - uint visitexpand;//number of visits before expanding a node - bool prunesymmetry; //prune symmetric children from the move list, useful for proving but likely not for playing - uint gcsolved; //garbage collect solved nodes or keep them in the tree, assuming they meet the required amount of work -//knowledge - int localreply; //boost for a local reply, ie a move near the previous move - int locality; //boost for playing near previous stones - int connect; //boost for having connections to edges and corners - int size; //boost for large groups - int bridge; //boost replying to a probe at a bridge - int dists; //boost based on minimum number of stones needed to finish a non-ring win -//rollout - int weightedrandom; //use weighted random for move ordering based on gammas - float checkringdepth; //how deep to allow rings as a win condition in a rollout - int ringperm; //how many stones in a ring must be in place before the rollout begins - bool rolloutpattern; //play the response to a virtual connection threat in rollouts - int lastgoodreply; //use the last-good-reply rollout heuristic - int instantwin; //how deep to look for instant wins in rollouts - - float gammas[4096]; //pattern weights for weighted random - - Board rootboard; - Node root; - uword nodes; - int gclimit; //the minimum experience needed to not be garbage collected - - uint64_t runs, maxruns; - - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - double time_used; - - Player(); - ~Player(); - - void timedout(); - - string statestring(); - - void stop_threads(); - void start_threads(); - void reset_threads(); - - void set_ponder(bool p); - void set_board(const Board & board); - - void move(const Move & m); - - double gamelen(); - - Node * genmove(double time, int max_runs, bool flexible); - vector get_pv(); - void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - - bool do_backup(Node * node, Node * backup, int toplay); - - Node * find_child(Node * node, const Move & move); - void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); - -protected: - Node * return_move(Node * node, int toplay) const; -}; diff --git a/havannah/playeruct.cpp b/havannah/playeruct.cpp deleted file mode 100644 index 57b3e99..0000000 --- a/havannah/playeruct.cpp +++ /dev/null @@ -1,473 +0,0 @@ - -#include -#include - -#include "../lib/string.h" - -#include "player.h" - -void Player::PlayerUCT::iterate(){ - if(player->profile){ - timestamps[0] = Time(); - stage = 0; - } - - movelist.reset(&(player->rootboard)); - player->root.exp.addvloss(); - Board copy = player->rootboard; - use_rave = (unitrand() < player->userave); - use_explore = (unitrand() < player->useexplore); - walk_tree(copy, & player->root, 0); - player->root.exp.addv(movelist.getexp(3-player->rootboard.toplay())); - - if(player->profile){ - times[0] += timestamps[1] - timestamps[0]; - times[1] += timestamps[2] - timestamps[1]; - times[2] += timestamps[3] - timestamps[2]; - times[3] += Time() - timestamps[3]; - } -} - -void Player::PlayerUCT::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); - - if(!node->children.empty() && node->outcome < 0){ - //choose a child and recurse - Node * child; - do{ - int remain = board.movesremain(); - child = choose_move(node, toplay, remain); - - if(child->outcome < 0){ - movelist.addtree(child->move, toplay); - - if(!board.move(child->move)){ - logerr("move failed: " + child->move.to_s() + "\n" + board.to_s(false)); - assert(false && "move failed"); - } - - child->exp.addvloss(); //balanced out after rollouts - - walk_tree(board, child, depth+1); - - child->exp.addv(movelist.getexp(toplay)); - - if(!player->do_backup(node, child, toplay) && //not solved - player->ravefactor > min_rave && //using rave - node->children.num() > 1 && //not a macro move - 50*remain*(player->ravefactor + player->decrrave*remain) > node->exp.num()) //rave is still significant - update_rave(node, toplay); - - return; - } - }while(!player->do_backup(node, child, toplay)); - - return; - } - - if(player->profile && stage == 0){ - stage = 1; - timestamps[1] = Time(); - } - - int won = (player->minimax ? node->outcome : board.won()); - - //if it's not already decided - if(won < 0){ - //create children if valid - if(node->exp.num() >= player->visitexpand+1 && create_children(board, node, toplay)){ - walk_tree(board, node, depth); - return; - } - - if(player->profile){ - stage = 2; - timestamps[2] = Time(); - } - - //do random game on this node - random_policy.prepare(board); - for(int i = 0; i < player->rollouts; i++){ - Board copy = board; - rollout(copy, node->move, depth); - } - }else{ - movelist.finishrollout(won); //got to a terminal state, it's worth recording - } - - treelen.add(depth); - - movelist.subvlosses(1); - - if(player->profile){ - timestamps[3] = Time(); - if(stage == 1) - timestamps[2] = timestamps[3]; - stage = 3; - } - - return; -} - -bool sort_node_know(const Player::Node & a, const Player::Node & b){ - return (a.know > b.know); -} - -bool Player::PlayerUCT::create_children(Board & board, Node * node, int toplay){ - if(!node->children.lock()) - return false; - - if(player->dists || player->detectdraw){ - dists.run(&board, (player->dists > 0), (player->detectdraw ? 0 : toplay)); - - if(player->detectdraw){ -// assert(node->outcome == -3); - node->outcome = dists.isdraw(); //could be winnable by only one side - - if(node->outcome == 0){ //proven draw, neither side can influence the outcome - node->bestmove = *(board.moveit()); //just choose the first move since all are equal at this point - node->children.unlock(); - return true; - } - } - } - - CompactTree::Children temp; - temp.alloc(board.movesremain(), player->ctmem); - - int losses = 0; - - Node * child = temp.begin(), - * end = temp.end(), - * loss = NULL; - Board::MoveIterator move = board.moveit(player->prunesymmetry); - int nummoves = 0; - for(; !move.done() && child != end; ++move, ++child){ - *child = Node(*move); - - if(player->minimax){ - child->outcome = board.test_win(*move); - - if(player->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ - losses++; - loss = child; - } - - if(child->outcome == toplay){ //proven win from here, don't need children - node->outcome = child->outcome; - node->proofdepth = 1; - node->bestmove = *move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - } - - if(player->knowledge) - add_knowledge(board, node, child); - nummoves++; - } - - if(player->prunesymmetry) - temp.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(move.done() && child == end); - - //Make a macro move, add experience to the move so the current simulation continues past this move - if(losses == 1){ - Node macro = *loss; - temp.dealloc(player->ctmem); - temp.alloc(1, player->ctmem); - macro.exp.addwins(player->visitexpand); - *(temp.begin()) = macro; - }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - toplay; - node->proofdepth = 2; - node->bestmove = loss->move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - - if(player->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); - - PLUS(player->nodes, temp.num()); - node->children.swap(temp); - assert(temp.unlock()); - - return true; -} - -Player::Node * Player::PlayerUCT::choose_move(const Node * node, int toplay, int remain) const { - float val, maxval = -1000000000; - float logvisits = log(node->exp.num()); - int dynwidenlim = (player->dynwiden > 0 ? (int)(logvisits/player->logdynwiden)+2 : 361); - - float raveval = use_rave * (player->ravefactor + player->decrrave*remain); - float explore = use_explore * player->explore; - if(player->parentexplore) - explore *= node->exp.avg(); - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) //return a win immediately - return child; - - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse - }else{ - val = child->value(raveval, player->knowledge, player->fpurgency); - if(explore > 0) - val += explore*sqrt(logvisits/(child->exp.num() + 1)); - dynwidenlim--; - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - - return ret; -} - -/* -backup in this order: - -6 win -5 win/draw -4 draw if draw/loss -3 win/draw/loss -2 draw -1 draw/loss -0 lose -return true if fully solved, false if it's unknown or partially unknown -*/ -bool Player::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread - return true; - - if(backup->outcome == -3) //nothing proven by this child, so no chance - return false; - - - uint8_t proofdepth = backup->proofdepth; - if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; - backup = NULL; - - Node * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions - - if(proofdepth < child->proofdepth+1) - proofdepth = child->proofdepth+1; - - //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss - outcome = 3; - }else if(childoutcome == toplay){ //win - backup = child; - outcome = 6; - proofdepth = child->proofdepth+1; - break; - }else if(childoutcome == 3-toplay){ //loss - outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss - outcome = 4; - else - outcome = 2; - }else if(childoutcome == -toplay){ //win/draw - outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss - outcome = 1; - }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); - assert(false && "How'd I get here? All outcomes should be tested above"); - } - - sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; - bestsims = sims; - backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ - //find long losses or easy wins/draws - bestsims = sims; - backup = child; - } - } - - if(bestoutcome == 3) //no win, but found an unknown - return false; - } - - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ - node->bestmove = backup->move; - node->proofdepth = proofdepth; - }else //if it was in a race, try again, might promote a partial solve to full solve - return do_backup(node, backup, toplay); - - return (node->outcome >= 0); -} - -//update the rave score of all children that were played -void Player::PlayerUCT::update_rave(const Node * node, int toplay){ - Node * child = node->children.begin(), - * childend = node->children.end(); - - for( ; child != childend; ++child) - child->rave.addv(movelist.getrave(toplay, child->move)); -} - -void Player::PlayerUCT::add_knowledge(Board & board, Node * node, Node * child){ - if(player->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); - if(dist < 4) - child->know += player->localreply * (4 - dist); - } - - if(player->locality) //boost for moves near previous stones - child->know += player->locality * board.local(child->move, board.toplay()); - - Board::Cell cell; - if(player->connect || player->size) - cell = board.test_cell(child->move); - - if(player->connect) //boost for moves that connect to edges/corners - child->know += player->connect * (cell.numcorners() + cell.numedges()); - - if(player->size) //boost for size of the group - child->know += player->size * cell.size; - - if(player->bridge && test_bridge_probe(board, node->move, child->move)) //boost for maintaining a virtual connection - child->know += player->bridge; - - if(player->dists) - child->know += abs(player->dists) * max(0, board.get_size_d() - dists.get(child->move, board.toplay())); -} - -//test whether this move is a forced reply to the opponent probing your virtual connections -bool Player::PlayerUCT::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { - //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) - return false; - - bool equals = false; - - int state = 0; - int piece = 3 - board.get(move); - for(int i = 0; i < 8; i++){ - Move cur = move + neighbours[i % 6]; - - bool on = board.onboard(cur); - int v = 0; - if(on) - v = board.get(cur); - - //state machine that progresses when it see the pattern, but counting borders as part of the pattern - if(state == 0){ - if(!on || v == piece) - state = 1; - //else state = 0; - }else if(state == 1){ - if(on){ - if(v == 0){ - state = 2; - equals = (test == cur); - }else if(v != piece) - state = 0; - //else (v==piece) => state = 1; - } - //else state = 1; - }else{ // state == 2 - if(!on || v == piece){ - if(equals) - return true; - state = 1; - }else{ - state = 0; - } - } - } - return false; -} - -/////////////////////////////////////////// - - -//play a random game starting from a board state, and return the results of who won -int Player::PlayerUCT::rollout(Board & board, Move move, int depth){ - int won; - - if(player->instantwin) - instant_wins.rollout_start(board, player->instantwin); - - random_policy.rollout_start(board); - - //only check rings to the specified depth - int checkdepth = (int)player->checkringdepth; - //if it's negative, check for that fraction of the remaining moves - if(player->checkringdepth < 0) - checkdepth = (int)ceil(board.movesremain() * player->checkringdepth * -1); - - board.perm_rings = player->ringperm; - - while((won = board.won()) < 0){ - int turn = board.toplay(); - - board.check_rings = (depth < checkdepth); - - move = rollout_choose_move(board, move); - - movelist.addrollout(move, turn); - - assert2(board.move(move, true, false), "\n" + board.to_s(true) + "\n" + move.to_s()); - depth++; - } - - gamelen.add(depth); - - if(won > 0) - wintypes[won-1][(int)board.getwintype()].add(depth); - - //update the last good reply table - if(player->lastgoodreply) - last_good_reply.rollout_end(board, movelist, won); - - movelist.finishrollout(won); - return won; -} - -Move Player::PlayerUCT::rollout_choose_move(Board & board, const Move & prev){ - //look for instant wins - if(player->instantwin){ - Move move = instant_wins.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //force a bridge reply - if(player->rolloutpattern){ - Move move = protect_bridge.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //reuse the last good reply - if(player->lastgoodreply){ - Move move = last_good_reply.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - return random_policy.choose_move(board, prev); -} diff --git a/havannah/policy_bridge.h b/havannah/policy_bridge.h deleted file mode 100644 index c6f2b8d..0000000 --- a/havannah/policy_bridge.h +++ /dev/null @@ -1,51 +0,0 @@ - - -#pragma once - -#include "../lib/bits.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class ProtectBridge : public Policy { - int offset; - uint8_t lookup[2][1<<12]; - -public: - - ProtectBridge() : offset(0) { - // precompute the valid moves around a pattern for all possible 6-patterns. - for(unsigned int i = 0; i < 1<<12; i++){ - lookup[0][i] = lookup[1][i] = 0; - unsigned int p = i; - for(unsigned int d = 0; d < 6; d++){ - if((p & 0x1D) == 0x11) // 01 11 01 -> 01 00 01 - lookup[0][i] |= (1 << ((d+1)%6)); // +1 because we want to play in the empty spot - if((p & 0x2E) == 0x22) // 10 11 10 -> 10 00 10 - lookup[1][i] |= (1 << ((d+1)%6)); - p = ((p & 0xFFC)>>2) | ((p & 0x3) << 10); - } - } - } - - Move choose_move(const Board & board, const Move & prev) { - uint32_t p = board.pattern_small(prev); - uint16_t r = lookup[board.toplay()-1][p]; - - if(!r) // nothing to save - return M_UNKNOWN; - - unsigned int i; - if((r & (r - 1)) == 0){ // only one bit set - i = trailing_zeros(r); - } else { // multiple choices of bridges to save - offset = (offset + 1) % 6; // rotate the starting offset to avoid directional bias - r |= (r << 6); - r >>= offset; - i = (offset + trailing_zeros(r)) % 6; - } - return board.nb_begin(prev)[i]; - } -}; diff --git a/havannah/policy_lastgoodreply.h b/havannah/policy_lastgoodreply.h deleted file mode 100644 index 144da69..0000000 --- a/havannah/policy_lastgoodreply.h +++ /dev/null @@ -1,42 +0,0 @@ - -# pragma once - -#include "board.h" -#include "move.h" -#include "policy.h" - -class LastGoodReply : public Policy { - Move goodreply[2][Board::max_vecsize]; - int enabled; -public: - - LastGoodReply(int _enabled = 2) : enabled(_enabled) { - for(int p = 0; p < 2; p++) - for(int i = 0; i < Board::max_vecsize; i++) - goodreply[p][i] = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) const { - if (enabled && prev != M_SWAP) { - Move move = goodreply[board.toplay()-1][board.xy(prev)]; - if(move != M_UNKNOWN && board.valid_move_fast(move)) - return move; - } - return M_UNKNOWN; - } - - void rollout_end(const Board & board, const MoveList & movelist, int won) { - if(!enabled || won <= 0) - return; - int m = -1; - for(const MovePlayer * i = movelist.begin(), * e = movelist.end(); i != e; i++){ - if(m >= 0){ - if(i->player == won && *i != M_SWAP) - goodreply[i->player - 1][m] = *i; - else if(enabled == 2) - goodreply[i->player - 1][m] = M_UNKNOWN; - } - m = board.xy(*i); - } - } -}; diff --git a/havannah/policy_random.h b/havannah/policy_random.h deleted file mode 100644 index d84a82a..0000000 --- a/havannah/policy_random.h +++ /dev/null @@ -1,45 +0,0 @@ - -#pragma once - -#include - -#include "../lib/xorshift.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - -class RandomPolicy : public Policy { - XORShift_uint32 rand; - Move moves[Board::max_vecsize]; - int num; - int cur; -public: - - RandomPolicy() : num(0), cur(0) { - } - - // only need to save the valid moves once since all the rollouts start from the same position - void prepare(const Board & board) { - num = 0; - for(Board::MoveIterator m = board.moveit(false); !m.done(); ++m) - moves[num++] = *m; - } - - // reset the set of moves to make from above. Since they're used in random order they don't need to be in iterator order - void rollout_start(Board & board) { - cur = num; - } - - Move choose_move(const Board & board, const Move & prev) { - while(true){ - int r = rand() % cur; - cur--; - Move m = moves[r]; - moves[r] = moves[cur]; - moves[cur] = m; - if(board.valid_move_fast(m)) - return m; - } - } -}; diff --git a/havannah/solver.h b/havannah/solver.h deleted file mode 100644 index d6e6240..0000000 --- a/havannah/solver.h +++ /dev/null @@ -1,68 +0,0 @@ - -#pragma once - -//Interface for the various solvers - -#include "../lib/types.h" - -#include "board.h" - -class Solver { -public: - int outcome; // 0 = tie, 1 = white, 2 = black, -1 = white or tie, -2 = black or tie, anything else unknown - int maxdepth; - uint64_t nodes_seen; - double time_used; - Move bestmove; - - Solver() : outcome(-3), maxdepth(0), nodes_seen(0), time_used(0) { } - virtual ~Solver() { } - - virtual void solve(double time) { } - virtual void set_board(const Board & board, bool clear = true) { } - virtual void move(const Move & m) { } - virtual void set_memlimit(uint64_t lim) { } // in bytes - virtual void clear_mem() { } - -protected: - volatile bool timeout; - void timedout(){ timeout = true; } - Board rootboard; - - static int solve1ply(const Board & board, int & nodes) { - int outcome = -3; - int turn = board.toplay(); - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - } - return outcome; - } - - static int solve2ply(const Board & board, int & nodes) { - int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - - if(board.test_win(*move, opponent) > 0) - losses++; - } - if(losses >= 2) - return opponent; - return outcome; - } - -}; diff --git a/havannah/solverab.cpp b/havannah/solverab.cpp deleted file mode 100644 index 1abdf47..0000000 --- a/havannah/solverab.cpp +++ /dev/null @@ -1,137 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverab.h" - -void SolverAB::solve(double time){ - reset(); - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - if(TT == NULL && maxnodes) - TT = new ABTTNode[maxnodes]; - - Alarm timer(time, std::bind(&SolverAB::timedout, this)); - Time start; - - int turn = rootboard.toplay(); - - for(maxdepth = startdepth; !timeout; maxdepth++){ -// logerr("Starting depth " + to_str(maxdepth) + "\n"); - - //the first depth of negamax - int ret, alpha = -2, beta = 2; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - nodes_seen++; - - Board next = rootboard; - next.move(*move); - - int value = -negamax(next, maxdepth - 1, -beta, -alpha); - - if(value > alpha){ - alpha = value; - bestmove = *move; - } - - if(alpha >= beta){ - ret = beta; - break; - } - } - ret = alpha; - - - if(ret){ - if( ret == -2){ outcome = (turn == 1 ? 2 : 1); bestmove = Move(M_NONE); } - else if(ret == 2){ outcome = turn; } - else /*-1 || 1*/ { outcome = 0; } - - break; - } - } - - time_used = Time() - start; -} - - -int SolverAB::negamax(const Board & board, const int depth, int alpha, int beta){ - if(board.won() >= 0) - return (board.won() ? -2 : -1); - - if(depth <= 0 || timeout) - return 0; - - int b = beta; - int first = true; - int value, losses = 0; - static const int lookup[6] = {0, 0, 0, 1, 2, 2}; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - nodes_seen++; - - hash_t hash = board.test_hash(*move); - if(int ttval = tt_get(hash)){ - value = ttval; - }else if(depth <= 2){ - value = lookup[board.test_win(*move)+3]; - - if(board.test_win(*move, 3 - board.toplay()) > 0) - losses++; - }else{ - Board next = board; - next.move(*move); - - value = -negamax(next, depth - 1, -b, -alpha); - - if(scout && value > alpha && value < beta && !first) // re-search - value = -negamax(next, depth - 1, -beta, -alpha); - } - tt_set(hash, value); - - if(value > alpha) - alpha = value; - - if(alpha >= beta) - return beta; - - if(scout){ - b = alpha + 1; // set up null window - first = false; - } - } - - if(losses >= 2) - return -2; - - return alpha; -} - -int SolverAB::negamax_outcome(const Board & board, const int depth){ - int abval = negamax(board, depth, -2, 2); - if( abval == 0) return -3; //unknown - else if(abval == 2) return board.toplay(); //win - else if(abval == -2) return 3 - board.toplay(); //loss - else return 0; //draw -} - -int SolverAB::tt_get(const Board & board){ - return tt_get(board.gethash()); -} -int SolverAB::tt_get(const hash_t & hash){ - if(!TT) return 0; - ABTTNode * node = & TT[hash % maxnodes]; - return (node->hash == hash ? node->value : 0); -} -void SolverAB::tt_set(const Board & board, int value){ - tt_set(board.gethash(), value); -} -void SolverAB::tt_set(const hash_t & hash, int value){ - if(!TT || value == 0) return; - ABTTNode * node = & TT[hash % maxnodes]; - node->hash = hash; - node->value = value; -} diff --git a/havannah/solverab.h b/havannah/solverab.h deleted file mode 100644 index 35ca7b9..0000000 --- a/havannah/solverab.h +++ /dev/null @@ -1,72 +0,0 @@ - -#pragma once - -//An Alpha-beta solver, single threaded with an optional transposition table. - -#include "solver.h" - -class SolverAB : public Solver { - struct ABTTNode { - hash_t hash; - char value; - ABTTNode(hash_t h = 0, char v = 0) : hash(h), value(v) { } - }; - -public: - bool scout; - int startdepth; - - ABTTNode * TT; - uint64_t maxnodes, memlimit; - - SolverAB(bool Scout = false) { - scout = Scout; - startdepth = 2; - TT = NULL; - set_memlimit(100*1024*1024); - } - ~SolverAB() { } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(ABTTNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void solve(double time); - -//return -2 for loss, -1,1 for tie, 0 for unknown, 2 for win, all from toplay's perspective - int negamax(const Board & board, const int depth, int alpha, int beta); - int negamax_outcome(const Board & board, const int depth); - - int tt_get(const hash_t & hash); - int tt_get(const Board & board); - void tt_set(const hash_t & hash, int val); - void tt_set(const Board & board, int val); -}; - diff --git a/havannah/solverpns.cpp b/havannah/solverpns.cpp deleted file mode 100644 index 7f11a1a..0000000 --- a/havannah/solverpns.cpp +++ /dev/null @@ -1,213 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns.h" - -void SolverPNS::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(memlimit/sizeof(PNSNode)) + ", max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS::run_pns(){ - while(!timeout && root.phi != 0 && root.delta != 0){ - if(!pns(rootboard, &root, 0, INF32/2, INF32/2)){ - logerr("Starting solver GC with limit " + to_str(gclimit) + " ... "); - - Time starttime; - garbage_collect(& root); - - Time gctime; - ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*ctmem.meminuse()/memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(ctmem.meminuse() >= memlimit/2) - gclimit = (unsigned int)(gclimit*1.3); - else if(gclimit > 5) - gclimit = (unsigned int)(gclimit*0.9); //slowly decay to a minimum of 5 - } - } -} - -bool SolverPNS::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(maxdepth < depth) - maxdepth = depth; - - if(node->children.empty()){ - if(ctmem.memalloced() >= memlimit) - return false; - - int numnodes = board.movesremain(); - nodes += node->alloc(numnodes, ctmem); - - if(lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(lbdist && outcome < 0) - pd = dists.get(*move); - - node->children[i] = PNSNode(*move).outcome(outcome, board.toplay(), ties, pd); - - i++; - } - node->children.shrink(i); //if symmetry, there may be extra moves to ignore - - nodes_seen += i; - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->delta <= child->delta){ - child2 = child; - child = i; - }else if(i->delta < child2->delta){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - while(child->delta != node->phi) - child++; - } - - Board next = board; - next.move(child->move); - - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->work += iters - itersbefore; - - if(child->phi == 0 || child->delta == 0) //clear child's children - nodes -= child->dealloc(ctmem); - - if(updatePDnum(node) && !df) - break; - - }while(!timeout && mem && (!df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - nodes -= child->dealloc(ctmem); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - nodes -= child->dealloc(ctmem); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/havannah/solverpns.h b/havannah/solverpns.h deleted file mode 100644 index b040d82..0000000 --- a/havannah/solverpns.h +++ /dev/null @@ -1,206 +0,0 @@ - -#pragma once - -//A single-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - uint32_t phi, delta; - uint64_t work; - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - uint64_t iters; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - - PNSNode root; - LBDists dists; - - SolverPNS() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - gclimit = 5; - iters = 0; - - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNS(){ - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - -//update the phi and delta for the node - bool updatePDnum(PNSNode * node); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/havannah/solverpns2.cpp b/havannah/solverpns2.cpp deleted file mode 100644 index 4995fc5..0000000 --- a/havannah/solverpns2.cpp +++ /dev/null @@ -1,323 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns2.h" - -void SolverPNS2::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - start_threads(); - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS2::timedout, this)); - Time start; - -// logerr("max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS2::SolverThread::run(){ - while(true){ - switch(solver->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_Start, Thread_Running); - CAS(solver->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(solver->root.terminal()){ //solved - CAS(solver->threadstate, Thread_Running, Thread_Wait_End); - break; - } - if(solver->ctmem.memalloced() >= solver->memlimit){ //out of memory, start garbage collection - CAS(solver->threadstate, Thread_Running, Thread_GC); - break; - } - - pns(solver->rootboard, &solver->root, 0, INF32/2, INF32/2); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(solver->gcbarrier.wait()){ - logerr("Starting solver GC with limit " + to_str(solver->gclimit) + " ... "); - - Time starttime; - solver->garbage_collect(& solver->root); - - Time gctime; - solver->ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*solver->ctmem.meminuse()/solver->memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(solver->ctmem.meminuse() >= solver->memlimit/2) - solver->gclimit = (unsigned int)(solver->gclimit*1.3); - else if(solver->gclimit > 5) - solver->gclimit = (unsigned int)(solver->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(solver->threadstate, Thread_GC, Thread_Running); - CAS(solver->threadstate, Thread_GC_End, Thread_Wait_End); - } - solver->gcbarrier.wait(); - break; - } - } -} - -void SolverPNS2::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); - timeout = true; -} - -string SolverPNS2::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void SolverPNS2::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void SolverPNS2::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void SolverPNS2::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new SolverThread(this)); -} - - -bool SolverPNS2::SolverThread::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(solver->maxdepth < depth) - solver->maxdepth = depth; - - if(node->children.empty()){ - if(node->terminal()) - return true; - - if(solver->ctmem.memalloced() >= solver->memlimit) - return false; - - if(!node->children.lock()) - return false; - - int numnodes = board.movesremain(); - CompactTree::Children temp; - temp.alloc(numnodes, solver->ctmem); - PLUS(solver->nodes, numnodes); - - if(solver->lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(solver->ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (solver->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - PLUS(solver->nodes_seen, pd); - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(solver->lbdist && outcome < 0) - pd = dists.get(*move); - - temp[i] = PNSNode(*move).outcome(outcome, board.toplay(), solver->ties, pd); - - i++; - } - temp.shrink(i); //if symmetry, there may be extra moves to ignore - node->children.swap(temp); - assert(temp.unlock()); - - PLUS(solver->nodes_seen, i); - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(solver->df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->refdelta() <= child->refdelta()){ - child2 = child; - child = i; - }else if(i->refdelta() < child2->refdelta()){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + solver->epsilon) + 1)); - }else{ - tpc = tdc = 0; - for(PNSNode * i = node->children.begin(); i != childend; i++) - if(child->refdelta() > i->refdelta()) - child = i; - } - - Board next = board; - next.move(child->move); - - child->ref(); - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->deref(); - PLUS(child->work, iters - itersbefore); - - if(updatePDnum(node) && !solver->df) - break; - - }while(!solver->timeout && mem && (!solver->df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS2::SolverThread::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS2::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/havannah/solverpns2.h b/havannah/solverpns2.h deleted file mode 100644 index 5af5d1d..0000000 --- a/havannah/solverpns2.h +++ /dev/null @@ -1,265 +0,0 @@ - -#pragma once - -//A multi-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS2 : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - static const uint16_t reflock = 1<<15; - uint32_t phi, delta; - uint64_t work; - uint16_t refcount; //how many threads are down this node - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - uint32_t refdelta() const { - return delta + refcount; - } - - void ref() { PLUS(refcount, 1); } - void deref(){ PLUS(refcount, -1); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - class SolverThread { - protected: - public: - Thread thread; - SolverPNS2 * solver; - public: - uint64_t iters; - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - - SolverThread(SolverPNS2 * s) : solver(s), iters(0) { - thread(bind(&SolverThread::run, this)); - } - virtual ~SolverThread() { } - void reset(){ - iters = 0; - } - int join(){ return thread.join(); } - void run(); //thread runner - - //basic proof number search building a tree - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - //update the phi and delta for the node - bool updatePDnum(PNSNode * node); - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - int numthreads; - - PNSNode root; - LBDists dists; - - SolverPNS2() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - numthreads = 1; - gclimit = 5; - - reset(); - - set_memlimit(100*1024*1024); - - //no threads started until a board is set - threadstate = Thread_Wait_Start; - } - - ~SolverPNS2(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - string statestring(); - void stop_threads(); - void start_threads(); - void reset_threads(); - void timedout(); - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - - reset_threads(); //needed since the threads aren't started before a board it set - } - void move(const Move & m){ - stop_threads(); - - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/havannah/solverpns_tt.cpp b/havannah/solverpns_tt.cpp deleted file mode 100644 index 0818e8c..0000000 --- a/havannah/solverpns_tt.cpp +++ /dev/null @@ -1,282 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns_tt.h" - -void SolverPNSTT::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNSTT::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(maxnodes) + ", max memory: " + to_str(memlimit) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == 0){ - bestmove = *move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == DRAW){ - bestmove = *move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNSTT::run_pns(){ - if(TT == NULL) - TT = new PNSNode[maxnodes]; - - while(!timeout && root.phi != 0 && root.delta != 0) - pns(rootboard, &root, 0, INF32/2, INF32/2); -} - -void SolverPNSTT::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - if(depth > maxdepth) - maxdepth = depth; - - do{ - PNSNode * child = NULL, - * child2 = NULL; - - Move move1, move2; - - uint32_t tpc, tdc; - - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - if(child == NULL){ - child = child2 = i; - move1 = move2 = *move; - }else if(i->delta <= child->delta){ - child2 = child; - child = i; - move2 = move1; - move1 = *move; - }else if(i->delta < child2->delta){ - child2 = i; - move2 = *move; - } - } - - if(child->delta && child->phi){ //unsolved - if(df){ - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - } - - Board next = board; - next.move(move1); - pns(next, child, depth + 1, tpc, tdc); - - //just found a loss, try to copy proof to siblings - if(copyproof && child->delta == LOSS){ -// logerr("!" + move1.to_s() + " "); - int count = abs(copyproof); - for(Board::MoveIterator move = board.moveit(true); count-- && !move.done(); ++move){ - if(!tt(board, *move)->terminal()){ -// logerr("?" + move->to_s() + " "); - Board sibling = board; - sibling.move(*move); - copy_proof(next, sibling, move1, *move); - updatePDnum(sibling); - - if(copyproof < 0 && !tt(sibling)->terminal()) - break; - } - } - } - } - - if(updatePDnum(board, node) && !df) //must pass node to updatePDnum since it may refer to the root which isn't in the TT - break; - - }while(!timeout && node->phi && node->delta && (!df || (node->phi < tp && node->delta < td))); -} - -bool SolverPNSTT::updatePDnum(const Board & board, PNSNode * node){ - hash_t hash = board.gethash(); - - if(node == NULL) - node = TT + (hash % maxnodes); - - uint32_t min = LOSS; - uint64_t sum = 0; - - bool win = false; - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(hash == node->hash && min == node->phi && sum == node->delta){ - return false; - }else{ - node->hash = hash; //just in case it was overwritten by something else - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//source is a move that is a proven loss, and dest is an unproven sibling -//each has one move that the other doesn't, which are stored in smove and dmove -//if either move is used but only available in one board, the other is substituted -void SolverPNSTT::copy_proof(const Board & source, const Board & dest, Move smove, Move dmove){ - if(timeout || tt(source)->delta != LOSS || tt(dest)->terminal()) - return; - - //find winning move from the source tree - Move bestmove = M_UNKNOWN; - for(Board::MoveIterator move = source.moveit(true); !move.done(); ++move){ - if(tt(source, *move)->phi == LOSS){ - bestmove = *move; - break; - } - } - - if(bestmove == M_UNKNOWN) //due to transposition table collision - return; - - Board dest2 = dest; - - if(bestmove == dmove){ - assert(dest2.move(smove)); - smove = dmove = M_UNKNOWN; - }else{ - assert(dest2.move(bestmove)); - if(bestmove == smove) - smove = dmove = M_UNKNOWN; - } - - if(tt(dest2)->terminal()) - return; - - Board source2 = source; - assert(source2.move(bestmove)); - - if(source2.won() >= 0) - return; - - //test all responses - for(Board::MoveIterator move = dest2.moveit(true); !move.done(); ++move){ - if(tt(dest2, *move)->terminal()) - continue; - - Move csmove = smove, cdmove = dmove; - - Board source3 = source2, dest3 = dest2; - - if(*move == csmove){ - assert(source3.move(cdmove)); - csmove = cdmove = M_UNKNOWN; - }else{ - assert(source3.move(*move)); - if(*move == csmove) - csmove = cdmove = M_UNKNOWN; - } - - assert(dest3.move(*move)); - - copy_proof(source3, dest3, csmove, cdmove); - - updatePDnum(dest3); - } - - updatePDnum(dest2); -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board){ - hash_t hash = board.gethash(); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - pd = 0; - outcome = (ab == 1 ? solve1ply(board, pd) : solve2ply(board, pd)); - nodes_seen += pd; - }else{ - outcome = board.won(); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board, Move move){ - hash_t hash = board.test_hash(move, board.toplay()); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(move); - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(move); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} diff --git a/havannah/solverpns_tt.h b/havannah/solverpns_tt.h deleted file mode 100644 index 95d344e..0000000 --- a/havannah/solverpns_tt.h +++ /dev/null @@ -1,129 +0,0 @@ - -#pragma once - -//A single-threaded, transposition table based, proof number search solver. - -#include "../lib/zobrist.h" - -#include "solver.h" - -class SolverPNSTT : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - hash_t hash; - uint32_t phi, delta; - - PNSNode() : hash(0), phi(0), delta(0) { } - PNSNode(hash_t h, int v = 1) : hash(h), phi(v), delta(v) { } - PNSNode(hash_t h, int p, int d) : hash(h), phi(p), delta(d) { } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } //unknown - else if(outcome == 2) { phi = LOSS; delta = 0; } //win - else if(outcome == -2) { phi = 0; delta = LOSS; } //loss - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } //draw - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - }; - - PNSNode root; - PNSNode * TT; - uint64_t maxnodes, memlimit; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - int copyproof; //how many siblings to try to copy a proof to - - - SolverPNSTT() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - copyproof = 0; - - TT = NULL; - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNSTT(){ - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - - root = PNSNode(rootboard.gethash(), 1); - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(PNSNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - void pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - void copy_proof(const Board & source, const Board & dest, Move smove, Move dmove); - -//update the phi and delta for the node - bool updatePDnum(const Board & board, PNSNode * node = NULL); - - PNSNode * tt(const Board & board); - PNSNode * tt(const Board & board, Move move); -}; diff --git a/hex/agent.h b/hex/agent.h index 6adecd2..917b769 100644 --- a/hex/agent.h +++ b/hex/agent.h @@ -3,11 +3,19 @@ //Interface for the various agents: players and solvers +#include "../lib/outcome.h" +#include "../lib/sgf.h" #include "../lib/types.h" #include "board.h" + +namespace Morat { +namespace Hex { + class Agent { +protected: + typedef std::vector vecmove; public: Agent() { } virtual ~Agent() { } @@ -19,51 +27,57 @@ class Agent { virtual void set_memlimit(uint64_t lim) = 0; // in bytes virtual void clear_mem() = 0; - virtual vector get_pv() const = 0; - string move_stats() const { return move_stats(vector()); } - virtual string move_stats(const vector moves) const = 0; + virtual vecmove get_pv() const = 0; + std::string move_stats() const { return move_stats(vecmove()); } + virtual std::string move_stats(const vecmove moves) const = 0; virtual double gamelen() const = 0; virtual void timedout(){ timeout = true; } + virtual void gen_sgf(SGFPrinter & sgf, int limit) const = 0; + virtual void load_sgf(SGFParser & sgf) = 0; + protected: volatile bool timeout; Board rootboard; - static int solve1ply(const Board & board, unsigned int & nodes) { - int outcome = -3; - int turn = board.toplay(); + static Outcome solve1ply(const Board & board, unsigned int & nodes) { + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; } return outcome; } - static int solve2ply(const Board & board, unsigned int & nodes) { + static Outcome solve2ply(const Board & board, unsigned int & nodes) { int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); + Side op = ~turn; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; - if(board.test_win(*move, opponent) > 0) + if(board.test_outcome(*move, op) == +op) losses++; } if(losses >= 2) - return opponent; + return (Outcome)op; return outcome; } - }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentab.cpp b/hex/agentab.cpp index 2c66bce..f76fd2d 100644 --- a/hex/agentab.cpp +++ b/hex/agentab.cpp @@ -6,6 +6,10 @@ #include "agentab.h" + +namespace Morat { +namespace Hex { + void AgentAB::search(double time, uint64_t maxiters, int verbose) { reset(); if(rootboard.won() >= 0) @@ -41,8 +45,8 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { if(verbose){ logerr("Finished: " + to_str(nodes_seen) + " nodes in " + to_str(time_used*1000, 0) + " msec: " + to_str((uint64_t)((double)nodes_seen/time_used)) + " Nodes/s\n"); - vector pv = get_pv(); - string pvstr; + vecmove pv = get_pv(); + std::string pvstr; for(auto m : pv) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); @@ -56,11 +60,11 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int depth) { nodes_seen++; - int won = board.won(); - if(won >= 0){ - if(won == 0) + Outcome won = board.won(); + if(won >= Outcome::DRAW){ + if(won == Outcome::DRAW) return SCORE_DRAW; - if(won == board.toplay()) + if(won == +board.toplay()) return SCORE_WIN; return SCORE_LOSS; } @@ -81,8 +85,8 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if(TT && (node = tt_get(board)) && node->depth >= depth){ switch(node->flag){ case VALID: return node->score; - case LBOUND: alpha = max(alpha, node->score); break; - case UBOUND: beta = min(beta, node->score); break; + case LBOUND: alpha = std::max(alpha, node->score); break; + case UBOUND: beta = std::min(beta, node->score); break; default: assert(false && "Unknown flag!"); } if(alpha >= beta) @@ -125,11 +129,11 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d return score; } -string AgentAB::move_stats(vector moves) const { - string s = ""; +std::string AgentAB::move_stats(vecmove moves) const { + std::string s = ""; Board b = rootboard; - for(vector::iterator m = moves.begin(); m != moves.end(); ++m) + for(vecmove::iterator m = moves.begin(); m != moves.end(); ++m) b.move(*m); for(MoveIterator move(b); !move.done(); ++move){ @@ -162,8 +166,8 @@ Move AgentAB::return_move(const Board & board, int verbose) const { return best; } -vector AgentAB::get_pv() const { - vector pv; +std::vector AgentAB::get_pv() const { + vecmove pv; Board b = rootboard; int i = 20; @@ -197,3 +201,6 @@ AgentAB::Node * AgentAB::tt_get(uint64_t h) const { void AgentAB::tt_set(const Node & n) { *(tt(n.hash)) = n; } + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentab.h b/hex/agentab.h index 646043f..4c90787 100644 --- a/hex/agentab.h +++ b/hex/agentab.h @@ -7,6 +7,10 @@ #include "agent.h" + +namespace Morat { +namespace Hex { + class AgentAB : public Agent { static const int16_t SCORE_WIN = 32767; static const int16_t SCORE_LOSS = -32767; @@ -30,7 +34,7 @@ class AgentAB : public Agent { Node(uint64_t h = ~0ull, int16_t s = 0, Move b = M_UNKNOWN, int8_t d = 0, int8_t f = 0) : //. int8_t o = -3 hash(h), score(s), bestmove(b), depth(d), flag(f), padding(0xDEAD) { } //, outcome(o) - string to_s() const { + std::string to_s() const { return "score " + to_str(score) + ", depth " + to_str((int)depth) + ", flag " + to_str((int)flag) + @@ -93,8 +97,16 @@ class AgentAB : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(rootboard, verbose); } double gamelen() const { return rootboard.movesremain(); } - vector get_pv() const; - string move_stats(vector moves) const; + vecmove get_pv() const; + std::string move_stats(vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + log("gen_sgf not supported in the ab agent."); + } + + void load_sgf(SGFParser & sgf) { + log("load_sgf not supported in the ab agent."); + } private: int16_t negamax(const Board & board, int16_t alpha, int16_t beta, int depth); @@ -105,3 +117,6 @@ class AgentAB : public Agent { Node * tt_get(const Board & b) const ; void tt_set(const Node & n) ; }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentmcts.cpp b/hex/agentmcts.cpp index 6f4822c..80d6b6a 100644 --- a/hex/agentmcts.cpp +++ b/hex/agentmcts.cpp @@ -10,12 +10,45 @@ #include "agentmcts.h" #include "board.h" + +namespace Morat { +namespace Hex { + const float AgentMCTS::min_rave = 0.1; +std::string AgentMCTS::Node::to_s() const { + return "AgentMCTS::Node" + ", move " + move.to_s() + + ", exp " + exp.to_s() + + ", rave " + rave.to_s() + + ", know " + to_str(know) + + ", outcome " + to_str((int)outcome.to_i()) + + ", depth " + to_str((int)proofdepth) + + ", best " + bestmove.to_s() + + ", children " + to_str(children.num()); +} + +bool AgentMCTS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 9){ + move = Move(dict["move"]); + exp = ExpPair(dict["exp"]); + rave = ExpPair(dict["rave"]); + know = from_str(dict["know"]); + outcome = Outcome(from_str(dict["outcome"])); + proofdepth = from_str(dict["depth"]); + bestmove = Move(dict["best"]); + // ignore children + return true; + } + return false; +} + void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) + if(rootboard.won() >= Outcome::DRAW || (time <= 0 && max_runs == 0)) return; Time starttime; @@ -56,30 +89,23 @@ void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ logerr("Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"); } - if(root.outcome != -3){ - logerr("Solved as a "); - if( root.outcome == 0) logerr("draw\n"); - else if(root.outcome == 3) logerr("draw by simultaneous win\n"); - else if(root.outcome == toplay) logerr("win\n"); - else if(root.outcome == 3-toplay) logerr("loss\n"); - else if(root.outcome == -toplay) logerr("win or draw\n"); - else if(root.outcome == toplay-3) logerr("loss or draw\n"); - } + if(root.outcome != Outcome::UNKNOWN) + logerr("Solved as a " + root.outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } pool.reset(); runs = 0; - if(ponder && root.outcome < 0) + if(ponder && root.outcome < Outcome::DRAW) pool.resume(); } @@ -194,8 +220,8 @@ void AgentMCTS::move(const Move & m){ rootboard.move(m); root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; + if(rootboard.won() < Outcome::DRAW) + root.outcome = Outcome::UNKNOWN; if(ponder) pool.resume(); @@ -208,16 +234,16 @@ double AgentMCTS::gamelen() const { return len.avg(); } -vector AgentMCTS::get_pv() const { - vector pv; +std::vector AgentMCTS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -226,8 +252,8 @@ vector AgentMCTS::get_pv() const { return pv; } -string AgentMCTS::move_stats(vector moves) const { - string s = ""; +std::string AgentMCTS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -248,8 +274,8 @@ string AgentMCTS::move_stats(vector moves) const { return s; } -Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { - if(node->outcome >= 0) +Move AgentMCTS::return_move(const Node * node, Side toplay, int verbose) const { + if(node->outcome >= Outcome::DRAW) return node->bestmove; double val, maxval = -1000000000000.0; //1 trillion @@ -259,10 +285,10 @@ Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss + if(child->outcome >= Outcome::DRAW){ + if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win + else if(child->outcome == Outcome::DRAW) val = -400000000000.0 + child->exp.num(); //longest tie + else val = -800000000000.0 + child->exp.num(); //longest loss }else{ //not proven if(msrave == -1) //num simulations val = child->exp.num(); @@ -290,13 +316,13 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ Node * child = node->children.begin(), * end = node->children.end(); - int toplay = board.toplay(); + Side toplay = board.toplay(); for( ; child != end; child++){ if(child->children.num() == 0) continue; - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved + if( (node->outcome >= Outcome::DRAW && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == Outcome::DRAW)) || //parent is solved, only keep the proof tree, plus heavy draws + (node->outcome < Outcome::DRAW && child->exp.num() > (child->outcome >= Outcome::DRAW ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved board.set(child->move); garbage_collect(board, child); board.unset(child->move); @@ -307,36 +333,22 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ } AgentMCTS::Node * AgentMCTS::find_child(const Node * node, const Move & move) const { - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - + for(auto & c : node->children) + if(c.move == move) + return &c; return NULL; } -void AgentMCTS::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; +void AgentMCTS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.exp.num() >= limit && (side != node.outcome || child.outcome == node.outcome)){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); } } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); } void AgentMCTS::create_children_simple(const Board & board, Node * node){ @@ -361,64 +373,25 @@ void AgentMCTS::create_children_simple(const Board & board, Node * node){ PLUS(nodes, node->children.num()); } -//reads the format from gen_hgf. -void AgentMCTS::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); +void AgentMCTS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, & node); - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); } } - - eat_char(fd, ')'); - - return; } + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentmcts.h b/hex/agentmcts.h index 2da03fc..61776ef 100644 --- a/hex/agentmcts.h +++ b/hex/agentmcts.h @@ -11,6 +11,12 @@ #include "../lib/depthstats.h" #include "../lib/exppair.h" #include "../lib/log.h" +#include "../lib/move.h" +#include "../lib/movelist.h" +#include "../lib/policy_bridge.h" +#include "../lib/policy_instantwin.h" +#include "../lib/policy_lastgoodreply.h" +#include "../lib/policy_random.h" #include "../lib/thread.h" #include "../lib/time.h" #include "../lib/types.h" @@ -19,14 +25,11 @@ #include "agent.h" #include "board.h" #include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" +namespace Morat { +namespace Hex { + class AgentMCTS : public Agent{ public: @@ -35,7 +38,7 @@ class AgentMCTS : public Agent{ ExpPair rave; ExpPair exp; int16_t know; - int8_t outcome; + Outcome outcome; uint8_t proofdepth; Move move; Move bestmove; //if outcome is set, then bestmove is the way to get there @@ -44,8 +47,8 @@ class AgentMCTS : public Agent{ //seems to need padding to multiples of 8 bytes or it segfaults? //don't forget to update the copy constructor/operator - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } + Node() : know(0), outcome(Outcome::UNKNOWN), proofdepth(0), move(M_NONE) { } + Node(const Move & m, Outcome o = Outcome::UNKNOWN) : know(0), outcome(o), proofdepth(0), move(m) { } Node(const Node & n) { *this = n; } Node & operator = (const Node & n){ if(this != & n){ //don't copy to self @@ -68,18 +71,8 @@ class AgentMCTS : public Agent{ children.swap(n.children); } - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); unsigned int size() const { unsigned int num = children.num(); @@ -142,16 +135,16 @@ class AgentMCTS : public Agent{ class AgentThread : public AgentThreadBase { mutable XORShift_float unitrand; - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; + LastGoodReply last_good_reply; + RandomPolicy random_policy; + ProtectBridge protect_bridge; + InstantWin instant_wins; bool use_rave; //whether to use rave for this simulation bool use_explore; //whether to use exploration for this simulation LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; + MoveList movelist; int stage; //which of the four MCTS stages is it on public: @@ -179,11 +172,11 @@ class AgentMCTS : public Agent{ void walk_tree(Board & board, Node * node, int depth); bool create_children(const Board & board, Node * node); void add_knowledge(const Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); + Node * choose_move(const Node * node, Side toplay, int remain) const; + void update_rave(const Node * node, Side toplay); bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - int rollout(Board & board, Move move, int depth); + Outcome rollout(Board & board, Move move, int depth); Move rollout_choose_move(Board & board, const Move & prev); Move rollout_pattern(const Board & board, const Move & move); }; @@ -261,12 +254,12 @@ class AgentMCTS : public Agent{ Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; bool done() { //solved or finished runs - return (rootboard.won() >= 0 || root.outcome >= 0 || (maxruns > 0 && runs >= maxruns)); + return (rootboard.won() >= Outcome::DRAW || root.outcome >= Outcome::DRAW || (maxruns > 0 && runs >= maxruns)); } bool need_gc() { @@ -292,16 +285,28 @@ class AgentMCTS : public Agent{ gclimit = (int)(gclimit*0.9); //slowly decay to a minimum of 5 } + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0) + limit = root.exp.num()/1000; + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } protected: void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - bool do_backup(Node * node, Node * backup, int toplay); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + bool do_backup(Node * node, Node * backup, Side toplay); + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const ; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentmcts_test.cpp b/hex/agentmcts_test.cpp new file mode 100644 index 0000000..03a5869 --- /dev/null +++ b/hex/agentmcts_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentmcts.h" + + +using namespace Morat; +using namespace Hex; + +TEST_CASE("Hex::AgentMCTS::Node::to_s/from_s", "[hex][agentmcts]") { + AgentMCTS::Node n(Move("a1")); + auto s = n.to_s(); + AgentMCTS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/hex/agentmctsthread.cpp b/hex/agentmctsthread.cpp index 8231d5f..e0ecfa0 100644 --- a/hex/agentmctsthread.cpp +++ b/hex/agentmctsthread.cpp @@ -6,6 +6,10 @@ #include "agentmcts.h" + +namespace Morat { +namespace Hex { + void AgentMCTS::AgentThread::iterate(){ INCR(agent->runs); if(agent->profile){ @@ -19,7 +23,7 @@ void AgentMCTS::AgentThread::iterate(){ use_rave = (unitrand() < agent->userave); use_explore = (unitrand() < agent->useexplore); walk_tree(copy, & agent->root, 0); - agent->root.exp.addv(movelist.getexp(3-agent->rootboard.toplay())); + agent->root.exp.addv(movelist.getexp(~agent->rootboard.toplay())); if(agent->profile){ times[0] += timestamps[1] - timestamps[0]; @@ -30,16 +34,16 @@ void AgentMCTS::AgentThread::iterate(){ } void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); + Side toplay = board.toplay(); - if(!node->children.empty() && node->outcome < 0){ + if(!node->children.empty() && node->outcome < Outcome::DRAW){ //choose a child and recurse Node * child; do{ int remain = board.movesremain(); child = choose_move(node, toplay, remain); - if(child->outcome < 0){ + if(child->outcome < Outcome::DRAW){ movelist.addtree(child->move, toplay); if(!board.move(child->move)){ @@ -71,10 +75,10 @@ void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ timestamps[1] = Time(); } - int won = (agent->minimax ? node->outcome : board.won()); + Outcome won = (agent->minimax ? node->outcome : board.won()); //if it's not already decided - if(won < 0){ + if(won < Outcome::DRAW){ //create children if valid if(node->exp.num() >= agent->visitexpand+1 && create_children(board, node)){ walk_tree(board, node, depth); @@ -125,6 +129,8 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ CompactTree::Children temp; temp.alloc(board.movesremain(), agent->ctmem); + Side toplay = board.toplay(); + Side opponent = ~toplay; int losses = 0; Node * child = temp.begin(), @@ -136,14 +142,14 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ *child = Node(*move); if(agent->minimax){ - child->outcome = board.test_win(*move); + child->outcome = board.test_outcome(*move); - if(agent->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ + if(agent->minimax >= 2 && board.test_outcome(*move, opponent) == +opponent){ losses++; loss = child; } - if(child->outcome == board.toplay()){ //proven win from here, don't need children + if(child->outcome == +toplay){ //proven win from here, don't need children node->outcome = child->outcome; node->proofdepth = 1; node->bestmove = *move; @@ -171,7 +177,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ macro.exp.addwins(agent->visitexpand); *(temp.begin()) = macro; }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - board.toplay(); + node->outcome = +opponent; node->proofdepth = 2; node->bestmove = loss->move; node->children.unlock(); @@ -180,7 +186,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ } if(agent->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); + std::sort(temp.begin(), temp.end(), sort_node_know); PLUS(agent->nodes, temp.num()); node->children.swap(temp); @@ -189,7 +195,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ return true; } -AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int toplay, int remain) const { +AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, Side toplay, int remain) const { float val, maxval = -1000000000; float logvisits = log(node->exp.num()); int dynwidenlim = (agent->dynwiden > 0 ? (int)(logvisits/agent->logdynwiden)+2 : Board::max_vecsize); @@ -204,11 +210,11 @@ AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int top * end = node->children.end(); for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ + if(child->outcome >= Outcome::DRAW){ if(child->outcome == toplay) //return a win immediately return child; - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse + val = (child->outcome == Outcome::DRAW ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse }else{ val = child->value(raveval, agent->knowledge, agent->fpurgency); if(explore > 0) @@ -237,80 +243,80 @@ backup in this order: 0 lose return true if fully solved, false if it's unknown or partially unknown */ -bool AgentMCTS::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread +bool AgentMCTS::do_backup(Node * node, Node * backup, Side toplay){ + Outcome node_outcome = node->outcome; + if(node_outcome >= Outcome::DRAW) //already proven, probably by a different thread return true; - if(backup->outcome == -3) //nothing proven by this child, so no chance + if(backup->outcome == Outcome::UNKNOWN) //nothing proven by this child, so no chance return false; uint8_t proofdepth = backup->proofdepth; if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; + uint64_t sims = 0, bestsims = 0, outcome = 0, best_outcome = 0; backup = NULL; Node * child = node->children.begin(), * end = node->children.end(); for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions + Outcome child_outcome = child->outcome; //save a copy to avoid race conditions if(proofdepth < child->proofdepth+1) proofdepth = child->proofdepth+1; //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss + if(child_outcome == Outcome::UNKNOWN){ // win/draw/loss outcome = 3; - }else if(childoutcome == toplay){ //win + }else if(child_outcome == toplay){ //win backup = child; outcome = 6; proofdepth = child->proofdepth+1; break; - }else if(childoutcome == 3-toplay){ //loss + }else if(child_outcome == ~toplay){ //loss outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss + }else if(child_outcome == Outcome::DRAW){ //draw + if(node_outcome == -toplay) //draw/loss, ie I can't win outcome = 4; else outcome = 2; - }else if(childoutcome == -toplay){ //win/draw + }else if(child_outcome == -~toplay){ //win/draw, ie opponent can't win outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss + }else if(child_outcome == -toplay){ //draw/loss, ie I can't win outcome = 1; }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); + logerr("child_outcome == " + child_outcome.to_s() + "\n"); assert(false && "How'd I get here? All outcomes should be tested above"); } sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; + if(best_outcome < outcome){ //better outcome is always preferable + best_outcome = outcome; bestsims = sims; backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ + }else if(best_outcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ //find long losses or easy wins/draws bestsims = sims; backup = child; } } - if(bestoutcome == 3) //no win, but found an unknown + if(best_outcome == 3) //no win, but found an unknown return false; } - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ + if(node->outcome.cas(node_outcome, backup->outcome)){ node->bestmove = backup->move; node->proofdepth = proofdepth; }else //if it was in a race, try again, might promote a partial solve to full solve return do_backup(node, backup, toplay); - return (node->outcome >= 0); + return (node->outcome >= Outcome::DRAW); } //update the rave score of all children that were played -void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ +void AgentMCTS::AgentThread::update_rave(const Node * node, Side toplay){ Node * child = node->children.begin(), * childend = node->children.end(); @@ -321,7 +327,7 @@ void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Node * child){ if(agent->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); + int dist = board.dist(node->move, child->move); if(dist < 4) child->know += agent->localreply * (4 - dist); } @@ -343,24 +349,24 @@ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Nod child->know += agent->bridge; if(agent->dists) - child->know += abs(agent->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); + child->know += abs(agent->dists) * std::max(0, board.get_size() - dists.get(child->move, board.toplay())); } //test whether this move is a forced reply to the opponent probing your virtual connections bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) + if(board.dist(move, test) != 1) return false; bool equals = false; int state = 0; - int piece = 3 - board.get(move); + Side piece = ~board.get(move); for(int i = 0; i < 8; i++){ Move cur = move + neighbours[i % 6]; bool on = board.onboard(cur); - int v = 0; + Side v = Side::NONE; if(on) v = board.get(cur); @@ -371,7 +377,7 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //else state = 0; }else if(state == 1){ if(on){ - if(v == 0){ + if(v == Side::NONE){ state = 2; equals = (test == cur); }else if(v != piece) @@ -396,16 +402,16 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //play a random game starting from a board state, and return the results of who won -int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ - int won; +Outcome AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ + Outcome won; if(agent->instantwin) instant_wins.rollout_start(board, agent->instantwin); random_policy.rollout_start(board); - while((won = board.won()) < 0){ - int turn = board.toplay(); + while((won = board.won()) < Outcome::DRAW){ + Side turn = board.toplay(); move = rollout_choose_move(board, move); @@ -449,3 +455,6 @@ Move AgentMCTS::AgentThread::rollout_choose_move(Board & board, const Move & pre return random_policy.choose_move(board, prev); } + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentpns.cpp b/hex/agentpns.cpp index ec270ff..8222101 100644 --- a/hex/agentpns.cpp +++ b/hex/agentpns.cpp @@ -5,6 +5,40 @@ #include "agentpns.h" + +namespace Morat { +namespace Hex { + +std::string AgentPNS::Node::to_s() const { + return "AgentPNS::Node" + ", move " + move.to_s() + + ", phi " + to_str(phi) + + ", delta " + to_str(delta) + + ", work " + to_str(work) + + ", children " + to_str(children.num()); +} + +bool AgentPNS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 6){ + move = Move(dict["move"]); + phi = from_str(dict["phi"]); + delta = from_str(dict["delta"]); + work = from_str(dict["work"]); + // ignore children + return true; + } + return false; +} + +void AgentPNS::test() { + Node n(Move("a1")); + auto s = n.to_s(); + Node k; + assert(k.from_s(s)); +} + void AgentPNS::search(double time, uint64_t maxiters, int verbose){ max_nodes_seen = maxiters; @@ -32,27 +66,20 @@ void AgentPNS::search(double time, uint64_t maxiters, int verbose){ logerr("Tree depth: " + treelen.to_s() + "\n"); } - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); logerr("Root: " + root.to_s() + "\n"); - int outcome = root.to_outcome(3-toplay); - if(outcome != -3){ - logerr("Solved as a "); - if( outcome == 0) logerr("draw\n"); - else if(outcome == 3) logerr("draw by simultaneous win\n"); - else if(outcome == toplay) logerr("win\n"); - else if(outcome == 3-toplay) logerr("loss\n"); - else if(outcome == -toplay) logerr("win or draw\n"); - else if(outcome == toplay-3) logerr("loss or draw\n"); - } + Outcome outcome = root.to_outcome(~toplay); + if(outcome != Outcome::UNKNOWN) + logerr("Solved as a " + outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } } @@ -83,8 +110,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin unsigned int i = 0; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - unsigned int pd = 1; - int outcome; + unsigned int pd; + Outcome outcome; if(agent->ab){ Board next = board; @@ -94,10 +121,10 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin outcome = (agent->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); }else{ pd = 1; - outcome = board.test_win(*move); + outcome = board.test_outcome(*move); } - if(agent->lbdist && outcome < 0) + if(agent->lbdist && outcome != Outcome::UNKNOWN) pd = dists.get(*move); temp[i] = Node(*move).outcome(outcome, board.toplay(), agent->ties, pd); @@ -132,8 +159,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin } } - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); + tpc = std::min(INF32/2, (td + child->phi - node->delta)); + tdc = std::min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); }else{ tpc = tdc = 0; for(auto & i : node->children) @@ -198,16 +225,16 @@ double AgentPNS::gamelen() const { return rootboard.movesremain(); } -vector AgentPNS::get_pv() const { - vector pv; +std::vector AgentPNS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -216,8 +243,8 @@ vector AgentPNS::get_pv() const { return pv; } -string AgentPNS::move_stats(vector moves) const { - string s = ""; +std::string AgentPNS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -238,7 +265,7 @@ string AgentPNS::move_stats(vector moves) const { return s; } -Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { +Move AgentPNS::return_move(const Node * node, Side toplay, int verbose) const { double val, maxval = -1000000000000.0; //1 trillion Node * ret = NULL, @@ -246,11 +273,11 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - int outcome = child->to_outcome(toplay); - if(outcome >= 0){ - if(outcome == toplay) val = 800000000000.0 - (double)child->work; //shortest win - else if(outcome == 0) val = -400000000000.0 + (double)child->work; //longest tie - else val = -800000000000.0 + (double)child->work; //longest loss + Outcome outcome = child->to_outcome(toplay); + if(outcome >= Outcome::DRAW){ + if( outcome == +toplay) val = 800000000000.0 - (double)child->work; //shortest win + else if(outcome == Outcome::DRAW) val = -400000000000.0 + (double)child->work; //longest tie + else val = -800000000000.0 + (double)child->work; //longest loss }else{ //not proven val = child->work; } @@ -290,3 +317,51 @@ void AgentPNS::garbage_collect(Node * node){ } } } + +void AgentPNS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + node->children.alloc(board.movesremain(), ctmem); + unsigned int i = 0; + for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ + Outcome outcome = board.test_outcome(*move); + node->children[i] = Node(*move).outcome(outcome, board.toplay(), ties, 1); + i++; + } + PLUS(nodes, i); + node->children.shrink(i); //if symmetry, there may be extra moves to ignore +} + +void AgentPNS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.work >= limit && (side != node.to_outcome(~side) || child.to_outcome(side) == node.to_outcome(~side))){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentPNS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, &node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentpns.h b/hex/agentpns.h index ad33042..89098ee 100644 --- a/hex/agentpns.h +++ b/hex/agentpns.h @@ -3,15 +3,21 @@ //A multi-threaded, tree based, proof number search solver. +#include + #include "../lib/agentpool.h" #include "../lib/compacttree.h" #include "../lib/depthstats.h" #include "../lib/log.h" +#include "../lib/string.h" #include "agent.h" #include "lbdist.h" +namespace Morat { +namespace Hex { + class AgentPNS : public Agent { static const uint32_t LOSS = (1<<30)-1; static const uint32_t DRAW = (1<<30)-2; @@ -51,33 +57,33 @@ class AgentPNS : public Agent { assert(children.empty()); } - Node & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); + Node & abval(int ab_outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && (ab_outcome == 1 || ab_outcome == -1)) + ab_outcome = (toplay == assign ? 2 : -2); - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } + if( ab_outcome == 0) { phi = value; delta = value; } + else if(ab_outcome == 2) { phi = LOSS; delta = 0; } + else if(ab_outcome == -2) { phi = 0; delta = LOSS; } + else /*(ab_outcome 1||-1)*/ { phi = 0; delta = DRAW; } return *this; } - Node & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; + Node & outcome(Outcome outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && outcome == Outcome::DRAW) + outcome = +assign; - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } + if( outcome == Outcome::UNKNOWN) { phi = value; delta = value; } + else if(outcome == +toplay) { phi = LOSS; delta = 0; } + else if(outcome == +~toplay) { phi = 0; delta = LOSS; } + else /*(outcome == Outcome::DRAW)*/ { phi = 0; delta = DRAW; } return *this; } - int to_outcome(int toplay) const { - if(phi == LOSS) return toplay; - if(delta == LOSS) return 3 - toplay; - if(delta == DRAW) return 0; - return -3; + Outcome to_outcome(Side toplay) const { + if(phi == LOSS) return +toplay; + if(delta == LOSS) return +~toplay; + if(delta == DRAW) return Outcome::DRAW; + return Outcome::UNKNOWN; } bool terminal(){ return (phi == 0 || delta == 0); } @@ -98,15 +104,8 @@ class AgentPNS : public Agent { return num; } - string to_s() const { - return "Node: move " + move.to_s() + - ", phi " + to_str(phi) + - ", delta " + to_str(delta) + - ", work " + to_str(work) + -// ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + -// ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); void swap_tree(Node & n){ children.swap(n.children); @@ -162,7 +161,7 @@ class AgentPNS : public Agent { int ab; // how deep of an alpha-beta search to run at each leaf node bool df; // go depth first? float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 + Side ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 bool lbdist; int numthreads; @@ -172,7 +171,7 @@ class AgentPNS : public Agent { ab = 2; df = true; epsilon = 0.25; - ties = 0; + ties = Side::NONE; lbdist = false; numthreads = 1; pool.set_num_threads(numthreads); @@ -228,7 +227,7 @@ class AgentPNS : public Agent { root.swap_tree(child); if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr(std::string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); assert(nodes == root.size()); @@ -280,12 +279,36 @@ class AgentPNS : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0){ + limit = 0; + //TODO: Set the root.work properly + for(auto & child : root.children) + limit += child.work; + limit /= 1000; + } + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } + + static void test(); private: //remove all the nodes with little work to free up some memory void garbage_collect(Node * node); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/agentpns_test.cpp b/hex/agentpns_test.cpp new file mode 100644 index 0000000..ab95efe --- /dev/null +++ b/hex/agentpns_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentpns.h" + + +using namespace Morat; +using namespace Hex; + +TEST_CASE("Hex::AgentPNS::Node::to_s/from_s", "[hex][agentpns]") { + AgentPNS::Node n(Move("a1")); + auto s = n.to_s(); + AgentPNS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/hex/board.cpp b/hex/board.cpp new file mode 100644 index 0000000..56565cd --- /dev/null +++ b/hex/board.cpp @@ -0,0 +1,73 @@ + +#include "board.h" + +namespace Morat { +namespace Hex { + +std::string Board::Cell::to_s(int i) const { + return "Cell " + to_str(i) +": " + "piece: " + to_str(piece.to_i())+ + ", size: " + to_str((int)size) + + ", parent: " + to_str((int)parent) + + ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + + ", perm: " + to_str((int)perm) + + ", pattern: " + to_str((int)pattern); +} + +std::string Board::to_s(bool color) const { + using std::string; + string white = "O", + black = "@", + empty = ".", + coord = "", + reset = ""; + if(color){ + string esc = "\033"; + reset = esc + "[0m"; + coord = esc + "[1;37m"; + empty = reset + "."; + white = esc + "[1;33m" + "@"; //yellow + black = esc + "[1;34m" + "@"; //blue + } + + string s; + for(int i = 0; i < size; i++) + s += " " + coord + to_str(i+1); + s += "\n"; + + for(int y = 0; y < size; y++){ + s += string(y, ' '); + s += coord + char('A' + y); + int end = lineend(y); + for(int x = 0; x < end; x++){ + s += (last == Move(x, y) ? coord + "[" : + last == Move(x-1, y) ? coord + "]" : " "); + Side p = get(x, y); + if( p == Side::NONE) s += empty; + else if(p == Side::P1) s += white; + else if(p == Side::P2) s += black; + else s += "?"; + } + s += (last == Move(end-1, y) ? coord + "]" : " "); + s += white + reset; + s += '\n'; + } + s += string(size + 2, ' '); + for(int i = 0; i < size; i++) + s += black + " "; + s += "\n"; + + s += reset; + return s; +} + +int Board::edges(int x, int y) const { + return (x == 0 ? 1 : 0) | + (x == sizem1 ? 2 : 0) | + (y == 0 ? 4 : 0) | + (y == sizem1 ? 8 : 0); +} + + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/board.h b/hex/board.h index f9e9487..ab469c9 100644 --- a/hex/board.h +++ b/hex/board.h @@ -4,17 +4,21 @@ #include #include #include +#include #include #include +#include "../lib/bitcount.h" #include "../lib/hashset.h" +#include "../lib/move.h" +#include "../lib/outcome.h" #include "../lib/string.h" #include "../lib/types.h" #include "../lib/zobrist.h" -#include "move.h" -using namespace std; +namespace Morat { +namespace Hex { /* * the board is represented as a flattened 2d array of the form: @@ -48,35 +52,31 @@ static MoveValid * staticneighbourlist[17] = { class Board{ public: + static constexpr const char * const name = "hex"; static const int default_size = 8; static const int min_size = 3; static const int max_size = 16; static const int max_vecsize = max_size * max_size; + static const int num_win_types = 1; static const int pattern_cells = 18; typedef uint64_t Pattern; struct Cell { - uint16_t piece; //who controls this cell, 0 for none, 1,2 for players + Side piece; //who controls this cell, 0 for none, 1,2 for players uint16_t size; //size of this group of cells -mutable uint16_t parent; //parent for this group of cells. 8 bits limits board size to 16 until it's no longer stored as a square +mutable uint16_t parent; //parent for this group of cells uint8_t edge; //which edges are this group connected to uint8_t perm; //is this a permanent piece or a randomly placed piece? Pattern pattern; //the pattern of pieces for neighbours, but from their perspective. Rotate 180 for my perpective - Cell() : piece(73), size(0), parent(0), edge(0), perm(0), pattern(0) { } - Cell(unsigned int p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : + Cell() : piece(Side::NONE), size(0), parent(0), edge(0), perm(0), pattern(0) { } + Cell(Side p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : piece(p), size(s), parent(a), edge(e), perm(0), pattern(t) { } - string to_s(int i) const { - return "Cell " + to_str((int)i) +": " - "piece: " + to_str((int)piece)+ - ", size: " + to_str((int)size) + - ", parent: " + to_str((int)parent) + - ", edge: " + to_str((int)edge) + - ", perm: " + to_str((int)perm) + - ", pattern: " + to_str((int)pattern); - } + int numedges() const { return BitsSetTable256[edge]; } + + std::string to_s(int i) const; }; class MoveIterator { //only returns valid moves... @@ -87,7 +87,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board HashSet hashes; public: MoveIterator(const Board & b, bool Unique) : board(b), lineend(0), move(Move(M_SWAP), -1), unique(Unique) { - if(board.outcome >= 0){ + if(board.outcome >= Outcome::DRAW){ move = MoveValid(0, board.size, -1); //already done } else { if(unique) @@ -113,9 +113,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board move.xy = -1; return *this; } - - move.x = 0; - move.xy = move.y * board.size; + move.x = board.linestart(move.y); + move.xy = board.xy(move.x, move.y); lineend = board.lineend(move.y); } }while(!board.valid_move_fast(move)); @@ -139,10 +138,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board short nummoves; short unique_depth; //update and test rotations/symmetry with less than this many pieces on the board Move last; - char toPlay; - char outcome; //-3 = unknown, 0 = tie, 1,2 = player win + Side toPlay; + Outcome outcome; - vector cells; + std::vector cells; Zobrist<6> hash; const MoveValid * neighbourlist; @@ -157,15 +156,15 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board last = M_NONE; nummoves = 0; unique_depth = 5; - toPlay = 1; - outcome = -3; + toPlay = Side::P1; + outcome = Outcome::UNKNOWN; neighbourlist = get_neighbour_list(); num_cells = vecsize(); cells.resize(vecsize()); for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ + for(int x = 0; x < size; x++){ int posxy = xy(x, y); Pattern p = 0, j = 3; for(const MoveValid * i = nb_begin(posxy), *e = nb_end_big_hood(i); i < e; i++){ @@ -173,7 +172,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board p |= j; j <<= 2; } - cells[posxy] = Cell(0, posxy, 1, edges(x, y), pattern_reverse(p)); + Side s = (onboard(x, y) ? Side::NONE : Side::UNDEF); + cells[posxy] = Cell(s, posxy, 1, edges(x, y), pattern_reverse(p)); } } } @@ -190,7 +190,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board int numcells() const { return num_cells; } int num_moves() const { return nummoves; } - int movesremain() const { return (won() >= 0 ? 0 : num_cells - nummoves); } + int movesremain() const { return (won() >= Outcome::DRAW ? 0 : num_cells - nummoves); } int xy(int x, int y) const { return y*size + x; } int xy(const Move & m) const { return m.y*size + m.x; } @@ -198,6 +198,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board MoveValid yx(int i) const { return MoveValid(i % size, i / size, i); } + int dist(const Move & a, const Move & b) const { + return (abs(a.x - b.x) + abs(a.y - b.y) + abs((a.x + a.y) - (b.x + b.y)) )/2; + } + const Cell * cell(int i) const { return & cells[i]; } const Cell * cell(int x, int y) const { return cell(xy(x,y)); } const Cell * cell(const Move & m) const { return cell(xy(m)); } @@ -205,40 +209,41 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board //assumes valid x,y - int get(int i) const { return cells[i].piece; } - int get(int x, int y) const { return get(xy(x, y)); } - int get(const Move & m) const { return get(xy(m)); } - int get(const MoveValid & m) const { return get(m.xy); } + Side get(int i) const { return cells[i].piece; } + Side get(int x, int y) const { return get(xy(x, y)); } + Side get(const Move & m) const { return get(xy(m)); } + Side get(const MoveValid & m) const { return get(m.xy); } - int geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : 0); } + Side geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : Side::UNDEF); } - int local(const Move & m, char turn) const { return local(xy(m), turn); } - int local(int i, char turn) const { + int local(const Move & m, Side turn) const { return local(xy(m), turn); } + int local(int i, Side turn) const { Pattern p = pattern(i); Pattern x = ((p & 0xAAAAAAAAAull) >> 1) ^ (p & 0x555555555ull); // p1 is now when p1 or p2 but not both (ie off the board) - p = x & (turn == 1 ? p : p >> 1); // now just the selected player + p = x & (turn == Side::P1 ? p : p >> 1); // now just the selected player return (p & 0x000000FFF ? 3 : 0) | (p & 0x000FFF000 ? 2 : 0) | (p & 0xFFF000000 ? 1 : 0); } - //assumes x, y are in array bounds - bool onboard_fast(int x, int y) const { return ( y < size && x < size); } - bool onboard_fast(const Move & m) const { return (m.y < size && m.x < size); } + //assumes x, y are in array bounds, and all moves within array bounds are valid + bool onboard_fast(int x, int y) const { return true; } + bool onboard_fast(const Move & m) const { return true; } //checks array bounds too - bool onboard(int x, int y) const { return ( x >= 0 && y >= 0 && onboard_fast(x, y) ); } - bool onboard(const Move & m)const { return (m.x >= 0 && m.y >= 0 && onboard_fast(m) ); } + bool onboard(int x, int y) const { return ( x >= 0 && y >= 0 && x < size && y < size && onboard_fast(x, y) ); } + bool onboard(const Move & m)const { return (m.x >= 0 && m.y >= 0 && m.x < size && m.y < size && onboard_fast(m) ); } bool onboard(const MoveValid & m) const { return m.onboard(); } //assumes x, y are in bounds and the game isn't already finished - bool valid_move_fast(int x, int y) const { return !get(x,y); } - bool valid_move_fast(const Move & m) const { return !get(m); } - bool valid_move_fast(const MoveValid & m) const { return !get(m.xy); } + bool valid_move_fast(int i) const { return get(i) == Side::NONE; } + bool valid_move_fast(int x, int y) const { return valid_move_fast(xy(x, y)); } + bool valid_move_fast(const Move & m) const { return valid_move_fast(xy(m)); } + bool valid_move_fast(const MoveValid & m) const { return valid_move_fast(m.xy); } //checks array bounds too - bool valid_move(int x, int y) const { return (outcome == -3 && onboard(x, y) && !get(x, y)); } - bool valid_move(const Move & m) const { return (outcome == -3 && onboard(m) && !get(m)); } - bool valid_move(const MoveValid & m) const { return (outcome == -3 && m.onboard() && !get(m)); } + bool valid_move(int x, int y) const { return (outcome < Outcome::DRAW && onboard(x, y) && valid_move_fast(x, y)); } + bool valid_move(const Move & m) const { return (outcome < Outcome::DRAW && onboard(m) && valid_move_fast(m)); } + bool valid_move(const MoveValid & m) const { return (outcome < Outcome::DRAW && m.onboard() && valid_move_fast(m)); } //iterator through neighbours of a position const MoveValid * nb_begin(int x, int y) const { return nb_begin(xy(x, y)); } @@ -252,12 +257,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board const MoveValid * nb_end_small_hood(const MoveValid * m) const { return m + 12; } const MoveValid * nb_end_big_hood(const MoveValid * m) const { return m + 18; } - int edges(int x, int y) const { - return (x == 0 ? 1 : 0) | - (x == sizem1 ? 2 : 0) | - (y == 0 ? 4 : 0) | - (y == sizem1 ? 8 : 0); - } + int edges(int x, int y) const; MoveValid * get_neighbour_list(){ if(!staticneighbourlist[(int)size]){ @@ -281,94 +281,24 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return staticneighbourlist[(int)size]; } - + int linestart(int y) const { return 0; } int lineend(int y) const { return size; } + int linelen(int y) const { return lineend(y) - linestart(y); } - string to_s(bool color) const { - string white = "O", - black = "@", - empty = ".", - coord = "", - reset = ""; - if(color){ - string esc = "\033"; - reset = esc + "[0m"; - coord = esc + "[1;37m"; - empty = reset + "."; - white = esc + "[1;33m" + "@"; //yellow - black = esc + "[1;34m" + "@"; //blue - } - - string s; - for(int i = 0; i < size; i++) - s += " " + coord + to_str(i+1); - s += "\n"; - - for(int y = 0; y < size; y++){ - s += string(y, ' '); - s += coord + char('A' + y); - int end = lineend(y); - for(int x = 0; x < size; x++){ - s += (last == Move(x, y) ? coord + "[" : - last == Move(x-1, y) ? coord + "]" : " "); - int p = get(x, y); - if(p == 0) s += empty; - if(p == 1) s += white; - if(p == 2) s += black; - if(p >= 3) s += "?"; - } - s += (last == Move(end-1, y) ? coord + "]" : " "); - s += white + reset; - s += '\n'; - } - s += string(size + 2, ' '); - for(int i = 0; i < size; i++) - s += black + " "; - s += "\n"; - - s += reset; - return s; - } + std::string to_s(bool color) const; + friend std::ostream& operator<< (std::ostream &out, const Board & b) { return out << b.to_s(true); } void print(bool color = true) const { printf("%s", to_s(color).c_str()); } - string boardstr() const { - string white, black; - for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ - int p = get(x, y); - if(p == 1) white += Move(x, y).to_s(); - if(p == 2) black += Move(x, y).to_s(); - } - } - return white + ";" + black; - } - - string won_str() const { - switch(outcome){ - case -3: return "none"; - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - } - return "unknown"; - } - - char won() const { + Outcome won() const { return outcome; } - int win() const{ // 0 for draw or unknown, 1 for win, -1 for loss - if(outcome <= 0) - return 0; - return (outcome == toplay() ? 1 : -1); - } + char getwintype() const { return outcome > Outcome::DRAW; } - char toplay() const { + Side toplay() const { return toPlay; } @@ -376,22 +306,22 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return MoveIterator(*this, (unique ? nummoves <= unique_depth : false)); } - void set(const Move & m, bool perm = true){ + void set(const Move & m, bool perm = true) { last = m; Cell * cell = & cells[xy(m)]; cell->piece = toPlay; cell->perm = perm; nummoves++; update_hash(m, toPlay); //depends on nummoves - toPlay = 3 - toPlay; + toPlay = ~toPlay; } - void unset(const Move & m){ //break win checks, but is a poor mans undo if all you care about is the hash - toPlay = 3 - toPlay; + void unset(const Move & m) { //break win checks, but is a poor mans undo if all you care about is the hash + toPlay = ~toPlay; update_hash(m, toPlay); nummoves--; Cell * cell = & cells[xy(m)]; - cell->piece = 0; + cell->piece = Side::NONE; cell->perm = 0; } @@ -421,7 +351,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return true; if(cells[i].size < cells[j].size) //force i's subtree to be bigger - swap(i, j); + std::swap(i, j); cells[j].parent = i; cells[i].size += cells[j].size; @@ -431,7 +361,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } Cell test_cell(const Move & pos) const { - char turn = toplay(); + Side turn = toplay(); int posxy = xy(pos); Cell testcell = cells[find_group(pos)]; @@ -463,7 +393,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (nummoves > unique_depth ? hash.get(0) : hash.get()); } - string hashstr() const { + std::string hashstr() const { static const char hexlookup[] = "0123456789abcdef"; char buf[19] = "0x"; hash_t val = gethash(); @@ -475,7 +405,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (char *)buf; } - void update_hash(const Move & pos, int turn){ + void update_hash(const Move & pos, Side side) { + int turn = side.to_i(); if(nummoves > unique_depth){ //simple update, no rotations/symmetry hash.update(0, 3*xy(pos) + turn); return; @@ -498,7 +429,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return test_hash(pos, toplay()); } - hash_t test_hash(const Move & pos, int turn) const { + hash_t test_hash(const Move & pos, Side side) const { + int turn = side.to_i(); if(nummoves >= unique_depth) //simple test, no rotations/symmetry return hash.test(0, 3*xy(pos) + turn); @@ -507,11 +439,11 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board z = sizem1 - x - y; hash_t m = hash.test(0, 3*xy(x, y) + turn); - m = min(m, hash.test(1, 3*xy(z, y) + turn)); - m = min(m, hash.test(2, 3*xy(z, x) + turn)); - m = min(m, hash.test(3, 3*xy(x, z) + turn)); - m = min(m, hash.test(4, 3*xy(y, z) + turn)); - m = min(m, hash.test(5, 3*xy(y, x) + turn)); + m = std::min(m, hash.test(1, 3*xy(z, y) + turn)); + m = std::min(m, hash.test(2, 3*xy(z, x) + turn)); + m = std::min(m, hash.test(3, 3*xy(x, z) + turn)); + m = std::min(m, hash.test(4, 3*xy(y, z) + turn)); + m = std::min(m, hash.test(5, 3*xy(y, x) + turn)); return m; } @@ -543,13 +475,13 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (((p & 0x03F03F03Full) << 6) | ((p & 0xFC0FC0FC0ull) >> 6)); } - static Pattern pattern_invert(Pattern p){ //switch players + static Pattern pattern_invert(Pattern p) { //switch players return ((p & 0xAAAAAAAAAull) >> 1) | ((p & 0x555555555ull) << 1); } - static Pattern pattern_rotate(Pattern p){ + static Pattern pattern_rotate(Pattern p) { return (((p & 0x003003003ull) << 10) | ((p & 0xFFCFFCFFCull) >> 2)); } - static Pattern pattern_mirror(Pattern p){ + static Pattern pattern_mirror(Pattern p) { // HGFEDC BA9876 543210 -> DEFGHC 6789AB 123450 return ((p & (3ull << 6)) ) | ((p & (3ull << 0)) ) | // 0,3 stay in place ((p & (3ull << 10)) >> 8) | ((p & (3ull << 2)) << 8) | // 1,5 swap @@ -561,36 +493,36 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board ((p & (3ull << 34)) >> 8) | ((p & (3ull << 26)) << 8) | // H,D swap ((p & (3ull << 32)) >> 4) | ((p & (3ull << 28)) << 4); // G,E swap } - static Pattern pattern_symmetry(Pattern p){ //takes a pattern and returns the representative version + static Pattern pattern_symmetry(Pattern p) { //takes a pattern and returns the representative version Pattern m = p; //012345 - m = min(m, (p = pattern_rotate(p)));//501234 - m = min(m, (p = pattern_rotate(p)));//450123 - m = min(m, (p = pattern_rotate(p)));//345012 - m = min(m, (p = pattern_rotate(p)));//234501 - m = min(m, (p = pattern_rotate(p)));//123450 - m = min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 - m = min(m, (p = pattern_rotate(p)));//105432 - m = min(m, (p = pattern_rotate(p)));//210543 - m = min(m, (p = pattern_rotate(p)));//321054 - m = min(m, (p = pattern_rotate(p)));//432105 - m = min(m, (p = pattern_rotate(p)));//543210 + m = std::min(m, (p = pattern_rotate(p)));//501234 + m = std::min(m, (p = pattern_rotate(p)));//450123 + m = std::min(m, (p = pattern_rotate(p)));//345012 + m = std::min(m, (p = pattern_rotate(p)));//234501 + m = std::min(m, (p = pattern_rotate(p)));//123450 + m = std::min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 + m = std::min(m, (p = pattern_rotate(p)));//105432 + m = std::min(m, (p = pattern_rotate(p)));//210543 + m = std::min(m, (p = pattern_rotate(p)));//321054 + m = std::min(m, (p = pattern_rotate(p)));//432105 + m = std::min(m, (p = pattern_rotate(p)));//543210 return m; } - bool move(const Move & pos, bool checkwin = true, bool permanent = true){ + bool move(const Move & pos, bool checkwin = true, bool permanent = true) { return move(MoveValid(pos, xy(pos)), checkwin, permanent); } - bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true){ - assert(outcome < 0); + bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true) { + assert(outcome < Outcome::DRAW); if(!valid_move(pos)) return false; - char turn = toplay(); + Side turn = toplay(); set(pos, permanent); // update the nearby patterns - Pattern p = turn; + Pattern p = turn.to_i(); for(const MoveValid * i = nb_begin(pos.xy), *e = nb_end_big_hood(i); i < e; i++){ if(i->onboard()){ cells[i->xy].pattern |= p; @@ -609,27 +541,27 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board // did I win? Cell * g = & cells[find_group(pos.xy)]; - uint8_t winmask = (turn == 1 ? 3 : 0xC); + uint8_t winmask = (turn == Side::P1 ? 3 : 0xC); if((g->edge & winmask) == winmask){ outcome = turn; } return true; } - bool test_local(const Move & pos, char turn) const { + bool test_local(const Move & pos, Side turn) const { return test_local(MoveValid(pos, xy(pos)), turn); } + bool test_local(const MoveValid & pos, Side turn) const { return (local(pos, turn) == 3); } //test if making this move would win, but don't actually make the move - int test_win(const Move & pos, char turn = 0) const { - if(turn == 0) - turn = toplay(); - + Outcome test_outcome(const Move & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const Move & pos, Side turn) const { return test_outcome(MoveValid(pos, xy(pos)), turn); } + Outcome test_outcome(const MoveValid & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const MoveValid & pos, Side turn) const { if(test_local(pos, turn)){ - int posxy = xy(pos); - Cell testcell = cells[find_group(posxy)]; + Cell testcell = cells[find_group(pos.xy)]; int numgroups = 0; - for(const MoveValid * i = nb_begin(posxy), *e = nb_end(i); i < e; i++){ + for(const MoveValid * i = nb_begin(pos), *e = nb_end(i); i < e; i++){ if(i->onboard() && turn == get(i->xy)){ const Cell * g = & cells[find_group(i->xy)]; testcell.edge |= g->edge; @@ -639,11 +571,14 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } } - int winmask = (turn == 1 ? 3 : 0xC); + int winmask = (turn == Side::P1 ? 3 : 0xC); if((testcell.edge & winmask) == winmask) return turn; } - return -3; + return Outcome::UNKNOWN; } }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/board_test.cpp b/hex/board_test.cpp new file mode 100644 index 0000000..7c0f24d --- /dev/null +++ b/hex/board_test.cpp @@ -0,0 +1,123 @@ + +#include "../lib/catch.hpp" +#include "../lib/string.h" + +#include "board.h" + + +using namespace Morat; +using namespace Hex; + +void test_game(Board b, std::vector moves, Outcome outcome) { + REQUIRE(b.num_moves() == 0); + Side side = Side::P1; + for(auto s : moves) { + Outcome expected = (s == moves.back() ? outcome : Outcome::UNKNOWN); + Move move(s); + CAPTURE(move); + CAPTURE(b); + REQUIRE(b.valid_move(move)); + REQUIRE(b.toplay() == side); + REQUIRE(b.test_outcome(move) == expected); + REQUIRE(b.move(move)); + REQUIRE(b.won() == expected); + side = ~side; + } +} +void test_game(Board b, std::string moves, Outcome outcome) { + test_game(b, explode(moves, " "), outcome); +} + +TEST_CASE("Hex::Board", "[hex][board]") { + Board b(7); + + SECTION("Basics") { + REQUIRE(b.get_size() == 7); + REQUIRE(b.movesremain() == 49); + } + + SECTION("valid moves") { + std::string valid[] = {"A1", "D4", + "a1", "a2", "a3", "a4", "a5", "a6", "a7", + "b1", "b2", "b3", "b4", "b5", "b6", "b7", + "c1", "c2", "c3", "c4", "c5", "c6", "c7", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e1", "e2", "e3", "e4", "e5", "e6", "e7", + "f1", "f2", "f3", "f4", "f5", "f6", "f7", + "g1", "g2", "g3", "g4", "g5", "g6", "g7", + }; + for(auto m : valid){ + REQUIRE(b.onboard(m)); + REQUIRE(b.valid_move(m)); + } + } + + SECTION("invalid moves") { + std::string invalid[] = {"a0", "a8", "a10", "b8", "c8", "e0", "e8", "f8", "f0", "h1", "f0"}; + for(auto m : invalid){ + REQUIRE_FALSE(b.onboard(m)); + REQUIRE_FALSE(b.valid_move(m)); + } + } + + SECTION("duplicate moves") { + Move m("a1"); + REQUIRE(b.valid_move(m)); + REQUIRE(b.move(m)); + REQUIRE_FALSE(b.valid_move(m)); + REQUIRE_FALSE(b.move(m)); + } + + SECTION("move distance") { + SECTION("x") { + REQUIRE(b.dist(Move("b2"), Move("b1")) == 1); + REQUIRE(b.dist(Move("b2"), Move("b3")) == 1); + } + SECTION("y") { + REQUIRE(b.dist(Move("b2"), Move("a2")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c2")) == 1); + } + SECTION("z") { + REQUIRE(b.dist(Move("b2"), Move("a3")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c1")) == 1); + } + SECTION("farther") { + REQUIRE(b.dist(Move("b2"), Move("a1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("c3")) == 2); + REQUIRE(b.dist(Move("b2"), Move("d4")) == 4); + REQUIRE(b.dist(Move("b2"), Move("d3")) == 3); + REQUIRE(b.dist(Move("b2"), Move("d1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("e3")) == 4); + } + } + + SECTION("Unknown_1") { + test_game(b, { "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + test_game(b, {"d4", "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_2") { + test_game(b, { "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + test_game(b, {"d4", "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_3") { + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::UNKNOWN); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::UNKNOWN); + + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::UNKNOWN); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::UNKNOWN); + } + + SECTION("White Connects") { + test_game(b, + "c2 c5 e4 d6 c6 d5 f5 e3 d4 d3 b4 d1 a6 c1 d2 b3 c3 e1 f2 a7 b6 b7 c7 a5 f3 e7 g6 g7 f7 g4 f1 g2 b5 e2 c4", + Outcome::P1); + } + + SECTION("Black Connects") { + test_game(b, + "d1 d4 e4 e3 g2 f3 g3 f5 g5 g4 a6 b5 f4 d6 e5 g1 a5 b3 c5 c3 a3 b6 g7 a7 f7 f2 c4 d3 a4 b4", + Outcome::P2); + } +} diff --git a/hex/gtp.h b/hex/gtp.h index f53c9e9..cfb2ee9 100644 --- a/hex/gtp.h +++ b/hex/gtp.h @@ -2,6 +2,8 @@ #pragma once #include "../lib/gtpcommon.h" +#include "../lib/history.h" +#include "../lib/move.h" #include "../lib/string.h" #include "agent.h" @@ -9,11 +11,13 @@ #include "agentmcts.h" #include "agentpns.h" #include "board.h" -#include "history.h" -#include "move.h" + + +namespace Morat { +namespace Hex { class GTP : public GTPCommon { - History hist; + History hist; public: int verbose; @@ -35,46 +39,46 @@ class GTP : public GTPCommon { set_board(); - newcallback("name", bind(>P::gtp_name, this, _1), "Name of the program"); - newcallback("version", bind(>P::gtp_version, this, _1), "Version of the program"); - newcallback("verbose", bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); - newcallback("extended", bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); - newcallback("debug", bind(>P::gtp_debug, this, _1), "Enable debug mode"); - newcallback("colorboard", bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); - newcallback("showboard", bind(>P::gtp_print, this, _1), "Show the board"); - newcallback("print", bind(>P::gtp_print, this, _1), "Alias for showboard"); - newcallback("dists", bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); -// newcallback("zobrist", bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); - newcallback("clear_board", bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); - newcallback("clear", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); - newcallback("boardsize", bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); - newcallback("size", bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); - newcallback("play", bind(>P::gtp_play, this, _1), "Place a stone: play "); - newcallback("white", bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); - newcallback("black", bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); - newcallback("undo", bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); - newcallback("time", bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); - newcallback("genmove", bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); - newcallback("solve", bind(>P::gtp_solve, this, _1), "Try to solve this position"); - -// newcallback("ab", bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); - newcallback("mcts", bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); - newcallback("pns", bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); - - newcallback("all_legal", bind(>P::gtp_all_legal, this, _1), "List all legal moves"); - newcallback("history", bind(>P::gtp_history, this, _1), "List of played moves"); - newcallback("playgame", bind(>P::gtp_playgame, this, _1), "Play a list of moves"); - newcallback("winner", bind(>P::gtp_winner, this, _1), "Check the winner of the game"); - newcallback("patterns", bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); - - newcallback("pv", bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); - newcallback("move_stats", bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); - - newcallback("params", bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); - -// newcallback("player_hgf", bind(>P::gtp_player_hgf, this, _1), "Output an hgf of the current tree"); -// newcallback("player_load_hgf", bind(>P::gtp_player_load_hgf,this, _1), "Load an hgf generated by player_hgf"); -// newcallback("player_gammas", bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); + newcallback("name", std::bind(>P::gtp_name, this, _1), "Name of the program"); + newcallback("version", std::bind(>P::gtp_version, this, _1), "Version of the program"); + newcallback("verbose", std::bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); + newcallback("extended", std::bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); + newcallback("debug", std::bind(>P::gtp_debug, this, _1), "Enable debug mode"); + newcallback("colorboard", std::bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); + newcallback("showboard", std::bind(>P::gtp_print, this, _1), "Show the board"); + newcallback("print", std::bind(>P::gtp_print, this, _1), "Alias for showboard"); + newcallback("dists", std::bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); + newcallback("zobrist", std::bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); + newcallback("clear_board", std::bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); + newcallback("clear", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("boardsize", std::bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); + newcallback("size", std::bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); + newcallback("play", std::bind(>P::gtp_play, this, _1), "Place a stone: play "); + newcallback("white", std::bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); + newcallback("black", std::bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); + newcallback("undo", std::bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); + newcallback("time", std::bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); + newcallback("genmove", std::bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); + newcallback("solve", std::bind(>P::gtp_solve, this, _1), "Try to solve this position"); + +// newcallback("ab", std::bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); + newcallback("mcts", std::bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); + newcallback("pns", std::bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); + + newcallback("all_legal", std::bind(>P::gtp_all_legal, this, _1), "List all legal moves"); + newcallback("history", std::bind(>P::gtp_history, this, _1), "List of played moves"); + newcallback("playgame", std::bind(>P::gtp_playgame, this, _1), "Play a list of moves"); + newcallback("winner", std::bind(>P::gtp_winner, this, _1), "Check the winner of the game"); + newcallback("patterns", std::bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); + + newcallback("pv", std::bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); + newcallback("move_stats", std::bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); + + newcallback("params", std::bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + + newcallback("save_sgf", std::bind(>P::gtp_save_sgf, this, _1), "Output an sgf of the current tree"); + newcallback("load_sgf", std::bind(>P::gtp_load_sgf, this, _1), "Load an sgf generated by save_sgf"); +// newcallback("player_gammas", std::bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); } void set_board(bool clear = true){ @@ -94,7 +98,7 @@ class GTP : public GTPCommon { GTPResponse gtp_all_legal(vecstr args); GTPResponse gtp_history(vecstr args); GTPResponse gtp_patterns(vecstr args); - GTPResponse play(const string & pos, int toplay); + GTPResponse play(const std::string & pos, Side toplay); GTPResponse gtp_playgame(vecstr args); GTPResponse gtp_play(vecstr args); GTPResponse gtp_playwhite(vecstr args); @@ -124,8 +128,11 @@ class GTP : public GTPCommon { GTPResponse gtp_pns_params(vecstr args); // GTPResponse gtp_player_gammas(vecstr args); -// GTPResponse gtp_player_hgf(vecstr args); -// GTPResponse gtp_player_load_hgf(vecstr args); + GTPResponse gtp_save_sgf(vecstr args); + GTPResponse gtp_load_sgf(vecstr args); - string solve_str(int outcome) const; + std::string solve_str(int outcome) const; }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/gtpagent.cpp b/hex/gtpagent.cpp index d32178a..4a58c9d 100644 --- a/hex/gtpagent.cpp +++ b/hex/gtpagent.cpp @@ -1,13 +1,12 @@ -#include +#include "gtp.h" -#include "../lib/fileio.h" -#include "gtp.h" +namespace Morat { +namespace Hex { using namespace std; - GTPResponse GTP::gtp_move_stats(vecstr args){ vector moves; for(auto s : args) @@ -249,7 +248,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" " -m --memory Memory limit in Mb [" + to_str(pns->memlimit/(1024*1024)) + "]\n" " -t --threads How many threads to run [" + to_str(pns->numthreads) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties) + "]\n" + " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties.to_i()) + "]\n" " -d --df Use depth-first thresholds [" + to_str(pns->df) + "]\n" " -e --epsilon How big should the threshold be [" + to_str(pns->epsilon) + "]\n" " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(pns->ab) + "]\n" @@ -267,7 +266,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); pns->set_memlimit(mem*1024*1024); }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - pns->ties = from_str(args[++i]); + pns->ties = Side(from_str(args[++i])); pns->clear_mem(); }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ pns->df = from_str(args[++i]); @@ -282,3 +281,6 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ return GTPResponse(true, errs); } + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/gtpgeneral.cpp b/hex/gtpgeneral.cpp index 60f0f73..6fa9d98 100644 --- a/hex/gtpgeneral.cpp +++ b/hex/gtpgeneral.cpp @@ -1,7 +1,15 @@ +#include + +#include "../lib/sgf.h" + #include "gtp.h" #include "lbdist.h" + +namespace Morat { +namespace Hex { + GTPResponse GTP::gtp_mcts(vecstr args){ delete agent; agent = new AgentMCTS(); @@ -39,7 +47,7 @@ GTPResponse GTP::gtp_boardsize(vecstr args){ if(size < Board::min_size || size > Board::max_size) return GTPResponse(false, "Size " + to_str(size) + " is out of range."); - hist = History(size); + hist = History(size); set_board(); time_control.new_game(); @@ -69,14 +77,14 @@ GTPResponse GTP::gtp_undo(vecstr args){ GTPResponse GTP::gtp_patterns(vecstr args){ bool symmetric = true; bool invert = true; - string ret; + std::string ret; const Board & board = *hist; for(Board::MoveIterator move = board.moveit(); !move.done(); ++move){ ret += move->to_s() + " "; unsigned int p = board.pattern(*move); if(symmetric) p = board.pattern_symmetry(p); - if(invert && board.toplay() == 2) + if(invert && board.toplay() == Side::P2) p = board.pattern_invert(p); ret += to_str(p); ret += "\n"; @@ -85,24 +93,24 @@ GTPResponse GTP::gtp_patterns(vecstr args){ } GTPResponse GTP::gtp_all_legal(vecstr args){ - string ret; + std::string ret; for(Board::MoveIterator move = hist->moveit(); !move.done(); ++move) ret += move->to_s() + " "; return GTPResponse(true, ret); } GTPResponse GTP::gtp_history(vecstr args){ - string ret; + std::string ret; for(auto m : hist) ret += m.to_s() + " "; return GTPResponse(true, ret); } -GTPResponse GTP::play(const string & pos, int toplay){ +GTPResponse GTP::play(const std::string & pos, Side toplay){ if(toplay != hist->toplay()) return GTPResponse(false, "It is the other player's turn!"); - if(hist->won() >= 0) + if(hist->won() >= Outcome::DRAW) return GTPResponse(false, "The game is already over."); Move m(pos); @@ -113,7 +121,7 @@ GTPResponse GTP::play(const string & pos, int toplay){ move(m); if(verbose >= 2) - logerr("Placement: " + m.to_s() + ", outcome: " + hist->won_str() + "\n" + hist->to_s(colorboard)); + logerr("Placement: " + m.to_s() + ", outcome: " + hist->won().to_s() + "\n" + hist->to_s(colorboard)); return GTPResponse(true); } @@ -131,37 +139,33 @@ GTPResponse GTP::gtp_play(vecstr args){ if(args.size() != 2) return GTPResponse(false, "Wrong number of arguments"); - char toplay = 0; switch(tolower(args[0][0])){ - case 'w': toplay = 1; break; - case 'b': toplay = 2; break; - default: - return GTPResponse(false, "Invalid player selection"); + case 'w': return play(args[1], Side::P1); + case 'b': return play(args[1], Side::P2); + default: return GTPResponse(false, "Invalid player selection"); } - - return play(args[1], toplay); } GTPResponse GTP::gtp_playwhite(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 1); + return play(args[0], Side::P1); } GTPResponse GTP::gtp_playblack(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 2); + return play(args[0], Side::P2); } GTPResponse GTP::gtp_winner(vecstr args){ - return GTPResponse(true, hist->won_str()); + return GTPResponse(true, hist->won().to_s()); } GTPResponse GTP::gtp_name(vecstr args){ - return GTPResponse(true, "Castro"); + return GTPResponse(true, std::string("morat-") + Board::name); } GTPResponse GTP::gtp_version(vecstr args){ @@ -193,7 +197,7 @@ GTPResponse GTP::gtp_extended(vecstr args){ } GTPResponse GTP::gtp_debug(vecstr args){ - string str = "\n"; + std::string str = "\n"; str += "Board size: " + to_str(hist->get_size()) + "\n"; str += "Board cells: " + to_str(hist->numcells()) + "\n"; str += "Board vec: " + to_str(hist->vecsize()) + "\n"; @@ -203,14 +207,15 @@ GTPResponse GTP::gtp_debug(vecstr args){ } GTPResponse GTP::gtp_dists(vecstr args){ + using std::string; Board board = *hist; LBDists dists(&board); - int side = 0; + Side side = Side::NONE; if(args.size() >= 1){ switch(tolower(args[0][0])){ - case 'w': side = 1; break; - case 'b': side = 2; break; + case 'w': side = Side::P1; break; + case 'b': side = Side::P2; break; default: return GTPResponse(false, "Invalid player selection"); } @@ -243,17 +248,17 @@ GTPResponse GTP::gtp_dists(vecstr args){ s += coord + char('A' + y); int end = board.lineend(y); for(int x = 0; x < end; x++){ - int p = board.get(x, y); + Side p = board.get(x, y); s += ' '; - if(p == 0){ - int d = (side ? dists.get(Move(x, y), side) : dists.get(Move(x, y))); - if(d < 30) + if(p == Side::NONE){ + int d = (side == Side::NONE ? dists.get(Move(x, y)) : dists.get(Move(x, y), side)); + if(d < 10) s += reset + to_str(d); else s += empty; - }else if(p == 1){ + }else if(p == Side::P1){ s += white; - }else if(p == 2){ + }else if(p == Side::P2){ s += black; } } @@ -265,3 +270,91 @@ GTPResponse GTP::gtp_dists(vecstr args){ GTPResponse GTP::gtp_zobrist(vecstr args){ return GTPResponse(true, hist->hashstr()); } + +GTPResponse GTP::gtp_save_sgf(vecstr args){ + int limit = -1; + if(args.size() == 0) + return GTPResponse(true, "save_sgf [work limit]"); + + std::ifstream infile(args[0].c_str()); + + if(infile) { + infile.close(); + return GTPResponse(false, "File " + args[0] + " already exists"); + } + + std::ofstream outfile(args[0].c_str()); + + if(!outfile) + return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); + + if(args.size() > 1) + limit = from_str(args[1]); + + SGFPrinter sgf(outfile); + sgf.game(Board::name); + sgf.program(gtp_name(vecstr()).response, gtp_version(vecstr()).response); + sgf.size(hist->get_size()); + + sgf.end_root(); + + Side s = Side::P1; + for(auto m : hist){ + sgf.move(s, m); + s = ~s; + } + + agent->gen_sgf(sgf, limit); + + sgf.end(); + outfile.close(); + return true; +} + + +GTPResponse GTP::gtp_load_sgf(vecstr args){ + if(args.size() == 0) + return GTPResponse(true, "load_sgf "); + + std::ifstream infile(args[0].c_str()); + + if(!infile) { + return GTPResponse(false, "Error opening file " + args[0] + " for reading"); + } + + SGFParser sgf(infile); + if(sgf.game() != Board::name){ + infile.close(); + return GTPResponse(false, "File is for the wrong game: " + sgf.game()); + } + + int size = sgf.size(); + if(size != hist->get_size()){ + if(hist.len() == 0){ + hist = History(size); + set_board(); + time_control.new_game(); + }else{ + infile.close(); + return GTPResponse(false, "File has the wrong boardsize to match the existing game"); + } + } + + Side s = Side::P1; + + while(sgf.next_node()){ + Move m = sgf.move(); + move(m); // push the game forward + s = ~s; + } + + if(sgf.has_children()) + agent->load_sgf(sgf); + + assert(sgf.done_child()); + infile.close(); + return true; +} + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/gtpplayer.cpp b/hex/gtpplayer.cpp deleted file mode 100644 index 1d9f89b..0000000 --- a/hex/gtpplayer.cpp +++ /dev/null @@ -1,547 +0,0 @@ - - -#include - -#include "../lib/fileio.h" - -#include "gtp.h" - -using namespace std; - - -GTPResponse GTP::gtp_move_stats(vecstr args){ - string s = ""; - - Player::Node * node = &(player.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - Player::Node * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - Player::Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s(); - s += "," + to_str((child->exp.num() ? child->exp.avg() : 0.0), 4) + "," + to_str(child->exp.num()); - s += "," + to_str((child->rave.num() ? child->rave.avg() : 0.0), 4) + "," + to_str(child->rave.num()); - s += "," + to_str(child->know); - if(child->outcome >= 0) - s += "," + won_str(child->outcome); - s += "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_player_solve(vecstr args){ - double use_time = (args.size() >= 1 ? - from_str(args[0]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = M_RESIGN; - if(ret) - best = ret->move; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret){ - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(ret->outcome >= 0){ - stats += "Solved as a "; - if(ret->outcome == toplay) stats += "win"; - else if(ret->outcome == 0) stats += "draw"; - else stats += "loss"; - stats += "\n"; - } - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - if(verbose) - logerr(stats); - - Solver s; - if(ret){ - s.outcome = (ret->outcome >= 0 ? ret->outcome : -3); - s.bestmove = ret->move; - s.maxdepth = gamelen.maxdepth; - s.nodes_seen = runs; - }else{ - s.outcome = 3-toplay; - s.bestmove = M_RESIGN; - s.maxdepth = 0; - s.nodes_seen = 0; - } - - return GTPResponse(true, solve_str(s)); -} - - -GTPResponse GTP::gtp_player_solved(vecstr args){ - string s = ""; - Player::Node * child = player.root.children.begin(), - * childend = player.root.children.end(); - int toplay = player.rootboard.toplay(); - int best = 0; - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - if(child->outcome == toplay) - return GTPResponse(true, won_str(toplay)); - else if(child->outcome < 0) - best = 2; - else if(child->outcome == 0) - best = 1; - } - if(best == 2) return GTPResponse(true, won_str(-3)); - if(best == 1) return GTPResponse(true, won_str(0)); - return GTPResponse(true, won_str(3 - toplay)); -} - -GTPResponse GTP::gtp_pv(vecstr args){ - string pvstr = ""; - vector pv = player.get_pv(); - for(unsigned int i = 0; i < pv.size(); i++) - pvstr += pv[i].to_s() + " "; - return GTPResponse(true, pvstr); -} - -GTPResponse GTP::gtp_player_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_hgf [sims limit]"); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(fd){ - fclose(fd); - return GTPResponse(false, "File " + args[0] + " already exists"); - } - - fd = fopen(args[0].c_str(), "w"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); - - unsigned int limit = 10000; - if(args.size() > 1) - limit = from_str(args[1]); - - Board board = *hist; - - - fprintf(fd, "(;FF[4]SZ[%i]\n", board.get_size()); - int p = 1; - for(auto m : hist){ - fprintf(fd, ";%c[%s]", (p == 1 ? 'W' : 'B'), m.to_s().c_str()); - p = 3-p; - } - - - Player::Node * child = player.root.children.begin(), - * end = player.root.children.end(); - - for( ; child != end; child++){ - if(child->exp.num() >= limit){ - board.set(child->move); - player.gen_hgf(board, child, limit, 1, fd); - board.unset(child->move); - } - } - - fprintf(fd, ")\n"); - - fclose(fd); - - return true; -} - -GTPResponse GTP::gtp_player_load_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_load_hgf "); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for reading failed"); - - int size; - assert(fscanf(fd, "(;FF[4]SZ[%i]", & size) > 0); - if(size != hist->get_size()){ - if(hist.len() == 0){ - hist = History(Board(size)); - set_board(); - }else{ - fclose(fd); - return GTPResponse(false, "File has the wrong boardsize to match the existing game"); - } - } - - eat_whitespace(fd); - - Board board(size); - Player::Node * node = & player.root; - vector prefix; - - char side, movestr[5]; - while(fscanf(fd, ";%c[%5[^]]]", &side, movestr) > 0){ - Move move(movestr); - - if(board.num_moves() >= (int)hist.len()){ - if(node->children.empty()) - player.create_children_simple(board, node); - - prefix.push_back(node); - node = player.find_child(node, move); - }else if(hist[board.num_moves()] != move){ - fclose(fd); - return GTPResponse(false, "The current game is deeper than this file"); - } - board.move(move); - - eat_whitespace(fd); - } - prefix.push_back(node); - - - if(fpeek(fd) != ')'){ - if(node->children.empty()) - player.create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Player::Node child; - player.load_hgf(board, & child, fd); - - Player::Node * i = player.find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_whitespace(fd); - assert(fgetc(fd) == ')'); - fclose(fd); - - while(!prefix.empty()){ - Player::Node * node = prefix.back(); - prefix.pop_back(); - - Player::Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = hist->toplay(); - if(prefix.size() % 2 == 1) - toplay = 3 - toplay; - - Player::Node * backup = child; - - node->exp.clear(); - for( ; child != end; child++){ - node->exp += child->exp.invert(); - if(child->outcome == toplay || child->exp.num() > backup->exp.num()) - backup = child; - } - player.do_backup(node, backup, toplay); - } - - return true; -} - - -GTPResponse GTP::gtp_genmove(vecstr args){ - if(player.rootboard.won() >= 0) - return GTPResponse(true, "resign"); - - double use_time = (args.size() >= 2 ? - from_str(args[1]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(args.size() >= 2) - use_time = from_str(args[1]); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - uword nodesbefore = player.nodes; - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = player.root.bestmove; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret) - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(player.root.outcome != -3){ - stats += "Solved as a "; - if(player.root.outcome == 0) stats += "draw"; - else if(player.root.outcome == toplay) stats += "win"; - else if(player.root.outcome == 3-toplay) stats += "loss"; - else if(player.root.outcome == -toplay) stats += "win or draw"; - else if(player.root.outcome == toplay-3) stats += "loss or draw"; - stats += "\n"; - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - string extended; - if(genmoveextended){ - //move score - if(ret) extended += " " + to_str(ret->exp.avg()); - else extended += " 0"; - //outcome - extended += " " + won_str(player.root.outcome); - //work - extended += " " + to_str(runs); - //nodes - extended += " " + to_str(player.nodes - nodesbefore); - } - - move(best); - - if(verbose >= 2){ - stats += "history: "; - for(auto m : hist) - stats += m.to_s() + " "; - stats += "\n"; - stats += hist->to_s(colorboard) + "\n"; - } - - if(verbose) - logerr(stats); - - return GTPResponse(true, best.to_s() + extended); -} - -GTPResponse GTP::gtp_player_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Set player parameters, eg: player_params -e 1 -f 0 -t 2 -o 1 -p 0\n" + - "Processing:\n" + -#ifndef SINGLE_THREAD - " -t --threads Number of MCTS threads [" + to_str(player.numthreads) + "]\n" + -#endif - " -o --ponder Continue to ponder during the opponents time [" + to_str(player.ponder) + "]\n" + - " -M --maxmem Max memory in Mb to use for the tree [" + to_str(player.maxmem/(1024*1024)) + "]\n" + - " --profile Output the time used by each phase of MCTS [" + to_str(player.profile) + "]\n" + - "Final move selection:\n" + - " -E --msexplore Lower bound constant in final move selection [" + to_str(player.msexplore) + "]\n" + - " -F --msrave Rave factor, 0 for pure exp, -1 # sims, -2 # wins [" + to_str(player.msrave) + "]\n" + - "Tree traversal:\n" + - " -e --explore Exploration rate for UCT [" + to_str(player.explore) + "]\n" + - " -A --parexplore Multiply the explore rate by parents experience [" + to_str(player.parentexplore) + "]\n" + - " -f --ravefactor The rave factor: alpha = rf/(rf + visits) [" + to_str(player.ravefactor) + "]\n" + - " -d --decrrave Decrease the rave factor over time: rf += d*empty [" + to_str(player.decrrave) + "]\n" + - " -a --knowledge Use knowledge: 0.01*know/sqrt(visits+1) [" + to_str(player.knowledge) + "]\n" + - " -r --userave Use rave with this probability [0-1] [" + to_str(player.userave) + "]\n" + - " -X --useexplore Use exploration with this probability [0-1] [" + to_str(player.useexplore) + "]\n" + - " -u --fpurgency Value to assign to an unplayed move [" + to_str(player.fpurgency) + "]\n" + - " -O --rollouts Number of rollouts to run per simulation [" + to_str(player.rollouts) + "]\n" + - " -I --dynwiden Dynamic widening, consider log_wid(exp) children [" + to_str(player.dynwiden) + "]\n" + - "Tree building:\n" + - " -s --shortrave Only use moves from short rollouts for rave [" + to_str(player.shortrave) + "]\n" + - " -k --keeptree Keep the tree from the previous move [" + to_str(player.keeptree) + "]\n" + - " -m --minimax Backup the minimax proof in the UCT tree [" + to_str(player.minimax) + "]\n" + - " -x --visitexpand Number of visits before expanding a node [" + to_str(player.visitexpand) + "]\n" + - " -P --symmetry Prune symmetric moves, good for proof, not play [" + to_str(player.prunesymmetry) + "]\n" + - " --gcsolved Garbage collect solved nodes with fewer sims than [" + to_str(player.gcsolved) + "]\n" + - "Node initialization knowledge, Give a bonus:\n" + - " -l --localreply based on the distance to the previous move [" + to_str(player.localreply) + "]\n" + - " -y --locality to stones near other stones of the same color [" + to_str(player.locality) + "]\n" + - " -c --connect to stones connected to edges [" + to_str(player.connect) + "]\n" + - " -S --size based on the size of the group [" + to_str(player.size) + "]\n" + - " -b --bridge to maintaining a 2-bridge after the op probes [" + to_str(player.bridge) + "]\n" + - " -D --distance to low minimum distance to win (<0 avoid VCs) [" + to_str(player.dists) + "]\n" + - "Rollout policy:\n" + - " -h --weightrand Weight the moves according to computed gammas [" + to_str(player.weightedrandom) + "]\n" + - " -p --pattern Maintain the virtual connection pattern [" + to_str(player.rolloutpattern) + "]\n" + - " -g --goodreply Reuse the last good reply (1), remove losses (2) [" + to_str(player.lastgoodreply) + "]\n" + - " -w --instantwin Look for instant wins to this depth [" + to_str(player.instantwin) + "]\n" - ); - - string errs; - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - player.numthreads = from_str(args[++i]); - bool p = player.ponder; - player.set_ponder(false); //stop the threads while resetting them - player.reset_threads(); - player.set_ponder(p); - }else if((arg == "-o" || arg == "--ponder") && i+1 < args.size()){ - player.set_ponder(from_str(args[++i])); - }else if((arg == "--profile") && i+1 < args.size()){ - player.profile = from_str(args[++i]); - }else if((arg == "-M" || arg == "--maxmem") && i+1 < args.size()){ - player.maxmem = from_str(args[++i])*1024*1024; - }else if((arg == "-E" || arg == "--msexplore") && i+1 < args.size()){ - player.msexplore = from_str(args[++i]); - }else if((arg == "-F" || arg == "--msrave") && i+1 < args.size()){ - player.msrave = from_str(args[++i]); - }else if((arg == "-e" || arg == "--explore") && i+1 < args.size()){ - player.explore = from_str(args[++i]); - }else if((arg == "-A" || arg == "--parexplore") && i+1 < args.size()){ - player.parentexplore = from_str(args[++i]); - }else if((arg == "-f" || arg == "--ravefactor") && i+1 < args.size()){ - player.ravefactor = from_str(args[++i]); - }else if((arg == "-d" || arg == "--decrrave") && i+1 < args.size()){ - player.decrrave = from_str(args[++i]); - }else if((arg == "-a" || arg == "--knowledge") && i+1 < args.size()){ - player.knowledge = from_str(args[++i]); - }else if((arg == "-s" || arg == "--shortrave") && i+1 < args.size()){ - player.shortrave = from_str(args[++i]); - }else if((arg == "-k" || arg == "--keeptree") && i+1 < args.size()){ - player.keeptree = from_str(args[++i]); - }else if((arg == "-m" || arg == "--minimax") && i+1 < args.size()){ - player.minimax = from_str(args[++i]); - }else if((arg == "-P" || arg == "--symmetry") && i+1 < args.size()){ - player.prunesymmetry = from_str(args[++i]); - }else if(( arg == "--gcsolved") && i+1 < args.size()){ - player.gcsolved = from_str(args[++i]); - }else if((arg == "-r" || arg == "--userave") && i+1 < args.size()){ - player.userave = from_str(args[++i]); - }else if((arg == "-X" || arg == "--useexplore") && i+1 < args.size()){ - player.useexplore = from_str(args[++i]); - }else if((arg == "-u" || arg == "--fpurgency") && i+1 < args.size()){ - player.fpurgency = from_str(args[++i]); - }else if((arg == "-O" || arg == "--rollouts") && i+1 < args.size()){ - player.rollouts = from_str(args[++i]); - if(player.gclimit < player.rollouts*5) - player.gclimit = player.rollouts*5; - }else if((arg == "-I" || arg == "--dynwiden") && i+1 < args.size()){ - player.dynwiden = from_str(args[++i]); - player.logdynwiden = std::log(player.dynwiden); - }else if((arg == "-x" || arg == "--visitexpand") && i+1 < args.size()){ - player.visitexpand = from_str(args[++i]); - }else if((arg == "-l" || arg == "--localreply") && i+1 < args.size()){ - player.localreply = from_str(args[++i]); - }else if((arg == "-y" || arg == "--locality") && i+1 < args.size()){ - player.locality = from_str(args[++i]); - }else if((arg == "-c" || arg == "--connect") && i+1 < args.size()){ - player.connect = from_str(args[++i]); - }else if((arg == "-S" || arg == "--size") && i+1 < args.size()){ - player.size = from_str(args[++i]); - }else if((arg == "-b" || arg == "--bridge") && i+1 < args.size()){ - player.bridge = from_str(args[++i]); - }else if((arg == "-D" || arg == "--distance") && i+1 < args.size()){ - player.dists = from_str(args[++i]); - }else if((arg == "-h" || arg == "--weightrand") && i+1 < args.size()){ - player.weightedrandom = from_str(args[++i]); - }else if((arg == "-p" || arg == "--pattern") && i+1 < args.size()){ - player.rolloutpattern = from_str(args[++i]); - }else if((arg == "-g" || arg == "--goodreply") && i+1 < args.size()){ - player.lastgoodreply = from_str(args[++i]); - }else if((arg == "-w" || arg == "--instantwin") && i+1 < args.size()){ - player.instantwin = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - return GTPResponse(true, errs); -} - -GTPResponse GTP::gtp_player_gammas(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "Must pass the filename of a set of gammas"); - - ifstream ifs(args[0].c_str()); - - if(!ifs.good()) - return GTPResponse(false, "Failed to open file for reading"); - - Board board = *hist; - - for(int i = 0; i < 4096; i++){ - int a; - float f; - ifs >> a >> f; - - if(i != a){ - ifs.close(); - return GTPResponse(false, "Line " + to_str(i) + " doesn't match the expected value"); - } - - int s = board.pattern_symmetry(i); - if(s == i) - player.gammas[i] = f; - else - player.gammas[i] = player.gammas[s]; - } - - ifs.close(); - return GTPResponse(true); -} diff --git a/hex/gtpsolver.cpp b/hex/gtpsolver.cpp deleted file mode 100644 index 1df5ea1..0000000 --- a/hex/gtpsolver.cpp +++ /dev/null @@ -1,331 +0,0 @@ - - -#include "gtp.h" - -string GTP::solve_str(int outcome) const { - switch(outcome){ - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - default: return "unknown"; - } -} - -string GTP::solve_str(const Solver & solve){ - string ret = ""; - ret += solve_str(solve.outcome) + " "; - ret += solve.bestmove.to_s() + " "; - ret += to_str(solve.maxdepth) + " "; - ret += to_str(solve.nodes_seen); - return ret; -} - - -GTPResponse GTP::gtp_solve_ab(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverab.solve(time); - - logerr("Finished in " + to_str(solverab.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverab)); -} - -GTPResponse GTP::gtp_solve_ab_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the alpha-beta solver settings, eg: ab_params -m 100 -s 1 -d 3\n" - " -m --memory Memory limit in Mb (0 to disable the TT) [" + to_str(solverab.memlimit/(1024*1024)) + "]\n" - " -s --scout Whether to scout ahead for the true minimax value [" + to_str(solverab.scout) + "]\n" - " -d --depth Starting depth [" + to_str(solverab.startdepth) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - solverab.set_memlimit(mem); - }else if((arg == "-s" || arg == "--scout") && i+1 < args.size()){ - solverab.scout = from_str(args[++i]); - }else if((arg == "-d" || arg == "--depth") && i+1 < args.size()){ - solverab.startdepth = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_ab_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - int value; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - value = solverab.tt_get(board.test_hash(*move)); - - s += move->to_s() + "," + to_str(value) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_ab_clear(vecstr args){ - solverab.clear_mem(); - return true; -} - - - -GTPResponse GTP::gtp_solve_pns(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns.solve(time); - - logerr("Finished in " + to_str(solverpns.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns)); -} - -GTPResponse GTP::gtp_solve_pns_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns_stats(vecstr args){ - string s = ""; - - SolverPNS::PNSNode * node = &(solverpns.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns_clear(vecstr args){ - solverpns.clear_mem(); - return true; -} - - -GTPResponse GTP::gtp_solve_pns2(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns2.solve(time); - - logerr("Finished in " + to_str(solverpns2.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns2)); -} - -GTPResponse GTP::gtp_solve_pns2_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns2.memlimit/(1024*1024)) + "]\n" - " -t --threads How many threads to run [" + to_str(solverpns2.numthreads) + "]\n" -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns2.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns2.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns2.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns2.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - solverpns2.numthreads = from_str(args[++i]); - solverpns2.reset_threads(); - }else if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns2.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns2.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns2.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns2.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns2.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns2_stats(vecstr args){ - string s = ""; - - SolverPNS2::PNSNode * node = &(solverpns2.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS2::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS2::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "," + to_str(child->work) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns2_clear(vecstr args){ - solverpns2.clear_mem(); - return true; -} - - - - -GTPResponse GTP::gtp_solve_pnstt(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpnstt.solve(time); - - logerr("Finished in " + to_str(solverpnstt.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpnstt)); -} - -GTPResponse GTP::gtp_solve_pnstt_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pnstt solver settings, eg: pnstt_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpnstt.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpnstt.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpnstt.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpnstt.ab) + "]\n" - " -c --copy Try to copy a proof to this many siblings, <0 quit early [" + to_str(solverpnstt.copyproof) + "]\n" -// " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpnstt.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpnstt.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpnstt.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpnstt.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpnstt.ab = from_str(args[++i]); - }else if((arg == "-c" || arg == "--copy") && i+1 < args.size()){ - solverpnstt.copyproof = from_str(args[++i]); -// }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ -// solverpnstt.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pnstt_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - SolverPNSTT::PNSNode * child = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - child = solverpnstt.tt(board, *move); - - s += move->to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pnstt_clear(vecstr args){ - solverpnstt.clear_mem(); - return true; -} diff --git a/hex/history.h b/hex/history.h deleted file mode 100644 index 00ccd06..0000000 --- a/hex/history.h +++ /dev/null @@ -1,70 +0,0 @@ - -#pragma once - -#include - -#include "../lib/string.h" - -#include "board.h" -#include "move.h" - -class History { - std::vector hist; - Board board; - -public: - - History() { } - History(const Board & b) : board(b) { } - - const Move & operator [] (int i) const { - return hist[i]; - } - - Move last() const { - if(hist.size() == 0) - return M_NONE; - - return hist.back(); - } - - const Board & operator * () const { return board; } - const Board * operator -> () const { return & board; } - - std::vector::const_iterator begin() const { return hist.begin(); } - std::vector::const_iterator end() const { return hist.end(); } - - const Board get_board() const { - Board b(board.get_size()); - for(auto m : hist) - b.move(m); - return b; - } - - int len() const { - return hist.size(); - } - - void clear() { - hist.clear(); - board = get_board(); - } - - bool undo() { - if(hist.size() <= 0) - return false; - - hist.pop_back(); - board = get_board(); - return true; - } - - bool move(const Move & m) { - if(board.valid_move(m)){ - board.move(m); - hist.push_back(m); - return true; - } - return false; - } -}; diff --git a/hex/lbdist.h b/hex/lbdist.h index 99ccf30..774ad89 100644 --- a/hex/lbdist.h +++ b/hex/lbdist.h @@ -12,8 +12,13 @@ Decrease distance when crossing your own virtual connection? //TODO: Needs to be fixed for only one direction per player +#include "../lib/move.h" + #include "board.h" -#include "move.h" + + +namespace Morat { +namespace Hex { class LBDists { struct MoveDist { @@ -71,15 +76,16 @@ class LBDists { IntPQueue Q; const Board * board; - int & dist(int edge, int player, int i) { return dists[edge][player-1][i]; } - int & dist(int edge, int player, const Move & m) { return dist(edge, player, board->xy(m)); } - int & dist(int edge, int player, int x, int y) { return dist(edge, player, board->xy(x, y)); } + int & dist(int edge, Side player, int i) { return dists[edge][player.to_i() - 1][i]; } + int & dist(int edge, Side player, const Move & m) { return dist(edge, player, board->xy(m)); } + int & dist(int edge, Side player, int x, int y) { return dist(edge, player, board->xy(x, y)); } - void init(int x, int y, int edge, int player, int dir){ - int val = board->get(x, y); - if(val != 3 - player){ - Q.push(MoveDist(x, y, (val == 0), dir)); - dist(edge, player, x, y) = (val == 0); + void init(int x, int y, int edge, Side player, int dir){ + Side val = board->get(x, y); + if(val != ~player){ + bool empty = (val == Side::NONE); + Q.push(MoveDist(x, y, empty, dir)); + dist(edge, player, x, y) = empty; } } @@ -88,7 +94,7 @@ class LBDists { LBDists() : board(NULL) {} LBDists(const Board * b) { run(b); } - void run(const Board * b, bool crossvcs = true, int side = 0) { + void run(const Board * b, bool crossvcs = true, Side side = Side::BOTH) { board = b; for(int i = 0; i < 3; i++) @@ -96,22 +102,21 @@ class LBDists { for(int k = 0; k < board->vecsize(); k++) dists[i][j][k] = maxdist; //far far away! + if(side == Side::P1 || side == Side::BOTH) init_player(crossvcs, Side::P1); + if(side == Side::P2 || side == Side::BOTH) init_player(crossvcs, Side::P2); + } + + void init_player(bool crossvcs, Side player){ int m = board->get_size(); int m1 = m-1; - int start, end; - if(side){ start = end = side; } - else { start = 1; end = 2; } - - for(int player = start; player <= end; player++){ - for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 - for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 - for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 - } + for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 + for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 + for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 } - void flood(int edge, int player, bool crossvcs){ - int otherplayer = 3 - player; + void flood(int edge, Side player, bool crossvcs){ + Side otherplayer = ~player; MoveDist cur; while(Q.pop(cur)){ @@ -121,12 +126,12 @@ class LBDists { if(board->onboard(next.pos)){ int pos = board->xy(next.pos); - int colour = board->get(pos); + Side colour = board->get(pos); if(colour == otherplayer) continue; - if(colour == 0){ + if(colour == Side::NONE){ if(!crossvcs && //forms a vc board->get(cur.pos + neighbours[(nd - 1) % 6]) == otherplayer && board->get(cur.pos + neighbours[(nd + 1) % 6]) == otherplayer) @@ -145,12 +150,15 @@ class LBDists { } } - int get(Move pos){ return min(get(pos, 1), get(pos, 2)); } - int get(Move pos, int player){ return get(board->xy(pos), player); } - int get(int pos, int player){ + int get(Move pos){ return std::min(get(pos, Side::P1), get(pos, Side::P2)); } + int get(Move pos, Side player){ return get(board->xy(pos), player); } + int get(int pos, Side player){ int sum = 0; for(int i = 0; i < 3; i++) sum += dist(i, player, pos); return sum; } }; + +}; // namespace Hex +}; // namespace Morat diff --git a/hex/chex.cpp b/hex/main.cpp similarity index 96% rename from hex/chex.cpp rename to hex/main.cpp index 8a4dd6f..5edd407 100644 --- a/hex/chex.cpp +++ b/hex/main.cpp @@ -1,5 +1,4 @@ - #include #include @@ -7,6 +6,10 @@ #include "gtp.h" + +using namespace Morat; +using namespace Hex; + using namespace std; void die(int code, const string & str){ @@ -15,6 +18,7 @@ void die(int code, const string & str){ } int main(int argc, char **argv){ + srand(Time().in_usec()); GTP gtp; @@ -52,7 +56,6 @@ int main(int argc, char **argv){ } } - gtp.setinfile(stdin); gtp.setoutfile(stdout); gtp.run(); diff --git a/hex/movelist.h b/hex/movelist.h deleted file mode 100644 index 27c22de..0000000 --- a/hex/movelist.h +++ /dev/null @@ -1,76 +0,0 @@ - -#pragma once - -#include "../lib/exppair.h" - -#include "board.h" -#include "move.h" - -struct MoveList { - ExpPair exp[2]; //aggregated outcomes overall - ExpPair rave[2][Board::max_vecsize]; //aggregated outcomes per move - MovePlayer moves[Board::max_vecsize]; //moves made in order - int tree; //number of moves in the tree - int rollout; //number of moves in the rollout - Board * board; //reference to rootboard for xy() - - MoveList() : tree(0), rollout(0), board(NULL) { } - - void addtree(const Move & move, char player){ - moves[tree++] = MovePlayer(move, player); - } - void addrollout(const Move & move, char player){ - moves[tree + rollout++] = MovePlayer(move, player); - } - void reset(Board * b){ - tree = 0; - rollout = 0; - board = b; - exp[0].clear(); - exp[1].clear(); - for(int i = 0; i < b->vecsize(); i++){ - rave[0][i].clear(); - rave[1][i].clear(); - } - } - void finishrollout(int won){ - exp[0].addloss(); - exp[1].addloss(); - if(won == 0){ - exp[0].addtie(); - exp[1].addtie(); - }else{ - exp[won-1].addwin(); - - for(MovePlayer * i = begin(), * e = end(); i != e; i++){ - ExpPair & r = rave[i->player-1][board->xy(*i)]; - r.addloss(); - if(i->player == won) - r.addwin(); - } - } - rollout = 0; - } - const MovePlayer * begin() const { - return moves; - } - MovePlayer * begin() { - return moves; - } - const MovePlayer * end() const { - return moves + tree + rollout; - } - MovePlayer * end() { - return moves + tree + rollout; - } - void subvlosses(int n){ - exp[0].addlosses(-n); - exp[1].addlosses(-n); - } - const ExpPair & getrave(int player, const Move & move) const { - return rave[player-1][board->xy(move)]; - } - const ExpPair & getexp(int player) const { - return exp[player-1]; - } -}; diff --git a/hex/player.cpp b/hex/player.cpp deleted file mode 100644 index b517471..0000000 --- a/hex/player.cpp +++ /dev/null @@ -1,506 +0,0 @@ - -#include -#include - -#include "../lib/alarm.h" -#include "../lib/fileio.h" -#include "../lib/string.h" -#include "../lib/time.h" - -#include "board.h" -#include "player.h" - -const float Player::min_rave = 0.1; - -void Player::PlayerThread::run(){ - while(true){ - switch(player->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_Start, Thread_Running); - CAS(player->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(player->rootboard.won() >= 0 || player->root.outcome >= 0 || (player->maxruns > 0 && player->runs >= player->maxruns)){ //solved or finished runs - if(CAS(player->threadstate, Thread_Running, Thread_Wait_End) && player->root.outcome >= 0) - logerr("Solved as " + to_str((int)player->root.outcome) + "\n"); - break; - } - if(player->ctmem.memalloced() >= player->maxmem){ //out of memory, start garbage collection - CAS(player->threadstate, Thread_Running, Thread_GC); - break; - } - - INCR(player->runs); - iterate(); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(player->gcbarrier.wait()){ - Time starttime; - logerr("Starting player GC with limit " + to_str(player->gclimit) + " ... "); - uint64_t nodesbefore = player->nodes; - Board copy = player->rootboard; - player->garbage_collect(copy, & player->root); - Time gctime; - player->ctmem.compact(1.0, 0.75); - Time compacttime; - logerr(to_str(100.0*player->nodes/nodesbefore, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(player->ctmem.meminuse() >= player->maxmem/2) - player->gclimit = (int)(player->gclimit*1.3); - else if(player->gclimit > player->rollouts*5) - player->gclimit = (int)(player->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(player->threadstate, Thread_GC, Thread_Running); - CAS(player->threadstate, Thread_GC_End, Thread_Wait_End); - } - player->gcbarrier.wait(); - break; - } - } -} - -Player::Node * Player::genmove(double time, int max_runs, bool flexible){ - time_used = 0; - int toplay = rootboard.toplay(); - - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) - return NULL; - - Time starttime; - - stop_threads(); - - if(runs) - logerr("Pondered " + to_str(runs) + " runs\n"); - - runs = 0; - maxruns = max_runs; - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); - - // if the move is forced and the time can be added to the clock, don't bother running at all - if(!flexible || root.children.num() != 1){ - //let them run! - start_threads(); - - Alarm timer; - if(time > 0) - timer(time - (Time() - starttime), std::bind(&Player::timedout, this)); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } - - if(ponder && root.outcome < 0) - start_threads(); - - time_used = Time() - starttime; - -//return the best one - return return_move(& root, toplay); -} - - - -Player::Player() { - nodes = 0; - gclimit = 5; - time_used = 0; - - profile = false; - ponder = false; -//#ifdef SINGLE_THREAD ... make sure only 1 thread - numthreads = 1; - maxmem = 1000*1024*1024; - - msrave = -2; - msexplore = 0; - - explore = 0; - parentexplore = false; - ravefactor = 500; - decrrave = 0; - knowledge = true; - userave = 1; - useexplore = 1; - fpurgency = 1; - rollouts = 5; - dynwiden = 0; - logdynwiden = (dynwiden ? std::log(dynwiden) : 0); - - shortrave = false; - keeptree = true; - minimax = 2; - visitexpand = 1; - prunesymmetry = false; - gcsolved = 100000; - - localreply = 5; - locality = 5; - connect = 20; - size = 0; - bridge = 100; - dists = 0; - - weightedrandom = 0; - rolloutpattern = true; - lastgoodreply = false; - instantwin = 0; - - for(int i = 0; i < 4096; i++) - gammas[i] = 1; - - //no threads started until a board is set - threadstate = Thread_Wait_Start; -} -Player::~Player(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); -} -void Player::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); -} - -string Player::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void Player::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void Player::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void Player::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++){ - threads[i]->join(); - delete threads[i]; - } - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new PlayerUCT(this)); -} - -void Player::set_ponder(bool p){ - if(ponder != p){ - ponder = p; - stop_threads(); - - if(ponder) - start_threads(); - } -} - -void Player::set_board(const Board & board){ - stop_threads(); - - nodes -= root.dealloc(ctmem); - root = Node(); - root.exp.addwins(visitexpand+1); - - rootboard = board; - - reset_threads(); //needed since the threads aren't started before a board it set - - if(ponder) - start_threads(); -} -void Player::move(const Move & m){ - stop_threads(); - - uword nodesbefore = nodes; - - if(keeptree && root.children.num() > 0){ - Node child; - - for(Node * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr("Nodes before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - }else{ - nodes -= root.dealloc(ctmem); - root = Node(); - root.move = m; - } - assert(nodes == root.size()); - - rootboard.move(m); - - root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; - - if(ponder) - start_threads(); -} - -double Player::gamelen(){ - DepthStats len; - for(unsigned int i = 0; i < threads.size(); i++) - len += threads[i]->gamelen; - return len.avg(); -} - -vector Player::get_pv(){ - vector pv; - - Node * r, * n = & root; - char turn = rootboard.toplay(); - while(!n->children.empty()){ - r = return_move(n, turn); - if(!r) break; - pv.push_back(r->move); - turn = 3 - turn; - n = r; - } - - if(pv.size() == 0) - pv.push_back(Move(M_RESIGN)); - - return pv; -} - -Player::Node * Player::return_move(Node * node, int toplay) const { - double val, maxval = -1000000000000.0; //1 trillion - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss - }else{ //not proven - if(msrave == -1) //num simulations - val = child->exp.num(); - else if(msrave == -2) //num wins - val = child->exp.sum(); - else - val = child->value(msrave, 0, 0) - msexplore*sqrt(log(node->exp.num())/(child->exp.num() + 1)); - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - -//set bestmove, but don't touch outcome, if it's solved that will already be set, otherwise it shouldn't be set - if(ret){ - node->bestmove = ret->move; - }else if(node->bestmove == M_UNKNOWN){ - // TODO: Is this needed? -// SolverAB solver; -// solver.set_board(rootboard); -// solver.solve(0.1); -// node->bestmove = solver.bestmove; - } - - assert(node->bestmove != M_UNKNOWN); - - return ret; -} - -void Player::garbage_collect(Board & board, Node * node){ - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - for( ; child != end; child++){ - if(child->children.num() == 0) - continue; - - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved - board.set(child->move); - garbage_collect(board, child); - board.unset(child->move); - }else{ - nodes -= child->dealloc(ctmem); - } - } -} - -Player::Node * Player::find_child(Node * node, const Move & move){ - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - - return NULL; -} - -void Player::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; - } - } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); -} - -void Player::create_children_simple(const Board & board, Node * node){ - assert(node->children.empty()); - - node->children.alloc(board.movesremain(), ctmem); - - Node * child = node->children.begin(), - * end = node->children.end(); - Board::MoveIterator moveit = board.moveit(prunesymmetry); - int nummoves = 0; - for(; !moveit.done() && child != end; ++moveit, ++child){ - *child = Node(*moveit); - nummoves++; - } - - if(prunesymmetry) - node->children.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(moveit.done() && child == end); - - PLUS(nodes, node->children.num()); -} - -//reads the format from gen_hgf. -void Player::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); - - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_char(fd, ')'); - - return; -} diff --git a/hex/player.h b/hex/player.h deleted file mode 100644 index 9741a1a..0000000 --- a/hex/player.h +++ /dev/null @@ -1,304 +0,0 @@ - -#pragma once - -//A Monte-Carlo Tree Search based player - -#include -#include - -#include "../lib/compacttree.h" -#include "../lib/depthstats.h" -#include "../lib/exppair.h" -#include "../lib/log.h" -#include "../lib/thread.h" -#include "../lib/time.h" -#include "../lib/types.h" -#include "../lib/xorshift.h" - -#include "board.h" -#include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" - - -class Player { -public: - - struct Node { - public: - ExpPair rave; - ExpPair exp; - int16_t know; - int8_t outcome; - uint8_t proofdepth; - Move move; - Move bestmove; //if outcome is set, then bestmove is the way to get there - CompactTree::Children children; -// int padding; - //seems to need padding to multiples of 8 bytes or it segfaults? - //don't forget to update the copy constructor/operator - - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } - Node(const Node & n) { *this = n; } - Node & operator = (const Node & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - rave = n.rave; - exp = n.exp; - know = n.know; - move = n.move; - bestmove = n.bestmove; - outcome = n.outcome; - proofdepth = n.proofdepth; - //children = n.children; ignore the children, they need to be swap_tree'd in - } - return *this; - } - - void swap_tree(Node & n){ - children.swap(n.children); - } - - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str(outcome) + "/" + to_str(proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } - - unsigned int size() const { - unsigned int num = children.num(); - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - ~Node(){ - assert(children.empty()); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - - //new way, more standard way of changing over from rave scores to real scores - float value(float ravefactor, bool knowledge, float fpurgency){ - float val = fpurgency; - float expnum = exp.num(); - float ravenum = rave.num(); - - if(ravefactor <= min_rave){ - if(expnum > 0) - val = exp.avg(); - }else if(ravenum > 0 || expnum > 0){ - float alpha = ravefactor/(ravefactor + expnum); -// float alpha = sqrt(ravefactor/(ravefactor + 3.0f*expnum)); -// float alpha = ravenum/(expnum + ravenum + expnum*ravenum*ravefactor); - - val = 0; - if(ravenum > 0) val += alpha*rave.avg(); - if(expnum > 0) val += (1.0f-alpha)*exp.avg(); - } - - if(knowledge && know > 0){ - if(expnum <= 1) - val += 0.01f * know; - else if(expnum < 1000) //knowledge is only useful with little experience - val += 0.01f * know / sqrt(expnum); - } - - return val; - } - }; - - class PlayerThread { - protected: - public: - mutable XORShift_float unitrand; - Thread thread; - Player * player; - public: - DepthStats treelen, gamelen; - double times[4]; //time spent in each of the stages - - PlayerThread() {} - virtual ~PlayerThread() { } - virtual void reset() { } - int join(){ return thread.join(); } - void run(); //thread runner, calls iterate on each iteration - virtual void iterate() { } //handles each iteration - }; - - class PlayerUCT : public PlayerThread { - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; - - bool use_rave; //whether to use rave for this simulation - bool use_explore; //whether to use exploration for this simulation - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; - int stage; //which of the four MCTS stages is it on - Time timestamps[4]; //timestamps for the beginning, before child creation, before rollout, after rollout - - public: - PlayerUCT(Player * p) : PlayerThread() { - player = p; - reset(); - thread(bind(&PlayerUCT::run, this)); - } - - void reset(){ - treelen.reset(); - gamelen.reset(); - - use_rave = false; - use_explore = false; - - for(int a = 0; a < 4; a++) - times[a] = 0; - } - - private: - void iterate(); - void walk_tree(Board & board, Node * node, int depth); - bool create_children(Board & board, Node * node, int toplay); - void add_knowledge(Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); - bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - - int rollout(Board & board, Move move, int depth); - Move rollout_choose_move(Board & board, const Move & prev); - Move rollout_pattern(const Board & board, const Move & move); - }; - - -public: - - static const float min_rave; - - bool ponder; //think during opponents time? - int numthreads; //number of player threads to run - u64 maxmem; //maximum memory for the tree in bytes - bool profile; //count how long is spent in each stage of MCTS -//final move selection - float msrave; //rave factor in final move selection, -1 means use number instead of value - float msexplore; //the UCT constant in final move selection -//tree traversal - bool parentexplore; // whether to multiple exploration by the parents winrate - float explore; //greater than one favours exploration, smaller than one favours exploitation - float ravefactor; //big numbers favour rave scores, small ignore it - float decrrave; //decrease rave over time, add this value for each empty position on the board - bool knowledge; //whether to include knowledge - float userave; //what probability to use rave - float useexplore; //what probability to use UCT exploration - float fpurgency; //what value to return for a move that hasn't been played yet - int rollouts; //number of rollouts to run after the tree traversal - float dynwiden; //dynamic widening, look at first log_dynwiden(experience) number of children, 0 to disable - float logdynwiden; // = log(dynwiden), cached for performance -//tree building - bool shortrave; //only update rave values on short rollouts - bool keeptree; //reuse the tree from the previous move - int minimax; //solve the minimax tree within the uct tree - uint visitexpand;//number of visits before expanding a node - bool prunesymmetry; //prune symmetric children from the move list, useful for proving but likely not for playing - uint gcsolved; //garbage collect solved nodes or keep them in the tree, assuming they meet the required amount of work -//knowledge - int localreply; //boost for a local reply, ie a move near the previous move - int locality; //boost for playing near previous stones - int connect; //boost for having connections to edges and corners - int size; //boost for large groups - int bridge; //boost replying to a probe at a bridge - int dists; //boost based on minimum number of stones needed to finish a non-ring win -//rollout - int weightedrandom; //use weighted random for move ordering based on gammas - bool rolloutpattern; //play the response to a virtual connection threat in rollouts - int lastgoodreply; //use the last-good-reply rollout heuristic - int instantwin; //how deep to look for instant wins in rollouts - - float gammas[4096]; //pattern weights for weighted random - - Board rootboard; - Node root; - uword nodes; - int gclimit; //the minimum experience needed to not be garbage collected - - uint64_t runs, maxruns; - - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - double time_used; - - Player(); - ~Player(); - - void timedout(); - - string statestring(); - - void stop_threads(); - void start_threads(); - void reset_threads(); - - void set_ponder(bool p); - void set_board(const Board & board); - - void move(const Move & m); - - double gamelen(); - - Node * genmove(double time, int max_runs, bool flexible); - vector get_pv(); - void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - - bool do_backup(Node * node, Node * backup, int toplay); - - Node * find_child(Node * node, const Move & move); - void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); - -protected: - Node * return_move(Node * node, int toplay) const; -}; diff --git a/hex/playeruct.cpp b/hex/playeruct.cpp deleted file mode 100644 index 55bc5e2..0000000 --- a/hex/playeruct.cpp +++ /dev/null @@ -1,449 +0,0 @@ - -#include -#include - -#include "../lib/string.h" - -#include "player.h" - -void Player::PlayerUCT::iterate(){ - if(player->profile){ - timestamps[0] = Time(); - stage = 0; - } - - movelist.reset(&(player->rootboard)); - player->root.exp.addvloss(); - Board copy = player->rootboard; - use_rave = (unitrand() < player->userave); - use_explore = (unitrand() < player->useexplore); - walk_tree(copy, & player->root, 0); - player->root.exp.addv(movelist.getexp(3-player->rootboard.toplay())); - - if(player->profile){ - times[0] += timestamps[1] - timestamps[0]; - times[1] += timestamps[2] - timestamps[1]; - times[2] += timestamps[3] - timestamps[2]; - times[3] += Time() - timestamps[3]; - } -} - -void Player::PlayerUCT::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); - - if(!node->children.empty() && node->outcome < 0){ - //choose a child and recurse - Node * child; - do{ - int remain = board.movesremain(); - child = choose_move(node, toplay, remain); - - if(child->outcome < 0){ - movelist.addtree(child->move, toplay); - - if(!board.move(child->move)){ - logerr("move failed: " + child->move.to_s() + "\n" + board.to_s(false)); - assert(false && "move failed"); - } - - child->exp.addvloss(); //balanced out after rollouts - - walk_tree(board, child, depth+1); - - child->exp.addv(movelist.getexp(toplay)); - - if(!player->do_backup(node, child, toplay) && //not solved - player->ravefactor > min_rave && //using rave - node->children.num() > 1 && //not a macro move - 50*remain*(player->ravefactor + player->decrrave*remain) > node->exp.num()) //rave is still significant - update_rave(node, toplay); - - return; - } - }while(!player->do_backup(node, child, toplay)); - - return; - } - - if(player->profile && stage == 0){ - stage = 1; - timestamps[1] = Time(); - } - - int won = (player->minimax ? node->outcome : board.won()); - - //if it's not already decided - if(won < 0){ - //create children if valid - if(node->exp.num() >= player->visitexpand+1 && create_children(board, node, toplay)){ - walk_tree(board, node, depth); - return; - } - - if(player->profile){ - stage = 2; - timestamps[2] = Time(); - } - - //do random game on this node - random_policy.prepare(board); - for(int i = 0; i < player->rollouts; i++){ - Board copy = board; - rollout(copy, node->move, depth); - } - }else{ - movelist.finishrollout(won); //got to a terminal state, it's worth recording - } - - treelen.add(depth); - - movelist.subvlosses(1); - - if(player->profile){ - timestamps[3] = Time(); - if(stage == 1) - timestamps[2] = timestamps[3]; - stage = 3; - } - - return; -} - -bool sort_node_know(const Player::Node & a, const Player::Node & b){ - return (a.know > b.know); -} - -bool Player::PlayerUCT::create_children(Board & board, Node * node, int toplay){ - if(!node->children.lock()) - return false; - - if(player->dists){ - dists.run(&board, (player->dists > 0), toplay); - } - - CompactTree::Children temp; - temp.alloc(board.movesremain(), player->ctmem); - - int losses = 0; - - Node * child = temp.begin(), - * end = temp.end(), - * loss = NULL; - Board::MoveIterator move = board.moveit(player->prunesymmetry); - int nummoves = 0; - for(; !move.done() && child != end; ++move, ++child){ - *child = Node(*move); - - if(player->minimax){ - child->outcome = board.test_win(*move); - - if(player->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ - losses++; - loss = child; - } - - if(child->outcome == toplay){ //proven win from here, don't need children - node->outcome = child->outcome; - node->proofdepth = 1; - node->bestmove = *move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - } - - if(player->knowledge) - add_knowledge(board, node, child); - nummoves++; - } - - if(player->prunesymmetry) - temp.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(move.done() && child == end); - - //Make a macro move, add experience to the move so the current simulation continues past this move - if(losses == 1){ - Node macro = *loss; - temp.dealloc(player->ctmem); - temp.alloc(1, player->ctmem); - macro.exp.addwins(player->visitexpand); - *(temp.begin()) = macro; - }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - toplay; - node->proofdepth = 2; - node->bestmove = loss->move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - - if(player->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); - - PLUS(player->nodes, temp.num()); - node->children.swap(temp); - assert(temp.unlock()); - - return true; -} - -Player::Node * Player::PlayerUCT::choose_move(const Node * node, int toplay, int remain) const { - float val, maxval = -1000000000; - float logvisits = log(node->exp.num()); - int dynwidenlim = (player->dynwiden > 0 ? (int)(logvisits/player->logdynwiden)+2 : 361); - - float raveval = use_rave * (player->ravefactor + player->decrrave*remain); - float explore = use_explore * player->explore; - if(player->parentexplore) - explore *= node->exp.avg(); - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) //return a win immediately - return child; - - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse - }else{ - val = child->value(raveval, player->knowledge, player->fpurgency); - if(explore > 0) - val += explore*sqrt(logvisits/(child->exp.num() + 1)); - dynwidenlim--; - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - - return ret; -} - -/* -backup in this order: - -6 win -5 win/draw -4 draw if draw/loss -3 win/draw/loss -2 draw -1 draw/loss -0 lose -return true if fully solved, false if it's unknown or partially unknown -*/ -bool Player::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread - return true; - - if(backup->outcome == -3) //nothing proven by this child, so no chance - return false; - - - uint8_t proofdepth = backup->proofdepth; - if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; - backup = NULL; - - Node * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions - - if(proofdepth < child->proofdepth+1) - proofdepth = child->proofdepth+1; - - //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss - outcome = 3; - }else if(childoutcome == toplay){ //win - backup = child; - outcome = 6; - proofdepth = child->proofdepth+1; - break; - }else if(childoutcome == 3-toplay){ //loss - outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss - outcome = 4; - else - outcome = 2; - }else if(childoutcome == -toplay){ //win/draw - outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss - outcome = 1; - }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); - assert(false && "How'd I get here? All outcomes should be tested above"); - } - - sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; - bestsims = sims; - backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ - //find long losses or easy wins/draws - bestsims = sims; - backup = child; - } - } - - if(bestoutcome == 3) //no win, but found an unknown - return false; - } - - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ - node->bestmove = backup->move; - node->proofdepth = proofdepth; - }else //if it was in a race, try again, might promote a partial solve to full solve - return do_backup(node, backup, toplay); - - return (node->outcome >= 0); -} - -//update the rave score of all children that were played -void Player::PlayerUCT::update_rave(const Node * node, int toplay){ - Node * child = node->children.begin(), - * childend = node->children.end(); - - for( ; child != childend; ++child) - child->rave.addv(movelist.getrave(toplay, child->move)); -} - -void Player::PlayerUCT::add_knowledge(Board & board, Node * node, Node * child){ - if(player->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); - if(dist < 4) - child->know += player->localreply * (4 - dist); - } - - if(player->locality) //boost for moves near previous stones - child->know += player->locality * board.local(child->move, board.toplay()); - - Board::Cell cell; - if(player->connect || player->size) - cell = board.test_cell(child->move); - - if(player->connect) //boost for moves that connect to edges - child->know += player->connect * cell.numedges(); - - if(player->size) //boost for size of the group - child->know += player->size * cell.size; - - if(player->bridge && test_bridge_probe(board, node->move, child->move)) //boost for maintaining a virtual connection - child->know += player->bridge; - - if(player->dists) - child->know += abs(player->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); -} - -//test whether this move is a forced reply to the opponent probing your virtual connections -bool Player::PlayerUCT::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { - //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) - return false; - - bool equals = false; - - int state = 0; - int piece = 3 - board.get(move); - for(int i = 0; i < 8; i++){ - Move cur = move + neighbours[i % 6]; - - bool on = board.onboard(cur); - int v = 0; - if(on) - v = board.get(cur); - - //state machine that progresses when it see the pattern, but counting borders as part of the pattern - if(state == 0){ - if(!on || v == piece) - state = 1; - //else state = 0; - }else if(state == 1){ - if(on){ - if(v == 0){ - state = 2; - equals = (test == cur); - }else if(v != piece) - state = 0; - //else (v==piece) => state = 1; - } - //else state = 1; - }else{ // state == 2 - if(!on || v == piece){ - if(equals) - return true; - state = 1; - }else{ - state = 0; - } - } - } - return false; -} - -/////////////////////////////////////////// - - -//play a random game starting from a board state, and return the results of who won -int Player::PlayerUCT::rollout(Board & board, Move move, int depth){ - int won; - - if(player->instantwin) - instant_wins.rollout_start(board, player->instantwin); - - random_policy.rollout_start(board); - - while((won = board.won()) < 0){ - int turn = board.toplay(); - - move = rollout_choose_move(board, move); - - movelist.addrollout(move, turn); - - assert2(board.move(move), "\n" + board.to_s(true) + "\n" + move.to_s()); - depth++; - } - - gamelen.add(depth); - - //update the last good reply table - if(player->lastgoodreply) - last_good_reply.rollout_end(board, movelist, won); - - movelist.finishrollout(won); - return won; -} - -Move Player::PlayerUCT::rollout_choose_move(Board & board, const Move & prev){ - //look for instant wins - if(player->instantwin){ - Move move = instant_wins.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //force a bridge reply - if(player->rolloutpattern){ - Move move = protect_bridge.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //reuse the last good reply - if(player->lastgoodreply){ - Move move = last_good_reply.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - return random_policy.choose_move(board, prev); -} diff --git a/hex/policy.h b/hex/policy.h deleted file mode 100644 index 01309d8..0000000 --- a/hex/policy.h +++ /dev/null @@ -1,28 +0,0 @@ - -#pragma once - -#include "board.h" -#include "move.h" -#include "movelist.h" - -class Policy { -public: - Policy() { } - - // called before all the rollouts start - void prepare(const Board & board) { } - - // called at the beginning of each rollout. - void rollout_start(Board & board) { } - - // Give me a move to make, or M_UNKNOWN - Move choose_move(const Board & board, const Move & prev) { - return M_UNKNOWN; - } - - // A move was just made, here's the updated board - void move_end(const Board & board, const Move & prev) { } - - // Game over, here's who won - void rollout_end(const MoveList & movelist, int won) { } -}; diff --git a/hex/policy_bridge.h b/hex/policy_bridge.h deleted file mode 100644 index c6f2b8d..0000000 --- a/hex/policy_bridge.h +++ /dev/null @@ -1,51 +0,0 @@ - - -#pragma once - -#include "../lib/bits.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class ProtectBridge : public Policy { - int offset; - uint8_t lookup[2][1<<12]; - -public: - - ProtectBridge() : offset(0) { - // precompute the valid moves around a pattern for all possible 6-patterns. - for(unsigned int i = 0; i < 1<<12; i++){ - lookup[0][i] = lookup[1][i] = 0; - unsigned int p = i; - for(unsigned int d = 0; d < 6; d++){ - if((p & 0x1D) == 0x11) // 01 11 01 -> 01 00 01 - lookup[0][i] |= (1 << ((d+1)%6)); // +1 because we want to play in the empty spot - if((p & 0x2E) == 0x22) // 10 11 10 -> 10 00 10 - lookup[1][i] |= (1 << ((d+1)%6)); - p = ((p & 0xFFC)>>2) | ((p & 0x3) << 10); - } - } - } - - Move choose_move(const Board & board, const Move & prev) { - uint32_t p = board.pattern_small(prev); - uint16_t r = lookup[board.toplay()-1][p]; - - if(!r) // nothing to save - return M_UNKNOWN; - - unsigned int i; - if((r & (r - 1)) == 0){ // only one bit set - i = trailing_zeros(r); - } else { // multiple choices of bridges to save - offset = (offset + 1) % 6; // rotate the starting offset to avoid directional bias - r |= (r << 6); - r >>= offset; - i = (offset + trailing_zeros(r)) % 6; - } - return board.nb_begin(prev)[i]; - } -}; diff --git a/hex/policy_instantwin.h b/hex/policy_instantwin.h deleted file mode 100644 index c3c1dfa..0000000 --- a/hex/policy_instantwin.h +++ /dev/null @@ -1,95 +0,0 @@ - -#pragma once - -#include "../lib/assert2.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class InstantWin : public Policy { - int max_rollout_moves; - int cur_rollout_moves; - - Move saved_loss; -public: - - InstantWin() : max_rollout_moves(10), cur_rollout_moves(0), saved_loss(M_UNKNOWN) { - } - - void rollout_start(Board & board, int max) { - if(max < 0) - max *= - board.get_size(); - max_rollout_moves = max; - - cur_rollout_moves = 0; - saved_loss = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) { - if(saved_loss != M_UNKNOWN) - return saved_loss; - - if(cur_rollout_moves++ >= max_rollout_moves) - return M_UNKNOWN; - - //must have an edge connection, or it has nothing to offer a group towards a win - const Board::Cell * c = board.cell(prev); - if(c->edge == 0) - return M_UNKNOWN; - - Move start, cur, loss = M_UNKNOWN; - int turn = 3 - board.toplay(); - - //find the first empty cell - int dir = -1; - for(int i = 0; i <= 5; i++){ - start = prev + neighbours[i]; - - if(!board.onboard(start) || board.get(start) != turn){ - dir = (i + 5) % 6; - break; - } - } - - if(dir == -1) //possible if it's in the middle of a ring - return M_UNKNOWN; - - cur = start; - -// logerr(board.to_s(true)); -// logerr(prev.to_s() + ":"); - - //follow contour of the current group looking for wins - do{ -// logerr(" " + cur.to_s()); - //check the current cell - if(board.onboard(cur) && board.get(cur) == 0 && board.test_win(cur, turn) > 0){ -// logerr(" loss"); - if(loss == M_UNKNOWN){ - loss = cur; - }else if(loss != cur){ - saved_loss = loss; - return cur; //game over, two wins found for opponent - } - } - - //advance to the next cell - for(int i = 5; i <= 9; i++){ - int nd = (dir + i) % 6; - Move next = cur + neighbours[nd]; - - if(!board.onboard(next) || board.get(next) != turn){ - cur = next; - dir = nd; - break; - } - } - }while(cur != start); //potentially skips part of it when the start is in a pocket, rare bug - -// logerr("\n"); - - return loss; // usually M_UNKNOWN - } -}; diff --git a/hex/policy_lastgoodreply.h b/hex/policy_lastgoodreply.h deleted file mode 100644 index 11fcc9a..0000000 --- a/hex/policy_lastgoodreply.h +++ /dev/null @@ -1,42 +0,0 @@ - -# pragma once - -#include "board.h" -#include "move.h" -#include "policy.h" - -class LastGoodReply : public Policy { - Move goodreply[2][Board::max_vecsize]; - int enabled; -public: - - LastGoodReply(int _enabled = 2) : enabled(_enabled) { - for(int p = 0; p < 2; p++) - for(int i = 0; i < Board::max_vecsize; i++) - goodreply[p][i] = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) const { - if (enabled && prev != M_SWAP) { - Move move = goodreply[board.toplay()-1][board.xy(prev)]; - if(move != M_UNKNOWN && board.valid_move_fast(move)) - return move; - } - return M_UNKNOWN; - } - - void rollout_end(const Board & board, const MoveList & movelist, int won) { - if(!enabled) - return; - int m = -1; - for(const MovePlayer * i = movelist.begin(), * e = movelist.end(); i != e; i++){ - if(m >= 0){ - if(i->player == won && *i != M_SWAP) - goodreply[i->player - 1][m] = *i; - else if(enabled == 2) - goodreply[i->player - 1][m] = M_UNKNOWN; - } - m = board.xy(*i); - } - } -}; diff --git a/hex/policy_random.h b/hex/policy_random.h deleted file mode 100644 index d84a82a..0000000 --- a/hex/policy_random.h +++ /dev/null @@ -1,45 +0,0 @@ - -#pragma once - -#include - -#include "../lib/xorshift.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - -class RandomPolicy : public Policy { - XORShift_uint32 rand; - Move moves[Board::max_vecsize]; - int num; - int cur; -public: - - RandomPolicy() : num(0), cur(0) { - } - - // only need to save the valid moves once since all the rollouts start from the same position - void prepare(const Board & board) { - num = 0; - for(Board::MoveIterator m = board.moveit(false); !m.done(); ++m) - moves[num++] = *m; - } - - // reset the set of moves to make from above. Since they're used in random order they don't need to be in iterator order - void rollout_start(Board & board) { - cur = num; - } - - Move choose_move(const Board & board, const Move & prev) { - while(true){ - int r = rand() % cur; - cur--; - Move m = moves[r]; - moves[r] = moves[cur]; - moves[cur] = m; - if(board.valid_move_fast(m)) - return m; - } - } -}; diff --git a/hex/solver.h b/hex/solver.h deleted file mode 100644 index d6e6240..0000000 --- a/hex/solver.h +++ /dev/null @@ -1,68 +0,0 @@ - -#pragma once - -//Interface for the various solvers - -#include "../lib/types.h" - -#include "board.h" - -class Solver { -public: - int outcome; // 0 = tie, 1 = white, 2 = black, -1 = white or tie, -2 = black or tie, anything else unknown - int maxdepth; - uint64_t nodes_seen; - double time_used; - Move bestmove; - - Solver() : outcome(-3), maxdepth(0), nodes_seen(0), time_used(0) { } - virtual ~Solver() { } - - virtual void solve(double time) { } - virtual void set_board(const Board & board, bool clear = true) { } - virtual void move(const Move & m) { } - virtual void set_memlimit(uint64_t lim) { } // in bytes - virtual void clear_mem() { } - -protected: - volatile bool timeout; - void timedout(){ timeout = true; } - Board rootboard; - - static int solve1ply(const Board & board, int & nodes) { - int outcome = -3; - int turn = board.toplay(); - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - } - return outcome; - } - - static int solve2ply(const Board & board, int & nodes) { - int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - - if(board.test_win(*move, opponent) > 0) - losses++; - } - if(losses >= 2) - return opponent; - return outcome; - } - -}; diff --git a/hex/solverab.cpp b/hex/solverab.cpp deleted file mode 100644 index 1abdf47..0000000 --- a/hex/solverab.cpp +++ /dev/null @@ -1,137 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverab.h" - -void SolverAB::solve(double time){ - reset(); - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - if(TT == NULL && maxnodes) - TT = new ABTTNode[maxnodes]; - - Alarm timer(time, std::bind(&SolverAB::timedout, this)); - Time start; - - int turn = rootboard.toplay(); - - for(maxdepth = startdepth; !timeout; maxdepth++){ -// logerr("Starting depth " + to_str(maxdepth) + "\n"); - - //the first depth of negamax - int ret, alpha = -2, beta = 2; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - nodes_seen++; - - Board next = rootboard; - next.move(*move); - - int value = -negamax(next, maxdepth - 1, -beta, -alpha); - - if(value > alpha){ - alpha = value; - bestmove = *move; - } - - if(alpha >= beta){ - ret = beta; - break; - } - } - ret = alpha; - - - if(ret){ - if( ret == -2){ outcome = (turn == 1 ? 2 : 1); bestmove = Move(M_NONE); } - else if(ret == 2){ outcome = turn; } - else /*-1 || 1*/ { outcome = 0; } - - break; - } - } - - time_used = Time() - start; -} - - -int SolverAB::negamax(const Board & board, const int depth, int alpha, int beta){ - if(board.won() >= 0) - return (board.won() ? -2 : -1); - - if(depth <= 0 || timeout) - return 0; - - int b = beta; - int first = true; - int value, losses = 0; - static const int lookup[6] = {0, 0, 0, 1, 2, 2}; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - nodes_seen++; - - hash_t hash = board.test_hash(*move); - if(int ttval = tt_get(hash)){ - value = ttval; - }else if(depth <= 2){ - value = lookup[board.test_win(*move)+3]; - - if(board.test_win(*move, 3 - board.toplay()) > 0) - losses++; - }else{ - Board next = board; - next.move(*move); - - value = -negamax(next, depth - 1, -b, -alpha); - - if(scout && value > alpha && value < beta && !first) // re-search - value = -negamax(next, depth - 1, -beta, -alpha); - } - tt_set(hash, value); - - if(value > alpha) - alpha = value; - - if(alpha >= beta) - return beta; - - if(scout){ - b = alpha + 1; // set up null window - first = false; - } - } - - if(losses >= 2) - return -2; - - return alpha; -} - -int SolverAB::negamax_outcome(const Board & board, const int depth){ - int abval = negamax(board, depth, -2, 2); - if( abval == 0) return -3; //unknown - else if(abval == 2) return board.toplay(); //win - else if(abval == -2) return 3 - board.toplay(); //loss - else return 0; //draw -} - -int SolverAB::tt_get(const Board & board){ - return tt_get(board.gethash()); -} -int SolverAB::tt_get(const hash_t & hash){ - if(!TT) return 0; - ABTTNode * node = & TT[hash % maxnodes]; - return (node->hash == hash ? node->value : 0); -} -void SolverAB::tt_set(const Board & board, int value){ - tt_set(board.gethash(), value); -} -void SolverAB::tt_set(const hash_t & hash, int value){ - if(!TT || value == 0) return; - ABTTNode * node = & TT[hash % maxnodes]; - node->hash = hash; - node->value = value; -} diff --git a/hex/solverab.h b/hex/solverab.h deleted file mode 100644 index 35ca7b9..0000000 --- a/hex/solverab.h +++ /dev/null @@ -1,72 +0,0 @@ - -#pragma once - -//An Alpha-beta solver, single threaded with an optional transposition table. - -#include "solver.h" - -class SolverAB : public Solver { - struct ABTTNode { - hash_t hash; - char value; - ABTTNode(hash_t h = 0, char v = 0) : hash(h), value(v) { } - }; - -public: - bool scout; - int startdepth; - - ABTTNode * TT; - uint64_t maxnodes, memlimit; - - SolverAB(bool Scout = false) { - scout = Scout; - startdepth = 2; - TT = NULL; - set_memlimit(100*1024*1024); - } - ~SolverAB() { } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(ABTTNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void solve(double time); - -//return -2 for loss, -1,1 for tie, 0 for unknown, 2 for win, all from toplay's perspective - int negamax(const Board & board, const int depth, int alpha, int beta); - int negamax_outcome(const Board & board, const int depth); - - int tt_get(const hash_t & hash); - int tt_get(const Board & board); - void tt_set(const hash_t & hash, int val); - void tt_set(const Board & board, int val); -}; - diff --git a/hex/solverpns.cpp b/hex/solverpns.cpp deleted file mode 100644 index 7f11a1a..0000000 --- a/hex/solverpns.cpp +++ /dev/null @@ -1,213 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns.h" - -void SolverPNS::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(memlimit/sizeof(PNSNode)) + ", max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS::run_pns(){ - while(!timeout && root.phi != 0 && root.delta != 0){ - if(!pns(rootboard, &root, 0, INF32/2, INF32/2)){ - logerr("Starting solver GC with limit " + to_str(gclimit) + " ... "); - - Time starttime; - garbage_collect(& root); - - Time gctime; - ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*ctmem.meminuse()/memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(ctmem.meminuse() >= memlimit/2) - gclimit = (unsigned int)(gclimit*1.3); - else if(gclimit > 5) - gclimit = (unsigned int)(gclimit*0.9); //slowly decay to a minimum of 5 - } - } -} - -bool SolverPNS::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(maxdepth < depth) - maxdepth = depth; - - if(node->children.empty()){ - if(ctmem.memalloced() >= memlimit) - return false; - - int numnodes = board.movesremain(); - nodes += node->alloc(numnodes, ctmem); - - if(lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(lbdist && outcome < 0) - pd = dists.get(*move); - - node->children[i] = PNSNode(*move).outcome(outcome, board.toplay(), ties, pd); - - i++; - } - node->children.shrink(i); //if symmetry, there may be extra moves to ignore - - nodes_seen += i; - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->delta <= child->delta){ - child2 = child; - child = i; - }else if(i->delta < child2->delta){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - while(child->delta != node->phi) - child++; - } - - Board next = board; - next.move(child->move); - - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->work += iters - itersbefore; - - if(child->phi == 0 || child->delta == 0) //clear child's children - nodes -= child->dealloc(ctmem); - - if(updatePDnum(node) && !df) - break; - - }while(!timeout && mem && (!df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - nodes -= child->dealloc(ctmem); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - nodes -= child->dealloc(ctmem); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/hex/solverpns.h b/hex/solverpns.h deleted file mode 100644 index b040d82..0000000 --- a/hex/solverpns.h +++ /dev/null @@ -1,206 +0,0 @@ - -#pragma once - -//A single-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - uint32_t phi, delta; - uint64_t work; - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - uint64_t iters; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - - PNSNode root; - LBDists dists; - - SolverPNS() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - gclimit = 5; - iters = 0; - - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNS(){ - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - -//update the phi and delta for the node - bool updatePDnum(PNSNode * node); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/hex/solverpns2.cpp b/hex/solverpns2.cpp deleted file mode 100644 index 4995fc5..0000000 --- a/hex/solverpns2.cpp +++ /dev/null @@ -1,323 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns2.h" - -void SolverPNS2::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - start_threads(); - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS2::timedout, this)); - Time start; - -// logerr("max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS2::SolverThread::run(){ - while(true){ - switch(solver->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_Start, Thread_Running); - CAS(solver->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(solver->root.terminal()){ //solved - CAS(solver->threadstate, Thread_Running, Thread_Wait_End); - break; - } - if(solver->ctmem.memalloced() >= solver->memlimit){ //out of memory, start garbage collection - CAS(solver->threadstate, Thread_Running, Thread_GC); - break; - } - - pns(solver->rootboard, &solver->root, 0, INF32/2, INF32/2); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(solver->gcbarrier.wait()){ - logerr("Starting solver GC with limit " + to_str(solver->gclimit) + " ... "); - - Time starttime; - solver->garbage_collect(& solver->root); - - Time gctime; - solver->ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*solver->ctmem.meminuse()/solver->memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(solver->ctmem.meminuse() >= solver->memlimit/2) - solver->gclimit = (unsigned int)(solver->gclimit*1.3); - else if(solver->gclimit > 5) - solver->gclimit = (unsigned int)(solver->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(solver->threadstate, Thread_GC, Thread_Running); - CAS(solver->threadstate, Thread_GC_End, Thread_Wait_End); - } - solver->gcbarrier.wait(); - break; - } - } -} - -void SolverPNS2::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); - timeout = true; -} - -string SolverPNS2::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void SolverPNS2::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void SolverPNS2::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void SolverPNS2::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new SolverThread(this)); -} - - -bool SolverPNS2::SolverThread::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(solver->maxdepth < depth) - solver->maxdepth = depth; - - if(node->children.empty()){ - if(node->terminal()) - return true; - - if(solver->ctmem.memalloced() >= solver->memlimit) - return false; - - if(!node->children.lock()) - return false; - - int numnodes = board.movesremain(); - CompactTree::Children temp; - temp.alloc(numnodes, solver->ctmem); - PLUS(solver->nodes, numnodes); - - if(solver->lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(solver->ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (solver->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - PLUS(solver->nodes_seen, pd); - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(solver->lbdist && outcome < 0) - pd = dists.get(*move); - - temp[i] = PNSNode(*move).outcome(outcome, board.toplay(), solver->ties, pd); - - i++; - } - temp.shrink(i); //if symmetry, there may be extra moves to ignore - node->children.swap(temp); - assert(temp.unlock()); - - PLUS(solver->nodes_seen, i); - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(solver->df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->refdelta() <= child->refdelta()){ - child2 = child; - child = i; - }else if(i->refdelta() < child2->refdelta()){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + solver->epsilon) + 1)); - }else{ - tpc = tdc = 0; - for(PNSNode * i = node->children.begin(); i != childend; i++) - if(child->refdelta() > i->refdelta()) - child = i; - } - - Board next = board; - next.move(child->move); - - child->ref(); - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->deref(); - PLUS(child->work, iters - itersbefore); - - if(updatePDnum(node) && !solver->df) - break; - - }while(!solver->timeout && mem && (!solver->df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS2::SolverThread::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS2::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/hex/solverpns2.h b/hex/solverpns2.h deleted file mode 100644 index 5af5d1d..0000000 --- a/hex/solverpns2.h +++ /dev/null @@ -1,265 +0,0 @@ - -#pragma once - -//A multi-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS2 : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - static const uint16_t reflock = 1<<15; - uint32_t phi, delta; - uint64_t work; - uint16_t refcount; //how many threads are down this node - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - uint32_t refdelta() const { - return delta + refcount; - } - - void ref() { PLUS(refcount, 1); } - void deref(){ PLUS(refcount, -1); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - class SolverThread { - protected: - public: - Thread thread; - SolverPNS2 * solver; - public: - uint64_t iters; - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - - SolverThread(SolverPNS2 * s) : solver(s), iters(0) { - thread(bind(&SolverThread::run, this)); - } - virtual ~SolverThread() { } - void reset(){ - iters = 0; - } - int join(){ return thread.join(); } - void run(); //thread runner - - //basic proof number search building a tree - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - //update the phi and delta for the node - bool updatePDnum(PNSNode * node); - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - int numthreads; - - PNSNode root; - LBDists dists; - - SolverPNS2() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - numthreads = 1; - gclimit = 5; - - reset(); - - set_memlimit(100*1024*1024); - - //no threads started until a board is set - threadstate = Thread_Wait_Start; - } - - ~SolverPNS2(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - string statestring(); - void stop_threads(); - void start_threads(); - void reset_threads(); - void timedout(); - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - - reset_threads(); //needed since the threads aren't started before a board it set - } - void move(const Move & m){ - stop_threads(); - - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/hex/solverpns_tt.cpp b/hex/solverpns_tt.cpp deleted file mode 100644 index 0818e8c..0000000 --- a/hex/solverpns_tt.cpp +++ /dev/null @@ -1,282 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns_tt.h" - -void SolverPNSTT::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNSTT::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(maxnodes) + ", max memory: " + to_str(memlimit) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == 0){ - bestmove = *move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == DRAW){ - bestmove = *move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNSTT::run_pns(){ - if(TT == NULL) - TT = new PNSNode[maxnodes]; - - while(!timeout && root.phi != 0 && root.delta != 0) - pns(rootboard, &root, 0, INF32/2, INF32/2); -} - -void SolverPNSTT::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - if(depth > maxdepth) - maxdepth = depth; - - do{ - PNSNode * child = NULL, - * child2 = NULL; - - Move move1, move2; - - uint32_t tpc, tdc; - - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - if(child == NULL){ - child = child2 = i; - move1 = move2 = *move; - }else if(i->delta <= child->delta){ - child2 = child; - child = i; - move2 = move1; - move1 = *move; - }else if(i->delta < child2->delta){ - child2 = i; - move2 = *move; - } - } - - if(child->delta && child->phi){ //unsolved - if(df){ - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - } - - Board next = board; - next.move(move1); - pns(next, child, depth + 1, tpc, tdc); - - //just found a loss, try to copy proof to siblings - if(copyproof && child->delta == LOSS){ -// logerr("!" + move1.to_s() + " "); - int count = abs(copyproof); - for(Board::MoveIterator move = board.moveit(true); count-- && !move.done(); ++move){ - if(!tt(board, *move)->terminal()){ -// logerr("?" + move->to_s() + " "); - Board sibling = board; - sibling.move(*move); - copy_proof(next, sibling, move1, *move); - updatePDnum(sibling); - - if(copyproof < 0 && !tt(sibling)->terminal()) - break; - } - } - } - } - - if(updatePDnum(board, node) && !df) //must pass node to updatePDnum since it may refer to the root which isn't in the TT - break; - - }while(!timeout && node->phi && node->delta && (!df || (node->phi < tp && node->delta < td))); -} - -bool SolverPNSTT::updatePDnum(const Board & board, PNSNode * node){ - hash_t hash = board.gethash(); - - if(node == NULL) - node = TT + (hash % maxnodes); - - uint32_t min = LOSS; - uint64_t sum = 0; - - bool win = false; - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(hash == node->hash && min == node->phi && sum == node->delta){ - return false; - }else{ - node->hash = hash; //just in case it was overwritten by something else - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//source is a move that is a proven loss, and dest is an unproven sibling -//each has one move that the other doesn't, which are stored in smove and dmove -//if either move is used but only available in one board, the other is substituted -void SolverPNSTT::copy_proof(const Board & source, const Board & dest, Move smove, Move dmove){ - if(timeout || tt(source)->delta != LOSS || tt(dest)->terminal()) - return; - - //find winning move from the source tree - Move bestmove = M_UNKNOWN; - for(Board::MoveIterator move = source.moveit(true); !move.done(); ++move){ - if(tt(source, *move)->phi == LOSS){ - bestmove = *move; - break; - } - } - - if(bestmove == M_UNKNOWN) //due to transposition table collision - return; - - Board dest2 = dest; - - if(bestmove == dmove){ - assert(dest2.move(smove)); - smove = dmove = M_UNKNOWN; - }else{ - assert(dest2.move(bestmove)); - if(bestmove == smove) - smove = dmove = M_UNKNOWN; - } - - if(tt(dest2)->terminal()) - return; - - Board source2 = source; - assert(source2.move(bestmove)); - - if(source2.won() >= 0) - return; - - //test all responses - for(Board::MoveIterator move = dest2.moveit(true); !move.done(); ++move){ - if(tt(dest2, *move)->terminal()) - continue; - - Move csmove = smove, cdmove = dmove; - - Board source3 = source2, dest3 = dest2; - - if(*move == csmove){ - assert(source3.move(cdmove)); - csmove = cdmove = M_UNKNOWN; - }else{ - assert(source3.move(*move)); - if(*move == csmove) - csmove = cdmove = M_UNKNOWN; - } - - assert(dest3.move(*move)); - - copy_proof(source3, dest3, csmove, cdmove); - - updatePDnum(dest3); - } - - updatePDnum(dest2); -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board){ - hash_t hash = board.gethash(); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - pd = 0; - outcome = (ab == 1 ? solve1ply(board, pd) : solve2ply(board, pd)); - nodes_seen += pd; - }else{ - outcome = board.won(); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board, Move move){ - hash_t hash = board.test_hash(move, board.toplay()); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(move); - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(move); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} diff --git a/hex/solverpns_tt.h b/hex/solverpns_tt.h deleted file mode 100644 index 95d344e..0000000 --- a/hex/solverpns_tt.h +++ /dev/null @@ -1,129 +0,0 @@ - -#pragma once - -//A single-threaded, transposition table based, proof number search solver. - -#include "../lib/zobrist.h" - -#include "solver.h" - -class SolverPNSTT : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - hash_t hash; - uint32_t phi, delta; - - PNSNode() : hash(0), phi(0), delta(0) { } - PNSNode(hash_t h, int v = 1) : hash(h), phi(v), delta(v) { } - PNSNode(hash_t h, int p, int d) : hash(h), phi(p), delta(d) { } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } //unknown - else if(outcome == 2) { phi = LOSS; delta = 0; } //win - else if(outcome == -2) { phi = 0; delta = LOSS; } //loss - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } //draw - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - }; - - PNSNode root; - PNSNode * TT; - uint64_t maxnodes, memlimit; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - int copyproof; //how many siblings to try to copy a proof to - - - SolverPNSTT() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - copyproof = 0; - - TT = NULL; - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNSTT(){ - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - - root = PNSNode(rootboard.gethash(), 1); - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(PNSNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - void pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - void copy_proof(const Board & source, const Board & dest, Move smove, Move dmove); - -//update the phi and delta for the node - bool updatePDnum(const Board & board, PNSNode * node = NULL); - - PNSNode * tt(const Board & board); - PNSNode * tt(const Board & board, Move move); -}; diff --git a/lib/agentpool.h b/lib/agentpool.h index f70894e..ff9466f 100644 --- a/lib/agentpool.h +++ b/lib/agentpool.h @@ -32,6 +32,8 @@ The main thread and the worker threads are coordinated with a simple state machi two barriers. */ +namespace Morat { + enum ThreadState { Thread_Cancelled, //threads should exit Thread_Wait_Start, //threads are waiting to start @@ -80,7 +82,7 @@ class AgentThreadPool { return *(threads[i]); } - string state_string() const { + std::string state_string() const { switch(thread_state){ case Thread_Cancelled: return "Thread_Wait_Cancelled"; case Thread_Wait_Start: return "Thread_Wait_Start"; @@ -173,7 +175,7 @@ class AgentThreadBase { AgentThreadBase(AgentThreadPool * p, AgentType * a) : pool(p), agent(a) { reset(); - thread(bind(&AgentThreadBase::run, this)); + thread(std::bind(&AgentThreadBase::run, this)); } virtual ~AgentThreadBase() { } @@ -229,3 +231,5 @@ class AgentThreadBase { virtual void iterate() = 0; //handles each iteration }; + +}; // namespace Morat diff --git a/lib/alarm-timer.cpp b/lib/alarm-timer.cpp index fa8b92d..00d0af5 100644 --- a/lib/alarm-timer.cpp +++ b/lib/alarm-timer.cpp @@ -2,6 +2,8 @@ #include "alarm.h" #include "timer.h" +namespace Morat { + void alarm_triggered(int signum){ Alarm::Handler::inst().reset(signum); } @@ -72,3 +74,5 @@ void Alarm::Handler::reset(int signum){ if(next > 0) timer().set(next, timer_triggered); } + +}; // namespace Morat diff --git a/lib/alarm.cpp b/lib/alarm.cpp index 7c94a93..e19b37f 100644 --- a/lib/alarm.cpp +++ b/lib/alarm.cpp @@ -1,6 +1,8 @@ #include "alarm.h" +namespace Morat { + void alarm_triggered(int signum){ Alarm::Handler::inst().reset(signum); } @@ -70,3 +72,4 @@ void Alarm::Handler::reset(int signum){ } } +}; // namespace Morat diff --git a/lib/alarm.h b/lib/alarm.h index 6948fda..5f68267 100644 --- a/lib/alarm.h +++ b/lib/alarm.h @@ -24,6 +24,8 @@ * timer.cancel(); */ +namespace Morat { + class Alarm { public: typedef std::function callback_t; @@ -84,3 +86,5 @@ class Alarm { }; friend void alarm_triggered(int); }; + +}; // namespace Morat diff --git a/lib/assert2.h b/lib/assert2.h index dbb2f83..0fdf1c8 100644 --- a/lib/assert2.h +++ b/lib/assert2.h @@ -4,7 +4,17 @@ #include #include + +#if __GNUC__ #define assert2(expr, str) \ ((expr) \ ? __ASSERT_VOID_CAST (0)\ : __assert_fail ((std::string(__STRING(expr)) + "; " + (str)).c_str(), __FILE__, __LINE__, __ASSERT_FUNCTION)) +#elif __clang__ +#define assert2(expr, str) \ + ((expr) \ + ? (void)(0)\ + : __assert_rtn ((std::string(__STRING(expr)) + "; " + (str)).c_str(), __FILE__, __LINE__, __func__)) +#else +#define assert2(expr, str) assert(expr) +#endif diff --git a/lib/bitcount.h b/lib/bitcount.h index c54649f..8a788b8 100644 --- a/lib/bitcount.h +++ b/lib/bitcount.h @@ -1,4 +1,8 @@ +#pragma once + +namespace Morat { + #define TWO(c) (0x1ull << (c)) #define MASK(c) ((~0ull) / (TWO(TWO(c)) + 1ull)) #define COUNT(x,c) ((x) & MASK(c)) + (((x) >> (TWO(c))) & MASK(c)) @@ -59,7 +63,7 @@ static const unsigned char BitsSetTable256[] = { // Both of these are limited to count the bits in the first 5 bytes, which is all that is needed here // That limitation is easy to fix should it be needed inline int precomputed_bitcount(uint64_t n){ - return + return BitsSetTable256[(n >> 0) & 0xff] + BitsSetTable256[(n >> 8) & 0xff] + BitsSetTable256[(n >> 16) & 0xff] + @@ -68,7 +72,7 @@ inline int precomputed_bitcount(uint64_t n){ } inline int precomputed_bitcount2(uint64_t n){ unsigned char * p = (unsigned char *) & n; - return + return BitsSetTable256[p[0]] + BitsSetTable256[p[1]] + BitsSetTable256[p[2]] + @@ -76,6 +80,4 @@ inline int precomputed_bitcount2(uint64_t n){ BitsSetTable256[p[4]]; } - - - +}; // namespace Morat diff --git a/lib/bits.h b/lib/bits.h index 04fcd9a..d3ec7b3 100644 --- a/lib/bits.h +++ b/lib/bits.h @@ -3,19 +3,27 @@ #include +namespace Morat { + #define trailing_zeros(n) __builtin_ctz(n) +// https://code.google.com/p/smhasher/wiki/MurmurHash3 inline uint32_t mix_bits(uint32_t h){ - h ^= (h << 13); - h ^= (h >> 17); - h ^= (h << 5); + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; return h; } +// https://code.google.com/p/smhasher/wiki/MurmurHash3 inline uint64_t mix_bits(uint64_t h){ - h ^= (h >> 17); - h ^= (h << 31); - h ^= (h >> 8); + h ^= h >> 33; + h *= 0xff51afd7ed558ccdull; + h ^= h >> 33; + h *= 0xc4ceb9fe1a85ec53ull; + h ^= h >> 33; return h; } @@ -31,3 +39,5 @@ NumType roundup(NumType v) { v++; return v; } + +}; // namespace Morat diff --git a/lib/catch.hpp b/lib/catch.hpp new file mode 100644 index 0000000..7f061ff --- /dev/null +++ b/lib/catch.hpp @@ -0,0 +1,8732 @@ +/* + * CATCH v1.0 build 48 (master branch) + * Generated: 2014-06-02 07:47:30.155371 + * ---------------------------------------------------------- + * This file has been merged from multiple headers. Please don't edit it directly + * Copyright (c) 2012 Two Blue Cubes Ltd. All rights reserved. + * + * Distributed under the Boost Software License, Version 1.0. (See accompanying + * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ +#ifndef TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +#define TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + +#define TWOBLUECUBES_CATCH_HPP_INCLUDED + +// #included from: internal/catch_suppress_warnings.h + +#define TWOBLUECUBES_CATCH_SUPPRESS_WARNINGS_H_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wglobal-constructors" +#pragma clang diagnostic ignored "-Wvariadic-macros" +#pragma clang diagnostic ignored "-Wc99-extensions" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#pragma clang diagnostic ignored "-Wc++98-compat" +#pragma clang diagnostic ignored "-Wc++98-compat-pedantic" +#endif + +#ifdef CATCH_CONFIG_MAIN +# define CATCH_CONFIG_RUNNER +#endif + +#ifdef CATCH_CONFIG_RUNNER +# ifndef CLARA_CONFIG_MAIN +# define CLARA_CONFIG_MAIN_NOT_DEFINED +# define CLARA_CONFIG_MAIN +# endif +#endif + +// #included from: internal/catch_notimplemented_exception.h +#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_H_INCLUDED + +// #included from: catch_common.h +#define TWOBLUECUBES_CATCH_COMMON_H_INCLUDED + +#define INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) name##line +#define INTERNAL_CATCH_UNIQUE_NAME_LINE( name, line ) INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) +#define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __LINE__ ) + +#define INTERNAL_CATCH_STRINGIFY2( expr ) #expr +#define INTERNAL_CATCH_STRINGIFY( expr ) INTERNAL_CATCH_STRINGIFY2( expr ) + +#include +#include +#include + +// #included from: catch_compiler_capabilities.h +#define TWOBLUECUBES_CATCH_COMPILER_CAPABILITIES_HPP_INCLUDED + +// Much of the following code is based on Boost (1.53) + +#ifdef __clang__ + +# if __has_feature(cxx_nullptr) +# define CATCH_CONFIG_CPP11_NULLPTR +# endif + +# if __has_feature(cxx_noexcept) +# define CATCH_CONFIG_CPP11_NOEXCEPT +# endif + +#endif // __clang__ + +//////////////////////////////////////////////////////////////////////////////// +// Borland +#ifdef __BORLANDC__ + +#if (__BORLANDC__ > 0x582 ) +//#define CATCH_CONFIG_SFINAE // Not confirmed +#endif + +#endif // __BORLANDC__ + +//////////////////////////////////////////////////////////////////////////////// +// EDG +#ifdef __EDG_VERSION__ + +#if (__EDG_VERSION__ > 238 ) +//#define CATCH_CONFIG_SFINAE // Not confirmed +#endif + +#endif // __EDG_VERSION__ + +//////////////////////////////////////////////////////////////////////////////// +// Digital Mars +#ifdef __DMC__ + +#if (__DMC__ > 0x840 ) +//#define CATCH_CONFIG_SFINAE // Not confirmed +#endif + +#endif // __DMC__ + +//////////////////////////////////////////////////////////////////////////////// +// GCC +#ifdef __GNUC__ + +#if __GNUC__ < 3 + +#if (__GNUC_MINOR__ >= 96 ) +//#define CATCH_CONFIG_SFINAE +#endif + +#elif __GNUC__ >= 3 + +// #define CATCH_CONFIG_SFINAE // Taking this out completely for now + +#endif // __GNUC__ < 3 + +#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6 && defined(__GXX_EXPERIMENTAL_CXX0X__) ) + +#define CATCH_CONFIG_CPP11_NULLPTR +#endif + +#endif // __GNUC__ + +//////////////////////////////////////////////////////////////////////////////// +// Visual C++ +#ifdef _MSC_VER + +#if (_MSC_VER >= 1310 ) // (VC++ 7.0+) +//#define CATCH_CONFIG_SFINAE // Not confirmed +#endif + +#endif // _MSC_VER + +// Use variadic macros if the compiler supports them +#if ( defined _MSC_VER && _MSC_VER > 1400 && !defined __EDGE__) || \ + ( defined __WAVE__ && __WAVE_HAS_VARIADICS ) || \ + ( defined __GNUC__ && __GNUC__ >= 3 ) || \ + ( !defined __cplusplus && __STDC_VERSION__ >= 199901L || __cplusplus >= 201103L ) + +#ifndef CATCH_CONFIG_NO_VARIADIC_MACROS +#define CATCH_CONFIG_VARIADIC_MACROS +#endif + +#endif + +//////////////////////////////////////////////////////////////////////////////// +// C++ language feature support + +// detect language version: +#if (__cplusplus == 201103L) +# define CATCH_CPP11 +# define CATCH_CPP11_OR_GREATER +#elif (__cplusplus >= 201103L) +# define CATCH_CPP11_OR_GREATER +#endif + +// noexcept support: +#if defined(CATCH_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_NOEXCEPT) +# define CATCH_NOEXCEPT noexcept +# define CATCH_NOEXCEPT_IS(x) noexcept(x) +#else +# define CATCH_NOEXCEPT throw() +# define CATCH_NOEXCEPT_IS(x) +#endif + +namespace Catch { + + class NonCopyable { + NonCopyable( NonCopyable const& ); + void operator = ( NonCopyable const& ); + protected: + NonCopyable() {} + virtual ~NonCopyable(); + }; + + class SafeBool { + public: + typedef void (SafeBool::*type)() const; + + static type makeSafe( bool value ) { + return value ? &SafeBool::trueValue : 0; + } + private: + void trueValue() const {} + }; + + template + inline void deleteAll( ContainerT& container ) { + typename ContainerT::const_iterator it = container.begin(); + typename ContainerT::const_iterator itEnd = container.end(); + for(; it != itEnd; ++it ) + delete *it; + } + template + inline void deleteAllValues( AssociativeContainerT& container ) { + typename AssociativeContainerT::const_iterator it = container.begin(); + typename AssociativeContainerT::const_iterator itEnd = container.end(); + for(; it != itEnd; ++it ) + delete it->second; + } + + bool startsWith( std::string const& s, std::string const& prefix ); + bool endsWith( std::string const& s, std::string const& suffix ); + bool contains( std::string const& s, std::string const& infix ); + void toLowerInPlace( std::string& s ); + std::string toLower( std::string const& s ); + std::string trim( std::string const& str ); + + struct pluralise { + pluralise( std::size_t count, std::string const& label ); + + friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ); + + std::size_t m_count; + std::string m_label; + }; + + struct SourceLineInfo { + + SourceLineInfo(); + SourceLineInfo( char const* _file, std::size_t _line ); + SourceLineInfo( SourceLineInfo const& other ); +# ifdef CATCH_CPP11_OR_GREATER + SourceLineInfo( SourceLineInfo && ) = default; + SourceLineInfo& operator = ( SourceLineInfo const& ) = default; + SourceLineInfo& operator = ( SourceLineInfo && ) = default; +# endif + bool empty() const; + bool operator == ( SourceLineInfo const& other ) const; + + std::string file; + std::size_t line; + }; + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ); + + // This is just here to avoid compiler warnings with macro constants and boolean literals + inline bool isTrue( bool value ){ return value; } + inline bool alwaysTrue() { return true; } + inline bool alwaysFalse() { return false; } + + void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo ); + + // Use this in variadic streaming macros to allow + // >> +StreamEndStop + // as well as + // >> stuff +StreamEndStop + struct StreamEndStop { + std::string operator+() { + return std::string(); + } + }; + template + T const& operator + ( T const& value, StreamEndStop ) { + return value; + } +} + +#define CATCH_INTERNAL_LINEINFO ::Catch::SourceLineInfo( __FILE__, static_cast( __LINE__ ) ) +#define CATCH_INTERNAL_ERROR( msg ) ::Catch::throwLogicError( msg, CATCH_INTERNAL_LINEINFO ); + +#include + +namespace Catch { + + class NotImplementedException : public std::exception + { + public: + NotImplementedException( SourceLineInfo const& lineInfo ); + NotImplementedException( NotImplementedException const& ) {} + + virtual ~NotImplementedException() CATCH_NOEXCEPT {} + + virtual const char* what() const CATCH_NOEXCEPT; + + private: + std::string m_what; + SourceLineInfo m_lineInfo; + }; + +} // end namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define CATCH_NOT_IMPLEMENTED throw Catch::NotImplementedException( CATCH_INTERNAL_LINEINFO ) + +// #included from: internal/catch_context.h +#define TWOBLUECUBES_CATCH_CONTEXT_H_INCLUDED + +// #included from: catch_interfaces_generators.h +#define TWOBLUECUBES_CATCH_INTERFACES_GENERATORS_H_INCLUDED + +#include + +namespace Catch { + + struct IGeneratorInfo { + virtual ~IGeneratorInfo(); + virtual bool moveNext() = 0; + virtual std::size_t getCurrentIndex() const = 0; + }; + + struct IGeneratorsForTest { + virtual ~IGeneratorsForTest(); + + virtual IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) = 0; + virtual bool moveNext() = 0; + }; + + IGeneratorsForTest* createGeneratorsForTest(); + +} // end namespace Catch + +// #included from: catch_ptr.hpp +#define TWOBLUECUBES_CATCH_PTR_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + // An intrusive reference counting smart pointer. + // T must implement addRef() and release() methods + // typically implementing the IShared interface + template + class Ptr { + public: + Ptr() : m_p( NULL ){} + Ptr( T* p ) : m_p( p ){ + if( m_p ) + m_p->addRef(); + } + Ptr( Ptr const& other ) : m_p( other.m_p ){ + if( m_p ) + m_p->addRef(); + } + ~Ptr(){ + if( m_p ) + m_p->release(); + } + void reset() { + if( m_p ) + m_p->release(); + m_p = NULL; + } + Ptr& operator = ( T* p ){ + Ptr temp( p ); + swap( temp ); + return *this; + } + Ptr& operator = ( Ptr const& other ){ + Ptr temp( other ); + swap( temp ); + return *this; + } + void swap( Ptr& other ) { std::swap( m_p, other.m_p ); } + T* get() { return m_p; } + const T* get() const{ return m_p; } + T& operator*() const { return *m_p; } + T* operator->() const { return m_p; } + bool operator !() const { return m_p == NULL; } + operator SafeBool::type() const { return SafeBool::makeSafe( m_p != NULL ); } + + private: + T* m_p; + }; + + struct IShared : NonCopyable { + virtual ~IShared(); + virtual void addRef() const = 0; + virtual void release() const = 0; + }; + + template + struct SharedImpl : T { + + SharedImpl() : m_rc( 0 ){} + + virtual void addRef() const { + ++m_rc; + } + virtual void release() const { + if( --m_rc == 0 ) + delete this; + } + + mutable unsigned int m_rc; + }; + +} // end namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#include +#include +#include + +namespace Catch { + + class TestCase; + class Stream; + struct IResultCapture; + struct IRunner; + struct IGeneratorsForTest; + struct IConfig; + + struct IContext + { + virtual ~IContext(); + + virtual IResultCapture* getResultCapture() = 0; + virtual IRunner* getRunner() = 0; + virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) = 0; + virtual bool advanceGeneratorsForCurrentTest() = 0; + virtual Ptr getConfig() const = 0; + }; + + struct IMutableContext : IContext + { + virtual ~IMutableContext(); + virtual void setResultCapture( IResultCapture* resultCapture ) = 0; + virtual void setRunner( IRunner* runner ) = 0; + virtual void setConfig( Ptr const& config ) = 0; + }; + + IContext& getCurrentContext(); + IMutableContext& getCurrentMutableContext(); + void cleanUpContext(); + Stream createStream( std::string const& streamName ); + +} + +// #included from: internal/catch_test_registry.hpp +#define TWOBLUECUBES_CATCH_TEST_REGISTRY_HPP_INCLUDED + +// #included from: catch_interfaces_testcase.h +#define TWOBLUECUBES_CATCH_INTERFACES_TESTCASE_H_INCLUDED + +#include + +namespace Catch { + + class TestSpec; + + struct ITestCase : IShared { + virtual void invoke () const = 0; + protected: + virtual ~ITestCase(); + }; + + class TestCase; + struct IConfig; + + struct ITestCaseRegistry { + virtual ~ITestCaseRegistry(); + virtual std::vector const& getAllTests() const = 0; + virtual void getFilteredTests( TestSpec const& testSpec, IConfig const& config, std::vector& matchingTestCases ) const = 0; + + }; +} + +namespace Catch { + +template +class MethodTestCase : public SharedImpl { + +public: + MethodTestCase( void (C::*method)() ) : m_method( method ) {} + + virtual void invoke() const { + C obj; + (obj.*m_method)(); + } + +private: + virtual ~MethodTestCase() {} + + void (C::*m_method)(); +}; + +typedef void(*TestFunction)(); + +struct NameAndDesc { + NameAndDesc( const char* _name = "", const char* _description= "" ) + : name( _name ), description( _description ) + {} + + const char* name; + const char* description; +}; + +struct AutoReg { + + AutoReg( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ); + + template + AutoReg( void (C::*method)(), + char const* className, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ) { + registerTestCase( new MethodTestCase( method ), + className, + nameAndDesc, + lineInfo ); + } + + void registerTestCase( ITestCase* testCase, + char const* className, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ); + + ~AutoReg(); + +private: + AutoReg( AutoReg const& ); + void operator= ( AutoReg const& ); +}; + +} // end namespace Catch + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE( ... ) \ + static void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )(); \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( __VA_ARGS__ ) ); }\ + static void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )() + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); } + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... )\ + namespace{ \ + struct INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) : ClassName{ \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )::test, #ClassName, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); \ + } \ + void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )::test() + +#else + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE( Name, Desc ) \ + static void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )(); \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( Name, Desc ) ); }\ + static void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )() + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, Name, Desc ) \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( Name, Desc ), CATCH_INTERNAL_LINEINFO ); } + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, TestName, Desc )\ + namespace{ \ + struct INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) : ClassName{ \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )::test, #ClassName, Catch::NameAndDesc( TestName, Desc ), CATCH_INTERNAL_LINEINFO ); \ + } \ + void INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )::test() + +#endif + +// #included from: internal/catch_capture.hpp +#define TWOBLUECUBES_CATCH_CAPTURE_HPP_INCLUDED + +// #included from: catch_result_builder.h +#define TWOBLUECUBES_CATCH_RESULT_BUILDER_H_INCLUDED + +// #included from: catch_result_type.h +#define TWOBLUECUBES_CATCH_RESULT_TYPE_H_INCLUDED + +namespace Catch { + + // ResultWas::OfType enum + struct ResultWas { enum OfType { + Unknown = -1, + Ok = 0, + Info = 1, + Warning = 2, + + FailureBit = 0x10, + + ExpressionFailed = FailureBit | 1, + ExplicitFailure = FailureBit | 2, + + Exception = 0x100 | FailureBit, + + ThrewException = Exception | 1, + DidntThrowException = Exception | 2 + + }; }; + + inline bool isOk( ResultWas::OfType resultType ) { + return ( resultType & ResultWas::FailureBit ) == 0; + } + inline bool isJustInfo( int flags ) { + return flags == ResultWas::Info; + } + + // ResultDisposition::Flags enum + struct ResultDisposition { enum Flags { + Normal = 0x00, + + ContinueOnFailure = 0x01, // Failures fail test, but execution continues + FalseTest = 0x02, // Prefix expression with ! + SuppressFail = 0x04 // Failures are reported but do not fail the test + }; }; + + inline ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) { + return static_cast( static_cast( lhs ) | static_cast( rhs ) ); + } + + inline bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; } + inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; } + inline bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; } + +} // end namespace Catch + +// #included from: catch_assertionresult.h +#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_H_INCLUDED + +#include + +namespace Catch { + + struct AssertionInfo + { + AssertionInfo() {} + AssertionInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + std::string const& _capturedExpression, + ResultDisposition::Flags _resultDisposition ); + + std::string macroName; + SourceLineInfo lineInfo; + std::string capturedExpression; + ResultDisposition::Flags resultDisposition; + }; + + struct AssertionResultData + { + AssertionResultData() : resultType( ResultWas::Unknown ) {} + + std::string reconstructedExpression; + std::string message; + ResultWas::OfType resultType; + }; + + class AssertionResult { + public: + AssertionResult(); + AssertionResult( AssertionInfo const& info, AssertionResultData const& data ); + ~AssertionResult(); +# ifdef CATCH_CPP11_OR_GREATER + AssertionResult( AssertionResult const& ) = default; + AssertionResult( AssertionResult && ) = default; + AssertionResult& operator = ( AssertionResult const& ) = default; + AssertionResult& operator = ( AssertionResult && ) = default; +# endif + + bool isOk() const; + bool succeeded() const; + ResultWas::OfType getResultType() const; + bool hasExpression() const; + bool hasMessage() const; + std::string getExpression() const; + std::string getExpressionInMacro() const; + bool hasExpandedExpression() const; + std::string getExpandedExpression() const; + std::string getMessage() const; + SourceLineInfo getSourceInfo() const; + std::string getTestMacroName() const; + + protected: + AssertionInfo m_info; + AssertionResultData m_resultData; + }; + +} // end namespace Catch + +namespace Catch { + + struct TestFailureException{}; + + template class ExpressionLhs; + + struct STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison; + + struct CopyableStream { + CopyableStream() {} + CopyableStream( CopyableStream const& other ) { + oss << other.oss.str(); + } + CopyableStream& operator=( CopyableStream const& other ) { + oss.str(""); + oss << other.oss.str(); + return *this; + } + std::ostringstream oss; + }; + + class ResultBuilder { + public: + ResultBuilder( char const* macroName, + SourceLineInfo const& lineInfo, + char const* capturedExpression, + ResultDisposition::Flags resultDisposition ); + + template + ExpressionLhs operator->* ( T const& operand ); + ExpressionLhs operator->* ( bool value ); + + template + ResultBuilder& operator << ( T const& value ) { + m_stream.oss << value; + return *this; + } + + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator && ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator || ( RhsT const& ); + + ResultBuilder& setResultType( ResultWas::OfType result ); + ResultBuilder& setResultType( bool result ); + ResultBuilder& setLhs( std::string const& lhs ); + ResultBuilder& setRhs( std::string const& rhs ); + ResultBuilder& setOp( std::string const& op ); + + void endExpression(); + + std::string reconstructExpression() const; + AssertionResult build() const; + + void useActiveException( ResultDisposition::Flags resultDisposition = ResultDisposition::Normal ); + void captureResult( ResultWas::OfType resultType ); + void captureExpression(); + void react(); + bool shouldDebugBreak() const; + bool allowThrows() const; + + private: + AssertionInfo m_assertionInfo; + AssertionResultData m_data; + struct ExprComponents { + ExprComponents() : testFalse( false ) {} + bool testFalse; + std::string lhs, rhs, op; + } m_exprComponents; + CopyableStream m_stream; + + bool m_shouldDebugBreak; + bool m_shouldThrow; + }; + +} // namespace Catch + +// Include after due to circular dependency: +// #included from: catch_expression_lhs.hpp +#define TWOBLUECUBES_CATCH_EXPRESSION_LHS_HPP_INCLUDED + +// #included from: catch_evaluate.hpp +#define TWOBLUECUBES_CATCH_EVALUATE_HPP_INCLUDED + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4389) // '==' : signed/unsigned mismatch +#endif + +#include + +namespace Catch { +namespace Internal { + + enum Operator { + IsEqualTo, + IsNotEqualTo, + IsLessThan, + IsGreaterThan, + IsLessThanOrEqualTo, + IsGreaterThanOrEqualTo + }; + + template struct OperatorTraits { static const char* getName(){ return "*error*"; } }; + template<> struct OperatorTraits { static const char* getName(){ return "=="; } }; + template<> struct OperatorTraits { static const char* getName(){ return "!="; } }; + template<> struct OperatorTraits { static const char* getName(){ return "<"; } }; + template<> struct OperatorTraits { static const char* getName(){ return ">"; } }; + template<> struct OperatorTraits { static const char* getName(){ return "<="; } }; + template<> struct OperatorTraits{ static const char* getName(){ return ">="; } }; + + template + inline T& opCast(T const& t) { return const_cast(t); } + +// nullptr_t support based on pull request #154 from Konstantin Baumann +#ifdef CATCH_CONFIG_CPP11_NULLPTR + inline std::nullptr_t opCast(std::nullptr_t) { return nullptr; } +#endif // CATCH_CONFIG_CPP11_NULLPTR + + // So the compare overloads can be operator agnostic we convey the operator as a template + // enum, which is used to specialise an Evaluator for doing the comparison. + template + class Evaluator{}; + + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs) { + return opCast( lhs ) == opCast( rhs ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return opCast( lhs ) != opCast( rhs ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return opCast( lhs ) < opCast( rhs ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return opCast( lhs ) > opCast( rhs ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return opCast( lhs ) >= opCast( rhs ); + } + }; + template + struct Evaluator { + static bool evaluate( T1 const& lhs, T2 const& rhs ) { + return opCast( lhs ) <= opCast( rhs ); + } + }; + + template + bool applyEvaluator( T1 const& lhs, T2 const& rhs ) { + return Evaluator::evaluate( lhs, rhs ); + } + + // This level of indirection allows us to specialise for integer types + // to avoid signed/ unsigned warnings + + // "base" overload + template + bool compare( T1 const& lhs, T2 const& rhs ) { + return Evaluator::evaluate( lhs, rhs ); + } + + // unsigned X to int + template bool compare( unsigned int lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned long lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned char lhs, int rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + + // unsigned X to long + template bool compare( unsigned int lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned long lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + template bool compare( unsigned char lhs, long rhs ) { + return applyEvaluator( lhs, static_cast( rhs ) ); + } + + // int to unsigned X + template bool compare( int lhs, unsigned int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( int lhs, unsigned long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( int lhs, unsigned char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // long to unsigned X + template bool compare( long lhs, unsigned int rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long lhs, unsigned long rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + template bool compare( long lhs, unsigned char rhs ) { + return applyEvaluator( static_cast( lhs ), rhs ); + } + + // pointer to long (when comparing against NULL) + template bool compare( long lhs, T* rhs ) { + return Evaluator::evaluate( reinterpret_cast( lhs ), rhs ); + } + template bool compare( T* lhs, long rhs ) { + return Evaluator::evaluate( lhs, reinterpret_cast( rhs ) ); + } + + // pointer to int (when comparing against NULL) + template bool compare( int lhs, T* rhs ) { + return Evaluator::evaluate( reinterpret_cast( lhs ), rhs ); + } + template bool compare( T* lhs, int rhs ) { + return Evaluator::evaluate( lhs, reinterpret_cast( rhs ) ); + } + +#ifdef CATCH_CONFIG_CPP11_NULLPTR + // pointer to nullptr_t (when comparing against nullptr) + template bool compare( std::nullptr_t, T* rhs ) { + return Evaluator::evaluate( NULL, rhs ); + } + template bool compare( T* lhs, std::nullptr_t ) { + return Evaluator::evaluate( lhs, NULL ); + } +#endif // CATCH_CONFIG_CPP11_NULLPTR + +} // end of namespace Internal +} // end of namespace Catch + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// #included from: catch_tostring.h +#define TWOBLUECUBES_CATCH_TOSTRING_H_INCLUDED + +// #included from: catch_sfinae.hpp +#define TWOBLUECUBES_CATCH_SFINAE_HPP_INCLUDED + +// Try to detect if the current compiler supports SFINAE + +namespace Catch { + + struct TrueType { + static const bool value = true; + typedef void Enable; + char sizer[1]; + }; + struct FalseType { + static const bool value = false; + typedef void Disable; + char sizer[2]; + }; + +#ifdef CATCH_CONFIG_SFINAE + + template struct NotABooleanExpression; + + template struct If : NotABooleanExpression {}; + template<> struct If : TrueType {}; + template<> struct If : FalseType {}; + + template struct SizedIf; + template<> struct SizedIf : TrueType {}; + template<> struct SizedIf : FalseType {}; + +#endif // CATCH_CONFIG_SFINAE + +} // end namespace Catch + +#include +#include +#include +#include +#include + +#ifdef __OBJC__ +// #included from: catch_objc_arc.hpp +#define TWOBLUECUBES_CATCH_OBJC_ARC_HPP_INCLUDED + +#import + +#ifdef __has_feature +#define CATCH_ARC_ENABLED __has_feature(objc_arc) +#else +#define CATCH_ARC_ENABLED 0 +#endif + +void arcSafeRelease( NSObject* obj ); +id performOptionalSelector( id obj, SEL sel ); + +#if !CATCH_ARC_ENABLED +inline void arcSafeRelease( NSObject* obj ) { + [obj release]; +} +inline id performOptionalSelector( id obj, SEL sel ) { + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; + return nil; +} +#define CATCH_UNSAFE_UNRETAINED +#define CATCH_ARC_STRONG +#else +inline void arcSafeRelease( NSObject* ){} +inline id performOptionalSelector( id obj, SEL sel ) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warc-performSelector-leaks" +#endif + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + return nil; +} +#define CATCH_UNSAFE_UNRETAINED __unsafe_unretained +#define CATCH_ARC_STRONG __strong +#endif + +#endif + +namespace Catch { +namespace Detail { + +// SFINAE is currently disabled by default for all compilers. +// If the non SFINAE version of IsStreamInsertable is ambiguous for you +// and your compiler supports SFINAE, try #defining CATCH_CONFIG_SFINAE +#ifdef CATCH_CONFIG_SFINAE + + template + class IsStreamInsertableHelper { + template struct TrueIfSizeable : TrueType {}; + + template + static TrueIfSizeable dummy(T2*); + static FalseType dummy(...); + + public: + typedef SizedIf type; + }; + + template + struct IsStreamInsertable : IsStreamInsertableHelper::type {}; + +#else + + struct BorgType { + template BorgType( T const& ); + }; + + TrueType& testStreamable( std::ostream& ); + FalseType testStreamable( FalseType ); + + FalseType operator<<( std::ostream const&, BorgType const& ); + + template + struct IsStreamInsertable { + static std::ostream &s; + static T const&t; + enum { value = sizeof( testStreamable(s << t) ) == sizeof( TrueType ) }; + }; + +#endif + + template + struct StringMakerBase { + template + static std::string convert( T const& ) { return "{?}"; } + }; + + template<> + struct StringMakerBase { + template + static std::string convert( T const& _value ) { + std::ostringstream oss; + oss << _value; + return oss.str(); + } + }; + + std::string rawMemoryToString( const void *object, std::size_t size ); + + template + inline std::string rawMemoryToString( const T& object ) { + return rawMemoryToString( &object, sizeof(object) ); + } + +} // end namespace Detail + +template +std::string toString( T const& value ); + +template +struct StringMaker : + Detail::StringMakerBase::value> {}; + +template +struct StringMaker { + template + static std::string convert( U* p ) { + if( !p ) + return INTERNAL_CATCH_STRINGIFY( NULL ); + else + return Detail::rawMemoryToString( p ); + } +}; + +template +struct StringMaker { + static std::string convert( R C::* p ) { + if( !p ) + return INTERNAL_CATCH_STRINGIFY( NULL ); + else + return Detail::rawMemoryToString( p ); + } +}; + +namespace Detail { + template + std::string rangeToString( InputIterator first, InputIterator last ); +} + +template +struct StringMaker > { + static std::string convert( std::vector const& v ) { + return Detail::rangeToString( v.begin(), v.end() ); + } +}; + +namespace Detail { + template + std::string makeString( T const& value ) { + return StringMaker::convert( value ); + } +} // end namespace Detail + +/// \brief converts any type to a string +/// +/// The default template forwards on to ostringstream - except when an +/// ostringstream overload does not exist - in which case it attempts to detect +/// that and writes {?}. +/// Overload (not specialise) this template for custom typs that you don't want +/// to provide an ostream overload for. +template +std::string toString( T const& value ) { + return StringMaker::convert( value ); +} + +// Built in overloads + +std::string toString( std::string const& value ); +std::string toString( std::wstring const& value ); +std::string toString( const char* const value ); +std::string toString( char* const value ); +std::string toString( int value ); +std::string toString( unsigned long value ); +std::string toString( unsigned int value ); +std::string toString( const double value ); +std::string toString( bool value ); +std::string toString( char value ); +std::string toString( signed char value ); +std::string toString( unsigned char value ); + +#ifdef CATCH_CONFIG_CPP11_NULLPTR +std::string toString( std::nullptr_t ); +#endif + +#ifdef __OBJC__ + std::string toString( NSString const * const& nsstring ); + std::string toString( NSString * CATCH_ARC_STRONG const& nsstring ); + std::string toString( NSObject* const& nsObject ); +#endif + + namespace Detail { + template + std::string rangeToString( InputIterator first, InputIterator last ) { + std::ostringstream oss; + oss << "{ "; + if( first != last ) { + oss << toString( *first ); + for( ++first ; first != last ; ++first ) { + oss << ", " << toString( *first ); + } + } + oss << " }"; + return oss.str(); + } +} + +} // end namespace Catch + +namespace Catch { + +// Wraps the LHS of an expression and captures the operator and RHS (if any) - +// wrapping them all in a ResultBuilder object +template +class ExpressionLhs { + ExpressionLhs& operator = ( ExpressionLhs const& ); +# ifdef CATCH_CPP11_OR_GREATER + ExpressionLhs& operator = ( ExpressionLhs && ) = delete; +# endif + +public: + ExpressionLhs( ResultBuilder& rb, T lhs ) : m_rb( rb ), m_lhs( lhs ) {} +# ifdef CATCH_CPP11_OR_GREATER + ExpressionLhs( ExpressionLhs const& ) = default; + ExpressionLhs( ExpressionLhs && ) = default; +# endif + + template + ResultBuilder& operator == ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + ResultBuilder& operator != ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + ResultBuilder& operator < ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + ResultBuilder& operator > ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + ResultBuilder& operator <= ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + template + ResultBuilder& operator >= ( RhsT const& rhs ) { + return captureExpression( rhs ); + } + + ResultBuilder& operator == ( bool rhs ) { + return captureExpression( rhs ); + } + + ResultBuilder& operator != ( bool rhs ) { + return captureExpression( rhs ); + } + + void endExpression() { + bool value = m_lhs ? true : false; + m_rb + .setLhs( Catch::toString( value ) ) + .setResultType( value ) + .endExpression(); + } + + // Only simple binary expressions are allowed on the LHS. + // If more complex compositions are required then place the sub expression in parentheses + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator + ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator - ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator / ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator * ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator && ( RhsT const& ); + template STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator || ( RhsT const& ); + +private: + template + ResultBuilder& captureExpression( RhsT const& rhs ) { + return m_rb + .setResultType( Internal::compare( m_lhs, rhs ) ) + .setLhs( Catch::toString( m_lhs ) ) + .setRhs( Catch::toString( rhs ) ) + .setOp( Internal::OperatorTraits::getName() ); + } + +private: + ResultBuilder& m_rb; + T m_lhs; +}; + +} // end namespace Catch + + +namespace Catch { + + template + inline ExpressionLhs ResultBuilder::operator->* ( T const& operand ) { + return ExpressionLhs( *this, operand ); + } + + inline ExpressionLhs ResultBuilder::operator->* ( bool value ) { + return ExpressionLhs( *this, value ); + } + +} // namespace Catch + +// #included from: catch_message.h +#define TWOBLUECUBES_CATCH_MESSAGE_H_INCLUDED + +#include + +namespace Catch { + + struct MessageInfo { + MessageInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ); + + std::string macroName; + SourceLineInfo lineInfo; + ResultWas::OfType type; + std::string message; + unsigned int sequence; + + bool operator == ( MessageInfo const& other ) const { + return sequence == other.sequence; + } + bool operator < ( MessageInfo const& other ) const { + return sequence < other.sequence; + } + private: + static unsigned int globalCount; + }; + + struct MessageBuilder { + MessageBuilder( std::string const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ) + : m_info( macroName, lineInfo, type ) + {} + + template + MessageBuilder& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + MessageInfo m_info; + std::ostringstream m_stream; + }; + + class ScopedMessage { + public: + ScopedMessage( MessageBuilder const& builder ); + ScopedMessage( ScopedMessage const& other ); + ~ScopedMessage(); + + MessageInfo m_info; + }; + +} // end namespace Catch + +// #included from: catch_interfaces_capture.h +#define TWOBLUECUBES_CATCH_INTERFACES_CAPTURE_H_INCLUDED + +#include + +namespace Catch { + + class TestCase; + class AssertionResult; + struct AssertionInfo; + struct SectionInfo; + struct MessageInfo; + class ScopedMessageBuilder; + struct Counts; + + struct IResultCapture { + + virtual ~IResultCapture(); + + virtual void assertionEnded( AssertionResult const& result ) = 0; + virtual bool sectionStarted( SectionInfo const& sectionInfo, + Counts& assertions ) = 0; + virtual void sectionEnded( SectionInfo const& name, Counts const& assertions, double _durationInSeconds ) = 0; + virtual void pushScopedMessage( MessageInfo const& message ) = 0; + virtual void popScopedMessage( MessageInfo const& message ) = 0; + + virtual std::string getCurrentTestName() const = 0; + virtual const AssertionResult* getLastResult() const = 0; + }; + + IResultCapture& getResultCapture(); +} + +// #included from: catch_debugger.h +#define TWOBLUECUBES_CATCH_DEBUGGER_H_INCLUDED + +// #included from: catch_platform.h +#define TWOBLUECUBES_CATCH_PLATFORM_H_INCLUDED + +#if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) +#define CATCH_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +#define CATCH_PLATFORM_IPHONE +#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) +#define CATCH_PLATFORM_WINDOWS +#endif + +#include + +namespace Catch{ + + bool isDebuggerActive(); + void writeToDebugConsole( std::string const& text ); +} + +#ifdef CATCH_PLATFORM_MAC + + // The following code snippet based on: + // http://cocoawithlove.com/2008/03/break-into-debugger.html + #ifdef DEBUG + #if defined(__ppc64__) || defined(__ppc__) + #define CATCH_BREAK_INTO_DEBUGGER() \ + if( Catch::isDebuggerActive() ) { \ + __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n" \ + : : : "memory","r0","r3","r4" ); \ + } + #else + #define CATCH_BREAK_INTO_DEBUGGER() if( Catch::isDebuggerActive() ) {__asm__("int $3\n" : : );} + #endif + #endif + +#elif defined(_MSC_VER) + #define CATCH_BREAK_INTO_DEBUGGER() if( Catch::isDebuggerActive() ) { __debugbreak(); } +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) void __stdcall DebugBreak(); + #define CATCH_BREAK_INTO_DEBUGGER() if( Catch::isDebuggerActive() ) { DebugBreak(); } +#endif + +#ifndef CATCH_BREAK_INTO_DEBUGGER +#define CATCH_BREAK_INTO_DEBUGGER() Catch::alwaysTrue(); +#endif + +// #included from: catch_interfaces_runner.h +#define TWOBLUECUBES_CATCH_INTERFACES_RUNNER_H_INCLUDED + +namespace Catch { + class TestCase; + + struct IRunner { + virtual ~IRunner(); + virtual bool aborting() const = 0; + }; +} + +/////////////////////////////////////////////////////////////////////////////// +// In the event of a failure works out if the debugger needs to be invoked +// and/or an exception thrown and takes appropriate action. +// This needs to be done as a macro so the debugger will stop in the user +// source code rather than in Catch library code +#define INTERNAL_CATCH_REACT( resultBuilder ) \ + if( resultBuilder.shouldDebugBreak() ) CATCH_BREAK_INTO_DEBUGGER(); \ + resultBuilder.react(); + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TEST( expr, resultDisposition, macroName ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + ( __catchResult->*expr ).endExpression(); \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( Catch::ResultDisposition::Normal ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::isTrue( false && (expr) ) ) // expr here is never evaluated at runtime but it forces the compiler to give it a look + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_IF( expr, resultDisposition, macroName ) \ + INTERNAL_CATCH_TEST( expr, resultDisposition, macroName ); \ + if( Catch::getResultCapture().getLastResult()->succeeded() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_ELSE( expr, resultDisposition, macroName ) \ + INTERNAL_CATCH_TEST( expr, resultDisposition, macroName ); \ + if( !Catch::getResultCapture().getLastResult()->succeeded() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_NO_THROW( expr, resultDisposition, macroName ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + expr; \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( resultDisposition ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS( expr, resultDisposition, macroName ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + if( __catchResult.allowThrows() ) \ + expr; \ + __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \ + } \ + catch( ... ) { \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_AS( expr, exceptionType, resultDisposition, macroName ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \ + try { \ + if( __catchResult.allowThrows() ) \ + expr; \ + __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \ + } \ + catch( exceptionType ) { \ + __catchResult.captureResult( Catch::ResultWas::Ok ); \ + } \ + catch( ... ) { \ + __catchResult.useActiveException( resultDisposition ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +/////////////////////////////////////////////////////////////////////////////// +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define INTERNAL_CATCH_MSG( messageType, resultDisposition, macroName, ... ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \ + __catchResult << __VA_ARGS__ + ::Catch::StreamEndStop(); \ + __catchResult.captureResult( messageType ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) +#else + #define INTERNAL_CATCH_MSG( messageType, resultDisposition, macroName, log ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \ + __catchResult << log + ::Catch::StreamEndStop(); \ + __catchResult.captureResult( messageType ); \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) +#endif + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_INFO( log, macroName ) \ + Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage ) = Catch::MessageBuilder( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log; + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CHECK_THAT( arg, matcher, resultDisposition, macroName ) \ + do { \ + Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #arg " " #matcher, resultDisposition ); \ + try { \ + std::string matcherAsString = ::Catch::Matchers::matcher.toString(); \ + __catchResult \ + .setLhs( Catch::toString( arg ) ) \ + .setRhs( matcherAsString == "{?}" ? #matcher : matcherAsString ) \ + .setOp( "matches" ) \ + .setResultType( ::Catch::Matchers::matcher.match( arg ) ); \ + __catchResult.captureExpression(); \ + } catch( ... ) { \ + __catchResult.useActiveException( resultDisposition | Catch::ResultDisposition::ContinueOnFailure ); \ + } \ + INTERNAL_CATCH_REACT( __catchResult ) \ + } while( Catch::alwaysFalse() ) + +// #included from: internal/catch_section.h +#define TWOBLUECUBES_CATCH_SECTION_H_INCLUDED + +// #included from: catch_section_info.h +#define TWOBLUECUBES_CATCH_SECTION_INFO_H_INCLUDED + +namespace Catch { + + struct SectionInfo { + SectionInfo( std::string const& _name, + std::string const& _description, + SourceLineInfo const& _lineInfo ) + : name( _name ), + description( _description ), + lineInfo( _lineInfo ) + {} + + std::string name; + std::string description; + SourceLineInfo lineInfo; + }; + +} // end namespace Catch + +// #included from: catch_totals.hpp +#define TWOBLUECUBES_CATCH_TOTALS_HPP_INCLUDED + +#include + +namespace Catch { + + struct Counts { + Counts() : passed( 0 ), failed( 0 ) {} + + Counts operator - ( Counts const& other ) const { + Counts diff; + diff.passed = passed - other.passed; + diff.failed = failed - other.failed; + return diff; + } + Counts& operator += ( Counts const& other ) { + passed += other.passed; + failed += other.failed; + return *this; + } + + std::size_t total() const { + return passed + failed; + } + + std::size_t passed; + std::size_t failed; + }; + + struct Totals { + + Totals operator - ( Totals const& other ) const { + Totals diff; + diff.assertions = assertions - other.assertions; + diff.testCases = testCases - other.testCases; + return diff; + } + + Totals delta( Totals const& prevTotals ) const { + Totals diff = *this - prevTotals; + if( diff.assertions.failed > 0 ) + ++diff.testCases.failed; + else + ++diff.testCases.passed; + return diff; + } + + Totals& operator += ( Totals const& other ) { + assertions += other.assertions; + testCases += other.testCases; + return *this; + } + + Counts assertions; + Counts testCases; + }; +} + +// #included from: catch_timer.h +#define TWOBLUECUBES_CATCH_TIMER_H_INCLUDED + +#ifdef CATCH_PLATFORM_WINDOWS +typedef unsigned long long uint64_t; +#else +#include +#endif + +namespace Catch { + + class Timer { + public: + Timer() : m_ticks( 0 ) {} + void start(); + unsigned int getElapsedNanoseconds() const; + unsigned int getElapsedMilliseconds() const; + double getElapsedSeconds() const; + + private: + uint64_t m_ticks; + }; + +} // namespace Catch + +#include + +namespace Catch { + + class Section { + public: + Section( SourceLineInfo const& lineInfo, + std::string const& name, + std::string const& description = "" ); + ~Section(); +# ifdef CATCH_CPP11_OR_GREATER + Section( Section const& ) = default; + Section( Section && ) = default; + Section& operator = ( Section const& ) = default; + Section& operator = ( Section && ) = default; +# endif + + // This indicates whether the section should be executed or not + operator bool(); + + private: + + SectionInfo m_info; + + std::string m_name; + Counts m_assertions; + bool m_sectionIncluded; + Timer m_timer; + }; + +} // end namespace Catch + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define INTERNAL_CATCH_SECTION( ... ) \ + if( Catch::Section INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::Section( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) ) +#else + #define INTERNAL_CATCH_SECTION( name, desc ) \ + if( Catch::Section INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::Section( CATCH_INTERNAL_LINEINFO, name, desc ) ) +#endif + +// #included from: internal/catch_generators.hpp +#define TWOBLUECUBES_CATCH_GENERATORS_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + +template +struct IGenerator { + virtual ~IGenerator() {} + virtual T getValue( std::size_t index ) const = 0; + virtual std::size_t size () const = 0; +}; + +template +class BetweenGenerator : public IGenerator { +public: + BetweenGenerator( T from, T to ) : m_from( from ), m_to( to ){} + + virtual T getValue( std::size_t index ) const { + return m_from+static_cast( index ); + } + + virtual std::size_t size() const { + return static_cast( 1+m_to-m_from ); + } + +private: + + T m_from; + T m_to; +}; + +template +class ValuesGenerator : public IGenerator { +public: + ValuesGenerator(){} + + void add( T value ) { + m_values.push_back( value ); + } + + virtual T getValue( std::size_t index ) const { + return m_values[index]; + } + + virtual std::size_t size() const { + return m_values.size(); + } + +private: + std::vector m_values; +}; + +template +class CompositeGenerator { +public: + CompositeGenerator() : m_totalSize( 0 ) {} + + // *** Move semantics, similar to auto_ptr *** + CompositeGenerator( CompositeGenerator& other ) + : m_fileInfo( other.m_fileInfo ), + m_totalSize( 0 ) + { + move( other ); + } + + CompositeGenerator& setFileInfo( const char* fileInfo ) { + m_fileInfo = fileInfo; + return *this; + } + + ~CompositeGenerator() { + deleteAll( m_composed ); + } + + operator T () const { + size_t overallIndex = getCurrentContext().getGeneratorIndex( m_fileInfo, m_totalSize ); + + typename std::vector*>::const_iterator it = m_composed.begin(); + typename std::vector*>::const_iterator itEnd = m_composed.end(); + for( size_t index = 0; it != itEnd; ++it ) + { + const IGenerator* generator = *it; + if( overallIndex >= index && overallIndex < index + generator->size() ) + { + return generator->getValue( overallIndex-index ); + } + index += generator->size(); + } + CATCH_INTERNAL_ERROR( "Indexed past end of generated range" ); + return T(); // Suppress spurious "not all control paths return a value" warning in Visual Studio - if you know how to fix this please do so + } + + void add( const IGenerator* generator ) { + m_totalSize += generator->size(); + m_composed.push_back( generator ); + } + + CompositeGenerator& then( CompositeGenerator& other ) { + move( other ); + return *this; + } + + CompositeGenerator& then( T value ) { + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( value ); + add( valuesGen ); + return *this; + } + +private: + + void move( CompositeGenerator& other ) { + std::copy( other.m_composed.begin(), other.m_composed.end(), std::back_inserter( m_composed ) ); + m_totalSize += other.m_totalSize; + other.m_composed.clear(); + } + + std::vector*> m_composed; + std::string m_fileInfo; + size_t m_totalSize; +}; + +namespace Generators +{ + template + CompositeGenerator between( T from, T to ) { + CompositeGenerator generators; + generators.add( new BetweenGenerator( from, to ) ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2 ) { + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + generators.add( valuesGen ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2, T val3 ){ + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + valuesGen->add( val3 ); + generators.add( valuesGen ); + return generators; + } + + template + CompositeGenerator values( T val1, T val2, T val3, T val4 ) { + CompositeGenerator generators; + ValuesGenerator* valuesGen = new ValuesGenerator(); + valuesGen->add( val1 ); + valuesGen->add( val2 ); + valuesGen->add( val3 ); + valuesGen->add( val4 ); + generators.add( valuesGen ); + return generators; + } + +} // end namespace Generators + +using namespace Generators; + +} // end namespace Catch + +#define INTERNAL_CATCH_LINESTR2( line ) #line +#define INTERNAL_CATCH_LINESTR( line ) INTERNAL_CATCH_LINESTR2( line ) + +#define INTERNAL_CATCH_GENERATE( expr ) expr.setFileInfo( __FILE__ "(" INTERNAL_CATCH_LINESTR( __LINE__ ) ")" ) + +// #included from: internal/catch_interfaces_exception.h +#define TWOBLUECUBES_CATCH_INTERFACES_EXCEPTION_H_INCLUDED + +#include +// #included from: catch_interfaces_registry_hub.h +#define TWOBLUECUBES_CATCH_INTERFACES_REGISTRY_HUB_H_INCLUDED + +#include + +namespace Catch { + + class TestCase; + struct ITestCaseRegistry; + struct IExceptionTranslatorRegistry; + struct IExceptionTranslator; + struct IReporterRegistry; + struct IReporterFactory; + + struct IRegistryHub { + virtual ~IRegistryHub(); + + virtual IReporterRegistry const& getReporterRegistry() const = 0; + virtual ITestCaseRegistry const& getTestCaseRegistry() const = 0; + virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() = 0; + }; + + struct IMutableRegistryHub { + virtual ~IMutableRegistryHub(); + virtual void registerReporter( std::string const& name, IReporterFactory* factory ) = 0; + virtual void registerTest( TestCase const& testInfo ) = 0; + virtual void registerTranslator( const IExceptionTranslator* translator ) = 0; + }; + + IRegistryHub& getRegistryHub(); + IMutableRegistryHub& getMutableRegistryHub(); + void cleanUp(); + std::string translateActiveException(); + +} + + +namespace Catch { + + typedef std::string(*exceptionTranslateFunction)(); + + struct IExceptionTranslator { + virtual ~IExceptionTranslator(); + virtual std::string translate() const = 0; + }; + + struct IExceptionTranslatorRegistry { + virtual ~IExceptionTranslatorRegistry(); + + virtual std::string translateActiveException() const = 0; + }; + + class ExceptionTranslatorRegistrar { + template + class ExceptionTranslator : public IExceptionTranslator { + public: + + ExceptionTranslator( std::string(*translateFunction)( T& ) ) + : m_translateFunction( translateFunction ) + {} + + virtual std::string translate() const { + try { + throw; + } + catch( T& ex ) { + return m_translateFunction( ex ); + } + } + + protected: + std::string(*m_translateFunction)( T& ); + }; + + public: + template + ExceptionTranslatorRegistrar( std::string(*translateFunction)( T& ) ) { + getMutableRegistryHub().registerTranslator + ( new ExceptionTranslator( translateFunction ) ); + } + }; +} + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) \ + static std::string INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator )( signature ); \ + namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ) ); }\ + static std::string INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator )( signature ) + +// #included from: internal/catch_approx.hpp +#define TWOBLUECUBES_CATCH_APPROX_HPP_INCLUDED + +#include +#include + +namespace Catch { +namespace Detail { + + class Approx { + public: + explicit Approx ( double value ) + : m_epsilon( std::numeric_limits::epsilon()*100 ), + m_scale( 1.0 ), + m_value( value ) + {} + + Approx( Approx const& other ) + : m_epsilon( other.m_epsilon ), + m_scale( other.m_scale ), + m_value( other.m_value ) + {} + + static Approx custom() { + return Approx( 0 ); + } + + Approx operator()( double value ) { + Approx approx( value ); + approx.epsilon( m_epsilon ); + approx.scale( m_scale ); + return approx; + } + + friend bool operator == ( double lhs, Approx const& rhs ) { + // Thanks to Richard Harris for his help refining this formula + return fabs( lhs - rhs.m_value ) < rhs.m_epsilon * (rhs.m_scale + (std::max)( fabs(lhs), fabs(rhs.m_value) ) ); + } + + friend bool operator == ( Approx const& lhs, double rhs ) { + return operator==( rhs, lhs ); + } + + friend bool operator != ( double lhs, Approx const& rhs ) { + return !operator==( lhs, rhs ); + } + + friend bool operator != ( Approx const& lhs, double rhs ) { + return !operator==( rhs, lhs ); + } + + Approx& epsilon( double newEpsilon ) { + m_epsilon = newEpsilon; + return *this; + } + + Approx& scale( double newScale ) { + m_scale = newScale; + return *this; + } + + std::string toString() const { + std::ostringstream oss; + oss << "Approx( " << Catch::toString( m_value ) << " )"; + return oss.str(); + } + + private: + double m_epsilon; + double m_scale; + double m_value; + }; +} + +template<> +inline std::string toString( Detail::Approx const& value ) { + return value.toString(); +} + +} // end namespace Catch + +// #included from: internal/catch_matchers.hpp +#define TWOBLUECUBES_CATCH_MATCHERS_HPP_INCLUDED + +namespace Catch { +namespace Matchers { + namespace Impl { + + template + struct Matcher : SharedImpl + { + typedef ExpressionT ExpressionType; + + virtual ~Matcher() {} + virtual Ptr clone() const = 0; + virtual bool match( ExpressionT const& expr ) const = 0; + virtual std::string toString() const = 0; + }; + + template + struct MatcherImpl : Matcher { + + virtual Ptr > clone() const { + return Ptr >( new DerivedT( static_cast( *this ) ) ); + } + }; + + namespace Generic { + + template + class AllOf : public MatcherImpl, ExpressionT> { + public: + + AllOf() {} + AllOf( AllOf const& other ) : m_matchers( other.m_matchers ) {} + + AllOf& add( Matcher const& matcher ) { + m_matchers.push_back( matcher.clone() ); + return *this; + } + virtual bool match( ExpressionT const& expr ) const + { + for( std::size_t i = 0; i < m_matchers.size(); ++i ) + if( !m_matchers[i]->match( expr ) ) + return false; + return true; + } + virtual std::string toString() const { + std::ostringstream oss; + oss << "( "; + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if( i != 0 ) + oss << " and "; + oss << m_matchers[i]->toString(); + } + oss << " )"; + return oss.str(); + } + + private: + std::vector > > m_matchers; + }; + + template + class AnyOf : public MatcherImpl, ExpressionT> { + public: + + AnyOf() {} + AnyOf( AnyOf const& other ) : m_matchers( other.m_matchers ) {} + + AnyOf& add( Matcher const& matcher ) { + m_matchers.push_back( matcher.clone() ); + return *this; + } + virtual bool match( ExpressionT const& expr ) const + { + for( std::size_t i = 0; i < m_matchers.size(); ++i ) + if( m_matchers[i]->match( expr ) ) + return true; + return false; + } + virtual std::string toString() const { + std::ostringstream oss; + oss << "( "; + for( std::size_t i = 0; i < m_matchers.size(); ++i ) { + if( i != 0 ) + oss << " or "; + oss << m_matchers[i]->toString(); + } + oss << " )"; + return oss.str(); + } + + private: + std::vector > > m_matchers; + }; + + } + + namespace StdString { + + inline std::string makeString( std::string const& str ) { return str; } + inline std::string makeString( const char* str ) { return str ? std::string( str ) : std::string(); } + + struct Equals : MatcherImpl { + Equals( std::string const& str ) : m_str( str ){} + Equals( Equals const& other ) : m_str( other.m_str ){} + + virtual ~Equals(); + + virtual bool match( std::string const& expr ) const { + return m_str == expr; + } + virtual std::string toString() const { + return "equals: \"" + m_str + "\""; + } + + std::string m_str; + }; + + struct Contains : MatcherImpl { + Contains( std::string const& substr ) : m_substr( substr ){} + Contains( Contains const& other ) : m_substr( other.m_substr ){} + + virtual ~Contains(); + + virtual bool match( std::string const& expr ) const { + return expr.find( m_substr ) != std::string::npos; + } + virtual std::string toString() const { + return "contains: \"" + m_substr + "\""; + } + + std::string m_substr; + }; + + struct StartsWith : MatcherImpl { + StartsWith( std::string const& substr ) : m_substr( substr ){} + StartsWith( StartsWith const& other ) : m_substr( other.m_substr ){} + + virtual ~StartsWith(); + + virtual bool match( std::string const& expr ) const { + return expr.find( m_substr ) == 0; + } + virtual std::string toString() const { + return "starts with: \"" + m_substr + "\""; + } + + std::string m_substr; + }; + + struct EndsWith : MatcherImpl { + EndsWith( std::string const& substr ) : m_substr( substr ){} + EndsWith( EndsWith const& other ) : m_substr( other.m_substr ){} + + virtual ~EndsWith(); + + virtual bool match( std::string const& expr ) const { + return expr.find( m_substr ) == expr.size() - m_substr.size(); + } + virtual std::string toString() const { + return "ends with: \"" + m_substr + "\""; + } + + std::string m_substr; + }; + } // namespace StdString + } // namespace Impl + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + template + inline Impl::Generic::AllOf AllOf( Impl::Matcher const& m1, + Impl::Matcher const& m2 ) { + return Impl::Generic::AllOf().add( m1 ).add( m2 ); + } + template + inline Impl::Generic::AllOf AllOf( Impl::Matcher const& m1, + Impl::Matcher const& m2, + Impl::Matcher const& m3 ) { + return Impl::Generic::AllOf().add( m1 ).add( m2 ).add( m3 ); + } + template + inline Impl::Generic::AnyOf AnyOf( Impl::Matcher const& m1, + Impl::Matcher const& m2 ) { + return Impl::Generic::AnyOf().add( m1 ).add( m2 ); + } + template + inline Impl::Generic::AnyOf AnyOf( Impl::Matcher const& m1, + Impl::Matcher const& m2, + Impl::Matcher const& m3 ) { + return Impl::Generic::AnyOf().add( m1 ).add( m2 ).add( m3 ); + } + + inline Impl::StdString::Equals Equals( std::string const& str ) { + return Impl::StdString::Equals( str ); + } + inline Impl::StdString::Equals Equals( const char* str ) { + return Impl::StdString::Equals( Impl::StdString::makeString( str ) ); + } + inline Impl::StdString::Contains Contains( std::string const& substr ) { + return Impl::StdString::Contains( substr ); + } + inline Impl::StdString::Contains Contains( const char* substr ) { + return Impl::StdString::Contains( Impl::StdString::makeString( substr ) ); + } + inline Impl::StdString::StartsWith StartsWith( std::string const& substr ) { + return Impl::StdString::StartsWith( substr ); + } + inline Impl::StdString::StartsWith StartsWith( const char* substr ) { + return Impl::StdString::StartsWith( Impl::StdString::makeString( substr ) ); + } + inline Impl::StdString::EndsWith EndsWith( std::string const& substr ) { + return Impl::StdString::EndsWith( substr ); + } + inline Impl::StdString::EndsWith EndsWith( const char* substr ) { + return Impl::StdString::EndsWith( Impl::StdString::makeString( substr ) ); + } + +} // namespace Matchers + +using namespace Matchers; + +} // namespace Catch + +// These files are included here so the single_include script doesn't put them +// in the conditionally compiled sections +// #included from: internal/catch_test_case_info.h +#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_H_INCLUDED + +#include +#include + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + struct ITestCase; + + struct TestCaseInfo { + TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::set const& _tags, + bool _isHidden, + SourceLineInfo const& _lineInfo ); + + TestCaseInfo( TestCaseInfo const& other ); + + std::string name; + std::string className; + std::string description; + std::set tags; + std::set lcaseTags; + std::string tagsAsString; + SourceLineInfo lineInfo; + bool isHidden; + bool throws; + }; + + class TestCase : public TestCaseInfo { + public: + + TestCase( ITestCase* testCase, TestCaseInfo const& info ); + TestCase( TestCase const& other ); + + TestCase withName( std::string const& _newName ) const; + + void invoke() const; + + TestCaseInfo const& getTestCaseInfo() const; + + bool isHidden() const; + bool throws() const; + + void swap( TestCase& other ); + bool operator == ( TestCase const& other ) const; + bool operator < ( TestCase const& other ) const; + TestCase& operator = ( TestCase const& other ); + + private: + Ptr test; + }; + + TestCase makeTestCase( ITestCase* testCase, + std::string const& className, + std::string const& name, + std::string const& description, + SourceLineInfo const& lineInfo ); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + +#ifdef __OBJC__ +// #included from: internal/catch_objc.hpp +#define TWOBLUECUBES_CATCH_OBJC_HPP_INCLUDED + +#import + +#include + +// NB. Any general catch headers included here must be included +// in catch.hpp first to make sure they are included by the single +// header for non obj-usage + +/////////////////////////////////////////////////////////////////////////////// +// This protocol is really only here for (self) documenting purposes, since +// all its methods are optional. +@protocol OcFixture + +@optional + +-(void) setUp; +-(void) tearDown; + +@end + +namespace Catch { + + class OcMethod : public SharedImpl { + + public: + OcMethod( Class cls, SEL sel ) : m_cls( cls ), m_sel( sel ) {} + + virtual void invoke() const { + id obj = [[m_cls alloc] init]; + + performOptionalSelector( obj, @selector(setUp) ); + performOptionalSelector( obj, m_sel ); + performOptionalSelector( obj, @selector(tearDown) ); + + arcSafeRelease( obj ); + } + private: + virtual ~OcMethod() {} + + Class m_cls; + SEL m_sel; + }; + + namespace Detail{ + + inline std::string getAnnotation( Class cls, + std::string const& annotationName, + std::string const& testCaseName ) { + NSString* selStr = [[NSString alloc] initWithFormat:@"Catch_%s_%s", annotationName.c_str(), testCaseName.c_str()]; + SEL sel = NSSelectorFromString( selStr ); + arcSafeRelease( selStr ); + id value = performOptionalSelector( cls, sel ); + if( value ) + return [(NSString*)value UTF8String]; + return ""; + } + } + + inline size_t registerTestMethods() { + size_t noTestMethods = 0; + int noClasses = objc_getClassList( NULL, 0 ); + + Class* classes = (CATCH_UNSAFE_UNRETAINED Class *)malloc( sizeof(Class) * noClasses); + objc_getClassList( classes, noClasses ); + + for( int c = 0; c < noClasses; c++ ) { + Class cls = classes[c]; + { + u_int count; + Method* methods = class_copyMethodList( cls, &count ); + for( u_int m = 0; m < count ; m++ ) { + SEL selector = method_getName(methods[m]); + std::string methodName = sel_getName(selector); + if( startsWith( methodName, "Catch_TestCase_" ) ) { + std::string testCaseName = methodName.substr( 15 ); + std::string name = Detail::getAnnotation( cls, "Name", testCaseName ); + std::string desc = Detail::getAnnotation( cls, "Description", testCaseName ); + const char* className = class_getName( cls ); + + getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, name.c_str(), desc.c_str(), SourceLineInfo() ) ); + noTestMethods++; + } + } + free(methods); + } + } + return noTestMethods; + } + + namespace Matchers { + namespace Impl { + namespace NSStringMatchers { + + template + struct StringHolder : MatcherImpl{ + StringHolder( NSString* substr ) : m_substr( [substr copy] ){} + StringHolder( StringHolder const& other ) : m_substr( [other.m_substr copy] ){} + StringHolder() { + arcSafeRelease( m_substr ); + } + + NSString* m_substr; + }; + + struct Equals : StringHolder { + Equals( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( ExpressionType const& str ) const { + return (str != nil || m_substr == nil ) && + [str isEqualToString:m_substr]; + } + + virtual std::string toString() const { + return "equals string: \"" + Catch::toString( m_substr ) + "\""; + } + }; + + struct Contains : StringHolder { + Contains( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( ExpressionType const& str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location != NSNotFound; + } + + virtual std::string toString() const { + return "contains string: \"" + Catch::toString( m_substr ) + "\""; + } + }; + + struct StartsWith : StringHolder { + StartsWith( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( ExpressionType const& str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == 0; + } + + virtual std::string toString() const { + return "starts with: \"" + Catch::toString( m_substr ) + "\""; + } + }; + struct EndsWith : StringHolder { + EndsWith( NSString* substr ) : StringHolder( substr ){} + + virtual bool match( ExpressionType const& str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == [str length] - [m_substr length]; + } + + virtual std::string toString() const { + return "ends with: \"" + Catch::toString( m_substr ) + "\""; + } + }; + + } // namespace NSStringMatchers + } // namespace Impl + + inline Impl::NSStringMatchers::Equals + Equals( NSString* substr ){ return Impl::NSStringMatchers::Equals( substr ); } + + inline Impl::NSStringMatchers::Contains + Contains( NSString* substr ){ return Impl::NSStringMatchers::Contains( substr ); } + + inline Impl::NSStringMatchers::StartsWith + StartsWith( NSString* substr ){ return Impl::NSStringMatchers::StartsWith( substr ); } + + inline Impl::NSStringMatchers::EndsWith + EndsWith( NSString* substr ){ return Impl::NSStringMatchers::EndsWith( substr ); } + + } // namespace Matchers + + using namespace Matchers; + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define OC_TEST_CASE( name, desc )\ ++(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Name_test ) \ +{\ +return @ name; \ +}\ ++(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Description_test ) \ +{ \ +return @ desc; \ +} \ +-(void) INTERNAL_CATCH_UNIQUE_NAME( Catch_TestCase_test ) + +#endif + +#ifdef CATCH_CONFIG_RUNNER +// #included from: internal/catch_impl.hpp +#define TWOBLUECUBES_CATCH_IMPL_HPP_INCLUDED + +// Collect all the implementation files together here +// These are the equivalent of what would usually be cpp files + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#endif + +// #included from: catch_runner.hpp +#define TWOBLUECUBES_CATCH_RUNNER_HPP_INCLUDED + +// #included from: internal/catch_commandline.hpp +#define TWOBLUECUBES_CATCH_COMMANDLINE_HPP_INCLUDED + +// #included from: catch_config.hpp +#define TWOBLUECUBES_CATCH_CONFIG_HPP_INCLUDED + +// #included from: catch_test_spec_parser.hpp +#define TWOBLUECUBES_CATCH_TEST_SPEC_PARSER_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// #included from: catch_test_spec.hpp +#define TWOBLUECUBES_CATCH_TEST_SPEC_HPP_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +#include +#include + +namespace Catch { + + class TestSpec { + struct Pattern : SharedImpl<> { + virtual ~Pattern(); + virtual bool matches( TestCaseInfo const& testCase ) const = 0; + }; + class NamePattern : public Pattern { + enum WildcardPosition { + NoWildcard = 0, + WildcardAtStart = 1, + WildcardAtEnd = 2, + WildcardAtBothEnds = WildcardAtStart | WildcardAtEnd + }; + + public: + NamePattern( std::string const& name ) : m_name( toLower( name ) ), m_wildcard( NoWildcard ) { + if( startsWith( m_name, "*" ) ) { + m_name = m_name.substr( 1 ); + m_wildcard = WildcardAtStart; + } + if( endsWith( m_name, "*" ) ) { + m_name = m_name.substr( 0, m_name.size()-1 ); + m_wildcard = (WildcardPosition)( m_wildcard | WildcardAtEnd ); + } + } + virtual ~NamePattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { + switch( m_wildcard ) { + case NoWildcard: + return m_name == toLower( testCase.name ); + case WildcardAtStart: + return endsWith( toLower( testCase.name ), m_name ); + case WildcardAtEnd: + return startsWith( toLower( testCase.name ), m_name ); + case WildcardAtBothEnds: + return contains( toLower( testCase.name ), m_name ); + } + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunreachable-code" +#endif + throw std::logic_error( "Unknown enum" ); +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + } + private: + std::string m_name; + WildcardPosition m_wildcard; + }; + class TagPattern : public Pattern { + public: + TagPattern( std::string const& tag ) : m_tag( toLower( tag ) ) {} + virtual ~TagPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { + return testCase.lcaseTags.find( m_tag ) != testCase.lcaseTags.end(); + } + private: + std::string m_tag; + }; + class ExcludedPattern : public Pattern { + public: + ExcludedPattern( Ptr const& underlyingPattern ) : m_underlyingPattern( underlyingPattern ) {} + virtual ~ExcludedPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const { return !m_underlyingPattern->matches( testCase ); } + private: + Ptr m_underlyingPattern; + }; + + struct Filter { + std::vector > m_patterns; + + bool matches( TestCaseInfo const& testCase ) const { + // All patterns in a filter must match for the filter to be a match + for( std::vector >::const_iterator it = m_patterns.begin(), itEnd = m_patterns.end(); it != itEnd; ++it ) + if( !(*it)->matches( testCase ) ) + return false; + return true; + } + }; + + public: + bool hasFilters() const { + return !m_filters.empty(); + } + bool matches( TestCaseInfo const& testCase ) const { + // A TestSpec matches if any filter matches + for( std::vector::const_iterator it = m_filters.begin(), itEnd = m_filters.end(); it != itEnd; ++it ) + if( it->matches( testCase ) ) + return true; + return false; + } + + private: + std::vector m_filters; + + friend class TestSpecParser; + }; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +namespace Catch { + + class TestSpecParser { + enum Mode{ None, Name, QuotedName, Tag }; + Mode m_mode; + bool m_exclusion; + std::size_t m_start, m_pos; + std::string m_arg; + TestSpec::Filter m_currentFilter; + TestSpec m_testSpec; + + public: + TestSpecParser parse( std::string const& arg ) { + m_mode = None; + m_exclusion = false; + m_start = std::string::npos; + m_arg = arg; + for( m_pos = 0; m_pos < m_arg.size(); ++m_pos ) + visitChar( m_arg[m_pos] ); + if( m_mode == Name ) + addPattern(); + return *this; + } + TestSpec testSpec() { + addFilter(); + return m_testSpec; + } + private: + void visitChar( char c ) { + if( m_mode == None ) { + switch( c ) { + case ' ': return; + case '~': m_exclusion = true; return; + case '[': return startNewMode( Tag, ++m_pos ); + case '"': return startNewMode( QuotedName, ++m_pos ); + default: startNewMode( Name, m_pos ); break; + } + } + if( m_mode == Name ) { + if( c == ',' ) { + addPattern(); + addFilter(); + } + else if( c == '[' ) { + if( subString() == "exclude:" ) + m_exclusion = true; + else + addPattern(); + startNewMode( Tag, ++m_pos ); + } + } + else if( m_mode == QuotedName && c == '"' ) + addPattern(); + else if( m_mode == Tag && c == ']' ) + addPattern(); + } + void startNewMode( Mode mode, std::size_t start ) { + m_mode = mode; + m_start = start; + } + std::string subString() const { return m_arg.substr( m_start, m_pos - m_start ); } + template + void addPattern() { + std::string token = subString(); + if( startsWith( token, "exclude:" ) ) { + m_exclusion = true; + token = token.substr( 8 ); + } + if( !token.empty() ) { + Ptr pattern = new T( token ); + if( m_exclusion ) + pattern = new TestSpec::ExcludedPattern( pattern ); + m_currentFilter.m_patterns.push_back( pattern ); + } + m_exclusion = false; + m_mode = None; + } + void addFilter() { + if( !m_currentFilter.m_patterns.empty() ) { + m_testSpec.m_filters.push_back( m_currentFilter ); + m_currentFilter = TestSpec::Filter(); + } + } + }; + inline TestSpec parseTestSpec( std::string const& arg ) { + return TestSpecParser().parse( arg ).testSpec(); + } + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// #included from: catch_interfaces_config.h +#define TWOBLUECUBES_CATCH_INTERFACES_CONFIG_H_INCLUDED + +#include +#include +#include + +namespace Catch { + + struct Verbosity { enum Level { + NoOutput = 0, + Quiet, + Normal + }; }; + + struct WarnAbout { enum What { + Nothing = 0x00, + NoAssertions = 0x01 + }; }; + + struct ShowDurations { enum OrNot { + DefaultForReporter, + Always, + Never + }; }; + + class TestSpec; + + struct IConfig : IShared { + + virtual ~IConfig(); + + virtual bool allowThrows() const = 0; + virtual std::ostream& stream() const = 0; + virtual std::string name() const = 0; + virtual bool includeSuccessfulResults() const = 0; + virtual bool shouldDebugBreak() const = 0; + virtual bool warnAboutMissingAssertions() const = 0; + virtual int abortAfter() const = 0; + virtual bool showInvisibles() const = 0; + virtual ShowDurations::OrNot showDurations() const = 0; + virtual TestSpec const& testSpec() const = 0; + }; +} + +// #included from: catch_stream.h +#define TWOBLUECUBES_CATCH_STREAM_H_INCLUDED + +#include + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + class Stream { + public: + Stream(); + Stream( std::streambuf* _streamBuf, bool _isOwned ); + void release(); + + std::streambuf* streamBuf; + + private: + bool isOwned; + }; +} + +#include +#include +#include +#include + +#ifndef CATCH_CONFIG_CONSOLE_WIDTH +#define CATCH_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { + + struct ConfigData { + + ConfigData() + : listTests( false ), + listTags( false ), + listReporters( false ), + listTestNamesOnly( false ), + showSuccessfulTests( false ), + shouldDebugBreak( false ), + noThrow( false ), + showHelp( false ), + showInvisibles( false ), + abortAfter( -1 ), + verbosity( Verbosity::Normal ), + warnings( WarnAbout::Nothing ), + showDurations( ShowDurations::DefaultForReporter ) + {} + + bool listTests; + bool listTags; + bool listReporters; + bool listTestNamesOnly; + + bool showSuccessfulTests; + bool shouldDebugBreak; + bool noThrow; + bool showHelp; + bool showInvisibles; + + int abortAfter; + + Verbosity::Level verbosity; + WarnAbout::What warnings; + ShowDurations::OrNot showDurations; + + std::string reporterName; + std::string outputFilename; + std::string name; + std::string processName; + + std::vector testsOrTags; + }; + + class Config : public SharedImpl { + private: + Config( Config const& other ); + Config& operator = ( Config const& other ); + virtual void dummy(); + public: + + Config() + : m_os( std::cout.rdbuf() ) + {} + + Config( ConfigData const& data ) + : m_data( data ), + m_os( std::cout.rdbuf() ) + { + if( !data.testsOrTags.empty() ) { + TestSpecParser parser; + for( std::size_t i = 0; i < data.testsOrTags.size(); ++i ) + parser.parse( data.testsOrTags[i] ); + m_testSpec = parser.testSpec(); + } + } + + virtual ~Config() { + m_os.rdbuf( std::cout.rdbuf() ); + m_stream.release(); + } + + void setFilename( std::string const& filename ) { + m_data.outputFilename = filename; + } + + std::string const& getFilename() const { + return m_data.outputFilename ; + } + + bool listTests() const { return m_data.listTests; } + bool listTestNamesOnly() const { return m_data.listTestNamesOnly; } + bool listTags() const { return m_data.listTags; } + bool listReporters() const { return m_data.listReporters; } + + std::string getProcessName() const { return m_data.processName; } + + bool shouldDebugBreak() const { return m_data.shouldDebugBreak; } + + void setStreamBuf( std::streambuf* buf ) { + m_os.rdbuf( buf ? buf : std::cout.rdbuf() ); + } + + void useStream( std::string const& streamName ) { + Stream stream = createStream( streamName ); + setStreamBuf( stream.streamBuf ); + m_stream.release(); + m_stream = stream; + } + + std::string getReporterName() const { return m_data.reporterName; } + + int abortAfter() const { return m_data.abortAfter; } + + TestSpec const& testSpec() const { return m_testSpec; } + + bool showHelp() const { return m_data.showHelp; } + bool showInvisibles() const { return m_data.showInvisibles; } + + // IConfig interface + virtual bool allowThrows() const { return !m_data.noThrow; } + virtual std::ostream& stream() const { return m_os; } + virtual std::string name() const { return m_data.name.empty() ? m_data.processName : m_data.name; } + virtual bool includeSuccessfulResults() const { return m_data.showSuccessfulTests; } + virtual bool warnAboutMissingAssertions() const { return m_data.warnings & WarnAbout::NoAssertions; } + virtual ShowDurations::OrNot showDurations() const { return m_data.showDurations; } + + private: + ConfigData m_data; + + Stream m_stream; + mutable std::ostream m_os; + TestSpec m_testSpec; + }; + +} // end namespace Catch + +// #included from: catch_clara.h +#define TWOBLUECUBES_CATCH_CLARA_H_INCLUDED + +// Use Catch's value for console width (store Clara's off to the side, if present) +#ifdef CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CLARA_CONFIG_CONSOLE_WIDTH +#undef CLARA_CONFIG_CONSOLE_WIDTH +#endif +#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH + +// Declare Clara inside the Catch namespace +#define STITCH_CLARA_OPEN_NAMESPACE namespace Catch { +// #included from: ../external/clara.h + +// Only use header guard if we are not using an outer namespace +#if !defined(TWOBLUECUBES_CLARA_H_INCLUDED) || defined(STITCH_CLARA_OPEN_NAMESPACE) + +#ifndef STITCH_CLARA_OPEN_NAMESPACE +#define TWOBLUECUBES_CLARA_H_INCLUDED +#define STITCH_CLARA_OPEN_NAMESPACE +#define STITCH_CLARA_CLOSE_NAMESPACE +#else +#define STITCH_CLARA_CLOSE_NAMESPACE } +#endif + +#define STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE STITCH_CLARA_OPEN_NAMESPACE + +// ----------- #included from tbc_text_format.h ----------- + +// Only use header guard if we are not using an outer namespace +#if !defined(TBC_TEXT_FORMAT_H_INCLUDED) || defined(STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE) +#ifndef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +#define TBC_TEXT_FORMAT_H_INCLUDED +#endif + +#include +#include +#include + +// Use optional outer namespace +#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +namespace STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE { +#endif + +namespace Tbc { + +#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH + const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + struct TextAttributes { + TextAttributes() + : initialIndent( std::string::npos ), + indent( 0 ), + width( consoleWidth-1 ), + tabChar( '\t' ) + {} + + TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; } + TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; } + TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; } + TextAttributes& setTabChar( char _value ) { tabChar = _value; return *this; } + + std::size_t initialIndent; // indent of first line, or npos + std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos + std::size_t width; // maximum width of text, including indent. Longer text will wrap + char tabChar; // If this char is seen the indent is changed to current pos + }; + + class Text { + public: + Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() ) + : attr( _attr ) + { + std::string wrappableChars = " [({.,/|\\-"; + std::size_t indent = _attr.initialIndent != std::string::npos + ? _attr.initialIndent + : _attr.indent; + std::string remainder = _str; + + while( !remainder.empty() ) { + if( lines.size() >= 1000 ) { + lines.push_back( "... message truncated due to excessive size" ); + return; + } + std::size_t tabPos = std::string::npos; + std::size_t width = (std::min)( remainder.size(), _attr.width - indent ); + std::size_t pos = remainder.find_first_of( '\n' ); + if( pos <= width ) { + width = pos; + } + pos = remainder.find_last_of( _attr.tabChar, width ); + if( pos != std::string::npos ) { + tabPos = pos; + if( remainder[width] == '\n' ) + width--; + remainder = remainder.substr( 0, tabPos ) + remainder.substr( tabPos+1 ); + } + + if( width == remainder.size() ) { + spliceLine( indent, remainder, width ); + } + else if( remainder[width] == '\n' ) { + spliceLine( indent, remainder, width ); + if( width <= 1 || remainder.size() != 1 ) + remainder = remainder.substr( 1 ); + indent = _attr.indent; + } + else { + pos = remainder.find_last_of( wrappableChars, width ); + if( pos != std::string::npos && pos > 0 ) { + spliceLine( indent, remainder, pos ); + if( remainder[0] == ' ' ) + remainder = remainder.substr( 1 ); + } + else { + spliceLine( indent, remainder, width-1 ); + lines.back() += "-"; + } + if( lines.size() == 1 ) + indent = _attr.indent; + if( tabPos != std::string::npos ) + indent += tabPos; + } + } + } + + void spliceLine( std::size_t _indent, std::string& _remainder, std::size_t _pos ) { + lines.push_back( std::string( _indent, ' ' ) + _remainder.substr( 0, _pos ) ); + _remainder = _remainder.substr( _pos ); + } + + typedef std::vector::const_iterator const_iterator; + + const_iterator begin() const { return lines.begin(); } + const_iterator end() const { return lines.end(); } + std::string const& last() const { return lines.back(); } + std::size_t size() const { return lines.size(); } + std::string const& operator[]( std::size_t _index ) const { return lines[_index]; } + std::string toString() const { + std::ostringstream oss; + oss << *this; + return oss.str(); + } + + inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) { + for( Text::const_iterator it = _text.begin(), itEnd = _text.end(); + it != itEnd; ++it ) { + if( it != _text.begin() ) + _stream << "\n"; + _stream << *it; + } + return _stream; + } + + private: + std::string str; + TextAttributes attr; + std::vector lines; + }; + +} // end namespace Tbc + +#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE +} // end outer namespace +#endif + +#endif // TBC_TEXT_FORMAT_H_INCLUDED + +// ----------- end of #include from tbc_text_format.h ----------- +// ........... back in /Users/philnash/Dev/OSS/Clara/srcs/clara.h + +#undef STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE + +#include +#include +#include +#include + +// Use optional outer namespace +#ifdef STITCH_CLARA_OPEN_NAMESPACE +STITCH_CLARA_OPEN_NAMESPACE +#endif + +namespace Clara { + + struct UnpositionalTag {}; + + extern UnpositionalTag _; + +#ifdef CLARA_CONFIG_MAIN + UnpositionalTag _; +#endif + + namespace Detail { + +#ifdef CLARA_CONSOLE_WIDTH + const unsigned int consoleWidth = CLARA_CONFIG_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + using namespace Tbc; + + inline bool startsWith( std::string const& str, std::string const& prefix ) { + return str.size() >= prefix.size() && str.substr( 0, prefix.size() ) == prefix; + } + + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + template struct RemoveConstRef{ typedef T type; }; + + template struct IsBool { static const bool value = false; }; + template<> struct IsBool { static const bool value = true; }; + + template + void convertInto( std::string const& _source, T& _dest ) { + std::stringstream ss; + ss << _source; + ss >> _dest; + if( ss.fail() ) + throw std::runtime_error( "Unable to convert " + _source + " to destination type" ); + } + inline void convertInto( std::string const& _source, std::string& _dest ) { + _dest = _source; + } + inline void convertInto( std::string const& _source, bool& _dest ) { + std::string sourceLC = _source; + std::transform( sourceLC.begin(), sourceLC.end(), sourceLC.begin(), ::tolower ); + if( sourceLC == "y" || sourceLC == "1" || sourceLC == "true" || sourceLC == "yes" || sourceLC == "on" ) + _dest = true; + else if( sourceLC == "n" || sourceLC == "0" || sourceLC == "false" || sourceLC == "no" || sourceLC == "off" ) + _dest = false; + else + throw std::runtime_error( "Expected a boolean value but did not recognise:\n '" + _source + "'" ); + } + inline void convertInto( bool _source, bool& _dest ) { + _dest = _source; + } + template + inline void convertInto( bool, T& ) { + throw std::runtime_error( "Invalid conversion" ); + } + + template + struct IArgFunction { + virtual ~IArgFunction() {} +# ifdef CATCH_CPP11_OR_GREATER + IArgFunction() = default; + IArgFunction( IArgFunction const& ) = default; +# endif + virtual void set( ConfigT& config, std::string const& value ) const = 0; + virtual void setFlag( ConfigT& config ) const = 0; + virtual bool takesArg() const = 0; + virtual IArgFunction* clone() const = 0; + }; + + template + class BoundArgFunction { + public: + BoundArgFunction() : functionObj( NULL ) {} + BoundArgFunction( IArgFunction* _functionObj ) : functionObj( _functionObj ) {} + BoundArgFunction( BoundArgFunction const& other ) : functionObj( other.functionObj ? other.functionObj->clone() : NULL ) {} + BoundArgFunction& operator = ( BoundArgFunction const& other ) { + IArgFunction* newFunctionObj = other.functionObj ? other.functionObj->clone() : NULL; + delete functionObj; + functionObj = newFunctionObj; + return *this; + } + ~BoundArgFunction() { delete functionObj; } + + void set( ConfigT& config, std::string const& value ) const { + functionObj->set( config, value ); + } + void setFlag( ConfigT& config ) const { + functionObj->setFlag( config ); + } + bool takesArg() const { return functionObj->takesArg(); } + + bool isSet() const { + return functionObj != NULL; + } + private: + IArgFunction* functionObj; + }; + + template + struct NullBinder : IArgFunction{ + virtual void set( C&, std::string const& ) const {} + virtual void setFlag( C& ) const {} + virtual bool takesArg() const { return true; } + virtual IArgFunction* clone() const { return new NullBinder( *this ); } + }; + + template + struct BoundDataMember : IArgFunction{ + BoundDataMember( M C::* _member ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + convertInto( stringValue, p.*member ); + } + virtual void setFlag( C& p ) const { + convertInto( true, p.*member ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundDataMember( *this ); } + M C::* member; + }; + template + struct BoundUnaryMethod : IArgFunction{ + BoundUnaryMethod( void (C::*_member)( M ) ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + typename RemoveConstRef::type value; + convertInto( stringValue, value ); + (p.*member)( value ); + } + virtual void setFlag( C& p ) const { + typename RemoveConstRef::type value; + convertInto( true, value ); + (p.*member)( value ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundUnaryMethod( *this ); } + void (C::*member)( M ); + }; + template + struct BoundNullaryMethod : IArgFunction{ + BoundNullaryMethod( void (C::*_member)() ) : member( _member ) {} + virtual void set( C& p, std::string const& stringValue ) const { + bool value; + convertInto( stringValue, value ); + if( value ) + (p.*member)(); + } + virtual void setFlag( C& p ) const { + (p.*member)(); + } + virtual bool takesArg() const { return false; } + virtual IArgFunction* clone() const { return new BoundNullaryMethod( *this ); } + void (C::*member)(); + }; + + template + struct BoundUnaryFunction : IArgFunction{ + BoundUnaryFunction( void (*_function)( C& ) ) : function( _function ) {} + virtual void set( C& obj, std::string const& stringValue ) const { + bool value; + convertInto( stringValue, value ); + if( value ) + function( obj ); + } + virtual void setFlag( C& p ) const { + function( p ); + } + virtual bool takesArg() const { return false; } + virtual IArgFunction* clone() const { return new BoundUnaryFunction( *this ); } + void (*function)( C& ); + }; + + template + struct BoundBinaryFunction : IArgFunction{ + BoundBinaryFunction( void (*_function)( C&, T ) ) : function( _function ) {} + virtual void set( C& obj, std::string const& stringValue ) const { + typename RemoveConstRef::type value; + convertInto( stringValue, value ); + function( obj, value ); + } + virtual void setFlag( C& obj ) const { + typename RemoveConstRef::type value; + convertInto( true, value ); + function( obj, value ); + } + virtual bool takesArg() const { return !IsBool::value; } + virtual IArgFunction* clone() const { return new BoundBinaryFunction( *this ); } + void (*function)( C&, T ); + }; + + } // namespace Detail + + struct Parser { + Parser() : separators( " \t=:" ) {} + + struct Token { + enum Type { Positional, ShortOpt, LongOpt }; + Token( Type _type, std::string const& _data ) : type( _type ), data( _data ) {} + Type type; + std::string data; + }; + + void parseIntoTokens( int argc, char const * const * argv, std::vector& tokens ) const { + const std::string doubleDash = "--"; + for( int i = 1; i < argc && argv[i] != doubleDash; ++i ) + parseIntoTokens( argv[i] , tokens); + } + void parseIntoTokens( std::string arg, std::vector& tokens ) const { + while( !arg.empty() ) { + Parser::Token token( Parser::Token::Positional, arg ); + arg = ""; + if( token.data[0] == '-' ) { + if( token.data.size() > 1 && token.data[1] == '-' ) { + token = Parser::Token( Parser::Token::LongOpt, token.data.substr( 2 ) ); + } + else { + token = Parser::Token( Parser::Token::ShortOpt, token.data.substr( 1 ) ); + if( token.data.size() > 1 && separators.find( token.data[1] ) == std::string::npos ) { + arg = "-" + token.data.substr( 1 ); + token.data = token.data.substr( 0, 1 ); + } + } + } + if( token.type != Parser::Token::Positional ) { + std::size_t pos = token.data.find_first_of( separators ); + if( pos != std::string::npos ) { + arg = token.data.substr( pos+1 ); + token.data = token.data.substr( 0, pos ); + } + } + tokens.push_back( token ); + } + } + std::string separators; + }; + + template + struct CommonArgProperties { + CommonArgProperties() {} + CommonArgProperties( Detail::BoundArgFunction const& _boundField ) : boundField( _boundField ) {} + + Detail::BoundArgFunction boundField; + std::string description; + std::string detail; + std::string placeholder; // Only value if boundField takes an arg + + bool takesArg() const { + return !placeholder.empty(); + } + void validate() const { + if( !boundField.isSet() ) + throw std::logic_error( "option not bound" ); + } + }; + struct OptionArgProperties { + std::vector shortNames; + std::string longName; + + bool hasShortName( std::string const& shortName ) const { + return std::find( shortNames.begin(), shortNames.end(), shortName ) != shortNames.end(); + } + bool hasLongName( std::string const& _longName ) const { + return _longName == longName; + } + }; + struct PositionalArgProperties { + PositionalArgProperties() : position( -1 ) {} + int position; // -1 means non-positional (floating) + + bool isFixedPositional() const { + return position != -1; + } + }; + + template + class CommandLine { + + struct Arg : CommonArgProperties, OptionArgProperties, PositionalArgProperties { + Arg() {} + Arg( Detail::BoundArgFunction const& _boundField ) : CommonArgProperties( _boundField ) {} + + using CommonArgProperties::placeholder; // !TBD + + std::string dbgName() const { + if( !longName.empty() ) + return "--" + longName; + if( !shortNames.empty() ) + return "-" + shortNames[0]; + return "positional args"; + } + std::string commands() const { + std::ostringstream oss; + bool first = true; + std::vector::const_iterator it = shortNames.begin(), itEnd = shortNames.end(); + for(; it != itEnd; ++it ) { + if( first ) + first = false; + else + oss << ", "; + oss << "-" << *it; + } + if( !longName.empty() ) { + if( !first ) + oss << ", "; + oss << "--" << longName; + } + if( !placeholder.empty() ) + oss << " <" << placeholder << ">"; + return oss.str(); + } + }; + + // NOTE: std::auto_ptr is deprecated in c++11/c++0x +#if defined(__cplusplus) && __cplusplus > 199711L + typedef std::unique_ptr ArgAutoPtr; +#else + typedef std::auto_ptr ArgAutoPtr; +#endif + + friend void addOptName( Arg& arg, std::string const& optName ) + { + if( optName.empty() ) + return; + if( Detail::startsWith( optName, "--" ) ) { + if( !arg.longName.empty() ) + throw std::logic_error( "Only one long opt may be specified. '" + + arg.longName + + "' already specified, now attempting to add '" + + optName + "'" ); + arg.longName = optName.substr( 2 ); + } + else if( Detail::startsWith( optName, "-" ) ) + arg.shortNames.push_back( optName.substr( 1 ) ); + else + throw std::logic_error( "option must begin with - or --. Option was: '" + optName + "'" ); + } + friend void setPositionalArg( Arg& arg, int position ) + { + arg.position = position; + } + + class ArgBuilder { + public: + ArgBuilder( Arg* arg ) : m_arg( arg ) {} + + // Bind a non-boolean data member (requires placeholder string) + template + void bind( M C::* field, std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundDataMember( field ); + m_arg->placeholder = placeholder; + } + // Bind a boolean data member (no placeholder required) + template + void bind( bool C::* field ) { + m_arg->boundField = new Detail::BoundDataMember( field ); + } + + // Bind a method taking a single, non-boolean argument (requires a placeholder string) + template + void bind( void (C::* unaryMethod)( M ), std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundUnaryMethod( unaryMethod ); + m_arg->placeholder = placeholder; + } + + // Bind a method taking a single, boolean argument (no placeholder string required) + template + void bind( void (C::* unaryMethod)( bool ) ) { + m_arg->boundField = new Detail::BoundUnaryMethod( unaryMethod ); + } + + // Bind a method that takes no arguments (will be called if opt is present) + template + void bind( void (C::* nullaryMethod)() ) { + m_arg->boundField = new Detail::BoundNullaryMethod( nullaryMethod ); + } + + // Bind a free function taking a single argument - the object to operate on (no placeholder string required) + template + void bind( void (* unaryFunction)( C& ) ) { + m_arg->boundField = new Detail::BoundUnaryFunction( unaryFunction ); + } + + // Bind a free function taking a single argument - the object to operate on (requires a placeholder string) + template + void bind( void (* binaryFunction)( C&, T ), std::string const& placeholder ) { + m_arg->boundField = new Detail::BoundBinaryFunction( binaryFunction ); + m_arg->placeholder = placeholder; + } + + ArgBuilder& describe( std::string const& description ) { + m_arg->description = description; + return *this; + } + ArgBuilder& detail( std::string const& detail ) { + m_arg->detail = detail; + return *this; + } + + protected: + Arg* m_arg; + }; + + class OptBuilder : public ArgBuilder { + public: + OptBuilder( Arg* arg ) : ArgBuilder( arg ) {} + OptBuilder( OptBuilder& other ) : ArgBuilder( other ) {} + + OptBuilder& operator[]( std::string const& optName ) { + addOptName( *ArgBuilder::m_arg, optName ); + return *this; + } + }; + + public: + + CommandLine() + : m_boundProcessName( new Detail::NullBinder() ), + m_highestSpecifiedArgPosition( 0 ), + m_throwOnUnrecognisedTokens( false ) + {} + CommandLine( CommandLine const& other ) + : m_boundProcessName( other.m_boundProcessName ), + m_options ( other.m_options ), + m_positionalArgs( other.m_positionalArgs ), + m_highestSpecifiedArgPosition( other.m_highestSpecifiedArgPosition ), + m_throwOnUnrecognisedTokens( other.m_throwOnUnrecognisedTokens ) + { + if( other.m_floatingArg.get() ) + m_floatingArg = ArgAutoPtr( new Arg( *other.m_floatingArg ) ); + } + + CommandLine& setThrowOnUnrecognisedTokens( bool shouldThrow = true ) { + m_throwOnUnrecognisedTokens = shouldThrow; + return *this; + } + + OptBuilder operator[]( std::string const& optName ) { + m_options.push_back( Arg() ); + addOptName( m_options.back(), optName ); + OptBuilder builder( &m_options.back() ); + return builder; + } + + ArgBuilder operator[]( int position ) { + m_positionalArgs.insert( std::make_pair( position, Arg() ) ); + if( position > m_highestSpecifiedArgPosition ) + m_highestSpecifiedArgPosition = position; + setPositionalArg( m_positionalArgs[position], position ); + ArgBuilder builder( &m_positionalArgs[position] ); + return builder; + } + + // Invoke this with the _ instance + ArgBuilder operator[]( UnpositionalTag ) { + if( m_floatingArg.get() ) + throw std::logic_error( "Only one unpositional argument can be added" ); + m_floatingArg = ArgAutoPtr( new Arg() ); + ArgBuilder builder( m_floatingArg.get() ); + return builder; + } + + template + void bindProcessName( M C::* field ) { + m_boundProcessName = new Detail::BoundDataMember( field ); + } + template + void bindProcessName( void (C::*_unaryMethod)( M ) ) { + m_boundProcessName = new Detail::BoundUnaryMethod( _unaryMethod ); + } + + void optUsage( std::ostream& os, std::size_t indent = 0, std::size_t width = Detail::consoleWidth ) const { + typename std::vector::const_iterator itBegin = m_options.begin(), itEnd = m_options.end(), it; + std::size_t maxWidth = 0; + for( it = itBegin; it != itEnd; ++it ) + maxWidth = (std::max)( maxWidth, it->commands().size() ); + + for( it = itBegin; it != itEnd; ++it ) { + Detail::Text usage( it->commands(), Detail::TextAttributes() + .setWidth( maxWidth+indent ) + .setIndent( indent ) ); + Detail::Text desc( it->description, Detail::TextAttributes() + .setWidth( width - maxWidth - 3 ) ); + + for( std::size_t i = 0; i < (std::max)( usage.size(), desc.size() ); ++i ) { + std::string usageCol = i < usage.size() ? usage[i] : ""; + os << usageCol; + + if( i < desc.size() && !desc[i].empty() ) + os << std::string( indent + 2 + maxWidth - usageCol.size(), ' ' ) + << desc[i]; + os << "\n"; + } + } + } + std::string optUsage() const { + std::ostringstream oss; + optUsage( oss ); + return oss.str(); + } + + void argSynopsis( std::ostream& os ) const { + for( int i = 1; i <= m_highestSpecifiedArgPosition; ++i ) { + if( i > 1 ) + os << " "; + typename std::map::const_iterator it = m_positionalArgs.find( i ); + if( it != m_positionalArgs.end() ) + os << "<" << it->second.placeholder << ">"; + else if( m_floatingArg.get() ) + os << "<" << m_floatingArg->placeholder << ">"; + else + throw std::logic_error( "non consecutive positional arguments with no floating args" ); + } + // !TBD No indication of mandatory args + if( m_floatingArg.get() ) { + if( m_highestSpecifiedArgPosition > 1 ) + os << " "; + os << "[<" << m_floatingArg->placeholder << "> ...]"; + } + } + std::string argSynopsis() const { + std::ostringstream oss; + argSynopsis( oss ); + return oss.str(); + } + + void usage( std::ostream& os, std::string const& procName ) const { + validate(); + os << "usage:\n " << procName << " "; + argSynopsis( os ); + if( !m_options.empty() ) { + os << " [options]\n\nwhere options are: \n"; + optUsage( os, 2 ); + } + os << "\n"; + } + std::string usage( std::string const& procName ) const { + std::ostringstream oss; + usage( oss, procName ); + return oss.str(); + } + + ConfigT parse( int argc, char const * const * argv ) const { + ConfigT config; + parseInto( argc, argv, config ); + return config; + } + + std::vector parseInto( int argc, char const * const * argv, ConfigT& config ) const { + std::string processName = argv[0]; + std::size_t lastSlash = processName.find_last_of( "/\\" ); + if( lastSlash != std::string::npos ) + processName = processName.substr( lastSlash+1 ); + m_boundProcessName.set( config, processName ); + std::vector tokens; + Parser parser; + parser.parseIntoTokens( argc, argv, tokens ); + return populate( tokens, config ); + } + + std::vector populate( std::vector const& tokens, ConfigT& config ) const { + validate(); + std::vector unusedTokens = populateOptions( tokens, config ); + unusedTokens = populateFixedArgs( unusedTokens, config ); + unusedTokens = populateFloatingArgs( unusedTokens, config ); + return unusedTokens; + } + + std::vector populateOptions( std::vector const& tokens, ConfigT& config ) const { + std::vector unusedTokens; + std::vector errors; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + typename std::vector::const_iterator it = m_options.begin(), itEnd = m_options.end(); + for(; it != itEnd; ++it ) { + Arg const& arg = *it; + + try { + if( ( token.type == Parser::Token::ShortOpt && arg.hasShortName( token.data ) ) || + ( token.type == Parser::Token::LongOpt && arg.hasLongName( token.data ) ) ) { + if( arg.takesArg() ) { + if( i == tokens.size()-1 || tokens[i+1].type != Parser::Token::Positional ) + errors.push_back( "Expected argument to option: " + token.data ); + else + arg.boundField.set( config, tokens[++i].data ); + } + else { + arg.boundField.setFlag( config ); + } + break; + } + } + catch( std::exception& ex ) { + errors.push_back( std::string( ex.what() ) + "\n- while parsing: (" + arg.commands() + ")" ); + } + } + if( it == itEnd ) { + if( token.type == Parser::Token::Positional || !m_throwOnUnrecognisedTokens ) + unusedTokens.push_back( token ); + else if( m_throwOnUnrecognisedTokens ) + errors.push_back( "unrecognised option: " + token.data ); + } + } + if( !errors.empty() ) { + std::ostringstream oss; + for( std::vector::const_iterator it = errors.begin(), itEnd = errors.end(); + it != itEnd; + ++it ) { + if( it != errors.begin() ) + oss << "\n"; + oss << *it; + } + throw std::runtime_error( oss.str() ); + } + return unusedTokens; + } + std::vector populateFixedArgs( std::vector const& tokens, ConfigT& config ) const { + std::vector unusedTokens; + int position = 1; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + typename std::map::const_iterator it = m_positionalArgs.find( position ); + if( it != m_positionalArgs.end() ) + it->second.boundField.set( config, token.data ); + else + unusedTokens.push_back( token ); + if( token.type == Parser::Token::Positional ) + position++; + } + return unusedTokens; + } + std::vector populateFloatingArgs( std::vector const& tokens, ConfigT& config ) const { + if( !m_floatingArg.get() ) + return tokens; + std::vector unusedTokens; + for( std::size_t i = 0; i < tokens.size(); ++i ) { + Parser::Token const& token = tokens[i]; + if( token.type == Parser::Token::Positional ) + m_floatingArg->boundField.set( config, token.data ); + else + unusedTokens.push_back( token ); + } + return unusedTokens; + } + + void validate() const + { + if( m_options.empty() && m_positionalArgs.empty() && !m_floatingArg.get() ) + throw std::logic_error( "No options or arguments specified" ); + + for( typename std::vector::const_iterator it = m_options.begin(), + itEnd = m_options.end(); + it != itEnd; ++it ) + it->validate(); + } + + private: + Detail::BoundArgFunction m_boundProcessName; + std::vector m_options; + std::map m_positionalArgs; + ArgAutoPtr m_floatingArg; + int m_highestSpecifiedArgPosition; + bool m_throwOnUnrecognisedTokens; + }; + +} // end namespace Clara + +STITCH_CLARA_CLOSE_NAMESPACE +#undef STITCH_CLARA_OPEN_NAMESPACE +#undef STITCH_CLARA_CLOSE_NAMESPACE + +#endif // TWOBLUECUBES_CLARA_H_INCLUDED +#undef STITCH_CLARA_OPEN_NAMESPACE + +// Restore Clara's value for console width, if present +#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +#include + +namespace Catch { + + inline void abortAfterFirst( ConfigData& config ) { config.abortAfter = 1; } + inline void abortAfterX( ConfigData& config, int x ) { + if( x < 1 ) + throw std::runtime_error( "Value after -x or --abortAfter must be greater than zero" ); + config.abortAfter = x; + } + inline void addTestOrTags( ConfigData& config, std::string const& _testSpec ) { config.testsOrTags.push_back( _testSpec ); } + + inline void addWarning( ConfigData& config, std::string const& _warning ) { + if( _warning == "NoAssertions" ) + config.warnings = (WarnAbout::What)( config.warnings | WarnAbout::NoAssertions ); + else + throw std::runtime_error( "Unrecognised warning: '" + _warning + "'" ); + + } + inline void setVerbosity( ConfigData& config, int level ) { + // !TBD: accept strings? + config.verbosity = (Verbosity::Level)level; + } + inline void setShowDurations( ConfigData& config, bool _showDurations ) { + config.showDurations = _showDurations + ? ShowDurations::Always + : ShowDurations::Never; + } + inline void loadTestNamesFromFile( ConfigData& config, std::string const& _filename ) { + std::ifstream f( _filename.c_str() ); + if( !f.is_open() ) + throw std::domain_error( "Unable to load input file: " + _filename ); + + std::string line; + while( std::getline( f, line ) ) { + line = trim(line); + if( !line.empty() && !startsWith( line, "#" ) ) + addTestOrTags( config, "\"" + line + "\"," ); + } + } + + inline Clara::CommandLine makeCommandLineParser() { + + using namespace Clara; + CommandLine cli; + + cli.bindProcessName( &ConfigData::processName ); + + cli["-?"]["-h"]["--help"] + .describe( "display usage information" ) + .bind( &ConfigData::showHelp ); + + cli["-l"]["--list-tests"] + .describe( "list all/matching test cases" ) + .bind( &ConfigData::listTests ); + + cli["-t"]["--list-tags"] + .describe( "list all/matching tags" ) + .bind( &ConfigData::listTags ); + + cli["-s"]["--success"] + .describe( "include successful tests in output" ) + .bind( &ConfigData::showSuccessfulTests ); + + cli["-b"]["--break"] + .describe( "break into debugger on failure" ) + .bind( &ConfigData::shouldDebugBreak ); + + cli["-e"]["--nothrow"] + .describe( "skip exception tests" ) + .bind( &ConfigData::noThrow ); + + cli["-i"]["--invisibles"] + .describe( "show invisibles (tabs, newlines)" ) + .bind( &ConfigData::showInvisibles ); + + cli["-o"]["--out"] + .describe( "output filename" ) + .bind( &ConfigData::outputFilename, "filename" ); + + cli["-r"]["--reporter"] +// .placeholder( "name[:filename]" ) + .describe( "reporter to use (defaults to console)" ) + .bind( &ConfigData::reporterName, "name" ); + + cli["-n"]["--name"] + .describe( "suite name" ) + .bind( &ConfigData::name, "name" ); + + cli["-a"]["--abort"] + .describe( "abort at first failure" ) + .bind( &abortAfterFirst ); + + cli["-x"]["--abortx"] + .describe( "abort after x failures" ) + .bind( &abortAfterX, "no. failures" ); + + cli["-w"]["--warn"] + .describe( "enable warnings" ) + .bind( &addWarning, "warning name" ); + +// - needs updating if reinstated +// cli.into( &setVerbosity ) +// .describe( "level of verbosity (0=no output)" ) +// .shortOpt( "v") +// .longOpt( "verbosity" ) +// .placeholder( "level" ); + + cli[_] + .describe( "which test or tests to use" ) + .bind( &addTestOrTags, "test name, pattern or tags" ); + + cli["-d"]["--durations"] + .describe( "show test durations" ) + .bind( &setShowDurations, "yes/no" ); + + cli["-f"]["--input-file"] + .describe( "load test names to run from a file" ) + .bind( &loadTestNamesFromFile, "filename" ); + + // Less common commands which don't have a short form + cli["--list-test-names-only"] + .describe( "list all/matching test cases names only" ) + .bind( &ConfigData::listTestNamesOnly ); + + cli["--list-reporters"] + .describe( "list all reporters" ) + .bind( &ConfigData::listReporters ); + + return cli; + } + +} // end namespace Catch + +// #included from: internal/catch_list.hpp +#define TWOBLUECUBES_CATCH_LIST_HPP_INCLUDED + +// #included from: catch_text.h +#define TWOBLUECUBES_CATCH_TEXT_H_INCLUDED + +#define TBC_TEXT_FORMAT_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH + +#define CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE Catch +// #included from: ../external/tbc_text_format.h +// Only use header guard if we are not using an outer namespace +#ifndef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +# ifdef TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED +# ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +# define TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +# endif +# else +# define TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED +# endif +#endif +#ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +#include +#include +#include + +// Use optional outer namespace +#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +namespace CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE { +#endif + +namespace Tbc { + +#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH + const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH; +#else + const unsigned int consoleWidth = 80; +#endif + + struct TextAttributes { + TextAttributes() + : initialIndent( std::string::npos ), + indent( 0 ), + width( consoleWidth-1 ), + tabChar( '\t' ) + {} + + TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; } + TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; } + TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; } + TextAttributes& setTabChar( char _value ) { tabChar = _value; return *this; } + + std::size_t initialIndent; // indent of first line, or npos + std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos + std::size_t width; // maximum width of text, including indent. Longer text will wrap + char tabChar; // If this char is seen the indent is changed to current pos + }; + + class Text { + public: + Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() ) + : attr( _attr ) + { + std::string wrappableChars = " [({.,/|\\-"; + std::size_t indent = _attr.initialIndent != std::string::npos + ? _attr.initialIndent + : _attr.indent; + std::string remainder = _str; + + while( !remainder.empty() ) { + if( lines.size() >= 1000 ) { + lines.push_back( "... message truncated due to excessive size" ); + return; + } + std::size_t tabPos = std::string::npos; + std::size_t width = (std::min)( remainder.size(), _attr.width - indent ); + std::size_t pos = remainder.find_first_of( '\n' ); + if( pos <= width ) { + width = pos; + } + pos = remainder.find_last_of( _attr.tabChar, width ); + if( pos != std::string::npos ) { + tabPos = pos; + if( remainder[width] == '\n' ) + width--; + remainder = remainder.substr( 0, tabPos ) + remainder.substr( tabPos+1 ); + } + + if( width == remainder.size() ) { + spliceLine( indent, remainder, width ); + } + else if( remainder[width] == '\n' ) { + spliceLine( indent, remainder, width ); + if( width <= 1 || remainder.size() != 1 ) + remainder = remainder.substr( 1 ); + indent = _attr.indent; + } + else { + pos = remainder.find_last_of( wrappableChars, width ); + if( pos != std::string::npos && pos > 0 ) { + spliceLine( indent, remainder, pos ); + if( remainder[0] == ' ' ) + remainder = remainder.substr( 1 ); + } + else { + spliceLine( indent, remainder, width-1 ); + lines.back() += "-"; + } + if( lines.size() == 1 ) + indent = _attr.indent; + if( tabPos != std::string::npos ) + indent += tabPos; + } + } + } + + void spliceLine( std::size_t _indent, std::string& _remainder, std::size_t _pos ) { + lines.push_back( std::string( _indent, ' ' ) + _remainder.substr( 0, _pos ) ); + _remainder = _remainder.substr( _pos ); + } + + typedef std::vector::const_iterator const_iterator; + + const_iterator begin() const { return lines.begin(); } + const_iterator end() const { return lines.end(); } + std::string const& last() const { return lines.back(); } + std::size_t size() const { return lines.size(); } + std::string const& operator[]( std::size_t _index ) const { return lines[_index]; } + std::string toString() const { + std::ostringstream oss; + oss << *this; + return oss.str(); + } + + inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) { + for( Text::const_iterator it = _text.begin(), itEnd = _text.end(); + it != itEnd; ++it ) { + if( it != _text.begin() ) + _stream << "\n"; + _stream << *it; + } + return _stream; + } + + private: + std::string str; + TextAttributes attr; + std::vector lines; + }; + +} // end namespace Tbc + +#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE +} // end outer namespace +#endif + +#endif // TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED +#undef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE + +namespace Catch { + using Tbc::Text; + using Tbc::TextAttributes; +} + +// #included from: catch_console_colour.hpp +#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_HPP_INCLUDED + +namespace Catch { + + namespace Detail { + struct IColourImpl; + } + + struct Colour { + enum Code { + None = 0, + + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White, + + // By intention + FileName = LightGrey, + ResultError = BrightRed, + ResultSuccess = BrightGreen, + + Error = BrightRed, + Success = Green, + + OriginalExpression = Cyan, + ReconstructedExpression = Yellow, + + SecondaryText = LightGrey, + Headers = White + }; + + // Use constructed object for RAII guard + Colour( Code _colourCode ); + ~Colour(); + + // Use static method for one-shot changes + static void use( Code _colourCode ); + + private: + Colour( Colour const& other ); + static Detail::IColourImpl* impl(); + }; + +} // end namespace Catch + +// #included from: catch_interfaces_reporter.h +#define TWOBLUECUBES_CATCH_INTERFACES_REPORTER_H_INCLUDED + +// #included from: catch_option.hpp +#define TWOBLUECUBES_CATCH_OPTION_HPP_INCLUDED + +namespace Catch { + + // An optional type + template + class Option { + public: + Option() : nullableValue( NULL ) {} + Option( T const& _value ) + : nullableValue( new( storage ) T( _value ) ) + {} + Option( Option const& _other ) + : nullableValue( _other ? new( storage ) T( *_other ) : NULL ) + {} + + ~Option() { + reset(); + } + + Option& operator= ( Option const& _other ) { + if( &_other != this ) { + reset(); + if( _other ) + nullableValue = new( storage ) T( *_other ); + } + return *this; + } + Option& operator = ( T const& _value ) { + reset(); + nullableValue = new( storage ) T( _value ); + return *this; + } + + void reset() { + if( nullableValue ) + nullableValue->~T(); + nullableValue = NULL; + } + + T& operator*() { return *nullableValue; } + T const& operator*() const { return *nullableValue; } + T* operator->() { return nullableValue; } + const T* operator->() const { return nullableValue; } + + T valueOr( T const& defaultValue ) const { + return nullableValue ? *nullableValue : defaultValue; + } + + bool some() const { return nullableValue != NULL; } + bool none() const { return nullableValue == NULL; } + + bool operator !() const { return nullableValue == NULL; } + operator SafeBool::type() const { + return SafeBool::makeSafe( some() ); + } + + private: + T* nullableValue; + char storage[sizeof(T)]; + }; + +} // end namespace Catch + +#include +#include +#include +#include + +namespace Catch +{ + struct ReporterConfig { + explicit ReporterConfig( Ptr const& _fullConfig ) + : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {} + + ReporterConfig( Ptr const& _fullConfig, std::ostream& _stream ) + : m_stream( &_stream ), m_fullConfig( _fullConfig ) {} + + std::ostream& stream() const { return *m_stream; } + Ptr fullConfig() const { return m_fullConfig; } + + private: + std::ostream* m_stream; + Ptr m_fullConfig; + }; + + struct ReporterPreferences { + ReporterPreferences() + : shouldRedirectStdOut( false ) + {} + + bool shouldRedirectStdOut; + }; + + template + struct LazyStat : Option { + LazyStat() : used( false ) {} + LazyStat& operator=( T const& _value ) { + Option::operator=( _value ); + used = false; + return *this; + } + void reset() { + Option::reset(); + used = false; + } + bool used; + }; + + struct TestRunInfo { + TestRunInfo( std::string const& _name ) : name( _name ) {} + std::string name; + }; + struct GroupInfo { + GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ) + : name( _name ), + groupIndex( _groupIndex ), + groupsCounts( _groupsCount ) + {} + + std::string name; + std::size_t groupIndex; + std::size_t groupsCounts; + }; + + struct AssertionStats { + AssertionStats( AssertionResult const& _assertionResult, + std::vector const& _infoMessages, + Totals const& _totals ) + : assertionResult( _assertionResult ), + infoMessages( _infoMessages ), + totals( _totals ) + { + if( assertionResult.hasMessage() ) { + // Copy message into messages list. + // !TBD This should have been done earlier, somewhere + MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() ); + builder << assertionResult.getMessage(); + builder.m_info.message = builder.m_stream.str(); + + infoMessages.push_back( builder.m_info ); + } + } + virtual ~AssertionStats(); + +# ifdef CATCH_CPP11_OR_GREATER + AssertionStats( AssertionStats const& ) = default; + AssertionStats( AssertionStats && ) = default; + AssertionStats& operator = ( AssertionStats const& ) = default; + AssertionStats& operator = ( AssertionStats && ) = default; +# endif + + AssertionResult assertionResult; + std::vector infoMessages; + Totals totals; + }; + + struct SectionStats { + SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ) + : sectionInfo( _sectionInfo ), + assertions( _assertions ), + durationInSeconds( _durationInSeconds ), + missingAssertions( _missingAssertions ) + {} + virtual ~SectionStats(); +# ifdef CATCH_CPP11_OR_GREATER + SectionStats( SectionStats const& ) = default; + SectionStats( SectionStats && ) = default; + SectionStats& operator = ( SectionStats const& ) = default; + SectionStats& operator = ( SectionStats && ) = default; +# endif + + SectionInfo sectionInfo; + Counts assertions; + double durationInSeconds; + bool missingAssertions; + }; + + struct TestCaseStats { + TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ) + : testInfo( _testInfo ), + totals( _totals ), + stdOut( _stdOut ), + stdErr( _stdErr ), + aborting( _aborting ) + {} + virtual ~TestCaseStats(); + +# ifdef CATCH_CPP11_OR_GREATER + TestCaseStats( TestCaseStats const& ) = default; + TestCaseStats( TestCaseStats && ) = default; + TestCaseStats& operator = ( TestCaseStats const& ) = default; + TestCaseStats& operator = ( TestCaseStats && ) = default; +# endif + + TestCaseInfo testInfo; + Totals totals; + std::string stdOut; + std::string stdErr; + bool aborting; + }; + + struct TestGroupStats { + TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ) + : groupInfo( _groupInfo ), + totals( _totals ), + aborting( _aborting ) + {} + TestGroupStats( GroupInfo const& _groupInfo ) + : groupInfo( _groupInfo ), + aborting( false ) + {} + virtual ~TestGroupStats(); + +# ifdef CATCH_CPP11_OR_GREATER + TestGroupStats( TestGroupStats const& ) = default; + TestGroupStats( TestGroupStats && ) = default; + TestGroupStats& operator = ( TestGroupStats const& ) = default; + TestGroupStats& operator = ( TestGroupStats && ) = default; +# endif + + GroupInfo groupInfo; + Totals totals; + bool aborting; + }; + + struct TestRunStats { + TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ) + : runInfo( _runInfo ), + totals( _totals ), + aborting( _aborting ) + {} + virtual ~TestRunStats(); + +# ifndef CATCH_CPP11_OR_GREATER + TestRunStats( TestRunStats const& _other ) + : runInfo( _other.runInfo ), + totals( _other.totals ), + aborting( _other.aborting ) + {} +# else + TestRunStats( TestRunStats const& ) = default; + TestRunStats( TestRunStats && ) = default; + TestRunStats& operator = ( TestRunStats const& ) = default; + TestRunStats& operator = ( TestRunStats && ) = default; +# endif + + TestRunInfo runInfo; + Totals totals; + bool aborting; + }; + + struct IStreamingReporter : IShared { + virtual ~IStreamingReporter(); + + // Implementing class must also provide the following static method: + // static std::string getDescription(); + + virtual ReporterPreferences getPreferences() const = 0; + + virtual void noMatchingTestCases( std::string const& spec ) = 0; + + virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0; + virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0; + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0; + virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0; + + virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0; + + virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0; + virtual void sectionEnded( SectionStats const& sectionStats ) = 0; + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0; + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0; + virtual void testRunEnded( TestRunStats const& testRunStats ) = 0; + }; + + struct IReporterFactory { + virtual ~IReporterFactory(); + virtual IStreamingReporter* create( ReporterConfig const& config ) const = 0; + virtual std::string getDescription() const = 0; + }; + + struct IReporterRegistry { + typedef std::map FactoryMap; + + virtual ~IReporterRegistry(); + virtual IStreamingReporter* create( std::string const& name, Ptr const& config ) const = 0; + virtual FactoryMap const& getFactories() const = 0; + }; + +} + +#include +#include + +namespace Catch { + + inline std::size_t listTests( Config const& config ) { + + TestSpec testSpec = config.testSpec(); + if( config.testSpec().hasFilters() ) + std::cout << "Matching test cases:\n"; + else { + std::cout << "All available test cases:\n"; + testSpec = TestSpecParser().parse( "*" ).testSpec(); + } + + std::size_t matchedTests = 0; + TextAttributes nameAttr, tagsAttr; + nameAttr.setInitialIndent( 2 ).setIndent( 4 ); + tagsAttr.setIndent( 6 ); + + std::vector matchedTestCases; + getRegistryHub().getTestCaseRegistry().getFilteredTests( testSpec, config, matchedTestCases ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + matchedTests++; + TestCaseInfo const& testCaseInfo = it->getTestCaseInfo(); + Colour::Code colour = testCaseInfo.isHidden + ? Colour::SecondaryText + : Colour::None; + Colour colourGuard( colour ); + + std::cout << Text( testCaseInfo.name, nameAttr ) << std::endl; + if( !testCaseInfo.tags.empty() ) + std::cout << Text( testCaseInfo.tagsAsString, tagsAttr ) << std::endl; + } + + if( !config.testSpec().hasFilters() ) + std::cout << pluralise( matchedTests, "test case" ) << "\n" << std::endl; + else + std::cout << pluralise( matchedTests, "matching test case" ) << "\n" << std::endl; + return matchedTests; + } + + inline std::size_t listTestsNamesOnly( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( !config.testSpec().hasFilters() ) + testSpec = TestSpecParser().parse( "*" ).testSpec(); + std::size_t matchedTests = 0; + std::vector matchedTestCases; + getRegistryHub().getTestCaseRegistry().getFilteredTests( testSpec, config, matchedTestCases ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + matchedTests++; + TestCaseInfo const& testCaseInfo = it->getTestCaseInfo(); + std::cout << testCaseInfo.name << std::endl; + } + return matchedTests; + } + + struct TagInfo { + TagInfo() : count ( 0 ) {} + void add( std::string const& spelling ) { + ++count; + spellings.insert( spelling ); + } + std::string all() const { + std::string out; + for( std::set::const_iterator it = spellings.begin(), itEnd = spellings.end(); + it != itEnd; + ++it ) + out += "[" + *it + "]"; + return out; + } + std::set spellings; + std::size_t count; + }; + + inline std::size_t listTags( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.testSpec().hasFilters() ) + std::cout << "Tags for matching test cases:\n"; + else { + std::cout << "All available tags:\n"; + testSpec = TestSpecParser().parse( "*" ).testSpec(); + } + + std::map tagCounts; + + std::vector matchedTestCases; + getRegistryHub().getTestCaseRegistry().getFilteredTests( testSpec, config, matchedTestCases ); + for( std::vector::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end(); + it != itEnd; + ++it ) { + for( std::set::const_iterator tagIt = it->getTestCaseInfo().tags.begin(), + tagItEnd = it->getTestCaseInfo().tags.end(); + tagIt != tagItEnd; + ++tagIt ) { + std::string tagName = *tagIt; + std::string lcaseTagName = toLower( tagName ); + std::map::iterator countIt = tagCounts.find( lcaseTagName ); + if( countIt == tagCounts.end() ) + countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first; + countIt->second.add( tagName ); + } + } + + for( std::map::const_iterator countIt = tagCounts.begin(), + countItEnd = tagCounts.end(); + countIt != countItEnd; + ++countIt ) { + std::ostringstream oss; + oss << " " << std::setw(2) << countIt->second.count << " "; + Text wrapper( countIt->second.all(), TextAttributes() + .setInitialIndent( 0 ) + .setIndent( oss.str().size() ) + .setWidth( CATCH_CONFIG_CONSOLE_WIDTH-10 ) ); + std::cout << oss.str() << wrapper << "\n"; + } + std::cout << pluralise( tagCounts.size(), "tag" ) << "\n" << std::endl; + return tagCounts.size(); + } + + inline std::size_t listReporters( Config const& /*config*/ ) { + std::cout << "Available reports:\n"; + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + IReporterRegistry::FactoryMap::const_iterator itBegin = factories.begin(), itEnd = factories.end(), it; + std::size_t maxNameLen = 0; + for(it = itBegin; it != itEnd; ++it ) + maxNameLen = (std::max)( maxNameLen, it->first.size() ); + + for(it = itBegin; it != itEnd; ++it ) { + Text wrapper( it->second->getDescription(), TextAttributes() + .setInitialIndent( 0 ) + .setIndent( 7+maxNameLen ) + .setWidth( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 ) ); + std::cout << " " + << it->first + << ":" + << std::string( maxNameLen - it->first.size() + 2, ' ' ) + << wrapper << "\n"; + } + std::cout << std::endl; + return factories.size(); + } + + inline Option list( Config const& config ) { + Option listedCount; + if( config.listTests() ) + listedCount = listedCount.valueOr(0) + listTests( config ); + if( config.listTestNamesOnly() ) + listedCount = listedCount.valueOr(0) + listTestsNamesOnly( config ); + if( config.listTags() ) + listedCount = listedCount.valueOr(0) + listTags( config ); + if( config.listReporters() ) + listedCount = listedCount.valueOr(0) + listReporters( config ); + return listedCount; + } + +} // end namespace Catch + +// #included from: internal/catch_runner_impl.hpp +#define TWOBLUECUBES_CATCH_RUNNER_IMPL_HPP_INCLUDED + +// #included from: catch_test_case_tracker.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_TRACKER_HPP_INCLUDED + +#include +#include +#include + +namespace Catch { +namespace SectionTracking { + + class TrackedSection { + + typedef std::map TrackedSections; + + public: + enum RunState { + NotStarted, + Executing, + ExecutingChildren, + Completed + }; + + TrackedSection( std::string const& name, TrackedSection* parent ) + : m_name( name ), m_runState( NotStarted ), m_parent( parent ) + {} + + RunState runState() const { return m_runState; } + + TrackedSection* findChild( std::string const& childName ) { + TrackedSections::iterator it = m_children.find( childName ); + return it != m_children.end() + ? &it->second + : NULL; + } + TrackedSection* acquireChild( std::string const& childName ) { + if( TrackedSection* child = findChild( childName ) ) + return child; + m_children.insert( std::make_pair( childName, TrackedSection( childName, this ) ) ); + return findChild( childName ); + } + void enter() { + if( m_runState == NotStarted ) + m_runState = Executing; + } + void leave() { + for( TrackedSections::const_iterator it = m_children.begin(), itEnd = m_children.end(); + it != itEnd; + ++it ) + if( it->second.runState() != Completed ) { + m_runState = ExecutingChildren; + return; + } + m_runState = Completed; + } + TrackedSection* getParent() { + return m_parent; + } + bool hasChildren() const { + return !m_children.empty(); + } + + private: + std::string m_name; + RunState m_runState; + TrackedSections m_children; + TrackedSection* m_parent; + + }; + + class TestCaseTracker { + public: + TestCaseTracker( std::string const& testCaseName ) + : m_testCase( testCaseName, NULL ), + m_currentSection( &m_testCase ), + m_completedASectionThisRun( false ) + {} + + bool enterSection( std::string const& name ) { + TrackedSection* child = m_currentSection->acquireChild( name ); + if( m_completedASectionThisRun || child->runState() == TrackedSection::Completed ) + return false; + + m_currentSection = child; + m_currentSection->enter(); + return true; + } + void leaveSection() { + m_currentSection->leave(); + m_currentSection = m_currentSection->getParent(); + assert( m_currentSection != NULL ); + m_completedASectionThisRun = true; + } + + bool currentSectionHasChildren() const { + return m_currentSection->hasChildren(); + } + bool isCompleted() const { + return m_testCase.runState() == TrackedSection::Completed; + } + + class Guard { + public: + Guard( TestCaseTracker& tracker ) : m_tracker( tracker ) { + m_tracker.enterTestCase(); + } + ~Guard() { + m_tracker.leaveTestCase(); + } + private: + Guard( Guard const& ); + void operator = ( Guard const& ); + TestCaseTracker& m_tracker; + }; + + private: + void enterTestCase() { + m_currentSection = &m_testCase; + m_completedASectionThisRun = false; + m_testCase.enter(); + } + void leaveTestCase() { + m_testCase.leave(); + } + + TrackedSection m_testCase; + TrackedSection* m_currentSection; + bool m_completedASectionThisRun; + }; + +} // namespace SectionTracking + +using SectionTracking::TestCaseTracker; + +} // namespace Catch + +#include +#include + +namespace Catch { + + class StreamRedirect { + + public: + StreamRedirect( std::ostream& stream, std::string& targetString ) + : m_stream( stream ), + m_prevBuf( stream.rdbuf() ), + m_targetString( targetString ) + { + stream.rdbuf( m_oss.rdbuf() ); + } + + ~StreamRedirect() { + m_targetString += m_oss.str(); + m_stream.rdbuf( m_prevBuf ); + } + + private: + std::ostream& m_stream; + std::streambuf* m_prevBuf; + std::ostringstream m_oss; + std::string& m_targetString; + }; + + /////////////////////////////////////////////////////////////////////////// + + class RunContext : public IResultCapture, public IRunner { + + RunContext( RunContext const& ); + void operator =( RunContext const& ); + + public: + + explicit RunContext( Ptr const& config, Ptr const& reporter ) + : m_runInfo( config->name() ), + m_context( getCurrentMutableContext() ), + m_activeTestCase( NULL ), + m_config( config ), + m_reporter( reporter ), + m_prevRunner( m_context.getRunner() ), + m_prevResultCapture( m_context.getResultCapture() ), + m_prevConfig( m_context.getConfig() ) + { + m_context.setRunner( this ); + m_context.setConfig( m_config ); + m_context.setResultCapture( this ); + m_reporter->testRunStarting( m_runInfo ); + } + + virtual ~RunContext() { + m_reporter->testRunEnded( TestRunStats( m_runInfo, m_totals, aborting() ) ); + m_context.setRunner( m_prevRunner ); + m_context.setConfig( NULL ); + m_context.setResultCapture( m_prevResultCapture ); + m_context.setConfig( m_prevConfig ); + } + + void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount ) { + m_reporter->testGroupStarting( GroupInfo( testSpec, groupIndex, groupsCount ) ); + } + void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount ) { + m_reporter->testGroupEnded( TestGroupStats( GroupInfo( testSpec, groupIndex, groupsCount ), totals, aborting() ) ); + } + + Totals runTest( TestCase const& testCase ) { + Totals prevTotals = m_totals; + + std::string redirectedCout; + std::string redirectedCerr; + + TestCaseInfo testInfo = testCase.getTestCaseInfo(); + + m_reporter->testCaseStarting( testInfo ); + + m_activeTestCase = &testCase; + m_testCaseTracker = TestCaseTracker( testInfo.name ); + + do { + do { + runCurrentTest( redirectedCout, redirectedCerr ); + } + while( !m_testCaseTracker->isCompleted() && !aborting() ); + } + while( getCurrentContext().advanceGeneratorsForCurrentTest() && !aborting() ); + + Totals deltaTotals = m_totals.delta( prevTotals ); + m_totals.testCases += deltaTotals.testCases; + m_reporter->testCaseEnded( TestCaseStats( testInfo, + deltaTotals, + redirectedCout, + redirectedCerr, + aborting() ) ); + + m_activeTestCase = NULL; + m_testCaseTracker.reset(); + + return deltaTotals; + } + + Ptr config() const { + return m_config; + } + + private: // IResultCapture + + virtual void assertionEnded( AssertionResult const& result ) { + if( result.getResultType() == ResultWas::Ok ) { + m_totals.assertions.passed++; + } + else if( !result.isOk() ) { + m_totals.assertions.failed++; + } + + if( m_reporter->assertionEnded( AssertionStats( result, m_messages, m_totals ) ) ) + m_messages.clear(); + + // Reset working state + m_lastAssertionInfo = AssertionInfo( "", m_lastAssertionInfo.lineInfo, "{Unknown expression after the reported line}" , m_lastAssertionInfo.resultDisposition ); + m_lastResult = result; + } + + virtual bool sectionStarted ( + SectionInfo const& sectionInfo, + Counts& assertions + ) + { + std::ostringstream oss; + oss << sectionInfo.name << "@" << sectionInfo.lineInfo; + + if( !m_testCaseTracker->enterSection( oss.str() ) ) + return false; + + m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo; + + m_reporter->sectionStarting( sectionInfo ); + + assertions = m_totals.assertions; + + return true; + } + bool testForMissingAssertions( Counts& assertions ) { + if( assertions.total() != 0 || + !m_config->warnAboutMissingAssertions() || + m_testCaseTracker->currentSectionHasChildren() ) + return false; + m_totals.assertions.failed++; + assertions.failed++; + return true; + } + + virtual void sectionEnded( SectionInfo const& info, Counts const& prevAssertions, double _durationInSeconds ) { + if( std::uncaught_exception() ) { + m_unfinishedSections.push_back( UnfinishedSections( info, prevAssertions, _durationInSeconds ) ); + return; + } + + Counts assertions = m_totals.assertions - prevAssertions; + bool missingAssertions = testForMissingAssertions( assertions ); + + m_testCaseTracker->leaveSection(); + + m_reporter->sectionEnded( SectionStats( info, assertions, _durationInSeconds, missingAssertions ) ); + m_messages.clear(); + } + + virtual void pushScopedMessage( MessageInfo const& message ) { + m_messages.push_back( message ); + } + + virtual void popScopedMessage( MessageInfo const& message ) { + m_messages.erase( std::remove( m_messages.begin(), m_messages.end(), message ), m_messages.end() ); + } + + virtual std::string getCurrentTestName() const { + return m_activeTestCase + ? m_activeTestCase->getTestCaseInfo().name + : ""; + } + + virtual const AssertionResult* getLastResult() const { + return &m_lastResult; + } + + public: + // !TBD We need to do this another way! + bool aborting() const { + return m_totals.assertions.failed == static_cast( m_config->abortAfter() ); + } + + private: + + void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr ) { + TestCaseInfo const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection( testCaseInfo.name, testCaseInfo.description, testCaseInfo.lineInfo ); + m_reporter->sectionStarting( testCaseSection ); + Counts prevAssertions = m_totals.assertions; + double duration = 0; + try { + m_lastAssertionInfo = AssertionInfo( "TEST_CASE", testCaseInfo.lineInfo, "", ResultDisposition::Normal ); + TestCaseTracker::Guard guard( *m_testCaseTracker ); + + Timer timer; + timer.start(); + if( m_reporter->getPreferences().shouldRedirectStdOut ) { + StreamRedirect coutRedir( std::cout, redirectedCout ); + StreamRedirect cerrRedir( std::cerr, redirectedCerr ); + m_activeTestCase->invoke(); + } + else { + m_activeTestCase->invoke(); + } + duration = timer.getElapsedSeconds(); + } + catch( TestFailureException& ) { + // This just means the test was aborted due to failure + } + catch(...) { + ResultBuilder exResult( m_lastAssertionInfo.macroName.c_str(), + m_lastAssertionInfo.lineInfo, + m_lastAssertionInfo.capturedExpression.c_str(), + m_lastAssertionInfo.resultDisposition ); + exResult.useActiveException(); + } + // If sections ended prematurely due to an exception we stored their + // infos here so we can tear them down outside the unwind process. + for( std::vector::const_reverse_iterator it = m_unfinishedSections.rbegin(), + itEnd = m_unfinishedSections.rend(); + it != itEnd; + ++it ) + sectionEnded( it->info, it->prevAssertions, it->durationInSeconds ); + m_unfinishedSections.clear(); + m_messages.clear(); + + Counts assertions = m_totals.assertions - prevAssertions; + bool missingAssertions = testForMissingAssertions( assertions ); + + SectionStats testCaseSectionStats( testCaseSection, assertions, duration, missingAssertions ); + m_reporter->sectionEnded( testCaseSectionStats ); + } + + private: + struct UnfinishedSections { + UnfinishedSections( SectionInfo const& _info, Counts const& _prevAssertions, double _durationInSeconds ) + : info( _info ), prevAssertions( _prevAssertions ), durationInSeconds( _durationInSeconds ) + {} + + SectionInfo info; + Counts prevAssertions; + double durationInSeconds; + }; + + TestRunInfo m_runInfo; + IMutableContext& m_context; + TestCase const* m_activeTestCase; + Option m_testCaseTracker; + AssertionResult m_lastResult; + + Ptr m_config; + Totals m_totals; + Ptr m_reporter; + std::vector m_messages; + IRunner* m_prevRunner; + IResultCapture* m_prevResultCapture; + Ptr m_prevConfig; + AssertionInfo m_lastAssertionInfo; + std::vector m_unfinishedSections; + }; + + IResultCapture& getResultCapture() { + if( IResultCapture* capture = getCurrentContext().getResultCapture() ) + return *capture; + else + throw std::logic_error( "No result capture instance" ); + } + +} // end namespace Catch + +// #included from: internal/catch_version.h +#define TWOBLUECUBES_CATCH_VERSION_H_INCLUDED + +namespace Catch { + + // Versioning information + struct Version { + Version( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _buildNumber, + char const* const _branchName ) + : majorVersion( _majorVersion ), + minorVersion( _minorVersion ), + buildNumber( _buildNumber ), + branchName( _branchName ) + {} + + unsigned int const majorVersion; + unsigned int const minorVersion; + unsigned int const buildNumber; + char const* const branchName; + + private: + void operator=( Version const& ); + }; + + extern Version libraryVersion; +} + +#include +#include +#include + +namespace Catch { + + class Runner { + + public: + Runner( Ptr const& config ) + : m_config( config ) + { + openStream(); + makeReporter(); + } + + Totals runTests() { + + RunContext context( m_config.get(), m_reporter ); + + Totals totals; + + context.testGroupStarting( "", 1, 1 ); // deprecated? + + TestSpec testSpec = m_config->testSpec(); + if( !testSpec.hasFilters() ) + testSpec = TestSpecParser().parse( "~[.]" ).testSpec(); // All not hidden tests + + std::vector testCases; + getRegistryHub().getTestCaseRegistry().getFilteredTests( testSpec, *m_config, testCases ); + + int testsRunForGroup = 0; + for( std::vector::const_iterator it = testCases.begin(), itEnd = testCases.end(); + it != itEnd; + ++it ) { + testsRunForGroup++; + if( m_testsAlreadyRun.find( *it ) == m_testsAlreadyRun.end() ) { + + if( context.aborting() ) + break; + + totals += context.runTest( *it ); + m_testsAlreadyRun.insert( *it ); + } + } + context.testGroupEnded( "", totals, 1, 1 ); + return totals; + } + + private: + void openStream() { + // Open output file, if specified + if( !m_config->getFilename().empty() ) { + m_ofs.open( m_config->getFilename().c_str() ); + if( m_ofs.fail() ) { + std::ostringstream oss; + oss << "Unable to open file: '" << m_config->getFilename() << "'"; + throw std::domain_error( oss.str() ); + } + m_config->setStreamBuf( m_ofs.rdbuf() ); + } + } + void makeReporter() { + std::string reporterName = m_config->getReporterName().empty() + ? "console" + : m_config->getReporterName(); + + m_reporter = getRegistryHub().getReporterRegistry().create( reporterName, m_config.get() ); + if( !m_reporter ) { + std::ostringstream oss; + oss << "No reporter registered with name: '" << reporterName << "'"; + throw std::domain_error( oss.str() ); + } + } + + private: + Ptr m_config; + std::ofstream m_ofs; + Ptr m_reporter; + std::set m_testsAlreadyRun; + }; + + class Session { + static bool alreadyInstantiated; + + public: + + struct OnUnusedOptions { enum DoWhat { Ignore, Fail }; }; + + Session() + : m_cli( makeCommandLineParser() ) { + if( alreadyInstantiated ) { + std::string msg = "Only one instance of Catch::Session can ever be used"; + std::cerr << msg << std::endl; + throw std::logic_error( msg ); + } + alreadyInstantiated = true; + } + ~Session() { + Catch::cleanUp(); + } + + void showHelp( std::string const& processName ) { + std::cout << "\nCatch v" << libraryVersion.majorVersion << "." + << libraryVersion.minorVersion << " build " + << libraryVersion.buildNumber; + if( libraryVersion.branchName != std::string( "master" ) ) + std::cout << " (" << libraryVersion.branchName << " branch)"; + std::cout << "\n"; + + m_cli.usage( std::cout, processName ); + std::cout << "For more detail usage please see the project docs\n" << std::endl; + } + + int applyCommandLine( int argc, char* const argv[], OnUnusedOptions::DoWhat unusedOptionBehaviour = OnUnusedOptions::Fail ) { + try { + m_cli.setThrowOnUnrecognisedTokens( unusedOptionBehaviour == OnUnusedOptions::Fail ); + m_unusedTokens = m_cli.parseInto( argc, argv, m_configData ); + if( m_configData.showHelp ) + showHelp( m_configData.processName ); + m_config.reset(); + } + catch( std::exception& ex ) { + { + Colour colourGuard( Colour::Red ); + std::cerr << "\nError(s) in input:\n" + << Text( ex.what(), TextAttributes().setIndent(2) ) + << "\n\n"; + } + m_cli.usage( std::cout, m_configData.processName ); + return (std::numeric_limits::max)(); + } + return 0; + } + + void useConfigData( ConfigData const& _configData ) { + m_configData = _configData; + m_config.reset(); + } + + int run( int argc, char* const argv[] ) { + + int returnCode = applyCommandLine( argc, argv ); + if( returnCode == 0 ) + returnCode = run(); + return returnCode; + } + + int run() { + if( m_configData.showHelp ) + return 0; + + try + { + config(); // Force config to be constructed + Runner runner( m_config ); + + // Handle list request + if( Option listed = list( config() ) ) + return static_cast( *listed ); + + return static_cast( runner.runTests().assertions.failed ); + } + catch( std::exception& ex ) { + std::cerr << ex.what() << std::endl; + return (std::numeric_limits::max)(); + } + } + + Clara::CommandLine const& cli() const { + return m_cli; + } + std::vector const& unusedTokens() const { + return m_unusedTokens; + } + ConfigData& configData() { + return m_configData; + } + Config& config() { + if( !m_config ) + m_config = new Config( m_configData ); + return *m_config; + } + + private: + Clara::CommandLine m_cli; + std::vector m_unusedTokens; + ConfigData m_configData; + Ptr m_config; + }; + + bool Session::alreadyInstantiated = false; + +} // end namespace Catch + +// #included from: catch_registry_hub.hpp +#define TWOBLUECUBES_CATCH_REGISTRY_HUB_HPP_INCLUDED + +// #included from: catch_test_case_registry_impl.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_REGISTRY_IMPL_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + + class TestRegistry : public ITestCaseRegistry { + public: + TestRegistry() : m_unnamedCount( 0 ) {} + virtual ~TestRegistry(); + + virtual void registerTest( TestCase const& testCase ) { + std::string name = testCase.getTestCaseInfo().name; + if( name == "" ) { + std::ostringstream oss; + oss << "Anonymous test case " << ++m_unnamedCount; + return registerTest( testCase.withName( oss.str() ) ); + } + + if( m_functions.find( testCase ) == m_functions.end() ) { + m_functions.insert( testCase ); + m_functionsInOrder.push_back( testCase ); + if( !testCase.isHidden() ) + m_nonHiddenFunctions.push_back( testCase ); + } + else { + TestCase const& prev = *m_functions.find( testCase ); + { + Colour colourGuard( Colour::Red ); + std::cerr << "error: TEST_CASE( \"" << name << "\" ) already defined.\n" + << "\tFirst seen at " << prev.getTestCaseInfo().lineInfo << "\n" + << "\tRedefined at " << testCase.getTestCaseInfo().lineInfo << std::endl; + } + exit(1); + } + } + + virtual std::vector const& getAllTests() const { + return m_functionsInOrder; + } + + virtual std::vector const& getAllNonHiddenTests() const { + return m_nonHiddenFunctions; + } + + virtual void getFilteredTests( TestSpec const& testSpec, IConfig const& config, std::vector& matchingTestCases ) const { + for( std::vector::const_iterator it = m_functionsInOrder.begin(), + itEnd = m_functionsInOrder.end(); + it != itEnd; + ++it ) { + if( testSpec.matches( *it ) && ( config.allowThrows() || !it->throws() ) ) + matchingTestCases.push_back( *it ); + } + } + + private: + + std::set m_functions; + std::vector m_functionsInOrder; + std::vector m_nonHiddenFunctions; + size_t m_unnamedCount; + }; + + /////////////////////////////////////////////////////////////////////////// + + class FreeFunctionTestCase : public SharedImpl { + public: + + FreeFunctionTestCase( TestFunction fun ) : m_fun( fun ) {} + + virtual void invoke() const { + m_fun(); + } + + private: + virtual ~FreeFunctionTestCase(); + + TestFunction m_fun; + }; + + inline std::string extractClassName( std::string const& classOrQualifiedMethodName ) { + std::string className = classOrQualifiedMethodName; + if( startsWith( className, "&" ) ) + { + std::size_t lastColons = className.rfind( "::" ); + std::size_t penultimateColons = className.rfind( "::", lastColons-1 ); + if( penultimateColons == std::string::npos ) + penultimateColons = 1; + className = className.substr( penultimateColons, lastColons-penultimateColons ); + } + return className; + } + + /////////////////////////////////////////////////////////////////////////// + + AutoReg::AutoReg( TestFunction function, + SourceLineInfo const& lineInfo, + NameAndDesc const& nameAndDesc ) { + registerTestCase( new FreeFunctionTestCase( function ), "", nameAndDesc, lineInfo ); + } + + AutoReg::~AutoReg() {} + + void AutoReg::registerTestCase( ITestCase* testCase, + char const* classOrQualifiedMethodName, + NameAndDesc const& nameAndDesc, + SourceLineInfo const& lineInfo ) { + + getMutableRegistryHub().registerTest + ( makeTestCase( testCase, + extractClassName( classOrQualifiedMethodName ), + nameAndDesc.name, + nameAndDesc.description, + lineInfo ) ); + } + +} // end namespace Catch + +// #included from: catch_reporter_registry.hpp +#define TWOBLUECUBES_CATCH_REPORTER_REGISTRY_HPP_INCLUDED + +#include + +namespace Catch { + + class ReporterRegistry : public IReporterRegistry { + + public: + + virtual ~ReporterRegistry() { + deleteAllValues( m_factories ); + } + + virtual IStreamingReporter* create( std::string const& name, Ptr const& config ) const { + FactoryMap::const_iterator it = m_factories.find( name ); + if( it == m_factories.end() ) + return NULL; + return it->second->create( ReporterConfig( config ) ); + } + + void registerReporter( std::string const& name, IReporterFactory* factory ) { + m_factories.insert( std::make_pair( name, factory ) ); + } + + FactoryMap const& getFactories() const { + return m_factories; + } + + private: + FactoryMap m_factories; + }; +} + +// #included from: catch_exception_translator_registry.hpp +#define TWOBLUECUBES_CATCH_EXCEPTION_TRANSLATOR_REGISTRY_HPP_INCLUDED + +#ifdef __OBJC__ +#import "Foundation/Foundation.h" +#endif + +namespace Catch { + + class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry { + public: + ~ExceptionTranslatorRegistry() { + deleteAll( m_translators ); + } + + virtual void registerTranslator( const IExceptionTranslator* translator ) { + m_translators.push_back( translator ); + } + + virtual std::string translateActiveException() const { + try { +#ifdef __OBJC__ + // In Objective-C try objective-c exceptions first + @try { + throw; + } + @catch (NSException *exception) { + return toString( [exception description] ); + } +#else + throw; +#endif + } + catch( TestFailureException& ) { + throw; + } + catch( std::exception& ex ) { + return ex.what(); + } + catch( std::string& msg ) { + return msg; + } + catch( const char* msg ) { + return msg; + } + catch(...) { + return tryTranslators( m_translators.begin() ); + } + } + + std::string tryTranslators( std::vector::const_iterator it ) const { + if( it == m_translators.end() ) + return "Unknown exception"; + + try { + return (*it)->translate(); + } + catch(...) { + return tryTranslators( it+1 ); + } + } + + private: + std::vector m_translators; + }; +} + +namespace Catch { + + namespace { + + class RegistryHub : public IRegistryHub, public IMutableRegistryHub { + + RegistryHub( RegistryHub const& ); + void operator=( RegistryHub const& ); + + public: // IRegistryHub + RegistryHub() { + } + virtual IReporterRegistry const& getReporterRegistry() const { + return m_reporterRegistry; + } + virtual ITestCaseRegistry const& getTestCaseRegistry() const { + return m_testCaseRegistry; + } + virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() { + return m_exceptionTranslatorRegistry; + } + + public: // IMutableRegistryHub + virtual void registerReporter( std::string const& name, IReporterFactory* factory ) { + m_reporterRegistry.registerReporter( name, factory ); + } + virtual void registerTest( TestCase const& testInfo ) { + m_testCaseRegistry.registerTest( testInfo ); + } + virtual void registerTranslator( const IExceptionTranslator* translator ) { + m_exceptionTranslatorRegistry.registerTranslator( translator ); + } + + private: + TestRegistry m_testCaseRegistry; + ReporterRegistry m_reporterRegistry; + ExceptionTranslatorRegistry m_exceptionTranslatorRegistry; + }; + + // Single, global, instance + inline RegistryHub*& getTheRegistryHub() { + static RegistryHub* theRegistryHub = NULL; + if( !theRegistryHub ) + theRegistryHub = new RegistryHub(); + return theRegistryHub; + } + } + + IRegistryHub& getRegistryHub() { + return *getTheRegistryHub(); + } + IMutableRegistryHub& getMutableRegistryHub() { + return *getTheRegistryHub(); + } + void cleanUp() { + delete getTheRegistryHub(); + getTheRegistryHub() = NULL; + cleanUpContext(); + } + std::string translateActiveException() { + return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException(); + } + +} // end namespace Catch + +// #included from: catch_notimplemented_exception.hpp +#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_HPP_INCLUDED + +#include + +namespace Catch { + + NotImplementedException::NotImplementedException( SourceLineInfo const& lineInfo ) + : m_lineInfo( lineInfo ) { + std::ostringstream oss; + oss << lineInfo << ": function "; + oss << "not implemented"; + m_what = oss.str(); + } + + const char* NotImplementedException::what() const CATCH_NOEXCEPT { + return m_what.c_str(); + } + +} // end namespace Catch + +// #included from: catch_context_impl.hpp +#define TWOBLUECUBES_CATCH_CONTEXT_IMPL_HPP_INCLUDED + +// #included from: catch_stream.hpp +#define TWOBLUECUBES_CATCH_STREAM_HPP_INCLUDED + +// #included from: catch_streambuf.h +#define TWOBLUECUBES_CATCH_STREAMBUF_H_INCLUDED + +#include + +namespace Catch { + + class StreamBufBase : public std::streambuf { + public: + virtual ~StreamBufBase() CATCH_NOEXCEPT; + }; +} + +#include +#include + +namespace Catch { + + template + class StreamBufImpl : public StreamBufBase { + char data[bufferSize]; + WriterF m_writer; + + public: + StreamBufImpl() { + setp( data, data + sizeof(data) ); + } + + ~StreamBufImpl() CATCH_NOEXCEPT { + sync(); + } + + private: + int overflow( int c ) { + sync(); + + if( c != EOF ) { + if( pbase() == epptr() ) + m_writer( std::string( 1, static_cast( c ) ) ); + else + sputc( static_cast( c ) ); + } + return 0; + } + + int sync() { + if( pbase() != pptr() ) { + m_writer( std::string( pbase(), static_cast( pptr() - pbase() ) ) ); + setp( pbase(), epptr() ); + } + return 0; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + struct OutputDebugWriter { + + void operator()( std::string const&str ) { + writeToDebugConsole( str ); + } + }; + + Stream::Stream() + : streamBuf( NULL ), isOwned( false ) + {} + + Stream::Stream( std::streambuf* _streamBuf, bool _isOwned ) + : streamBuf( _streamBuf ), isOwned( _isOwned ) + {} + + void Stream::release() { + if( isOwned ) { + delete streamBuf; + streamBuf = NULL; + isOwned = false; + } + } +} + +namespace Catch { + + class Context : public IMutableContext { + + Context() : m_config( NULL ), m_runner( NULL ), m_resultCapture( NULL ) {} + Context( Context const& ); + void operator=( Context const& ); + + public: // IContext + virtual IResultCapture* getResultCapture() { + return m_resultCapture; + } + virtual IRunner* getRunner() { + return m_runner; + } + virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) { + return getGeneratorsForCurrentTest() + .getGeneratorInfo( fileInfo, totalSize ) + .getCurrentIndex(); + } + virtual bool advanceGeneratorsForCurrentTest() { + IGeneratorsForTest* generators = findGeneratorsForCurrentTest(); + return generators && generators->moveNext(); + } + + virtual Ptr getConfig() const { + return m_config; + } + + public: // IMutableContext + virtual void setResultCapture( IResultCapture* resultCapture ) { + m_resultCapture = resultCapture; + } + virtual void setRunner( IRunner* runner ) { + m_runner = runner; + } + virtual void setConfig( Ptr const& config ) { + m_config = config; + } + + friend IMutableContext& getCurrentMutableContext(); + + private: + IGeneratorsForTest* findGeneratorsForCurrentTest() { + std::string testName = getResultCapture()->getCurrentTestName(); + + std::map::const_iterator it = + m_generatorsByTestName.find( testName ); + return it != m_generatorsByTestName.end() + ? it->second + : NULL; + } + + IGeneratorsForTest& getGeneratorsForCurrentTest() { + IGeneratorsForTest* generators = findGeneratorsForCurrentTest(); + if( !generators ) { + std::string testName = getResultCapture()->getCurrentTestName(); + generators = createGeneratorsForTest(); + m_generatorsByTestName.insert( std::make_pair( testName, generators ) ); + } + return *generators; + } + + private: + Ptr m_config; + IRunner* m_runner; + IResultCapture* m_resultCapture; + std::map m_generatorsByTestName; + }; + + namespace { + Context* currentContext = NULL; + } + IMutableContext& getCurrentMutableContext() { + if( !currentContext ) + currentContext = new Context(); + return *currentContext; + } + IContext& getCurrentContext() { + return getCurrentMutableContext(); + } + + Stream createStream( std::string const& streamName ) { + if( streamName == "stdout" ) return Stream( std::cout.rdbuf(), false ); + if( streamName == "stderr" ) return Stream( std::cerr.rdbuf(), false ); + if( streamName == "debug" ) return Stream( new StreamBufImpl, true ); + + throw std::domain_error( "Unknown stream: " + streamName ); + } + + void cleanUpContext() { + delete currentContext; + currentContext = NULL; + } +} + +// #included from: catch_console_colour_impl.hpp +#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_IMPL_HPP_INCLUDED + +namespace Catch { namespace Detail { + struct IColourImpl { + virtual ~IColourImpl() {} + virtual void use( Colour::Code _colourCode ) = 0; + }; +}} + +#if defined ( CATCH_PLATFORM_WINDOWS ) ///////////////////////////////////////// + +#ifndef NOMINMAX +#define NOMINMAX +#endif + +#ifdef __AFXDLL +#include +#else +#include +#endif + +namespace Catch { +namespace { + + class Win32ColourImpl : public Detail::IColourImpl { + public: + Win32ColourImpl() : stdoutHandle( GetStdHandle(STD_OUTPUT_HANDLE) ) + { + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo( stdoutHandle, &csbiInfo ); + originalAttributes = csbiInfo.wAttributes; + } + + virtual void use( Colour::Code _colourCode ) { + switch( _colourCode ) { + case Colour::None: return setTextAttribute( originalAttributes ); + case Colour::White: return setTextAttribute( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::Red: return setTextAttribute( FOREGROUND_RED ); + case Colour::Green: return setTextAttribute( FOREGROUND_GREEN ); + case Colour::Blue: return setTextAttribute( FOREGROUND_BLUE ); + case Colour::Cyan: return setTextAttribute( FOREGROUND_BLUE | FOREGROUND_GREEN ); + case Colour::Yellow: return setTextAttribute( FOREGROUND_RED | FOREGROUND_GREEN ); + case Colour::Grey: return setTextAttribute( 0 ); + + case Colour::LightGrey: return setTextAttribute( FOREGROUND_INTENSITY ); + case Colour::BrightRed: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED ); + case Colour::BrightGreen: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN ); + case Colour::BrightWhite: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + + case Colour::Bright: throw std::logic_error( "not a colour" ); + } + } + + private: + void setTextAttribute( WORD _textAttribute ) { + SetConsoleTextAttribute( stdoutHandle, _textAttribute ); + } + HANDLE stdoutHandle; + WORD originalAttributes; + }; + + inline bool shouldUseColourForPlatform() { + return true; + } + + static Detail::IColourImpl* platformColourInstance() { + static Win32ColourImpl s_instance; + return &s_instance; + } + +} // end anon namespace +} // end namespace Catch + +#else // Not Windows - assumed to be POSIX compatible ////////////////////////// + +#include + +namespace Catch { +namespace { + + // use POSIX/ ANSI console terminal codes + // Thanks to Adam Strzelecki for original contribution + // (http://github.com/nanoant) + // https://github.com/philsquared/Catch/pull/131 + class PosixColourImpl : public Detail::IColourImpl { + public: + virtual void use( Colour::Code _colourCode ) { + switch( _colourCode ) { + case Colour::None: + case Colour::White: return setColour( "[0m" ); + case Colour::Red: return setColour( "[0;31m" ); + case Colour::Green: return setColour( "[0;32m" ); + case Colour::Blue: return setColour( "[0:34m" ); + case Colour::Cyan: return setColour( "[0;36m" ); + case Colour::Yellow: return setColour( "[0;33m" ); + case Colour::Grey: return setColour( "[1;30m" ); + + case Colour::LightGrey: return setColour( "[0;37m" ); + case Colour::BrightRed: return setColour( "[1;31m" ); + case Colour::BrightGreen: return setColour( "[1;32m" ); + case Colour::BrightWhite: return setColour( "[1;37m" ); + + case Colour::Bright: throw std::logic_error( "not a colour" ); + } + } + private: + void setColour( const char* _escapeCode ) { + std::cout << '\033' << _escapeCode; + } + }; + + inline bool shouldUseColourForPlatform() { + return isatty(STDOUT_FILENO); + } + + static Detail::IColourImpl* platformColourInstance() { + static PosixColourImpl s_instance; + return &s_instance; + } + +} // end anon namespace +} // end namespace Catch + +#endif // not Windows + +namespace Catch { + + namespace { + struct NoColourImpl : Detail::IColourImpl { + void use( Colour::Code ) {} + + static IColourImpl* instance() { + static NoColourImpl s_instance; + return &s_instance; + } + }; + static bool shouldUseColour() { + return shouldUseColourForPlatform() && !isDebuggerActive(); + } + } + + Colour::Colour( Code _colourCode ){ use( _colourCode ); } + Colour::~Colour(){ use( None ); } + void Colour::use( Code _colourCode ) { + impl()->use( _colourCode ); + } + + Detail::IColourImpl* Colour::impl() { + return shouldUseColour() + ? platformColourInstance() + : NoColourImpl::instance(); + } + +} // end namespace Catch + +// #included from: catch_generators_impl.hpp +#define TWOBLUECUBES_CATCH_GENERATORS_IMPL_HPP_INCLUDED + +#include +#include +#include + +namespace Catch { + + struct GeneratorInfo : IGeneratorInfo { + + GeneratorInfo( std::size_t size ) + : m_size( size ), + m_currentIndex( 0 ) + {} + + bool moveNext() { + if( ++m_currentIndex == m_size ) { + m_currentIndex = 0; + return false; + } + return true; + } + + std::size_t getCurrentIndex() const { + return m_currentIndex; + } + + std::size_t m_size; + std::size_t m_currentIndex; + }; + + /////////////////////////////////////////////////////////////////////////// + + class GeneratorsForTest : public IGeneratorsForTest { + + public: + ~GeneratorsForTest() { + deleteAll( m_generatorsInOrder ); + } + + IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) { + std::map::const_iterator it = m_generatorsByName.find( fileInfo ); + if( it == m_generatorsByName.end() ) { + IGeneratorInfo* info = new GeneratorInfo( size ); + m_generatorsByName.insert( std::make_pair( fileInfo, info ) ); + m_generatorsInOrder.push_back( info ); + return *info; + } + return *it->second; + } + + bool moveNext() { + std::vector::const_iterator it = m_generatorsInOrder.begin(); + std::vector::const_iterator itEnd = m_generatorsInOrder.end(); + for(; it != itEnd; ++it ) { + if( (*it)->moveNext() ) + return true; + } + return false; + } + + private: + std::map m_generatorsByName; + std::vector m_generatorsInOrder; + }; + + IGeneratorsForTest* createGeneratorsForTest() + { + return new GeneratorsForTest(); + } + +} // end namespace Catch + +// #included from: catch_assertionresult.hpp +#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_HPP_INCLUDED + +namespace Catch { + + AssertionInfo::AssertionInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + std::string const& _capturedExpression, + ResultDisposition::Flags _resultDisposition ) + : macroName( _macroName ), + lineInfo( _lineInfo ), + capturedExpression( _capturedExpression ), + resultDisposition( _resultDisposition ) + {} + + AssertionResult::AssertionResult() {} + + AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data ) + : m_info( info ), + m_resultData( data ) + {} + + AssertionResult::~AssertionResult() {} + + // Result was a success + bool AssertionResult::succeeded() const { + return Catch::isOk( m_resultData.resultType ); + } + + // Result was a success, or failure is suppressed + bool AssertionResult::isOk() const { + return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition ); + } + + ResultWas::OfType AssertionResult::getResultType() const { + return m_resultData.resultType; + } + + bool AssertionResult::hasExpression() const { + return !m_info.capturedExpression.empty(); + } + + bool AssertionResult::hasMessage() const { + return !m_resultData.message.empty(); + } + + std::string AssertionResult::getExpression() const { + if( isFalseTest( m_info.resultDisposition ) ) + return "!" + m_info.capturedExpression; + else + return m_info.capturedExpression; + } + std::string AssertionResult::getExpressionInMacro() const { + if( m_info.macroName.empty() ) + return m_info.capturedExpression; + else + return m_info.macroName + "( " + m_info.capturedExpression + " )"; + } + + bool AssertionResult::hasExpandedExpression() const { + return hasExpression() && getExpandedExpression() != getExpression(); + } + + std::string AssertionResult::getExpandedExpression() const { + return m_resultData.reconstructedExpression; + } + + std::string AssertionResult::getMessage() const { + return m_resultData.message; + } + SourceLineInfo AssertionResult::getSourceInfo() const { + return m_info.lineInfo; + } + + std::string AssertionResult::getTestMacroName() const { + return m_info.macroName; + } + +} // end namespace Catch + +// #included from: catch_test_case_info.hpp +#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_HPP_INCLUDED + +namespace Catch { + + inline bool isSpecialTag( std::string const& tag ) { + return tag == "." || + tag == "hide" || + tag == "!hide" || + tag == "!throws"; + } + inline bool isReservedTag( std::string const& tag ) { + return !isSpecialTag( tag ) && tag.size() > 0 && !isalnum( tag[0] ); + } + inline void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) { + if( isReservedTag( tag ) ) { + { + Colour colourGuard( Colour::Red ); + std::cerr + << "Tag name [" << tag << "] not allowed.\n" + << "Tag names starting with non alpha-numeric characters are reserved\n"; + } + { + Colour colourGuard( Colour::FileName ); + std::cerr << _lineInfo << std::endl; + } + exit(1); + } + } + + TestCase makeTestCase( ITestCase* _testCase, + std::string const& _className, + std::string const& _name, + std::string const& _descOrTags, + SourceLineInfo const& _lineInfo ) + { + bool isHidden( startsWith( _name, "./" ) ); // Legacy support + + // Parse out tags + std::set tags; + std::string desc, tag; + bool inTag = false; + for( std::size_t i = 0; i < _descOrTags.size(); ++i ) { + char c = _descOrTags[i]; + if( !inTag ) { + if( c == '[' ) + inTag = true; + else + desc += c; + } + else { + if( c == ']' ) { + enforceNotReservedTag( tag, _lineInfo ); + + inTag = false; + if( tag == "hide" || tag == "." ) + isHidden = true; + else + tags.insert( tag ); + tag.clear(); + } + else + tag += c; + } + } + if( isHidden ) { + tags.insert( "hide" ); + tags.insert( "." ); + } + + TestCaseInfo info( _name, _className, desc, tags, isHidden, _lineInfo ); + return TestCase( _testCase, info ); + } + + TestCaseInfo::TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::set const& _tags, + bool _isHidden, + SourceLineInfo const& _lineInfo ) + : name( _name ), + className( _className ), + description( _description ), + tags( _tags ), + lineInfo( _lineInfo ), + isHidden( _isHidden ), + throws( false ) + { + std::ostringstream oss; + for( std::set::const_iterator it = _tags.begin(), itEnd = _tags.end(); it != itEnd; ++it ) { + oss << "[" << *it << "]"; + if( *it == "!throws" ) + throws = true; + lcaseTags.insert( toLower( *it ) ); + } + tagsAsString = oss.str(); + } + + TestCaseInfo::TestCaseInfo( TestCaseInfo const& other ) + : name( other.name ), + className( other.className ), + description( other.description ), + tags( other.tags ), + lcaseTags( other.lcaseTags ), + tagsAsString( other.tagsAsString ), + lineInfo( other.lineInfo ), + isHidden( other.isHidden ), + throws( other.throws ) + {} + + TestCase::TestCase( ITestCase* testCase, TestCaseInfo const& info ) : TestCaseInfo( info ), test( testCase ) {} + + TestCase::TestCase( TestCase const& other ) + : TestCaseInfo( other ), + test( other.test ) + {} + + TestCase TestCase::withName( std::string const& _newName ) const { + TestCase other( *this ); + other.name = _newName; + return other; + } + + void TestCase::swap( TestCase& other ) { + test.swap( other.test ); + name.swap( other.name ); + className.swap( other.className ); + description.swap( other.description ); + tags.swap( other.tags ); + lcaseTags.swap( other.lcaseTags ); + tagsAsString.swap( other.tagsAsString ); + std::swap( TestCaseInfo::isHidden, static_cast( other ).isHidden ); + std::swap( TestCaseInfo::throws, static_cast( other ).throws ); + std::swap( lineInfo, other.lineInfo ); + } + + void TestCase::invoke() const { + test->invoke(); + } + + bool TestCase::isHidden() const { + return TestCaseInfo::isHidden; + } + bool TestCase::throws() const { + return TestCaseInfo::throws; + } + + bool TestCase::operator == ( TestCase const& other ) const { + return test.get() == other.test.get() && + name == other.name && + className == other.className; + } + + bool TestCase::operator < ( TestCase const& other ) const { + return name < other.name; + } + TestCase& TestCase::operator = ( TestCase const& other ) { + TestCase temp( other ); + swap( temp ); + return *this; + } + + TestCaseInfo const& TestCase::getTestCaseInfo() const + { + return *this; + } + +} // end namespace Catch + +// #included from: catch_version.hpp +#define TWOBLUECUBES_CATCH_VERSION_HPP_INCLUDED + +namespace Catch { + + // These numbers are maintained by a script + Version libraryVersion( 1, 0, 48, "master" ); +} + +// #included from: catch_message.hpp +#define TWOBLUECUBES_CATCH_MESSAGE_HPP_INCLUDED + +namespace Catch { + + MessageInfo::MessageInfo( std::string const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ) + : macroName( _macroName ), + lineInfo( _lineInfo ), + type( _type ), + sequence( ++globalCount ) + {} + + // This may need protecting if threading support is added + unsigned int MessageInfo::globalCount = 0; + + //////////////////////////////////////////////////////////////////////////// + + ScopedMessage::ScopedMessage( MessageBuilder const& builder ) + : m_info( builder.m_info ) + { + m_info.message = builder.m_stream.str(); + getResultCapture().pushScopedMessage( m_info ); + } + ScopedMessage::ScopedMessage( ScopedMessage const& other ) + : m_info( other.m_info ) + {} + + ScopedMessage::~ScopedMessage() { + getResultCapture().popScopedMessage( m_info ); + } + +} // end namespace Catch + +// #included from: catch_legacy_reporter_adapter.hpp +#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_HPP_INCLUDED + +// #included from: catch_legacy_reporter_adapter.h +#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_H_INCLUDED + +namespace Catch +{ + // Deprecated + struct IReporter : IShared { + virtual ~IReporter(); + + virtual bool shouldRedirectStdout() const = 0; + + virtual void StartTesting() = 0; + virtual void EndTesting( Totals const& totals ) = 0; + virtual void StartGroup( std::string const& groupName ) = 0; + virtual void EndGroup( std::string const& groupName, Totals const& totals ) = 0; + virtual void StartTestCase( TestCaseInfo const& testInfo ) = 0; + virtual void EndTestCase( TestCaseInfo const& testInfo, Totals const& totals, std::string const& stdOut, std::string const& stdErr ) = 0; + virtual void StartSection( std::string const& sectionName, std::string const& description ) = 0; + virtual void EndSection( std::string const& sectionName, Counts const& assertions ) = 0; + virtual void NoAssertionsInSection( std::string const& sectionName ) = 0; + virtual void NoAssertionsInTestCase( std::string const& testName ) = 0; + virtual void Aborted() = 0; + virtual void Result( AssertionResult const& result ) = 0; + }; + + class LegacyReporterAdapter : public SharedImpl + { + public: + LegacyReporterAdapter( Ptr const& legacyReporter ); + virtual ~LegacyReporterAdapter(); + + virtual ReporterPreferences getPreferences() const; + virtual void noMatchingTestCases( std::string const& ); + virtual void testRunStarting( TestRunInfo const& ); + virtual void testGroupStarting( GroupInfo const& groupInfo ); + virtual void testCaseStarting( TestCaseInfo const& testInfo ); + virtual void sectionStarting( SectionInfo const& sectionInfo ); + virtual void assertionStarting( AssertionInfo const& ); + virtual bool assertionEnded( AssertionStats const& assertionStats ); + virtual void sectionEnded( SectionStats const& sectionStats ); + virtual void testCaseEnded( TestCaseStats const& testCaseStats ); + virtual void testGroupEnded( TestGroupStats const& testGroupStats ); + virtual void testRunEnded( TestRunStats const& testRunStats ); + + private: + Ptr m_legacyReporter; + }; +} + +namespace Catch +{ + LegacyReporterAdapter::LegacyReporterAdapter( Ptr const& legacyReporter ) + : m_legacyReporter( legacyReporter ) + {} + LegacyReporterAdapter::~LegacyReporterAdapter() {} + + ReporterPreferences LegacyReporterAdapter::getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = m_legacyReporter->shouldRedirectStdout(); + return prefs; + } + + void LegacyReporterAdapter::noMatchingTestCases( std::string const& ) {} + void LegacyReporterAdapter::testRunStarting( TestRunInfo const& ) { + m_legacyReporter->StartTesting(); + } + void LegacyReporterAdapter::testGroupStarting( GroupInfo const& groupInfo ) { + m_legacyReporter->StartGroup( groupInfo.name ); + } + void LegacyReporterAdapter::testCaseStarting( TestCaseInfo const& testInfo ) { + m_legacyReporter->StartTestCase( testInfo ); + } + void LegacyReporterAdapter::sectionStarting( SectionInfo const& sectionInfo ) { + m_legacyReporter->StartSection( sectionInfo.name, sectionInfo.description ); + } + void LegacyReporterAdapter::assertionStarting( AssertionInfo const& ) { + // Not on legacy interface + } + + bool LegacyReporterAdapter::assertionEnded( AssertionStats const& assertionStats ) { + if( assertionStats.assertionResult.getResultType() != ResultWas::Ok ) { + for( std::vector::const_iterator it = assertionStats.infoMessages.begin(), itEnd = assertionStats.infoMessages.end(); + it != itEnd; + ++it ) { + if( it->type == ResultWas::Info ) { + ResultBuilder rb( it->macroName.c_str(), it->lineInfo, "", ResultDisposition::Normal ); + rb << it->message; + rb.setResultType( ResultWas::Info ); + AssertionResult result = rb.build(); + m_legacyReporter->Result( result ); + } + } + } + m_legacyReporter->Result( assertionStats.assertionResult ); + return true; + } + void LegacyReporterAdapter::sectionEnded( SectionStats const& sectionStats ) { + if( sectionStats.missingAssertions ) + m_legacyReporter->NoAssertionsInSection( sectionStats.sectionInfo.name ); + m_legacyReporter->EndSection( sectionStats.sectionInfo.name, sectionStats.assertions ); + } + void LegacyReporterAdapter::testCaseEnded( TestCaseStats const& testCaseStats ) { + m_legacyReporter->EndTestCase + ( testCaseStats.testInfo, + testCaseStats.totals, + testCaseStats.stdOut, + testCaseStats.stdErr ); + } + void LegacyReporterAdapter::testGroupEnded( TestGroupStats const& testGroupStats ) { + if( testGroupStats.aborting ) + m_legacyReporter->Aborted(); + m_legacyReporter->EndGroup( testGroupStats.groupInfo.name, testGroupStats.totals ); + } + void LegacyReporterAdapter::testRunEnded( TestRunStats const& testRunStats ) { + m_legacyReporter->EndTesting( testRunStats.totals ); + } +} + +// #included from: catch_timer.hpp + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wc++11-long-long" +#endif + +#ifdef CATCH_PLATFORM_WINDOWS +#include +#else +#include +#endif + +namespace Catch { + + namespace { +#ifdef CATCH_PLATFORM_WINDOWS + uint64_t getCurrentTicks() { + static uint64_t hz=0, hzo=0; + if (!hz) { + QueryPerformanceFrequency((LARGE_INTEGER*)&hz); + QueryPerformanceCounter((LARGE_INTEGER*)&hzo); + } + uint64_t t; + QueryPerformanceCounter((LARGE_INTEGER*)&t); + return ((t-hzo)*1000000)/hz; + } +#else + uint64_t getCurrentTicks() { + timeval t; + gettimeofday(&t,NULL); + return (uint64_t)t.tv_sec * 1000000ull + (uint64_t)t.tv_usec; + } +#endif + } + + void Timer::start() { + m_ticks = getCurrentTicks(); + } + unsigned int Timer::getElapsedNanoseconds() const { + return (unsigned int)(getCurrentTicks() - m_ticks); + } + unsigned int Timer::getElapsedMilliseconds() const { + return (unsigned int)((getCurrentTicks() - m_ticks)/1000); + } + double Timer::getElapsedSeconds() const { + return (getCurrentTicks() - m_ticks)/1000000.0; + } + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif +// #included from: catch_common.hpp +#define TWOBLUECUBES_CATCH_COMMON_HPP_INCLUDED + +namespace Catch { + + bool startsWith( std::string const& s, std::string const& prefix ) { + return s.size() >= prefix.size() && s.substr( 0, prefix.size() ) == prefix; + } + bool endsWith( std::string const& s, std::string const& suffix ) { + return s.size() >= suffix.size() && s.substr( s.size()-suffix.size(), suffix.size() ) == suffix; + } + bool contains( std::string const& s, std::string const& infix ) { + return s.find( infix ) != std::string::npos; + } + void toLowerInPlace( std::string& s ) { + std::transform( s.begin(), s.end(), s.begin(), ::tolower ); + } + std::string toLower( std::string const& s ) { + std::string lc = s; + toLowerInPlace( lc ); + return lc; + } + std::string trim( std::string const& str ) { + static char const* whitespaceChars = "\n\r\t "; + std::string::size_type start = str.find_first_not_of( whitespaceChars ); + std::string::size_type end = str.find_last_not_of( whitespaceChars ); + + return start != std::string::npos ? str.substr( start, 1+end-start ) : ""; + } + + pluralise::pluralise( std::size_t count, std::string const& label ) + : m_count( count ), + m_label( label ) + {} + + std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ) { + os << pluraliser.m_count << " " << pluraliser.m_label; + if( pluraliser.m_count != 1 ) + os << "s"; + return os; + } + + SourceLineInfo::SourceLineInfo() : line( 0 ){} + SourceLineInfo::SourceLineInfo( char const* _file, std::size_t _line ) + : file( _file ), + line( _line ) + {} + SourceLineInfo::SourceLineInfo( SourceLineInfo const& other ) + : file( other.file ), + line( other.line ) + {} + bool SourceLineInfo::empty() const { + return file.empty(); + } + bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const { + return line == other.line && file == other.file; + } + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) { +#ifndef __GNUG__ + os << info.file << "(" << info.line << ")"; +#else + os << info.file << ":" << info.line; +#endif + return os; + } + + void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo ) { + std::ostringstream oss; + oss << locationInfo << ": Internal Catch error: '" << message << "'"; + if( alwaysTrue() ) + throw std::logic_error( oss.str() ); + } +} + +// #included from: catch_section.hpp +#define TWOBLUECUBES_CATCH_SECTION_HPP_INCLUDED + +namespace Catch { + + Section::Section( SourceLineInfo const& lineInfo, + std::string const& name, + std::string const& description ) + : m_info( name, description, lineInfo ), + m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) ) + { + m_timer.start(); + } + + Section::~Section() { + if( m_sectionIncluded ) + getResultCapture().sectionEnded( m_info, m_assertions, m_timer.getElapsedSeconds() ); + } + + // This indicates whether the section should be executed or not + Section::operator bool() { + return m_sectionIncluded; + } + +} // end namespace Catch + +// #included from: catch_debugger.hpp +#define TWOBLUECUBES_CATCH_DEBUGGER_HPP_INCLUDED + +#include + +#ifdef CATCH_PLATFORM_MAC + + #include + #include + #include + #include + #include + + namespace Catch{ + + // The following function is taken directly from the following technical note: + // http://developer.apple.com/library/mac/#qa/qa2004/qa1361.html + + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive(){ + + int mib[4]; + struct kinfo_proc info; + size_t size; + + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + + info.kp_proc.p_flag = 0; + + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + + // Call sysctl. + + size = sizeof(info); + if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, NULL, 0) != 0 ) { + std::cerr << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl; + return false; + } + + // We're being debugged if the P_TRACED flag is set. + + return ( (info.kp_proc.p_flag & P_TRACED) != 0 ); + } + } // namespace Catch + +#elif defined(_MSC_VER) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#else + namespace Catch { + inline bool isDebuggerActive() { return false; } + } +#endif // Platform + +#ifdef CATCH_PLATFORM_WINDOWS + extern "C" __declspec(dllimport) void __stdcall OutputDebugStringA( const char* ); + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + ::OutputDebugStringA( text.c_str() ); + } + } +#else + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + // !TBD: Need a version for Mac/ XCode and other IDEs + std::cout << text; + } + } +#endif // Platform + +// #included from: catch_tostring.hpp +#define TWOBLUECUBES_CATCH_TOSTRING_HPP_INCLUDED + +namespace Catch { + +namespace Detail { + + namespace { + struct Endianness { + enum Arch { Big, Little }; + + static Arch which() { + union _{ + int asInt; + char asChar[sizeof (int)]; + } u; + + u.asInt = 1; + return ( u.asChar[sizeof(int)-1] == 1 ) ? Big : Little; + } + }; + } + + std::string rawMemoryToString( const void *object, std::size_t size ) + { + // Reverse order for little endian architectures + int i = 0, end = static_cast( size ), inc = 1; + if( Endianness::which() == Endianness::Little ) { + i = end-1; + end = inc = -1; + } + + unsigned char const *bytes = static_cast(object); + std::ostringstream os; + os << "0x" << std::setfill('0') << std::hex; + for( ; i != end; i += inc ) + os << std::setw(2) << static_cast(bytes[i]); + return os.str(); + } +} + +std::string toString( std::string const& value ) { + std::string s = value; + if( getCurrentContext().getConfig()->showInvisibles() ) { + for(size_t i = 0; i < s.size(); ++i ) { + std::string subs; + switch( s[i] ) { + case '\n': subs = "\\n"; break; + case '\t': subs = "\\t"; break; + default: break; + } + if( !subs.empty() ) { + s = s.substr( 0, i ) + subs + s.substr( i+1 ); + ++i; + } + } + } + return "\"" + s + "\""; +} +std::string toString( std::wstring const& value ) { + + std::string s; + s.reserve( value.size() ); + for(size_t i = 0; i < value.size(); ++i ) + s += value[i] <= 0xff ? static_cast( value[i] ) : '?'; + return toString( s ); +} + +std::string toString( const char* const value ) { + return value ? Catch::toString( std::string( value ) ) : std::string( "{null string}" ); +} + +std::string toString( char* const value ) { + return Catch::toString( static_cast( value ) ); +} + +std::string toString( int value ) { + std::ostringstream oss; + oss << value; + return oss.str(); +} + +std::string toString( unsigned long value ) { + std::ostringstream oss; + if( value > 8192 ) + oss << "0x" << std::hex << value; + else + oss << value; + return oss.str(); +} + +std::string toString( unsigned int value ) { + return toString( static_cast( value ) ); +} + +std::string toString( const double value ) { + std::ostringstream oss; + oss << std::setprecision( 10 ) + << std::fixed + << value; + std::string d = oss.str(); + std::size_t i = d.find_last_not_of( '0' ); + if( i != std::string::npos && i != d.size()-1 ) { + if( d[i] == '.' ) + i++; + d = d.substr( 0, i+1 ); + } + return d; +} + +std::string toString( bool value ) { + return value ? "true" : "false"; +} + +std::string toString( char value ) { + return value < ' ' + ? toString( static_cast( value ) ) + : Detail::makeString( value ); +} + +std::string toString( signed char value ) { + return toString( static_cast( value ) ); +} + +std::string toString( unsigned char value ) { + return toString( static_cast( value ) ); +} + +#ifdef CATCH_CONFIG_CPP11_NULLPTR +std::string toString( std::nullptr_t ) { + return "nullptr"; +} +#endif + +#ifdef __OBJC__ + std::string toString( NSString const * const& nsstring ) { + if( !nsstring ) + return "nil"; + return std::string( "@\"" ) + [nsstring UTF8String] + "\""; + } + std::string toString( NSString * CATCH_ARC_STRONG const& nsstring ) { + if( !nsstring ) + return "nil"; + return std::string( "@\"" ) + [nsstring UTF8String] + "\""; + } + std::string toString( NSObject* const& nsObject ) { + return toString( [nsObject description] ); + } +#endif + +} // end namespace Catch + +// #included from: catch_result_builder.hpp +#define TWOBLUECUBES_CATCH_RESULT_BUILDER_HPP_INCLUDED + +namespace Catch { + + ResultBuilder::ResultBuilder( char const* macroName, + SourceLineInfo const& lineInfo, + char const* capturedExpression, + ResultDisposition::Flags resultDisposition ) + : m_assertionInfo( macroName, lineInfo, capturedExpression, resultDisposition ), + m_shouldDebugBreak( false ), + m_shouldThrow( false ) + {} + + ResultBuilder& ResultBuilder::setResultType( ResultWas::OfType result ) { + m_data.resultType = result; + return *this; + } + ResultBuilder& ResultBuilder::setResultType( bool result ) { + m_data.resultType = result ? ResultWas::Ok : ResultWas::ExpressionFailed; + return *this; + } + ResultBuilder& ResultBuilder::setLhs( std::string const& lhs ) { + m_exprComponents.lhs = lhs; + return *this; + } + ResultBuilder& ResultBuilder::setRhs( std::string const& rhs ) { + m_exprComponents.rhs = rhs; + return *this; + } + ResultBuilder& ResultBuilder::setOp( std::string const& op ) { + m_exprComponents.op = op; + return *this; + } + + void ResultBuilder::endExpression() { + m_exprComponents.testFalse = isFalseTest( m_assertionInfo.resultDisposition ); + captureExpression(); + } + + void ResultBuilder::useActiveException( ResultDisposition::Flags resultDisposition ) { + m_assertionInfo.resultDisposition = resultDisposition; + m_stream.oss << Catch::translateActiveException(); + captureResult( ResultWas::ThrewException ); + } + + void ResultBuilder::captureResult( ResultWas::OfType resultType ) { + setResultType( resultType ); + captureExpression(); + } + + void ResultBuilder::captureExpression() { + AssertionResult result = build(); + getResultCapture().assertionEnded( result ); + + if( !result.isOk() ) { + if( getCurrentContext().getConfig()->shouldDebugBreak() ) + m_shouldDebugBreak = true; + if( getCurrentContext().getRunner()->aborting() || m_assertionInfo.resultDisposition == ResultDisposition::Normal ) + m_shouldThrow = true; + } + } + void ResultBuilder::react() { + if( m_shouldThrow ) + throw Catch::TestFailureException(); + } + + bool ResultBuilder::shouldDebugBreak() const { return m_shouldDebugBreak; } + bool ResultBuilder::allowThrows() const { return getCurrentContext().getConfig()->allowThrows(); } + + AssertionResult ResultBuilder::build() const + { + assert( m_data.resultType != ResultWas::Unknown ); + + AssertionResultData data = m_data; + + // Flip bool results if testFalse is set + if( m_exprComponents.testFalse ) { + if( data.resultType == ResultWas::Ok ) + data.resultType = ResultWas::ExpressionFailed; + else if( data.resultType == ResultWas::ExpressionFailed ) + data.resultType = ResultWas::Ok; + } + + data.message = m_stream.oss.str(); + data.reconstructedExpression = reconstructExpression(); + if( m_exprComponents.testFalse ) { + if( m_exprComponents.op == "" ) + data.reconstructedExpression = "!" + data.reconstructedExpression; + else + data.reconstructedExpression = "!(" + data.reconstructedExpression + ")"; + } + return AssertionResult( m_assertionInfo, data ); + } + std::string ResultBuilder::reconstructExpression() const { + if( m_exprComponents.op == "" ) + return m_exprComponents.lhs.empty() ? m_assertionInfo.capturedExpression : m_exprComponents.op + m_exprComponents.lhs; + else if( m_exprComponents.op == "matches" ) + return m_exprComponents.lhs + " " + m_exprComponents.rhs; + else if( m_exprComponents.op != "!" ) { + if( m_exprComponents.lhs.size() + m_exprComponents.rhs.size() < 40 && + m_exprComponents.lhs.find("\n") == std::string::npos && + m_exprComponents.rhs.find("\n") == std::string::npos ) + return m_exprComponents.lhs + " " + m_exprComponents.op + " " + m_exprComponents.rhs; + else + return m_exprComponents.lhs + "\n" + m_exprComponents.op + "\n" + m_exprComponents.rhs; + } + else + return "{can't expand - use " + m_assertionInfo.macroName + "_FALSE( " + m_assertionInfo.capturedExpression.substr(1) + " ) instead of " + m_assertionInfo.macroName + "( " + m_assertionInfo.capturedExpression + " ) for better diagnostics}"; + } + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_xml.hpp +#define TWOBLUECUBES_CATCH_REPORTER_XML_HPP_INCLUDED + +// #included from: catch_reporter_bases.hpp +#define TWOBLUECUBES_CATCH_REPORTER_BASES_HPP_INCLUDED + +namespace Catch { + + struct StreamingReporterBase : SharedImpl { + + StreamingReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + {} + + virtual ~StreamingReporterBase(); + + virtual void noMatchingTestCases( std::string const& ) {} + + virtual void testRunStarting( TestRunInfo const& _testRunInfo ) { + currentTestRunInfo = _testRunInfo; + } + virtual void testGroupStarting( GroupInfo const& _groupInfo ) { + currentGroupInfo = _groupInfo; + } + + virtual void testCaseStarting( TestCaseInfo const& _testInfo ) { + currentTestCaseInfo = _testInfo; + } + virtual void sectionStarting( SectionInfo const& _sectionInfo ) { + m_sectionStack.push_back( _sectionInfo ); + } + + virtual void sectionEnded( SectionStats const& /* _sectionStats */ ) { + m_sectionStack.pop_back(); + } + virtual void testCaseEnded( TestCaseStats const& /* _testCaseStats */ ) { + currentTestCaseInfo.reset(); + assert( m_sectionStack.empty() ); + } + virtual void testGroupEnded( TestGroupStats const& /* _testGroupStats */ ) { + currentGroupInfo.reset(); + } + virtual void testRunEnded( TestRunStats const& /* _testRunStats */ ) { + currentTestCaseInfo.reset(); + currentGroupInfo.reset(); + currentTestRunInfo.reset(); + } + + Ptr m_config; + std::ostream& stream; + + LazyStat currentTestRunInfo; + LazyStat currentGroupInfo; + LazyStat currentTestCaseInfo; + + std::vector m_sectionStack; + }; + + struct CumulativeReporterBase : SharedImpl { + template + struct Node : SharedImpl<> { + explicit Node( T const& _value ) : value( _value ) {} + virtual ~Node() {} + + typedef std::vector > ChildNodes; + T value; + ChildNodes children; + }; + struct SectionNode : SharedImpl<> { + explicit SectionNode( SectionStats const& _stats ) : stats( _stats ) {} + virtual ~SectionNode(); + + bool operator == ( SectionNode const& other ) const { + return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo; + } + bool operator == ( Ptr const& other ) const { + return operator==( *other ); + } + + SectionStats stats; + typedef std::vector > ChildSections; + typedef std::vector Assertions; + ChildSections childSections; + Assertions assertions; + std::string stdOut; + std::string stdErr; + }; + + struct BySectionInfo { + BySectionInfo( SectionInfo const& other ) : m_other( other ) {} + BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {} + bool operator() ( Ptr const& node ) const { + return node->stats.sectionInfo.lineInfo == m_other.lineInfo; + } + private: + void operator=( BySectionInfo const& ); + SectionInfo const& m_other; + }; + + typedef Node TestCaseNode; + typedef Node TestGroupNode; + typedef Node TestRunNode; + + CumulativeReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + {} + ~CumulativeReporterBase(); + + virtual void testRunStarting( TestRunInfo const& ) {} + virtual void testGroupStarting( GroupInfo const& ) {} + + virtual void testCaseStarting( TestCaseInfo const& ) {} + + virtual void sectionStarting( SectionInfo const& sectionInfo ) { + SectionStats incompleteStats( sectionInfo, Counts(), 0, false ); + Ptr node; + if( m_sectionStack.empty() ) { + if( !m_rootSection ) + m_rootSection = new SectionNode( incompleteStats ); + node = m_rootSection; + } + else { + SectionNode& parentNode = *m_sectionStack.back(); + SectionNode::ChildSections::const_iterator it = + std::find_if( parentNode.childSections.begin(), + parentNode.childSections.end(), + BySectionInfo( sectionInfo ) ); + if( it == parentNode.childSections.end() ) { + node = new SectionNode( incompleteStats ); + parentNode.childSections.push_back( node ); + } + else + node = *it; + } + m_sectionStack.push_back( node ); + m_deepestSection = node; + } + + virtual void assertionStarting( AssertionInfo const& ) {} + + virtual bool assertionEnded( AssertionStats const& assertionStats ) { + assert( !m_sectionStack.empty() ); + SectionNode& sectionNode = *m_sectionStack.back(); + sectionNode.assertions.push_back( assertionStats ); + return true; + } + virtual void sectionEnded( SectionStats const& sectionStats ) { + assert( !m_sectionStack.empty() ); + SectionNode& node = *m_sectionStack.back(); + node.stats = sectionStats; + m_sectionStack.pop_back(); + } + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) { + Ptr node = new TestCaseNode( testCaseStats ); + assert( m_sectionStack.size() == 0 ); + node->children.push_back( m_rootSection ); + m_testCases.push_back( node ); + m_rootSection.reset(); + + assert( m_deepestSection ); + m_deepestSection->stdOut = testCaseStats.stdOut; + m_deepestSection->stdErr = testCaseStats.stdErr; + } + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) { + Ptr node = new TestGroupNode( testGroupStats ); + node->children.swap( m_testCases ); + m_testGroups.push_back( node ); + } + virtual void testRunEnded( TestRunStats const& testRunStats ) { + Ptr node = new TestRunNode( testRunStats ); + node->children.swap( m_testGroups ); + m_testRuns.push_back( node ); + testRunEndedCumulative(); + } + virtual void testRunEndedCumulative() = 0; + + Ptr m_config; + std::ostream& stream; + std::vector m_assertions; + std::vector > > m_sections; + std::vector > m_testCases; + std::vector > m_testGroups; + + std::vector > m_testRuns; + + Ptr m_rootSection; + Ptr m_deepestSection; + std::vector > m_sectionStack; + + }; + +} // end namespace Catch + +// #included from: ../internal/catch_reporter_registrars.hpp +#define TWOBLUECUBES_CATCH_REPORTER_REGISTRARS_HPP_INCLUDED + +namespace Catch { + + template + class LegacyReporterRegistrar { + + class ReporterFactory : public IReporterFactory { + virtual IStreamingReporter* create( ReporterConfig const& config ) const { + return new LegacyReporterAdapter( new T( config ) ); + } + + virtual std::string getDescription() const { + return T::getDescription(); + } + }; + + public: + + LegacyReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, new ReporterFactory() ); + } + }; + + template + class ReporterRegistrar { + + class ReporterFactory : public IReporterFactory { + + // *** Please Note ***: + // - If you end up here looking at a compiler error because it's trying to register + // your custom reporter class be aware that the native reporter interface has changed + // to IStreamingReporter. The "legacy" interface, IReporter, is still supported via + // an adapter. Just use REGISTER_LEGACY_REPORTER to take advantage of the adapter. + // However please consider updating to the new interface as the old one is now + // deprecated and will probably be removed quite soon! + // Please contact me via github if you have any questions at all about this. + // In fact, ideally, please contact me anyway to let me know you've hit this - as I have + // no idea who is actually using custom reporters at all (possibly no-one!). + // The new interface is designed to minimise exposure to interface changes in the future. + virtual IStreamingReporter* create( ReporterConfig const& config ) const { + return new T( config ); + } + + virtual std::string getDescription() const { + return T::getDescription(); + } + }; + + public: + + ReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, new ReporterFactory() ); + } + }; +} + +#define INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) \ + namespace{ Catch::LegacyReporterRegistrar catch_internal_RegistrarFor##reporterType( name ); } +#define INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) \ + namespace{ Catch::ReporterRegistrar catch_internal_RegistrarFor##reporterType( name ); } + +// #included from: ../internal/catch_xmlwriter.hpp +#define TWOBLUECUBES_CATCH_XMLWRITER_HPP_INCLUDED + +#include +#include +#include +#include + +namespace Catch { + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + ScopedElement( ScopedElement const& other ) + : m_writer( other.m_writer ){ + other.m_writer = NULL; + } + + ~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + ScopedElement& writeText( std::string const& text, bool indent = true ) { + m_writer->writeText( text, indent ); + return *this; + } + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer; + }; + + XmlWriter() + : m_tagIsOpen( false ), + m_needsNewline( false ), + m_os( &std::cout ) + {} + + XmlWriter( std::ostream& os ) + : m_tagIsOpen( false ), + m_needsNewline( false ), + m_os( &os ) + {} + + ~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + +//# ifndef CATCH_CPP11_OR_GREATER +// XmlWriter& operator = ( XmlWriter const& other ) { +// XmlWriter temp( other ); +// swap( temp ); +// return *this; +// } +//# else +// XmlWriter( XmlWriter const& ) = default; +// XmlWriter( XmlWriter && ) = default; +// XmlWriter& operator = ( XmlWriter const& ) = default; +// XmlWriter& operator = ( XmlWriter && ) = default; +//# endif +// +// void swap( XmlWriter& other ) { +// std::swap( m_tagIsOpen, other.m_tagIsOpen ); +// std::swap( m_needsNewline, other.m_needsNewline ); +// std::swap( m_tags, other.m_tags ); +// std::swap( m_indent, other.m_indent ); +// std::swap( m_os, other.m_os ); +// } + + XmlWriter& startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + stream() << m_indent << "<" << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + ScopedElement scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + stream() << "/>\n"; + m_tagIsOpen = false; + } + else { + stream() << m_indent << "\n"; + } + m_tags.pop_back(); + return *this; + } + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) { + stream() << " " << name << "=\""; + writeEncodedText( attribute ); + stream() << "\""; + } + return *this; + } + + XmlWriter& writeAttribute( std::string const& name, bool attribute ) { + stream() << " " << name << "=\"" << ( attribute ? "true" : "false" ) << "\""; + return *this; + } + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + if( !name.empty() ) + stream() << " " << name << "=\"" << attribute << "\""; + return *this; + } + + XmlWriter& writeText( std::string const& text, bool indent = true ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + stream() << m_indent; + writeEncodedText( text ); + m_needsNewline = true; + } + return *this; + } + + XmlWriter& writeComment( std::string const& text ) { + ensureTagClosed(); + stream() << m_indent << ""; + m_needsNewline = true; + return *this; + } + + XmlWriter& writeBlankLine() { + ensureTagClosed(); + stream() << "\n"; + return *this; + } + + void setStream( std::ostream& os ) { + m_os = &os; + } + + private: + XmlWriter( XmlWriter const& ); + void operator=( XmlWriter const& ); + + std::ostream& stream() { + return *m_os; + } + + void ensureTagClosed() { + if( m_tagIsOpen ) { + stream() << ">\n"; + m_tagIsOpen = false; + } + } + + void newlineIfNecessary() { + if( m_needsNewline ) { + stream() << "\n"; + m_needsNewline = false; + } + } + + void writeEncodedText( std::string const& text ) { + static const char* charsToEncode = "<&\""; + std::string mtext = text; + std::string::size_type pos = mtext.find_first_of( charsToEncode ); + while( pos != std::string::npos ) { + stream() << mtext.substr( 0, pos ); + + switch( mtext[pos] ) { + case '<': + stream() << "<"; + break; + case '&': + stream() << "&"; + break; + case '\"': + stream() << """; + break; + } + mtext = mtext.substr( pos+1 ); + pos = mtext.find_first_of( charsToEncode ); + } + stream() << mtext; + } + + bool m_tagIsOpen; + bool m_needsNewline; + std::vector m_tags; + std::string m_indent; + std::ostream* m_os; + }; + +} +namespace Catch { + class XmlReporter : public SharedImpl { + public: + XmlReporter( ReporterConfig const& config ) : m_config( config ), m_sectionDepth( 0 ) {} + + static std::string getDescription() { + return "Reports test results as an XML document"; + } + virtual ~XmlReporter(); + + private: // IReporter + + virtual bool shouldRedirectStdout() const { + return true; + } + + virtual void StartTesting() { + m_xml.setStream( m_config.stream() ); + m_xml.startElement( "Catch" ); + if( !m_config.fullConfig()->name().empty() ) + m_xml.writeAttribute( "name", m_config.fullConfig()->name() ); + } + + virtual void EndTesting( const Totals& totals ) { + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", totals.assertions.passed ) + .writeAttribute( "failures", totals.assertions.failed ); + m_xml.endElement(); + } + + virtual void StartGroup( const std::string& groupName ) { + m_xml.startElement( "Group" ) + .writeAttribute( "name", groupName ); + } + + virtual void EndGroup( const std::string&, const Totals& totals ) { + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", totals.assertions.passed ) + .writeAttribute( "failures", totals.assertions.failed ); + m_xml.endElement(); + } + + virtual void StartSection( const std::string& sectionName, const std::string& description ) { + if( m_sectionDepth++ > 0 ) { + m_xml.startElement( "Section" ) + .writeAttribute( "name", trim( sectionName ) ) + .writeAttribute( "description", description ); + } + } + virtual void NoAssertionsInSection( const std::string& ) {} + virtual void NoAssertionsInTestCase( const std::string& ) {} + + virtual void EndSection( const std::string& /*sectionName*/, const Counts& assertions ) { + if( --m_sectionDepth > 0 ) { + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", assertions.passed ) + .writeAttribute( "failures", assertions.failed ); + m_xml.endElement(); + } + } + + virtual void StartTestCase( const Catch::TestCaseInfo& testInfo ) { + m_xml.startElement( "TestCase" ).writeAttribute( "name", trim( testInfo.name ) ); + m_currentTestSuccess = true; + } + + virtual void Result( const Catch::AssertionResult& assertionResult ) { + if( !m_config.fullConfig()->includeSuccessfulResults() && assertionResult.getResultType() == ResultWas::Ok ) + return; + + if( assertionResult.hasExpression() ) { + m_xml.startElement( "Expression" ) + .writeAttribute( "success", assertionResult.succeeded() ) + .writeAttribute( "filename", assertionResult.getSourceInfo().file ) + .writeAttribute( "line", assertionResult.getSourceInfo().line ); + + m_xml.scopedElement( "Original" ) + .writeText( assertionResult.getExpression() ); + m_xml.scopedElement( "Expanded" ) + .writeText( assertionResult.getExpandedExpression() ); + m_currentTestSuccess &= assertionResult.succeeded(); + } + + switch( assertionResult.getResultType() ) { + case ResultWas::ThrewException: + m_xml.scopedElement( "Exception" ) + .writeAttribute( "filename", assertionResult.getSourceInfo().file ) + .writeAttribute( "line", assertionResult.getSourceInfo().line ) + .writeText( assertionResult.getMessage() ); + m_currentTestSuccess = false; + break; + case ResultWas::Info: + m_xml.scopedElement( "Info" ) + .writeText( assertionResult.getMessage() ); + break; + case ResultWas::Warning: + m_xml.scopedElement( "Warning" ) + .writeText( assertionResult.getMessage() ); + break; + case ResultWas::ExplicitFailure: + m_xml.scopedElement( "Failure" ) + .writeText( assertionResult.getMessage() ); + m_currentTestSuccess = false; + break; + case ResultWas::Unknown: + case ResultWas::Ok: + case ResultWas::FailureBit: + case ResultWas::ExpressionFailed: + case ResultWas::Exception: + case ResultWas::DidntThrowException: + break; + } + if( assertionResult.hasExpression() ) + m_xml.endElement(); + } + + virtual void Aborted() { + // !TBD + } + + virtual void EndTestCase( const Catch::TestCaseInfo&, const Totals&, const std::string&, const std::string& ) { + m_xml.scopedElement( "OverallResult" ).writeAttribute( "success", m_currentTestSuccess ); + m_xml.endElement(); + } + + private: + ReporterConfig m_config; + bool m_currentTestSuccess; + XmlWriter m_xml; + int m_sectionDepth; + }; + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_junit.hpp +#define TWOBLUECUBES_CATCH_REPORTER_JUNIT_HPP_INCLUDED + +#include + +namespace Catch { + + class JunitReporter : public CumulativeReporterBase { + public: + JunitReporter( ReporterConfig const& _config ) + : CumulativeReporterBase( _config ), + xml( _config.stream() ) + {} + + ~JunitReporter(); + + static std::string getDescription() { + return "Reports test results in an XML format that looks like Ant's junitreport target"; + } + + virtual void noMatchingTestCases( std::string const& /*spec*/ ) {} + + virtual ReporterPreferences getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = true; + return prefs; + } + + virtual void testRunStarting( TestRunInfo const& runInfo ) { + CumulativeReporterBase::testRunStarting( runInfo ); + xml.startElement( "testsuites" ); + } + + virtual void testGroupStarting( GroupInfo const& groupInfo ) { + suiteTimer.start(); + stdOutForSuite.str(""); + stdErrForSuite.str(""); + unexpectedExceptions = 0; + CumulativeReporterBase::testGroupStarting( groupInfo ); + } + + virtual bool assertionEnded( AssertionStats const& assertionStats ) { + if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException ) + unexpectedExceptions++; + return CumulativeReporterBase::assertionEnded( assertionStats ); + } + + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) { + stdOutForSuite << testCaseStats.stdOut; + stdErrForSuite << testCaseStats.stdErr; + CumulativeReporterBase::testCaseEnded( testCaseStats ); + } + + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) { + double suiteTime = suiteTimer.getElapsedSeconds(); + CumulativeReporterBase::testGroupEnded( testGroupStats ); + writeGroup( *m_testGroups.back(), suiteTime ); + } + + virtual void testRunEndedCumulative() { + xml.endElement(); + } + + void writeGroup( TestGroupNode const& groupNode, double suiteTime ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" ); + TestGroupStats const& stats = groupNode.value; + xml.writeAttribute( "name", stats.groupInfo.name ); + xml.writeAttribute( "errors", unexpectedExceptions ); + xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions ); + xml.writeAttribute( "tests", stats.totals.assertions.total() ); + xml.writeAttribute( "hostname", "tbd" ); // !TBD + if( m_config->showDurations() == ShowDurations::Never ) + xml.writeAttribute( "time", "" ); + else + xml.writeAttribute( "time", suiteTime ); + xml.writeAttribute( "timestamp", "tbd" ); // !TBD + + // Write test cases + for( TestGroupNode::ChildNodes::const_iterator + it = groupNode.children.begin(), itEnd = groupNode.children.end(); + it != itEnd; + ++it ) + writeTestCase( **it ); + + xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite.str() ), false ); + xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite.str() ), false ); + } + + void writeTestCase( TestCaseNode const& testCaseNode ) { + TestCaseStats const& stats = testCaseNode.value; + + // All test cases have exactly one section - which represents the + // test case itself. That section may have 0-n nested sections + assert( testCaseNode.children.size() == 1 ); + SectionNode const& rootSection = *testCaseNode.children.front(); + + std::string className = stats.testInfo.className; + + if( className.empty() ) { + if( rootSection.childSections.empty() ) + className = "global"; + } + writeSection( className, "", rootSection ); + } + + void writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode ) { + std::string name = trim( sectionNode.stats.sectionInfo.name ); + if( !rootName.empty() ) + name = rootName + "/" + name; + + if( !sectionNode.assertions.empty() || + !sectionNode.stdOut.empty() || + !sectionNode.stdErr.empty() ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testcase" ); + if( className.empty() ) { + xml.writeAttribute( "classname", name ); + xml.writeAttribute( "name", "root" ); + } + else { + xml.writeAttribute( "classname", className ); + xml.writeAttribute( "name", name ); + } + xml.writeAttribute( "time", toString( sectionNode.stats.durationInSeconds ) ); + + writeAssertions( sectionNode ); + + if( !sectionNode.stdOut.empty() ) + xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false ); + if( !sectionNode.stdErr.empty() ) + xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false ); + } + for( SectionNode::ChildSections::const_iterator + it = sectionNode.childSections.begin(), + itEnd = sectionNode.childSections.end(); + it != itEnd; + ++it ) + if( className.empty() ) + writeSection( name, "", **it ); + else + writeSection( className, name, **it ); + } + + void writeAssertions( SectionNode const& sectionNode ) { + for( SectionNode::Assertions::const_iterator + it = sectionNode.assertions.begin(), itEnd = sectionNode.assertions.end(); + it != itEnd; + ++it ) + writeAssertion( *it ); + } + void writeAssertion( AssertionStats const& stats ) { + AssertionResult const& result = stats.assertionResult; + if( !result.isOk() ) { + std::string elementName; + switch( result.getResultType() ) { + case ResultWas::ThrewException: + elementName = "error"; + break; + case ResultWas::ExplicitFailure: + elementName = "failure"; + break; + case ResultWas::ExpressionFailed: + elementName = "failure"; + break; + case ResultWas::DidntThrowException: + elementName = "failure"; + break; + + // We should never see these here: + case ResultWas::Info: + case ResultWas::Warning: + case ResultWas::Ok: + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + elementName = "internalError"; + break; + } + + XmlWriter::ScopedElement e = xml.scopedElement( elementName ); + + xml.writeAttribute( "message", result.getExpandedExpression() ); + xml.writeAttribute( "type", result.getTestMacroName() ); + + std::ostringstream oss; + if( !result.getMessage().empty() ) + oss << result.getMessage() << "\n"; + for( std::vector::const_iterator + it = stats.infoMessages.begin(), + itEnd = stats.infoMessages.end(); + it != itEnd; + ++it ) + if( it->type == ResultWas::Info ) + oss << it->message << "\n"; + + oss << "at " << result.getSourceInfo(); + xml.writeText( oss.str(), false ); + } + } + + XmlWriter xml; + Timer suiteTimer; + std::ostringstream stdOutForSuite; + std::ostringstream stdErrForSuite; + unsigned int unexpectedExceptions; + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "junit", JunitReporter ) + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_console.hpp +#define TWOBLUECUBES_CATCH_REPORTER_CONSOLE_HPP_INCLUDED + +#include + +namespace Catch { + + struct ConsoleReporter : StreamingReporterBase { + ConsoleReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ), + m_headerPrinted( false ), + m_atLeastOneTestCasePrinted( false ) + {} + + virtual ~ConsoleReporter(); + static std::string getDescription() { + return "Reports test results as plain lines of text"; + } + virtual ReporterPreferences getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = false; + return prefs; + } + + virtual void noMatchingTestCases( std::string const& spec ) { + stream << "No test cases matched '" << spec << "'" << std::endl; + } + + virtual void assertionStarting( AssertionInfo const& ) { + } + + virtual bool assertionEnded( AssertionStats const& _assertionStats ) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool printInfoMessages = true; + + // Drop out if result was successful and we're not printing those + if( !m_config->includeSuccessfulResults() && result.isOk() ) { + if( result.getResultType() != ResultWas::Warning ) + return false; + printInfoMessages = false; + } + + lazyPrint(); + + AssertionPrinter printer( stream, _assertionStats, printInfoMessages ); + printer.print(); + stream << std::endl; + return true; + } + + virtual void sectionStarting( SectionInfo const& _sectionInfo ) { + m_headerPrinted = false; + StreamingReporterBase::sectionStarting( _sectionInfo ); + } + virtual void sectionEnded( SectionStats const& _sectionStats ) { + if( _sectionStats.missingAssertions ) { + lazyPrint(); + Colour colour( Colour::ResultError ); + if( m_sectionStack.size() > 1 ) + stream << "\nNo assertions in section"; + else + stream << "\nNo assertions in test case"; + stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl; + } + if( m_headerPrinted ) { + if( m_config->showDurations() == ShowDurations::Always ) + stream << "Completed in " << _sectionStats.durationInSeconds << "s" << std::endl; + m_headerPrinted = false; + } + else { + if( m_config->showDurations() == ShowDurations::Always ) + stream << _sectionStats.sectionInfo.name << " completed in " << _sectionStats.durationInSeconds << "s" << std::endl; + } + StreamingReporterBase::sectionEnded( _sectionStats ); + } + + virtual void testCaseEnded( TestCaseStats const& _testCaseStats ) { + StreamingReporterBase::testCaseEnded( _testCaseStats ); + m_headerPrinted = false; + } + virtual void testGroupEnded( TestGroupStats const& _testGroupStats ) { + if( currentGroupInfo.used ) { + printSummaryDivider(); + stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n"; + printTotals( _testGroupStats.totals ); + stream << "\n" << std::endl; + } + StreamingReporterBase::testGroupEnded( _testGroupStats ); + } + virtual void testRunEnded( TestRunStats const& _testRunStats ) { + if( m_atLeastOneTestCasePrinted ) + printTotalsDivider(); + printTotals( _testRunStats.totals ); + stream << "\n" << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + private: + + class AssertionPrinter { + void operator= ( AssertionPrinter const& ); + public: + AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages ) + : stream( _stream ), + stats( _stats ), + result( _stats.assertionResult ), + colour( Colour::None ), + message( result.getMessage() ), + messages( _stats.infoMessages ), + printInfoMessages( _printInfoMessages ) + { + switch( result.getResultType() ) { + case ResultWas::Ok: + colour = Colour::Success; + passOrFail = "PASSED"; + //if( result.hasMessage() ) + if( _stats.infoMessages.size() == 1 ) + messageLabel = "with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "with messages"; + break; + case ResultWas::ExpressionFailed: + if( result.isOk() ) { + colour = Colour::Success; + passOrFail = "FAILED - but was ok"; + } + else { + colour = Colour::Error; + passOrFail = "FAILED"; + } + if( _stats.infoMessages.size() == 1 ) + messageLabel = "with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "with messages"; + break; + case ResultWas::ThrewException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to unexpected exception with message"; + break; + case ResultWas::DidntThrowException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "because no exception was thrown where one was expected"; + break; + case ResultWas::Info: + messageLabel = "info"; + break; + case ResultWas::Warning: + messageLabel = "warning"; + break; + case ResultWas::ExplicitFailure: + passOrFail = "FAILED"; + colour = Colour::Error; + if( _stats.infoMessages.size() == 1 ) + messageLabel = "explicitly with message"; + if( _stats.infoMessages.size() > 1 ) + messageLabel = "explicitly with messages"; + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + passOrFail = "** internal error **"; + colour = Colour::Error; + break; + } + } + + void print() const { + printSourceInfo(); + if( stats.totals.assertions.total() > 0 ) { + if( result.isOk() ) + stream << "\n"; + printResultType(); + printOriginalExpression(); + printReconstructedExpression(); + } + else { + stream << "\n"; + } + printMessage(); + } + + private: + void printResultType() const { + if( !passOrFail.empty() ) { + Colour colourGuard( colour ); + stream << passOrFail << ":\n"; + } + } + void printOriginalExpression() const { + if( result.hasExpression() ) { + Colour colourGuard( Colour::OriginalExpression ); + stream << " "; + stream << result.getExpressionInMacro(); + stream << "\n"; + } + } + void printReconstructedExpression() const { + if( result.hasExpandedExpression() ) { + stream << "with expansion:\n"; + Colour colourGuard( Colour::ReconstructedExpression ); + stream << Text( result.getExpandedExpression(), TextAttributes().setIndent(2) ) << "\n"; + } + } + void printMessage() const { + if( !messageLabel.empty() ) + stream << messageLabel << ":" << "\n"; + for( std::vector::const_iterator it = messages.begin(), itEnd = messages.end(); + it != itEnd; + ++it ) { + // If this assertion is a warning ignore any INFO messages + if( printInfoMessages || it->type != ResultWas::Info ) + stream << Text( it->message, TextAttributes().setIndent(2) ) << "\n"; + } + } + void printSourceInfo() const { + Colour colourGuard( Colour::FileName ); + stream << result.getSourceInfo() << ": "; + } + + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + Colour::Code colour; + std::string passOrFail; + std::string messageLabel; + std::string message; + std::vector messages; + bool printInfoMessages; + }; + + void lazyPrint() { + + if( !currentTestRunInfo.used ) + lazyPrintRunInfo(); + if( !currentGroupInfo.used ) + lazyPrintGroupInfo(); + + if( !m_headerPrinted ) { + printTestCaseAndSectionHeader(); + m_headerPrinted = true; + } + m_atLeastOneTestCasePrinted = true; + } + void lazyPrintRunInfo() { + stream << "\n" << getLineOfChars<'~'>() << "\n"; + Colour colour( Colour::SecondaryText ); + stream << currentTestRunInfo->name + << " is a Catch v" << libraryVersion.majorVersion << "." + << libraryVersion.minorVersion << " b" + << libraryVersion.buildNumber; + if( libraryVersion.branchName != std::string( "master" ) ) + stream << " (" << libraryVersion.branchName << ")"; + stream << " host application.\n" + << "Run with -? for options\n\n"; + + currentTestRunInfo.used = true; + } + void lazyPrintGroupInfo() { + if( !currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1 ) { + printClosedHeader( "Group: " + currentGroupInfo->name ); + currentGroupInfo.used = true; + } + } + void printTestCaseAndSectionHeader() { + assert( !m_sectionStack.empty() ); + printOpenHeader( currentTestCaseInfo->name ); + + if( m_sectionStack.size() > 1 ) { + Colour colourGuard( Colour::Headers ); + + std::vector::const_iterator + it = m_sectionStack.begin()+1, // Skip first section (test case) + itEnd = m_sectionStack.end(); + for( ; it != itEnd; ++it ) + printHeaderString( it->name, 2 ); + } + + SourceLineInfo lineInfo = m_sectionStack.front().lineInfo; + + if( !lineInfo.empty() ){ + stream << getLineOfChars<'-'>() << "\n"; + Colour colourGuard( Colour::FileName ); + stream << lineInfo << "\n"; + } + stream << getLineOfChars<'.'>() << "\n" << std::endl; + } + + void printClosedHeader( std::string const& _name ) { + printOpenHeader( _name ); + stream << getLineOfChars<'.'>() << "\n"; + } + void printOpenHeader( std::string const& _name ) { + stream << getLineOfChars<'-'>() << "\n"; + { + Colour colourGuard( Colour::Headers ); + printHeaderString( _name ); + } + } + + // if string has a : in first line will set indent to follow it on + // subsequent lines + void printHeaderString( std::string const& _string, std::size_t indent = 0 ) { + std::size_t i = _string.find( ": " ); + if( i != std::string::npos ) + i+=2; + else + i = 0; + stream << Text( _string, TextAttributes() + .setIndent( indent+i) + .setInitialIndent( indent ) ) << "\n"; + } + + void printTotals( const Totals& totals ) { + if( totals.testCases.total() == 0 ) { + stream << "No tests ran"; + } + else if( totals.assertions.total() == 0 ) { + Colour colour( Colour::Yellow ); + printCounts( "test case", totals.testCases ); + stream << " (no assertions)"; + } + else if( totals.assertions.failed ) { + Colour colour( Colour::ResultError ); + printCounts( "test case", totals.testCases ); + if( totals.testCases.failed > 0 ) { + stream << " ("; + printCounts( "assertion", totals.assertions ); + stream << ")"; + } + } + else { + Colour colour( Colour::ResultSuccess ); + stream << "All tests passed (" + << pluralise( totals.assertions.passed, "assertion" ) << " in " + << pluralise( totals.testCases.passed, "test case" ) << ")"; + } + } + void printCounts( std::string const& label, Counts const& counts ) { + if( counts.total() == 1 ) { + stream << "1 " << label << " - "; + if( counts.failed ) + stream << "failed"; + else + stream << "passed"; + } + else { + stream << counts.total() << " " << label << "s "; + if( counts.passed ) { + if( counts.failed ) + stream << "- " << counts.failed << " failed"; + else if( counts.passed == 2 ) + stream << "- both passed"; + else + stream << "- all passed"; + } + else { + if( counts.failed == 2 ) + stream << "- both failed"; + else + stream << "- all failed"; + } + } + } + + void printTotalsDivider() { + stream << getLineOfChars<'='>() << "\n"; + } + void printSummaryDivider() { + stream << getLineOfChars<'-'>() << "\n"; + } + template + static char const* getLineOfChars() { + static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0}; + if( !*line ) { + memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 ); + line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0; + } + return line; + } + + private: + bool m_headerPrinted; + bool m_atLeastOneTestCasePrinted; + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "console", ConsoleReporter ) + +} // end namespace Catch + +// #included from: ../reporters/catch_reporter_compact.hpp +#define TWOBLUECUBES_CATCH_REPORTER_COMPACT_HPP_INCLUDED + +namespace Catch { + + struct CompactReporter : StreamingReporterBase { + + CompactReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ) + {} + + virtual ~CompactReporter(); + + static std::string getDescription() { + return "Reports test results on a single line, suitable for IDEs"; + } + + virtual ReporterPreferences getPreferences() const { + ReporterPreferences prefs; + prefs.shouldRedirectStdOut = false; + return prefs; + } + + virtual void noMatchingTestCases( std::string const& spec ) { + stream << "No test cases matched '" << spec << "'" << std::endl; + } + + virtual void assertionStarting( AssertionInfo const& ) { + } + + virtual bool assertionEnded( AssertionStats const& _assertionStats ) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool printInfoMessages = true; + + // Drop out if result was successful and we're not printing those + if( !m_config->includeSuccessfulResults() && result.isOk() ) { + if( result.getResultType() != ResultWas::Warning ) + return false; + printInfoMessages = false; + } + + AssertionPrinter printer( stream, _assertionStats, printInfoMessages ); + printer.print(); + + stream << std::endl; + return true; + } + + virtual void testRunEnded( TestRunStats const& _testRunStats ) { + printTotals( _testRunStats.totals ); + stream << "\n" << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + private: + class AssertionPrinter { + void operator= ( AssertionPrinter const& ); + public: + AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages ) + : stream( _stream ) + , stats( _stats ) + , result( _stats.assertionResult ) + , messages( _stats.infoMessages ) + , itMessage( _stats.infoMessages.begin() ) + , printInfoMessages( _printInfoMessages ) + {} + + void print() { + printSourceInfo(); + + itMessage = messages.begin(); + + switch( result.getResultType() ) { + case ResultWas::Ok: + printResultType( Colour::ResultSuccess, passedString() ); + printOriginalExpression(); + printReconstructedExpression(); + if ( ! result.hasExpression() ) + printRemainingMessages( Colour::None ); + else + printRemainingMessages(); + break; + case ResultWas::ExpressionFailed: + if( result.isOk() ) + printResultType( Colour::ResultSuccess, failedString() + std::string( " - but was ok" ) ); + else + printResultType( Colour::Error, failedString() ); + printOriginalExpression(); + printReconstructedExpression(); + printRemainingMessages(); + break; + case ResultWas::ThrewException: + printResultType( Colour::Error, failedString() ); + printIssue( "unexpected exception with message:" ); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::DidntThrowException: + printResultType( Colour::Error, failedString() ); + printIssue( "expected exception, got none" ); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::Info: + printResultType( Colour::None, "info" ); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::Warning: + printResultType( Colour::None, "warning" ); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::ExplicitFailure: + printResultType( Colour::Error, failedString() ); + printIssue( "explicitly" ); + printRemainingMessages( Colour::None ); + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + printResultType( Colour::Error, "** internal error **" ); + break; + } + } + + private: + // Colour::LightGrey + + static Colour::Code dimColour() { return Colour::FileName; } + +#ifdef CATCH_PLATFORM_MAC + static const char* failedString() { return "FAILED"; } + static const char* passedString() { return "PASSED"; } +#else + static const char* failedString() { return "failed"; } + static const char* passedString() { return "passed"; } +#endif + + void printSourceInfo() const { + Colour colourGuard( Colour::FileName ); + stream << result.getSourceInfo() << ":"; + } + + void printResultType( Colour::Code colour, std::string passOrFail ) const { + if( !passOrFail.empty() ) { + { + Colour colourGuard( colour ); + stream << " " << passOrFail; + } + stream << ":"; + } + } + + void printIssue( std::string issue ) const { + stream << " " << issue; + } + + void printExpressionWas() { + if( result.hasExpression() ) { + stream << ";"; + { + Colour colour( dimColour() ); + stream << " expression was:"; + } + printOriginalExpression(); + } + } + + void printOriginalExpression() const { + if( result.hasExpression() ) { + stream << " " << result.getExpression(); + } + } + + void printReconstructedExpression() const { + if( result.hasExpandedExpression() ) { + { + Colour colour( dimColour() ); + stream << " for: "; + } + stream << result.getExpandedExpression(); + } + } + + void printMessage() { + if ( itMessage != messages.end() ) { + stream << " '" << itMessage->message << "'"; + ++itMessage; + } + } + + void printRemainingMessages( Colour::Code colour = dimColour() ) { + if ( itMessage == messages.end() ) + return; + + // using messages.end() directly yields compilation error: + std::vector::const_iterator itEnd = messages.end(); + const std::size_t N = static_cast( std::distance( itMessage, itEnd ) ); + + { + Colour colourGuard( colour ); + stream << " with " << pluralise( N, "message" ) << ":"; + } + + for(; itMessage != itEnd; ) { + // If this assertion is a warning ignore any INFO messages + if( printInfoMessages || itMessage->type != ResultWas::Info ) { + stream << " '" << itMessage->message << "'"; + if ( ++itMessage != itEnd ) { + Colour colourGuard( dimColour() ); + stream << " and"; + } + } + } + } + + private: + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + std::vector messages; + std::vector::const_iterator itMessage; + bool printInfoMessages; + }; + + // Colour, message variants: + // - white: No tests ran. + // - red: Failed [both/all] N test cases, failed [both/all] M assertions. + // - white: Passed [both/all] N test cases (no assertions). + // - red: Failed N tests cases, failed M assertions. + // - green: Passed [both/all] N tests cases with M assertions. + + std::string bothOrAll( std::size_t count ) const { + return count == 1 ? "" : count == 2 ? "both " : "all " ; + } + + void printTotals( const Totals& totals ) const { + if( totals.testCases.total() == 0 ) { + stream << "No tests ran."; + } + else if( totals.testCases.failed == totals.testCases.total() ) { + Colour colour( Colour::ResultError ); + const std::string qualify_assertions_failed = + totals.assertions.failed == totals.assertions.total() ? + bothOrAll( totals.assertions.failed ) : ""; + stream << + "Failed " << bothOrAll( totals.testCases.failed ) + << pluralise( totals.testCases.failed, "test case" ) << ", " + "failed " << qualify_assertions_failed << + pluralise( totals.assertions.failed, "assertion" ) << "."; + } + else if( totals.assertions.total() == 0 ) { + stream << + "Passed " << bothOrAll( totals.testCases.total() ) + << pluralise( totals.testCases.total(), "test case" ) + << " (no assertions)."; + } + else if( totals.assertions.failed ) { + Colour colour( Colour::ResultError ); + stream << + "Failed " << pluralise( totals.testCases.failed, "test case" ) << ", " + "failed " << pluralise( totals.assertions.failed, "assertion" ) << "."; + } + else { + Colour colour( Colour::ResultSuccess ); + stream << + "Passed " << bothOrAll( totals.testCases.passed ) + << pluralise( totals.testCases.passed, "test case" ) << + " with " << pluralise( totals.assertions.passed, "assertion" ) << "."; + } + } + }; + + INTERNAL_CATCH_REGISTER_REPORTER( "compact", CompactReporter ) + +} // end namespace Catch + +namespace Catch { + NonCopyable::~NonCopyable() {} + IShared::~IShared() {} + StreamBufBase::~StreamBufBase() CATCH_NOEXCEPT {} + IContext::~IContext() {} + IResultCapture::~IResultCapture() {} + ITestCase::~ITestCase() {} + ITestCaseRegistry::~ITestCaseRegistry() {} + IRegistryHub::~IRegistryHub() {} + IMutableRegistryHub::~IMutableRegistryHub() {} + IExceptionTranslator::~IExceptionTranslator() {} + IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() {} + IReporter::~IReporter() {} + IReporterFactory::~IReporterFactory() {} + IReporterRegistry::~IReporterRegistry() {} + IStreamingReporter::~IStreamingReporter() {} + AssertionStats::~AssertionStats() {} + SectionStats::~SectionStats() {} + TestCaseStats::~TestCaseStats() {} + TestGroupStats::~TestGroupStats() {} + TestRunStats::~TestRunStats() {} + CumulativeReporterBase::SectionNode::~SectionNode() {} + CumulativeReporterBase::~CumulativeReporterBase() {} + + StreamingReporterBase::~StreamingReporterBase() {} + ConsoleReporter::~ConsoleReporter() {} + CompactReporter::~CompactReporter() {} + IRunner::~IRunner() {} + IMutableContext::~IMutableContext() {} + IConfig::~IConfig() {} + XmlReporter::~XmlReporter() {} + JunitReporter::~JunitReporter() {} + TestRegistry::~TestRegistry() {} + FreeFunctionTestCase::~FreeFunctionTestCase() {} + IGeneratorInfo::~IGeneratorInfo() {} + IGeneratorsForTest::~IGeneratorsForTest() {} + TestSpec::Pattern::~Pattern() {} + TestSpec::NamePattern::~NamePattern() {} + TestSpec::TagPattern::~TagPattern() {} + TestSpec::ExcludedPattern::~ExcludedPattern() {} + + Matchers::Impl::StdString::Equals::~Equals() {} + Matchers::Impl::StdString::Contains::~Contains() {} + Matchers::Impl::StdString::StartsWith::~StartsWith() {} + Matchers::Impl::StdString::EndsWith::~EndsWith() {} + + void Config::dummy() {} + + INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( "xml", XmlReporter ) +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif + +#ifdef CATCH_CONFIG_MAIN +// #included from: internal/catch_default_main.hpp +#define TWOBLUECUBES_CATCH_DEFAULT_MAIN_HPP_INCLUDED + +#ifndef __OBJC__ + +// Standard C/C++ main entry point +int main (int argc, char * const argv[]) { + return Catch::Session().run( argc, argv ); +} + +#else // __OBJC__ + +// Objective-C entry point +int main (int argc, char * const argv[]) { +#if !CATCH_ARC_ENABLED + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; +#endif + + Catch::registerTestMethods(); + int result = Catch::Session().run( argc, (char* const*)argv ); + +#if !CATCH_ARC_ENABLED + [pool drain]; +#endif + + return result; +} + +#endif // __OBJC__ + +#endif + +#ifdef CLARA_CONFIG_MAIN_NOT_DEFINED +# undef CLARA_CONFIG_MAIN +#endif + +////// + +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#define CATCH_REQUIRE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::Normal, "CATCH_REQUIRE" ) +#define CATCH_REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, "CATCH_REQUIRE_FALSE" ) + +#define CATCH_REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( expr, Catch::ResultDisposition::Normal, "CATCH_REQUIRE_THROWS" ) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( expr, exceptionType, Catch::ResultDisposition::Normal, "CATCH_REQUIRE_THROWS_AS" ) +#define CATCH_REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( expr, Catch::ResultDisposition::Normal, "CATCH_REQUIRE_NOTHROW" ) + +#define CATCH_CHECK( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECK" ) +#define CATCH_CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, "CATCH_CHECK_FALSE" ) +#define CATCH_CHECKED_IF( expr ) INTERNAL_CATCH_IF( expr, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECKED_IF" ) +#define CATCH_CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( expr, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECKED_ELSE" ) +#define CATCH_CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, "CATCH_CHECK_NOFAIL" ) + +#define CATCH_CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( expr, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECK_THROWS" ) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( expr, exceptionType, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECK_THROWS_AS" ) +#define CATCH_CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( expr, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECK_NOTHROW" ) + +#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( arg, matcher, Catch::ResultDisposition::ContinueOnFailure, "CATCH_CHECK_THAT" ) +#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( arg, matcher, Catch::ResultDisposition::Normal, "CATCH_REQUIRE_THAT" ) + +#define CATCH_INFO( msg ) INTERNAL_CATCH_INFO( msg, "CATCH_INFO" ) +#define CATCH_WARN( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, "CATCH_WARN", msg ) +#define CATCH_SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( msg, "CATCH_INFO" ) +#define CATCH_CAPTURE( msg ) INTERNAL_CATCH_INFO( #msg " := " << msg, "CATCH_CAPTURE" ) +#define CATCH_SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( #msg " := " << msg, "CATCH_CAPTURE" ) + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) + #define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) + #define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) + #define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) + #define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, "CATCH_FAIL", __VA_ARGS__ ) + #define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, "CATCH_SUCCEED", __VA_ARGS__ ) +#else + #define CATCH_TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description ) + #define CATCH_TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description ) + #define CATCH_METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description ) + #define CATCH_SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description ) + #define CATCH_FAIL( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, "CATCH_FAIL", msg ) + #define CATCH_SUCCEED( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, "CATCH_SUCCEED", msg ) +#endif +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" ) + +#define CATCH_REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) +#define CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) + +#define CATCH_GENERATE( expr) INTERNAL_CATCH_GENERATE( expr ) + +// "BDD-style" convenience wrappers +#ifdef CATCH_CONFIG_VARIADIC_MACROS +#define CATCH_SCENARIO( ... ) CATCH_TEST_CASE( "Scenario: " __VA_ARGS__ ) +#else +#define CATCH_SCENARIO( name, tags ) CATCH_TEST_CASE( "Scenario: " name, tags ) +#endif +#define CATCH_GIVEN( desc ) CATCH_SECTION( "Given: " desc, "" ) +#define CATCH_WHEN( desc ) CATCH_SECTION( " When: " desc, "" ) +#define CATCH_AND_WHEN( desc ) CATCH_SECTION( " And: " desc, "" ) +#define CATCH_THEN( desc ) CATCH_SECTION( " Then: " desc, "" ) +#define CATCH_AND_THEN( desc ) CATCH_SECTION( " And: " desc, "" ) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#define REQUIRE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::Normal, "REQUIRE" ) +#define REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, "REQUIRE_FALSE" ) + +#define REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( expr, Catch::ResultDisposition::Normal, "REQUIRE_THROWS" ) +#define REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( expr, exceptionType, Catch::ResultDisposition::Normal, "REQUIRE_THROWS_AS" ) +#define REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( expr, Catch::ResultDisposition::Normal, "REQUIRE_NOTHROW" ) + +#define CHECK( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure, "CHECK" ) +#define CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, "CHECK_FALSE" ) +#define CHECKED_IF( expr ) INTERNAL_CATCH_IF( expr, Catch::ResultDisposition::ContinueOnFailure, "CHECKED_IF" ) +#define CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( expr, Catch::ResultDisposition::ContinueOnFailure, "CHECKED_ELSE" ) +#define CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( expr, Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, "CHECK_NOFAIL" ) + +#define CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( expr, Catch::ResultDisposition::ContinueOnFailure, "CHECK_THROWS" ) +#define CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( expr, exceptionType, Catch::ResultDisposition::ContinueOnFailure, "CHECK_THROWS_AS" ) +#define CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( expr, Catch::ResultDisposition::ContinueOnFailure, "CHECK_NOTHROW" ) + +#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( arg, matcher, Catch::ResultDisposition::ContinueOnFailure, "CHECK_THAT" ) +#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( arg, matcher, Catch::ResultDisposition::Normal, "REQUIRE_THAT" ) + +#define INFO( msg ) INTERNAL_CATCH_INFO( msg, "INFO" ) +#define WARN( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, "WARN", msg ) +#define SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( msg, "INFO" ) +#define CAPTURE( msg ) INTERNAL_CATCH_INFO( #msg " := " << msg, "CAPTURE" ) +#define SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( #msg " := " << msg, "CAPTURE" ) + +#ifdef CATCH_CONFIG_VARIADIC_MACROS + #define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) + #define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) + #define METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) + #define SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) + #define FAIL( ... ) INTERNAL_CATCH_MSG( Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, "FAIL", __VA_ARGS__ ) + #define SUCCEED( ... ) INTERNAL_CATCH_MSG( Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, "SUCCEED", __VA_ARGS__ ) +#else + #define TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description ) + #define TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description ) + #define METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description ) + #define SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description ) + #define FAIL( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, "FAIL", msg ) + #define SUCCEED( msg ) INTERNAL_CATCH_MSG( Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, "SUCCEED", msg ) +#endif +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" ) + +#define REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) +#define REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) + +#define GENERATE( expr) INTERNAL_CATCH_GENERATE( expr ) + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) + +// "BDD-style" convenience wrappers +#ifdef CATCH_CONFIG_VARIADIC_MACROS +#define SCENARIO( ... ) TEST_CASE( "Scenario: " __VA_ARGS__ ) +#else +#define SCENARIO( name, tags ) TEST_CASE( "Scenario: " name, tags ) +#endif +#define GIVEN( desc ) SECTION( " Given: " desc, "" ) +#define WHEN( desc ) SECTION( " When: " desc, "" ) +#define AND_WHEN( desc ) SECTION( "And when: " desc, "" ) +#define THEN( desc ) SECTION( " Then: " desc, "" ) +#define AND_THEN( desc ) SECTION( " And: " desc, "" ) + +using Catch::Detail::Approx; + +// #included from: internal/catch_reenable_warnings.h + +#define TWOBLUECUBES_CATCH_REENABLE_WARNINGS_H_INCLUDED + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + diff --git a/lib/compacttree.h b/lib/compacttree.h index e9f47fd..394ad44 100644 --- a/lib/compacttree.h +++ b/lib/compacttree.h @@ -8,6 +8,8 @@ #include "thread.h" +namespace Morat { + /* CompactTree is a Tree of Nodes. It malloc's one chunk at a time, and has a very efficient allocation strategy. * It maintains a freelist of empty segments, but never assigns a segment to a smaller amount of memory, * completely avoiding fragmentation, but potentially having empty space in sizes that are no longer popular. @@ -489,3 +491,5 @@ template class CompactTree { last = dchunk; } }; + +}; // namespace Morat diff --git a/lib/depthstats.h b/lib/depthstats.h index 6e03be7..78dde22 100644 --- a/lib/depthstats.h +++ b/lib/depthstats.h @@ -7,7 +7,7 @@ #include "string.h" -using namespace std; +namespace Morat { struct DepthStats { uint32_t mindepth, maxdepth, num; @@ -52,10 +52,12 @@ struct DepthStats { } double std_dev() const { if(num == 0) return 0.0; - return sqrt((double)sumdepthsq/num - ((double)sumdepth/num)*((double)sumdepth/num)); + return std::sqrt((double)sumdepthsq/num - ((double)sumdepth/num)*((double)sumdepth/num)); } - string to_s() const { + std::string to_s() const { if(num == 0) return "num=0"; return to_str(avg(), 4) +", dev=" + to_str(std_dev(), 4) + ", min=" + to_str(mindepth) + ", max=" + to_str(maxdepth) + ", num=" + to_str(num); } }; + +}; // namespace Morat diff --git a/lib/exppair.h b/lib/exppair.h index 44cc05c..a68c88d 100644 --- a/lib/exppair.h +++ b/lib/exppair.h @@ -1,18 +1,32 @@ #pragma once +#include "string.h" #include "thread.h" #include "types.h" +namespace Morat { + class ExpPair { uword s, n; ExpPair(uword S, uword N) : s(S), n(N) { } public: ExpPair() : s(0), n(0) { } - float avg() const { return 0.5f*s/n; } + float avg() const { return (n ? 0.5f*s/n : 0); } uword num() const { return n; } uword sum() const { return s/2; } + std::string to_s() const { + return to_str(avg(), 3) + "/" + to_str(num()); + } + + ExpPair(std::string str) { + auto parts = explode(str, "/"); + assert(parts.size() == 2); + n = from_str(parts[1]); + s = 2.0*from_str(parts[0])*n; + } + void clear() { s = 0; n = 0; } void addvloss(){ INCR(n); } @@ -50,3 +64,5 @@ class ExpPair { return ExpPair(n*2 - s, n); } }; + +}; // namespace Morat diff --git a/lib/fileio.cpp b/lib/fileio.cpp index 07d5a68..5628ac4 100644 --- a/lib/fileio.cpp +++ b/lib/fileio.cpp @@ -3,6 +3,7 @@ #include "fileio.h" +namespace Morat { using namespace std; int fpeek(FILE * fd){ @@ -12,13 +13,31 @@ int fpeek(FILE * fd){ } void eat_whitespace(FILE * fd){ int c = fgetc(fd); - while(c == ' ' || c == '\n') + while(c == ' ' || c == '\n' || c == '\t') c = fgetc(fd); ungetc(c, fd); } -void eat_char(FILE * fd, int expect){ +void eat_whitespace(std::istream & is){ + int c = is.peek(); + while(c == ' ' || c == '\n' || c == '\t'){ + is.get(); + c = is.peek(); + } +} +bool eat_char(FILE * fd, int expect){ int c = fgetc(fd); - assert(c == expect); + if (c == expect) + return true; + ungetc(c, fd); + return false; +} +bool eat_char(std::istream & is, int expect){ + int c = is.peek(); + if (c == expect){ + is.get(); + return true; + } + return false; } string read_until(FILE * fd, char until, bool include){ string ret; @@ -31,3 +50,5 @@ string read_until(FILE * fd, char until, bool include){ ungetc(c, fd); return ret; } + +}; diff --git a/lib/fileio.h b/lib/fileio.h index 64494cd..5f7ba44 100644 --- a/lib/fileio.h +++ b/lib/fileio.h @@ -2,10 +2,16 @@ #pragma once #include +#include #include +namespace Morat { + int fpeek(FILE * fd); void eat_whitespace(FILE * fd); -void eat_char(FILE * fd, int expect); +void eat_whitespace(std::istream & is); +bool eat_char(FILE * fd, int expect); +bool eat_char(std::istream & is, int expect); std::string read_until(FILE * fd, char until, bool include = false); +}; // namespace Morat diff --git a/lib/gtpbase.h b/lib/gtpbase.h index c5d462c..51600c9 100644 --- a/lib/gtpbase.h +++ b/lib/gtpbase.h @@ -10,45 +10,46 @@ #include "string.h" -using namespace std; -using namespace placeholders; //for bind +namespace Morat { + +using namespace std::placeholders; //for bind struct GTPResponse { bool success; - string id; - string response; + std::string id; + std::string response; GTPResponse() { } - GTPResponse(bool s, string r = ""){ + GTPResponse(bool s, std::string r = ""){ success = s; response = r; rtrim(response); } - GTPResponse(string r){ + GTPResponse(std::string r){ GTPResponse(true, r); } - string to_s(){ + std::string to_s(){ return (success ? '=' : '?') + id + ' ' + response + "\n\n"; } }; -typedef function gtp_callback_fn; +typedef std::function gtp_callback_fn; struct GTPCallback { - string name; - string desc; + std::string name; + std::string desc; gtp_callback_fn func; GTPCallback() { } - GTPCallback(string n, string d, gtp_callback_fn fn) : name(n), desc(d), func(fn) { } + GTPCallback(std::string n, std::string d, gtp_callback_fn fn) : name(n), desc(d), func(fn) { } }; class GTPBase { FILE * in, * out; - vector callbacks; + std::vector callbacks; unsigned int longest_cmd; bool running; @@ -60,11 +61,11 @@ class GTPBase { longest_cmd = 0; running = false; - newcallback("list_commands", bind(>PBase::gtp_list_commands, this, _1, false), "List the commands"); - newcallback("help", bind(>PBase::gtp_list_commands, this, _1, true), "List the commands, with descriptions"); - newcallback("quit", bind(>PBase::gtp_quit, this, _1), "Quit the program"); - newcallback("exit", bind(>PBase::gtp_quit, this, _1), "Alias for quit"); - newcallback("protocol_version", bind(>PBase::gtp_protocol_version, this, _1), "Show the gtp protocol version"); + newcallback("list_commands", std::bind(>PBase::gtp_list_commands, this, _1, false), "List the commands"); + newcallback("help", std::bind(>PBase::gtp_list_commands, this, _1, true), "List the commands, with descriptions"); + newcallback("quit", std::bind(>PBase::gtp_quit, this, _1), "Quit the program"); + newcallback("exit", std::bind(>PBase::gtp_quit, this, _1), "Alias for quit"); + newcallback("protocol_version", std::bind(>PBase::gtp_protocol_version, this, _1), "Show the gtp protocol version"); } void setinfile(FILE * i){ @@ -75,7 +76,7 @@ class GTPBase { out = o; } - void newcallback(const string name, const gtp_callback_fn & fn, const string desc = ""){ + void newcallback(const std::string name, const gtp_callback_fn & fn, const std::string desc = ""){ newcallback(GTPCallback(name, desc, fn)); if(longest_cmd < name.length()) longest_cmd = name.length(); @@ -85,23 +86,23 @@ class GTPBase { callbacks.push_back(a); } - int find_callback(const string & name){ + int find_callback(const std::string & name){ for(unsigned int i = 0; i < callbacks.size(); i++) if(callbacks[i].name == name) return i; return -1; } - GTPResponse cmd(string line){ + GTPResponse cmd(std::string line){ vecstr parts = explode(line, " "); - string id; + std::string id; if(parts.size() > 1 && atoi(parts[0].c_str())){ id = parts[0]; parts.erase(parts.begin()); } - string name = parts[0]; + std::string name = parts[0]; parts.erase(parts.begin()); int cb = find_callback(name); @@ -123,7 +124,7 @@ class GTPBase { char buf[1001]; while(running && fgets(buf, 1000, in)){ - string line(buf); + std::string line(buf); trim(line); @@ -133,7 +134,7 @@ class GTPBase { GTPResponse response = cmd(line); if(out){ - string output = response.to_s(); + std::string output = response.to_s(); fwrite(output.c_str(), 1, output.length(), out); fflush(out); @@ -152,11 +153,11 @@ class GTPBase { } GTPResponse gtp_list_commands(vecstr args, bool showdesc){ - string ret = "\n"; + std::string ret = "\n"; for(unsigned int i = 0; i < callbacks.size(); i++){ ret += callbacks[i].name; if(showdesc && callbacks[i].desc.length() > 0){ - ret += string(longest_cmd + 2 - callbacks[i].name.length(), ' '); + ret += std::string(longest_cmd + 2 - callbacks[i].name.length(), ' '); ret += callbacks[i].desc; } ret += "\n"; @@ -164,3 +165,5 @@ class GTPBase { return GTPResponse(true, ret); } }; + +}; // namespace Morat diff --git a/lib/gtpcommon.cpp b/lib/gtpcommon.cpp index abb6a42..2c94240 100644 --- a/lib/gtpcommon.cpp +++ b/lib/gtpcommon.cpp @@ -1,6 +1,8 @@ #include "gtpcommon.h" +namespace Morat { + using namespace std; GTPResponse GTPCommon::gtp_echo(vecstr args) const { @@ -70,3 +72,5 @@ string GTPCommon::won_str(int outcome) const { default: return "unknown"; } } + +}; // namespace Morat diff --git a/lib/gtpcommon.h b/lib/gtpcommon.h index e018442..d119783 100644 --- a/lib/gtpcommon.h +++ b/lib/gtpcommon.h @@ -4,6 +4,8 @@ #include "gtpbase.h" #include "timecontrol.h" +namespace Morat { + class GTPCommon : public GTPBase { protected: @@ -20,3 +22,5 @@ class GTPCommon : public GTPBase { GTPResponse gtp_time(vecstr args); std::string won_str(int outcome) const; }; + +}; // namespace Morat diff --git a/lib/hashset.h b/lib/hashset.h index 35a816d..a8919ee 100644 --- a/lib/hashset.h +++ b/lib/hashset.h @@ -5,6 +5,8 @@ #include "bits.h" +namespace Morat { + class HashSet { unsigned int size; // how many slots there are, must be a power of 2 unsigned int mask; // a mask for the size, ie size-1 @@ -64,3 +66,5 @@ class HashSet { return false; } }; + +}; // namespace Morat diff --git a/lib/hashtable.h b/lib/hashtable.h index 584c41d..eb4b515 100644 --- a/lib/hashtable.h +++ b/lib/hashtable.h @@ -5,6 +5,8 @@ //#include "log.h" //#include "string.h" +namespace Morat { + class HashTable { // A simple hash table. It support resizing, but you need to call resize once in a while // to make it happen. This is so it doesn't need to check regularly whether it needs to resize. @@ -119,3 +121,5 @@ class HashTable { } } }; + +}; // namespace Morat diff --git a/rex/history.h b/lib/history.h similarity index 92% rename from rex/history.h rename to lib/history.h index 00ccd06..0055c1f 100644 --- a/rex/history.h +++ b/lib/history.h @@ -3,11 +3,13 @@ #include -#include "../lib/string.h" - -#include "board.h" #include "move.h" +#include "string.h" + +namespace Morat { + +template class History { std::vector hist; Board board; @@ -68,3 +70,5 @@ class History { return false; } }; + +}; // namespace Morat diff --git a/lib/log.h b/lib/log.h index 32c8861..090afc5 100644 --- a/lib/log.h +++ b/lib/log.h @@ -4,7 +4,10 @@ #include #include +namespace Morat { + inline void logerr(std::string str){ fprintf(stderr, "%s", str.c_str()); } +}; // namespace Morat diff --git a/hex/move.h b/lib/move.h similarity index 80% rename from hex/move.h rename to lib/move.h index 84cf035..5a21f46 100644 --- a/hex/move.h +++ b/lib/move.h @@ -2,9 +2,14 @@ #pragma once #include +#include #include -#include "../lib/string.h" +#include "outcome.h" +#include "string.h" + + +namespace Morat { enum MoveSpecial { M_SWAP = -1, //-1 so that adding 1 makes it into a valid move @@ -39,6 +44,8 @@ struct Move { return std::string() + char(y + 'a') + to_str(x + 1); } + friend std::ostream& operator<< (std::ostream &out, const Move & m) { return out << m.to_s(); } + bool operator< (const Move & b) const { return (y == b.y ? x < b.x : y < b.y); } bool operator<=(const Move & b) const { return (y == b.y ? x <= b.x : y <= b.y); } bool operator> (const Move & b) const { return (y == b.y ? x > b.x : y > b.y); } @@ -51,18 +58,12 @@ struct Move { Move & operator+=(const Move & b) { y += b.y; x += b.x; return *this; } Move operator- (const Move & b) const { return Move(x - b.x, y - b.y); } Move & operator-=(const Move & b) { y -= b.y; x -= b.x; return *this; } - - int z() const { return (x - y); } - int dist(const Move & b) const { - return (abs(x - b.x) + abs(y - b.y) + abs(z() - b.z()))/2; - } }; struct MoveScore : public Move { int16_t score; - MoveScore() : score(0) { } - MoveScore(MoveSpecial a) : Move(a), score(0) { } + MoveScore(MoveSpecial a = M_UNKNOWN) : Move(a), score(0) { } MoveScore(int X, int Y, int s) : Move(X, Y), score(s) { } MoveScore operator+ (const Move & b) const { return MoveScore(x + b.x, y + b.y, score); } }; @@ -70,22 +71,17 @@ struct MoveScore : public Move { struct MoveValid : public Move { int16_t xy; - MoveValid() : Move(), xy(-1) { } + MoveValid(MoveSpecial a = M_UNKNOWN) : Move(a), xy(-1) { } MoveValid(int x, int y, int XY) : Move(x,y), xy(XY) { } MoveValid(const Move & m, int XY) : Move(m), xy(XY) { } bool onboard() const { return xy != -1; } }; struct MovePlayer : public Move { - char player; + Side player; - MovePlayer() : Move(M_UNKNOWN), player(0) { } - MovePlayer(const Move & m, char p = 0) : Move(m), player(p) { } + MovePlayer(MoveSpecial a = M_UNKNOWN) : Move(a), player(Side::NONE) { } + MovePlayer(const Move & m, Side p = Side::NONE) : Move(m), player(p) { } }; - -struct PairMove { - Move a, b; - PairMove(Move A = M_UNKNOWN, Move B = M_UNKNOWN) : a(A), b(B) { } - PairMove(MoveSpecial A) : a(Move(A)), b(M_UNKNOWN) { } -}; +}; // namespace Morat diff --git a/lib/move_test.cpp b/lib/move_test.cpp new file mode 100644 index 0000000..bebbba4 --- /dev/null +++ b/lib/move_test.cpp @@ -0,0 +1,26 @@ + +#include "catch.hpp" + +#include "move.h" + +namespace Morat { + +TEST_CASE("Move", "[move]"){ + REQUIRE(Move() == M_UNKNOWN); + REQUIRE(Move(M_UNKNOWN) == M_UNKNOWN); + REQUIRE(Move("unknown") == M_UNKNOWN); + REQUIRE(Move("a1") == Move(0, 0)); + REQUIRE(Move("c5") == Move(4, 2)); + REQUIRE(Move("a1").to_s() == "a1"); + REQUIRE(Move("c5").to_s() == "c5"); + REQUIRE(Move().to_s() == "unknown"); + REQUIRE(Move("a1") == Move("a1")); + REQUIRE(Move("a1") < Move("c5")); + REQUIRE(Move("a1") <= Move("c5")); + REQUIRE(Move("a1") != Move("c5")); + REQUIRE(Move("c5") > Move("a1")); + REQUIRE(Move("c5") >= Move("a1")); + REQUIRE(Move("a1") != M_UNKNOWN); +} + +}; // namespace Morat diff --git a/rex/movelist.h b/lib/movelist.h similarity index 69% rename from rex/movelist.h rename to lib/movelist.h index 27c22de..5a3c6ad 100644 --- a/rex/movelist.h +++ b/lib/movelist.h @@ -1,11 +1,13 @@ #pragma once -#include "../lib/exppair.h" - -#include "board.h" +#include "exppair.h" #include "move.h" + +namespace Morat { + +template struct MoveList { ExpPair exp[2]; //aggregated outcomes overall ExpPair rave[2][Board::max_vecsize]; //aggregated outcomes per move @@ -16,10 +18,10 @@ struct MoveList { MoveList() : tree(0), rollout(0), board(NULL) { } - void addtree(const Move & move, char player){ + void addtree(const Move & move, Side player){ moves[tree++] = MovePlayer(move, player); } - void addrollout(const Move & move, char player){ + void addrollout(const Move & move, Side player){ moves[tree + rollout++] = MovePlayer(move, player); } void reset(Board * b){ @@ -33,19 +35,19 @@ struct MoveList { rave[1][i].clear(); } } - void finishrollout(int won){ + void finishrollout(Outcome won){ exp[0].addloss(); exp[1].addloss(); - if(won == 0){ + if(won == Outcome::DRAW){ exp[0].addtie(); exp[1].addtie(); }else{ - exp[won-1].addwin(); + exp[won.to_i() - 1].addwin(); for(MovePlayer * i = begin(), * e = end(); i != e; i++){ - ExpPair & r = rave[i->player-1][board->xy(*i)]; + ExpPair & r = rave[i->player.to_i() - 1][board->xy(*i)]; r.addloss(); - if(i->player == won) + if(+i->player == won) r.addwin(); } } @@ -67,10 +69,12 @@ struct MoveList { exp[0].addlosses(-n); exp[1].addlosses(-n); } - const ExpPair & getrave(int player, const Move & move) const { - return rave[player-1][board->xy(move)]; + const ExpPair & getrave(Side player, const Move & move) const { + return rave[player.to_i() - 1][board->xy(move)]; } - const ExpPair & getexp(int player) const { - return exp[player-1]; + const ExpPair & getexp(Side player) const { + return exp[player.to_i() - 1]; } }; + +}; // namespace Morat diff --git a/lib/outcome.cpp b/lib/outcome.cpp new file mode 100644 index 0000000..5849e49 --- /dev/null +++ b/lib/outcome.cpp @@ -0,0 +1,68 @@ + +#include + +#include "outcome.h" +#include "thread.h" + +namespace Morat { + +const Side Side::UNDEF = -4; +const Side Side::NONE = 0; +const Side Side::P1 = 1; +const Side Side::P2 = 2; +const Side Side::BOTH = 3; + +const Outcome Outcome::UNDEF = -4; // not yet computed +const Outcome Outcome::UNKNOWN = -3; // nothing known, could be either +const Outcome Outcome::P2_DRAW = -2; // player 2 or draw; player 1 can't win +const Outcome Outcome::P1_DRAW = -1; // player 1 or draw; player 2 can't win +const Outcome Outcome::DRAW = 0; // draw, neither player can win +const Outcome Outcome::P1 = 1; // first player win +const Outcome Outcome::P2 = 2; // second player win +const Outcome Outcome::DRAW2 = 3; // draw by simultaneous win + + +std::ostream& operator<< (std::ostream &out, const Side & s) { + return out << s.to_s(); +} + +std::ostream& operator<< (std::ostream &out, const Outcome & o) { + return out << o.to_s(); +} + +std::string Side::to_s() const { + if(*this == UNDEF) return "never"; + if(*this == NONE) return "none"; + if(*this == P1) return "white"; + if(*this == P2) return "black"; + if(*this == BOTH) return "both"; + return "unknown"; +} + +std::string Outcome::to_s() const { + if(*this == UNDEF) return "undefined"; + if(*this == UNKNOWN) return "unknown"; + if(*this == P2_DRAW) return "black_or_draw"; + if(*this == P1_DRAW) return "white_or_draw"; + if(*this == DRAW || + *this == DRAW2) return "draw"; // simultaneous win + if(*this == P1) return "white"; + if(*this == P2) return "black"; + return "unknown"; +} + +std::string Outcome::to_s_rel(Side to_play) const { + if(*this == Outcome::DRAW) return "draw"; + if(*this == Outcome::DRAW2) return "draw by simultaneous win"; + if(*this == +to_play) return "win"; + if(*this == +~to_play) return "loss"; + if(*this == -~to_play) return "win or draw"; + if(*this == -to_play) return "loss or draw"; + return "unkown"; +} + +bool Outcome::cas(Outcome old, Outcome new_) { + return CAS(outcome, old.outcome, new_.outcome); +} + +}; // namespace Morat diff --git a/lib/outcome.h b/lib/outcome.h new file mode 100644 index 0000000..cf9a816 --- /dev/null +++ b/lib/outcome.h @@ -0,0 +1,118 @@ + +#pragma once + +#include +#include +#include + +namespace Morat { + +class Side { + typedef int8_t Type; + Type side; + +public: + + static const Side UNDEF; + static const Side NONE; + static const Side P1; + static const Side P2; + static const Side BOTH; + + Side() : side(NONE.side) { } + Side(const Type & s) : side(s) { } + Side & operator=(const Type & s) { + side = s; + return *this; + } + + friend std::ostream& operator<< (std::ostream &out, const Side & s); + std::string to_s() const; + Type to_i() const { return side; } +}; + + + +// who has or will win if played perfectly +class Outcome { + typedef int8_t Type; + Type outcome; + +public: + + // positive is that player has won or will win if played perfect + // negative is that player might be able to win + static const Outcome UNDEF; // not yet computed + static const Outcome UNKNOWN; // nothing known, could be either + static const Outcome P2_DRAW; // player 2 or draw; player 1 can't win + static const Outcome P1_DRAW; // player 1 or draw; player 2 can't win + static const Outcome DRAW; // draw, neither player can win + static const Outcome P1; // first player win + static const Outcome P2; // second player win + static const Outcome DRAW2; // draw by simultaneous win + + + Outcome() : outcome(UNKNOWN.outcome) { } + Outcome(const Type & o) : outcome(o) { } + Outcome(const Side & s) : outcome(s.to_i()) { } + Outcome & operator=(const Type & o) { + outcome = o; + return *this; + } + + friend std::ostream& operator<< (std::ostream &out, const Outcome & o); + std::string to_s() const; + std::string to_s_rel(Side to_play) const; + Type to_i() const { return outcome; } + bool cas(Outcome old, Outcome new_); +}; + + + +// switch to opponent +inline Side operator~(Side s) { return Side(3 - s.to_i()); } +// can't switch sides of an Outcome! +//: Outcome operator~(Outcome o) { return (Outcome)(3 - (int)o); } + +// promote from Side to Outcome +inline Outcome operator+(Side s) { return Outcome(s.to_i()); } + +// this side can't win +inline Outcome operator-(Side s) { return Outcome(s.to_i() - 3); } +inline Outcome operator-(Outcome o) { return Outcome(o.to_i() - 3); } + + +inline bool operator == (const Side & a, const Side & b) { return a.to_i() == b.to_i(); } +inline bool operator == (const Outcome & a, const Outcome & b) { return a.to_i() == b.to_i(); } +inline bool operator == (const Side & a, const Outcome & b) { return a.to_i() == b.to_i(); } +inline bool operator == (const Outcome & a, const Side & b) { return a.to_i() == b.to_i(); } + +inline bool operator != (const Side & a, const Side & b) { return a.to_i() != b.to_i(); } +inline bool operator != (const Outcome & a, const Outcome & b) { return a.to_i() != b.to_i(); } +inline bool operator != (const Side & a, const Outcome & b) { return a.to_i() != b.to_i(); } +inline bool operator != (const Outcome & a, const Side & b) { return a.to_i() != b.to_i(); } + +//inline bool operator > (const Side & a, const Side & b) { return a.to_i() > b.to_i(); } +//inline bool operator >= (const Side & a, const Side & b) { return a.to_i() >= b.to_i(); } +//inline bool operator < (const Side & a, const Side & b) { return a.to_i() < b.to_i(); } +//inline bool operator <= (const Side & a, const Side & b) { return a.to_i() <= b.to_i(); } + +inline bool operator > (const Outcome & a, const Outcome & b) { return a.to_i() > b.to_i(); } +inline bool operator >= (const Outcome & a, const Outcome & b) { return a.to_i() >= b.to_i(); } +inline bool operator < (const Outcome & a, const Outcome & b) { return a.to_i() < b.to_i(); } +inline bool operator <= (const Outcome & a, const Outcome & b) { return a.to_i() <= b.to_i(); } + + +// for saying that P1 | P2 can win, meaning both can win +inline Side operator|(const Side & a, const Side & b) { return Side(a.to_i() | b.to_i()); } +inline Outcome operator|(const Outcome & a, const Outcome & b) { return Outcome(a.to_i() | b.to_i()); } +inline Side & operator|=(Side & o, const Side & s) { return o = o | s; } +inline Outcome & operator|=(Outcome & o, const Outcome & s) { return o = o | s; } + +// for saying that side & P1 can win, meaning P1 can't win +inline Side operator&(const Side & a, const Side & b) { return Side(a.to_i() & b.to_i()); } +inline Outcome operator&(const Outcome & a, const Outcome & b) { return Outcome(a.to_i() & b.to_i()); } +inline Side & operator&=(Side & o, const Side & s) { return o = o & s; } +inline Outcome & operator&=(Outcome & o, const Outcome & s) { return o = o & s; } + +}; // namespace Morat diff --git a/lib/outcome_test.cpp b/lib/outcome_test.cpp new file mode 100644 index 0000000..631525f --- /dev/null +++ b/lib/outcome_test.cpp @@ -0,0 +1,69 @@ + +#include "catch.hpp" + +#include "outcome.h" + +namespace Morat { + +TEST_CASE("Side and Outcome", "[side][outcome]"){ + SECTION("switching") { + REQUIRE(~Side::P1 == Side::P2); + REQUIRE(~Side::P2 == Side::P1); + } + + SECTION("Promotion") { + REQUIRE(+Side::NONE == Outcome::DRAW); // neither side wins => draw + REQUIRE(+Side::P1 == Outcome::P1); + REQUIRE(+Side::P2 == Outcome::P2); + REQUIRE(+Side::BOTH == Outcome::DRAW2); // both sides win => draw + } + + SECTION("Can't win") { + REQUIRE(-Side::NONE == Outcome::UNKNOWN); // neither side is known to not win => unknown + REQUIRE(-Side::P1 == Outcome::P2_DRAW); // p1 can't win => p2 can win or draw + REQUIRE(-Side::P2 == Outcome::P1_DRAW); // p2 can't win => p1 can win or draw + REQUIRE(-Side::BOTH == Outcome::DRAW); // both sides can't win => no one can win => draw + REQUIRE(-Outcome::P1 == Outcome::P2_DRAW); + REQUIRE(-Outcome::P2 == Outcome::P1_DRAW); + + REQUIRE(-~Side::P1 == Outcome::P1_DRAW); + REQUIRE(-~Side::P2 == Outcome::P2_DRAW); + + // invalid! wrong order! ~ and - are not commutative + //: REQUIRE(~-Side::P1 == Outcome::P1_DRAW); + } + + SECTION("Side, outcome ==, !=") { + REQUIRE(Side::P1 == Outcome::P1); + REQUIRE(Side::P2 == Outcome::P2); + REQUIRE(Outcome::P1 == Side::P1); + REQUIRE(Outcome::P2 == Side::P2); + REQUIRE(Side::P1 != Outcome::P2); + REQUIRE(Side::P2 != Outcome::P1); + REQUIRE(Outcome::P1 != Side::P2); + REQUIRE(Outcome::P2 != Side::P1); + } + + SECTION("Side |") { + REQUIRE((Side::NONE | Side::P1) == Side::P1); + REQUIRE((Side::NONE | Side::P2) == Side::P2); + REQUIRE((Side::P1 | Side::P2) == Side::BOTH); + REQUIRE((Side::BOTH | Side::P1) == Side::BOTH); + REQUIRE((Outcome::P1 | Outcome::P2) == Outcome::DRAW2); + } + + SECTION("Side &") { + REQUIRE((Side::UNDEF & Side::P1) == Side::NONE); + REQUIRE((Side::UNDEF & Side::P2) == Side::NONE); + REQUIRE((Side::NONE & Side::P1) == Side::NONE); + REQUIRE((Side::NONE & Side::P2) == Side::NONE); + REQUIRE((Side::P1 & Side::P1) == Side::P1); + REQUIRE((Side::P1 & Side::P2) == Side::NONE); + REQUIRE((Side::P2 & Side::P2) == Side::P2); + REQUIRE((Side::P2 & Side::P1) == Side::NONE); + REQUIRE((Side::BOTH & Side::P1) == Side::P1); + REQUIRE((Side::BOTH & Side::P2) == Side::P2); + } +} + +}; // namespace Morat diff --git a/havannah/policy.h b/lib/policy.h similarity index 72% rename from havannah/policy.h rename to lib/policy.h index 01309d8..fb8a1f1 100644 --- a/havannah/policy.h +++ b/lib/policy.h @@ -1,10 +1,13 @@ #pragma once -#include "board.h" -#include "move.h" -#include "movelist.h" +#include "../lib/move.h" +#include "../lib/movelist.h" + +namespace Morat { + +template class Policy { public: Policy() { } @@ -24,5 +27,7 @@ class Policy { void move_end(const Board & board, const Move & prev) { } // Game over, here's who won - void rollout_end(const MoveList & movelist, int won) { } + void rollout_end(const MoveList & movelist, int won) { } }; + +}; // namespace Morat diff --git a/rex/policy_bridge.h b/lib/policy_bridge.h similarity index 79% rename from rex/policy_bridge.h rename to lib/policy_bridge.h index c6f2b8d..470a8b5 100644 --- a/rex/policy_bridge.h +++ b/lib/policy_bridge.h @@ -3,15 +3,17 @@ #pragma once #include "../lib/bits.h" +#include "../lib/move.h" -#include "board.h" -#include "move.h" #include "policy.h" -class ProtectBridge : public Policy { +namespace Morat { + +template +class ProtectBridge : public Policy { int offset; - uint8_t lookup[2][1<<12]; + uint8_t lookup[2][1<<12]; // 2 players, all possible local 6-patterns public: @@ -21,10 +23,14 @@ class ProtectBridge : public Policy { lookup[0][i] = lookup[1][i] = 0; unsigned int p = i; for(unsigned int d = 0; d < 6; d++){ + // player 1 if((p & 0x1D) == 0x11) // 01 11 01 -> 01 00 01 lookup[0][i] |= (1 << ((d+1)%6)); // +1 because we want to play in the empty spot + + // player 2 if((p & 0x2E) == 0x22) // 10 11 10 -> 10 00 10 lookup[1][i] |= (1 << ((d+1)%6)); + p = ((p & 0xFFC)>>2) | ((p & 0x3) << 10); } } @@ -32,7 +38,7 @@ class ProtectBridge : public Policy { Move choose_move(const Board & board, const Move & prev) { uint32_t p = board.pattern_small(prev); - uint16_t r = lookup[board.toplay()-1][p]; + uint16_t r = lookup[board.toplay().to_i()-1][p]; if(!r) // nothing to save return M_UNKNOWN; @@ -49,3 +55,5 @@ class ProtectBridge : public Policy { return board.nb_begin(prev)[i]; } }; + +}; // namespace Morat diff --git a/havannah/policy_instantwin.h b/lib/policy_instantwin.h similarity index 80% rename from havannah/policy_instantwin.h rename to lib/policy_instantwin.h index bf1906b..7ea9801 100644 --- a/havannah/policy_instantwin.h +++ b/lib/policy_instantwin.h @@ -2,13 +2,15 @@ #pragma once #include "../lib/assert2.h" +#include "../lib/move.h" -#include "board.h" -#include "move.h" #include "policy.h" -class InstantWin : public Policy { +namespace Morat { + +template +class InstantWin : public Policy { int max_rollout_moves; int cur_rollout_moves; @@ -35,17 +37,17 @@ class InstantWin : public Policy { return M_UNKNOWN; //must have an edge connection, or it has nothing to offer a group towards a win - const Board::Cell * c = board.cell(prev); + const auto * c = board.cell(prev); if(c->numedges() == 0) return M_UNKNOWN; - Move start, cur, loss = M_UNKNOWN; - int turn = 3 - board.toplay(); + MoveValid start, cur, loss = M_UNKNOWN; + Side turn = ~board.toplay(); //find the first empty cell int dir = -1; for(int i = 0; i <= 5; i++){ - start = prev + neighbours[i]; + start = board.nb_begin(prev)[i]; if(!board.onboard(start) || board.get(start) != turn){ dir = (i + 5) % 6; @@ -65,7 +67,7 @@ class InstantWin : public Policy { do{ // logerr(" " + cur.to_s()); //check the current cell - if(board.onboard(cur) && board.get(cur) == 0 && board.test_win(cur, turn) > 0){ + if(board.onboard(cur) && board.get(cur) == Side::NONE && board.test_outcome(cur, turn) == +turn){ // logerr(" loss"); if(loss == M_UNKNOWN){ loss = cur; @@ -78,7 +80,7 @@ class InstantWin : public Policy { //advance to the next cell for(int i = 5; i <= 9; i++){ int nd = (dir + i) % 6; - Move next = cur + neighbours[nd]; + MoveValid next = board.nb_begin(cur)[nd]; if(!board.onboard(next) || board.get(next) != turn){ cur = next; @@ -93,3 +95,5 @@ class InstantWin : public Policy { return loss; // usually M_UNKNOWN } }; + +}; // namespace Morat diff --git a/rex/policy_lastgoodreply.h b/lib/policy_lastgoodreply.h similarity index 57% rename from rex/policy_lastgoodreply.h rename to lib/policy_lastgoodreply.h index 11fcc9a..ec0231b 100644 --- a/rex/policy_lastgoodreply.h +++ b/lib/policy_lastgoodreply.h @@ -1,11 +1,15 @@ # pragma once -#include "board.h" -#include "move.h" +#include "../lib/move.h" + #include "policy.h" -class LastGoodReply : public Policy { + +namespace Morat { + +template +class LastGoodReply : public Policy { Move goodreply[2][Board::max_vecsize]; int enabled; public: @@ -18,25 +22,27 @@ class LastGoodReply : public Policy { Move choose_move(const Board & board, const Move & prev) const { if (enabled && prev != M_SWAP) { - Move move = goodreply[board.toplay()-1][board.xy(prev)]; + Move move = goodreply[board.toplay().to_i() - 1][board.xy(prev)]; if(move != M_UNKNOWN && board.valid_move_fast(move)) return move; } return M_UNKNOWN; } - void rollout_end(const Board & board, const MoveList & movelist, int won) { - if(!enabled) + void rollout_end(const Board & board, const MoveList & movelist, Outcome outcome) { + if(!enabled || outcome != Outcome::DRAW) return; int m = -1; for(const MovePlayer * i = movelist.begin(), * e = movelist.end(); i != e; i++){ if(m >= 0){ - if(i->player == won && *i != M_SWAP) - goodreply[i->player - 1][m] = *i; + if(+i->player == outcome && *i != M_SWAP) + goodreply[i->player.to_i() - 1][m] = *i; else if(enabled == 2) - goodreply[i->player - 1][m] = M_UNKNOWN; + goodreply[i->player.to_i() - 1][m] = M_UNKNOWN; } m = board.xy(*i); } } }; + +}; // namespace Morat diff --git a/rex/policy_random.h b/lib/policy_random.h similarity index 80% rename from rex/policy_random.h rename to lib/policy_random.h index d84a82a..2bf1537 100644 --- a/rex/policy_random.h +++ b/lib/policy_random.h @@ -3,13 +3,16 @@ #include +#include "../lib/move.h" #include "../lib/xorshift.h" -#include "board.h" -#include "move.h" #include "policy.h" -class RandomPolicy : public Policy { + +namespace Morat { + +template +class RandomPolicy : public Policy { XORShift_uint32 rand; Move moves[Board::max_vecsize]; int num; @@ -22,7 +25,7 @@ class RandomPolicy : public Policy { // only need to save the valid moves once since all the rollouts start from the same position void prepare(const Board & board) { num = 0; - for(Board::MoveIterator m = board.moveit(false); !m.done(); ++m) + for(auto m = board.moveit(false); !m.done(); ++m) moves[num++] = *m; } @@ -43,3 +46,5 @@ class RandomPolicy : public Policy { } } }; + +}; // namespace Morat diff --git a/lib/sgf.h b/lib/sgf.h new file mode 100644 index 0000000..f601fb6 --- /dev/null +++ b/lib/sgf.h @@ -0,0 +1,156 @@ + +#pragma once + +// SGF reader/parser, implements the SGF format: http://www.red-bean.com/sgf/ + +#include +#include +#include +#include + +#include "fileio.h" +#include "outcome.h" +#include "string.h" + +namespace Morat { + +template +class SGFPrinter { + std::ostream & _os; + bool _root; + int _depth; + bool _indent; + + void print(std::string s) { + assert(_os.good()); + _os << s; + } + void print(std::string key, std::string value) { + print(key + "[" + value + "]"); + } + +public: + + SGFPrinter(std::ostream & os) : _os(os), _root(true), _depth(1) { + print("(;"); + print("FF", "4"); + } + + void end(){ + print("\n)\n"); + } + + void size(int s) { + assert(_root); + print("SZ", to_str(s)); + } + void game(std::string name) { + assert(_root); + print("GM", name); + } + void program(std::string name, std::string version) { + assert(_root); + print("AP", name + ":" + version); + } + + void end_root() { + assert(_root); + _root = false; + print("\n "); + } + + void child_start() { + assert(!_root); + assert(_depth >= 1); + print("\n" + std::string(_depth, ' ') + "("); + _depth++; + _indent = false; + } + void child_end() { + assert(!_root); + assert(_depth >= 1); + _depth--; + if(_indent) + print("\n" + std::string(_depth, ' ')); + print(")"); + _indent = true; + } + void move(Side s, Move m) { + assert(!_root); + print(";"); + print((s == Side::P1 ? "W" : "B"), m.to_s()); + } + void comment(std::string c) { + assert(!_root); + print("C", c); + } +}; + + +template +class SGFParser { + + std::istream & _is; + std::unordered_map _properties; + + void read_node() { + _properties.clear(); + char key[11], value[1025]; + for(int c = _is.peek(); _is.good() && 'A' <= c && c <= 'Z'; c = _is.peek()) { + _is.getline(key, 10, '['); + _is.getline(value, 1024, ']'); + _properties[key] = value; + } + } + +public: + SGFParser(std::istream & is) : _is(is) { + next_child(); + } + + bool next_node() { + eat_whitespace(_is); + if(eat_char(_is, ';')){ + read_node(); + return true; + } + return false; + } + + bool has_children() { + eat_whitespace(_is); + return (_is.peek() == '('); + } + + bool next_child() { + eat_whitespace(_is); + if(eat_char(_is, '(')) + return next_node(); + return false; + } + + bool done_child() { + eat_whitespace(_is); + return eat_char(_is, ')'); + } + + int size() { + return from_str(_properties["SZ"]); + } + std::string game() { + return _properties["GM"]; + } + + Move move() { + if(_properties.count("W")) + return Move(_properties["W"]); + if(_properties.count("B")) + return Move(_properties["B"]); + return Move(); + } + std::string comment() { + return _properties["C"]; + } +}; + +}; // namespace Morat diff --git a/lib/sgf_test.cpp b/lib/sgf_test.cpp new file mode 100644 index 0000000..f2cd6f5 --- /dev/null +++ b/lib/sgf_test.cpp @@ -0,0 +1,117 @@ + +#include + +#include "catch.hpp" +#include "move.h" +#include "sgf.h" + + +namespace Morat { + +TEST_CASE("sgf simple", "[sgf]") { + + std::stringstream s; + + { // write an sgf file + SGFPrinter sgf(s); + sgf.game("havannah"); + sgf.size(5); + sgf.end_root(); + sgf.move(Side::P1, Move("a1")); + sgf.comment("good"); + sgf.end(); + + CHECK(s.str() == "(;FF[4]GM[havannah]SZ[5]\n" + " ;W[a1]C[good]\n" + ")\n"); + } + + { // read one and get back what was written above + SGFParser sgf(s); + REQUIRE(sgf.game() == "havannah"); + REQUIRE(sgf.size() == 5); + REQUIRE(sgf.move() == Move()); + REQUIRE(sgf.next_node()); + REQUIRE_FALSE(sgf.has_children()); + REQUIRE(sgf.move() == Move("a1")); + REQUIRE(sgf.comment() == "good"); + REQUIRE_FALSE(sgf.next_node()); + REQUIRE_FALSE(sgf.has_children()); + } +} + +TEST_CASE("sgf write/read", "[sgf]") { + std::stringstream s; + + { + SGFPrinter sgf(s); + sgf.game("havannah"); + sgf.size(5); + + sgf.end_root(); + + sgf.move(Side::P1, Move("a1")); + sgf.move(Side::P2, Move("b2")); + + sgf.child_start(); + sgf.move(Side::P1, Move("c1")); + sgf.comment("c1"); + + sgf.child_start(); + sgf.move(Side::P2, Move("d1")); + sgf.comment("d1"); + sgf.child_end(); + + sgf.child_end(); + + sgf.child_start(); + sgf.move(Side::P1, Move("c2")); + sgf.comment("c2"); + sgf.child_end(); + + sgf.end(); + + CHECK(s.str() == "(;FF[4]GM[havannah]SZ[5]\n" + " ;W[a1];B[b2]\n" + " (;W[c1]C[c1]\n" + " (;B[d1]C[d1])\n" + " )\n" + " (;W[c2]C[c2])\n" + ")\n"); + } + + { // read one and get back what was written above + SGFParser sgf(s); + REQUIRE(sgf.game() == "havannah"); + REQUIRE(sgf.size() == 5); + REQUIRE(sgf.move() == Move()); + REQUIRE(sgf.next_node()); + REQUIRE(sgf.move() == Move("a1")); + REQUIRE(sgf.comment() == ""); + REQUIRE(sgf.next_node()); + REQUIRE(sgf.move() == Move("b2")); + REQUIRE_FALSE(sgf.next_node()); + REQUIRE(sgf.has_children()); + REQUIRE(sgf.next_child()); + REQUIRE(sgf.move() == Move("c1")); + REQUIRE(sgf.comment() == "c1"); + REQUIRE(sgf.has_children()); + REQUIRE(sgf.next_child()); + REQUIRE(sgf.move() == Move("d1")); + REQUIRE(sgf.comment() == "d1"); + REQUIRE_FALSE(sgf.has_children()); + REQUIRE_FALSE(sgf.next_child()); + REQUIRE(sgf.done_child()); + REQUIRE_FALSE(sgf.has_children()); + REQUIRE_FALSE(sgf.next_child()); + REQUIRE(sgf.done_child()); + REQUIRE(sgf.has_children()); + REQUIRE(sgf.next_child()); + REQUIRE(sgf.move() == Move("c2")); + REQUIRE(sgf.comment() == "c2"); + REQUIRE(sgf.done_child()); + REQUIRE(sgf.done_child()); + } +} + +}; // namespace Morat diff --git a/lib/string.cpp b/lib/string.cpp index a123e8d..b15bc36 100644 --- a/lib/string.cpp +++ b/lib/string.cpp @@ -4,6 +4,8 @@ #include "string.h" #include "types.h" +namespace Morat { + using namespace std; string to_str(double a, int prec){ @@ -11,7 +13,6 @@ string to_str(double a, int prec){ a = round(a*p)/p; stringstream out; -// out.precision(prec); out << a; return out.str(); } @@ -32,10 +33,10 @@ void rtrim(string & str){ str.erase(1 + str.find_last_not_of(space)); } -vecstr explode(const string & str, const string & sep){ +vecstr explode(const string & str, const string & sep, int count){ vecstr ret; string::size_type old = 0, pos = 0; - while((pos = str.find_first_of(sep, old)) != string::npos){ + while((pos = str.find(sep, old)) != string::npos && --count != 0){ ret.push_back(str.substr(old, pos - old)); old = pos + sep.length(); } @@ -54,3 +55,15 @@ string implode(const vecstr & vec, const string & sep){ } return ret; } + +dictstr parse_dict(const std::string & str, const std::string & sep1, const std::string & sep2) { + dictstr ret; + auto parts = explode(str, sep1); + for(auto & p : parts){ + auto kv = explode(p, sep2, 2); + ret[kv[0]] = (kv.size() == 2 ? kv[1] : ""); + } + return ret; +} + +}; // namespace Morat diff --git a/lib/string.h b/lib/string.h index 12d9df0..0de2899 100644 --- a/lib/string.h +++ b/lib/string.h @@ -6,8 +6,12 @@ #include #include #include +#include + +namespace Morat { typedef std::vector vecstr; +typedef std::unordered_map dictstr; template std::string to_str(T a){ std::stringstream out; @@ -29,5 +33,8 @@ void trim(std::string & str); void ltrim(std::string & str); void rtrim(std::string & str); -vecstr explode(const std::string & str, const std::string & sep); +vecstr explode(const std::string & str, const std::string & sep, int count=0); std::string implode(const vecstr & vec, const std::string & sep); +dictstr parse_dict(const std::string & str, const std::string & sep1, const std::string & sep2); + +}; // namespace Morat diff --git a/lib/string_test.cpp b/lib/string_test.cpp new file mode 100644 index 0000000..8db6f57 --- /dev/null +++ b/lib/string_test.cpp @@ -0,0 +1,79 @@ + +#include "catch.hpp" + +#include "string.h" + +namespace Morat { + +using namespace std; + +TEST_CASE("to_str", "[string]"){ + REQUIRE("1" == to_str(1)); + REQUIRE("1.5" == to_str(1.5)); + REQUIRE("3.14" == to_str(3.14159, 2)); +} + +TEST_CASE("from_str", "[string]"){ + REQUIRE(1 == from_str("1")); + REQUIRE(1.5 == from_str("1.5")); +} + +TEST_CASE("trim", "[string]"){ + string s = " hello world \n"; + + SECTION("trim") { + trim(s); + REQUIRE(s == "hello world"); + } + + SECTION("ltrim") { + ltrim(s); + REQUIRE(s == "hello world \n"); + } + + SECTION("rtrim") { + rtrim(s); + REQUIRE(s == " hello world"); + } +} + +TEST_CASE("explode/explode", "[string]"){ + string s = "hello cruel world"; + + SECTION("explode"){ + auto parts = explode(s, " "); + REQUIRE(parts.size() == 3); + REQUIRE(parts[0] == "hello"); + REQUIRE(parts[1] == "cruel"); + REQUIRE(parts[2] == "world"); + } + + SECTION("explode length 1"){ + auto parts = explode(s, " ", 1); + REQUIRE(parts.size() == 1); + REQUIRE(parts[0] == s); + } + + SECTION("explode length 2"){ + auto parts = explode(s, " ", 2); + REQUIRE(parts.size() == 2); + REQUIRE(parts[0] == "hello"); + REQUIRE(parts[1] == "cruel world"); + } + + SECTION("implode"){ + auto parts = explode(s, " "); + auto r = implode(parts, " "); + REQUIRE(s == r); + } +} + +TEST_CASE("parse_dict", "[string]"){ + string s = "key: value, key2: val2"; + auto d = parse_dict(s, ", ", ": "); + REQUIRE(d.size() == 2); + REQUIRE(d["key"] == "value"); + REQUIRE(d["key2"] == "val2"); +} + +}; // namespace Morat diff --git a/lib/test.cpp b/lib/test.cpp new file mode 100644 index 0000000..0c7c351 --- /dev/null +++ b/lib/test.cpp @@ -0,0 +1,2 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" diff --git a/lib/thread.h b/lib/thread.h index e5f80e2..11f0dca 100644 --- a/lib/thread.h +++ b/lib/thread.h @@ -4,12 +4,10 @@ #include #include #include +#include #include -#include "types.h" - -using namespace std; -using namespace placeholders; //for bind +namespace Morat { // http://gcc.gnu.org/onlinedocs/gcc/Atomic-Builtins.html // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2007/n2427.html @@ -59,7 +57,7 @@ template A PLUS(A & var, const B & val){ class Thread { pthread_t thread; bool destruct; - function func; + std::function func; static void * runner(void * blah){ Thread * t = (Thread *)blah; @@ -71,7 +69,7 @@ class Thread { public: Thread() : destruct(false), func(nullfunc) { } - Thread(function fn) : destruct(false), func(nullfunc) { (*this)(fn); } + Thread(std::function fn) : destruct(false), func(nullfunc) { (*this)(fn); } //act as a move constructor, no copy constructor Thread(Thread & o) { *this = o; } @@ -90,7 +88,7 @@ class Thread { return *this; } - int operator()(function fn){ + int operator()(std::function fn){ assert(destruct == false); func = fn; @@ -269,3 +267,5 @@ class Barrier { } }; //*/ + +}; // namespace Morat diff --git a/lib/time.h b/lib/time.h index ba21477..79cf5af 100644 --- a/lib/time.h +++ b/lib/time.h @@ -2,8 +2,11 @@ #pragma once #include +#include #include +namespace Morat { + class Time { double t; public: @@ -17,10 +20,10 @@ class Time { t = time.tv_sec + (double)time.tv_usec/1000000; } - double to_f() const { return t; } - long long to_i() const { return (long long)t; } - long long in_msec() const { return (long long)(t*1000); } - long long in_usec() const { return (long long)(t*1000000); } + double to_f() const { return t; } + uint64_t to_i() const { return (uint64_t)t; } + uint64_t in_msec() const { return (uint64_t)(t*1000); } + uint64_t in_usec() const { return (uint64_t)(t*1000000); } Time operator + (double a) const { return Time(t+a); } Time & operator += (double a) { t += a; return *this; } @@ -36,3 +39,4 @@ class Time { bool operator != (const Time & a) const { return t != a.t; } }; +}; // namespace Morat diff --git a/lib/timecontrol.h b/lib/timecontrol.h index a3dd242..f7fd53c 100644 --- a/lib/timecontrol.h +++ b/lib/timecontrol.h @@ -1,8 +1,11 @@ #pragma once +#include #include +namespace Morat { + struct TimeControl { enum Method { PERCENT, EVEN, STATS }; Method method; // method to use to distribute the remaining time @@ -65,9 +68,11 @@ struct TimeControl { if(flexible) remain += move - used; else - remain += min(0.0, move - used); + remain += std::min(0.0, move - used); if(remain < 0) remain = 0; } }; + +}; // namespace Morat diff --git a/lib/timer.h b/lib/timer.h index 1a8f56b..d041a4b 100644 --- a/lib/timer.h +++ b/lib/timer.h @@ -6,13 +6,12 @@ #include "thread.h" -using namespace std; -using namespace placeholders; //for bind +namespace Morat { class Timer { Thread thread; bool destruct; - function callback; + std::function callback; double timeout; void waiter(){ @@ -27,15 +26,15 @@ class Timer { Timer() { timeout = 0; destruct = false; - callback = bind(&Timer::nullcallback, this); + callback = std::bind(&Timer::nullcallback, this); } - Timer(double time, function fn){ + Timer(double time, std::function fn){ destruct = false; set(time, fn); } - void set(double time, function fn){ + void set(double time, std::function fn){ cancel(); timeout = time; @@ -45,14 +44,14 @@ class Timer { fn(); }else{ destruct = true; - thread(bind(&Timer::waiter, this)); + thread(std::bind(&Timer::waiter, this)); } } void cancel(){ if(destruct){ destruct = false; - callback = bind(&Timer::nullcallback, this); + callback = std::bind(&Timer::nullcallback, this); thread.cancel(); thread.join(); } @@ -62,3 +61,5 @@ class Timer { cancel(); } }; + +}; // namespace Morat diff --git a/lib/types.h b/lib/types.h index ad1563b..d3bdbae 100644 --- a/lib/types.h +++ b/lib/types.h @@ -5,6 +5,8 @@ #include +namespace Morat { + typedef unsigned char uchar; typedef unsigned short ushort; typedef unsigned int uint; @@ -33,3 +35,5 @@ typedef uint64_t hash_t; #else #error Unknown word size #endif + +}; // namespace Morat diff --git a/lib/weightedrandscan.h b/lib/weightedrandscan.h index f822b04..401f860 100644 --- a/lib/weightedrandscan.h +++ b/lib/weightedrandscan.h @@ -9,6 +9,8 @@ This is useful for softmax and similar, used in the rollout policy. #include "bits.h" #include "xorshift.h" +namespace Morat { + // O(1) updates, O(s) choose class WeightedRandScan { mutable XORShift_double unitrand; @@ -102,3 +104,5 @@ class WeightedRandScan { return -1; } }; + +}; // namespace Morat diff --git a/lib/weightedrandtree.h b/lib/weightedrandtree.h index 4c7f1e5..e80ccca 100644 --- a/lib/weightedrandtree.h +++ b/lib/weightedrandtree.h @@ -11,6 +11,8 @@ Most operations are O(log n). #include "bits.h" #include "xorshift.h" +namespace Morat { + //rounds to power of 2 sizes, completely arbitrary weights // O(log n) updates, O(log n) choose class WeightedRandTree { @@ -108,3 +110,5 @@ class WeightedRandTree { return i - size; } }; + +}; // namespace Morat diff --git a/lib/xorshift.h b/lib/xorshift.h index fb0b2f1..5027ab1 100644 --- a/lib/xorshift.h +++ b/lib/xorshift.h @@ -2,24 +2,28 @@ #pragma once //Generates random numbers using the XORShift algorithm. +//Read http://xorshift.di.unimi.it/ for more details #include +#include "bits.h" #include "time.h" +namespace Morat { + //generates 32 bit values, has a 32bit period class XORShift_uint32 { uint32_t r; public: XORShift_uint32(uint32_t s = 0) { seed(s); } - void seed(uint32_t s) { r = (s ? s : Time().in_usec()); } + void seed(uint32_t s) { r = mix_bits(s ? s : (uint32_t)Time().in_usec()); } uint32_t operator()() { return rand(); } protected: uint32_t rand(){ r ^= (r << 13); r ^= (r >> 17); r ^= (r << 5); - return r; + return r * 1597334677; } }; @@ -28,14 +32,35 @@ class XORShift_uint64 { uint64_t r; public: XORShift_uint64(uint64_t s = 0) { seed(s); } - void seed(uint64_t s) { r = (s ? s : Time().in_usec()); } + void seed(uint64_t s) { r = mix_bits(s ? s : Time().in_usec()); } uint64_t operator()() { return rand(); } protected: uint64_t rand(){ - r ^= (r >> 17); - r ^= (r << 31); - r ^= (r >> 8); - return r; + r ^= r >> 12; // a + r ^= r << 25; // b + r ^= r >> 27; // c + return r * 2685821657736338717LL; + } +}; + +//generates 64 bit values, has a 128bit period +class XORShift_uint128 { + uint64_t r[2]; +public: + XORShift_uint128(uint64_t s = 0) { seed(s); } + void seed(uint64_t s) { + r[0] = mix_bits(s ? s : Time().in_usec()); + r[1] = mix_bits(r[0]); + } + uint64_t operator()() { return rand(); } +protected: + uint64_t rand(){ + uint64_t r1 = r[0]; + const uint64_t r0 = r[1]; + r[0] = r0; + r1 ^= r1 << 23; // a + r[1] = r1 ^ r0 ^ (r1 >> 17) ^ (r0 >> 26); // b, c + return r[1] + r0; } }; @@ -52,3 +77,5 @@ class XORShift_double : XORShift_uint64 { XORShift_double(uint64_t seed = 0) : XORShift_uint64(seed) {} double operator()() { return static_cast(rand()) * (1. / 18446744073709551616.); } // divide by 2^64 }; + +}; // namespace Morat diff --git a/lib/zobrist.cpp b/lib/zobrist.cpp index 64ea2b5..3675fb4 100644 --- a/lib/zobrist.cpp +++ b/lib/zobrist.cpp @@ -1,6 +1,8 @@ #include "zobrist.h" +namespace Morat { + const uint64_t zobrist_strings[4096] = { 0xa0c99a1c59023682ull, // 0 0x491b86e3b32b998dull, @@ -4099,3 +4101,5 @@ const uint64_t zobrist_strings[4096] = { 0x6f1f252f0dac492eull, 0xb58a3a80a24625b4ull }; + +}; // namespace Morat diff --git a/lib/zobrist.h b/lib/zobrist.h index 3533cf3..5ab02e1 100644 --- a/lib/zobrist.h +++ b/lib/zobrist.h @@ -5,6 +5,8 @@ #include +namespace Morat { + extern const uint64_t zobrist_strings[4096]; template @@ -40,3 +42,5 @@ class Zobrist { return m; } }; + +}; // namespace Morat diff --git a/pentago/agent.h b/pentago/agent.h index 93d7c50..767bdb2 100644 --- a/pentago/agent.h +++ b/pentago/agent.h @@ -3,12 +3,20 @@ //Interface for the various agents: players and solvers +#include "../lib/outcome.h" +#include "../lib/sgf.h" #include "../lib/types.h" #include "board.h" #include "moveiterator.h" + +namespace Morat { +namespace Pentago { + class Agent { +protected: + typedef std::vector vecmove; public: Agent() { } virtual ~Agent() { } @@ -20,31 +28,37 @@ class Agent { virtual void set_memlimit(uint64_t lim) = 0; // in bytes virtual void clear_mem() = 0; - virtual vector get_pv() const = 0; - string move_stats() const { return move_stats(vector()); } - virtual string move_stats(const vector moves) const = 0; + virtual vecmove get_pv() const = 0; + std::string move_stats() const { return move_stats(vecmove()); } + virtual std::string move_stats(const vecmove moves) const = 0; virtual double gamelen() const = 0; virtual void timedout(){ timeout = true; } + virtual void gen_sgf(SGFPrinter & sgf, int limit) const = 0; + virtual void load_sgf(SGFParser & sgf) = 0; + protected: volatile bool timeout; Board rootboard; - static int solve1ply(const Board & board, unsigned int & nodes) { - int outcome = -4; - int turn = board.toplay(); + static Outcome solve1ply(const Board & board, unsigned int & nodes) { + Outcome outcome = Outcome::UNDEF; + Side turn = board.toplay(); for(MoveIterator move(board); !move.done(); ++move){ ++nodes; - int won = move.board().won(); + Outcome won = move.board().won(); - if(won == turn) + if(won == +turn) return won; - else if(won == 0) - outcome = 0; - else if(outcome == 3 - turn || outcome == -4) + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; + else if(outcome == +~turn || outcome == Outcome::UNDEF) outcome = won; } return outcome; } }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentab.cpp b/pentago/agentab.cpp index 229c1c0..ad8f4b2 100644 --- a/pentago/agentab.cpp +++ b/pentago/agentab.cpp @@ -6,6 +6,10 @@ #include "agentab.h" + +namespace Morat { +namespace Pentago { + void AgentAB::search(double time, uint64_t maxiters, int verbose) { reset(); if(rootboard.won() >= 0) @@ -41,14 +45,14 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { if(verbose){ logerr("Finished: " + to_str(nodes_seen) + " nodes in " + to_str(time_used*1000, 0) + " msec: " + to_str((uint64_t)((double)nodes_seen/time_used)) + " Nodes/s\n"); - vector pv = get_pv(); - string pvstr; - for(vector::iterator m = pv.begin(); m != pv.end(); ++m) + vecmove pv = get_pv(); + std::string pvstr; + for(vecmove::iterator m = pv.begin(); m != pv.end(); ++m) pvstr += " " + m->to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } } @@ -56,11 +60,11 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int depth) { nodes_seen++; - int won = board.won(); - if(won >= 0){ - if(won == 0) + Outcome won = board.won(); + if(won >= Outcome::DRAW){ + if(won == Outcome::DRAW) return SCORE_DRAW; - if(won == board.toplay()) + if(won == +board.toplay()) return SCORE_WIN; return SCORE_LOSS; } @@ -81,8 +85,8 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if(TT && (node = tt_get(board)) && node->depth >= depth){ switch(node->flag){ case VALID: return node->score; - case LBOUND: alpha = max(alpha, node->score); break; - case UBOUND: beta = min(beta, node->score); break; + case LBOUND: alpha = std::max(alpha, node->score); break; + case UBOUND: beta = std::min(beta, node->score); break; default: assert(false && "Unknown flag!"); } if(alpha >= beta) @@ -94,11 +98,11 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d Board n = board; bool move_success = n.move(bestmove); -// if(!move_success){ -// logerr("FAIL!!!\nhash: " + to_str(board.hash()) + ", orientation: " + to_str(board.orient()) + ", state: " + board.state() + "\n"); -// logerr(node->to_s(board.orient()) + "\n"); -// logerr(board.to_s()); -// } + if(!move_success){ + logerr("FAIL!!!\nhash: " + to_str(board.simple_hash()) + ", state: " + board.state() + "\n"); + logerr(node->to_s() + "\n"); + logerr(board.to_s()); + } assert(move_success); score = -negamax(n, -beta, -alpha, depth-1); @@ -111,7 +115,7 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d //generate moves for (RandomMoveIterator move(board, rand); !move.done(); ++move) { // for (MoveIterator move(board); !move.done(); ++move) { - int16_t value = -negamax(move.board(), -beta, -max(alpha, score), depth-1); + int16_t value = -negamax(move.board(), -beta, -std::max(alpha, score), depth-1); if (score < value) { score = value; bestmove = *move; @@ -125,16 +129,16 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if (TT) { uint8_t flag = (score <= alpha ? UBOUND : score >= beta ? LBOUND : VALID); - tt_set(Node(board.hash(), score, bestmove, depth, flag)); + tt_set(Node(board.simple_hash(), score, bestmove, depth, flag)); } return score; } -string AgentAB::move_stats(vector moves) const { - string s = ""; +std::string AgentAB::move_stats(vecmove moves) const { + std::string s = ""; Board b = rootboard; - for(vector::iterator m = moves.begin(); m != moves.end(); ++m) + for(vecmove::iterator m = moves.begin(); m != moves.end(); ++m) b.move(*m); for(MoveIterator move(b); !move.done(); ++move){ @@ -167,8 +171,8 @@ Move AgentAB::return_move(const Board & board, int verbose) const { return best; } -vector AgentAB::get_pv() const { - vector pv; +std::vector AgentAB::get_pv() const { + vecmove pv; Board b = rootboard; int i = 20; @@ -193,7 +197,7 @@ AgentAB::Node * AgentAB::tt(uint64_t hash) const { } AgentAB::Node * AgentAB::tt_get(const Board & b) const { - return tt_get(b.hash()); + return tt_get(b.simple_hash()); } AgentAB::Node * AgentAB::tt_get(uint64_t h) const { Node * n = tt(h); @@ -202,3 +206,6 @@ AgentAB::Node * AgentAB::tt_get(uint64_t h) const { void AgentAB::tt_set(const Node & n) { *(tt(n.hash)) = n; } + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentab.h b/pentago/agentab.h index 9eb276e..f529dc0 100644 --- a/pentago/agentab.h +++ b/pentago/agentab.h @@ -3,10 +3,15 @@ //An Alpha-beta solver, single threaded with an optional transposition table. +#include "../lib/log.h" #include "../lib/xorshift.h" #include "agent.h" + +namespace Morat { +namespace Pentago { + class AgentAB : public Agent { static const int16_t SCORE_WIN = 32767; static const int16_t SCORE_LOSS = -32767; @@ -30,12 +35,11 @@ class AgentAB : public Agent { Node(uint64_t h = ~0ull, int16_t s = 0, Move b = M_UNKNOWN, int8_t d = 0, int8_t f = 0) : //. int8_t o = -3 hash(h), score(s), bestmove(b), depth(d), flag(f), padding(0xDEAD) { } //, outcome(o) - string to_s(int orientation=8) const { + std::string to_s() const { return "score " + to_str(score) + ", depth " + to_str((int)depth) + ", flag " + to_str((int)flag) + - ", best " + bestmove.to_s(true) + - (orientation == 8 ? string() : "/" + bestmove.rotate(orientation).to_s(true)); + ", best " + bestmove.to_s(); } }; @@ -94,8 +98,16 @@ class AgentAB : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(rootboard, verbose); } double gamelen() const { return rootboard.moves_remain(); } - vector get_pv() const; - string move_stats(vector moves) const; + vecmove get_pv() const; + std::string move_stats(vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + logerr("gen_sgf not supported in the ab agent."); + } + + void load_sgf(SGFParser & sgf) { + logerr("load_sgf not supported in the ab agent."); + } private: int16_t negamax(const Board & board, int16_t alpha, int16_t beta, int depth); @@ -106,3 +118,6 @@ class AgentAB : public Agent { Node * tt_get(const Board & b) const ; void tt_set(const Node & n) ; }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentmcts.cpp b/pentago/agentmcts.cpp index 8da7682..f59d893 100644 --- a/pentago/agentmcts.cpp +++ b/pentago/agentmcts.cpp @@ -3,118 +3,80 @@ #include #include "../lib/alarm.h" +#include "../lib/fileio.h" #include "../lib/string.h" #include "../lib/time.h" #include "agentmcts.h" #include "board.h" + +namespace Morat { +namespace Pentago { + const float AgentMCTS::min_rave = 0.1; -void AgentMCTS::MCTSThread::run(){ - while(true){ - switch(player->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_Start, Thread_Running); - CAS(player->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(player->rootboard.won() >= 0 || player->root.outcome >= 0 || (player->maxruns > 0 && player->runs >= player->maxruns)){ //solved or finished runs - if(CAS(player->threadstate, Thread_Running, Thread_Wait_End) && player->root.outcome >= 0) - logerr("Solved as " + to_str((int)player->root.outcome) + "\n"); - break; - } - if(player->ctmem.memalloced() >= player->maxmem){ //out of memory, start garbage collection - CAS(player->threadstate, Thread_Running, Thread_GC); - break; - } +std::string AgentMCTS::Node::to_s() const { + return "AgentMCTS::Node" + ", move " + move.to_s() + + ", exp " + exp.to_s() + + ", know " + to_str(know) + + ", outcome " + to_str((int)outcome.to_i()) + + ", depth " + to_str((int)proofdepth) + + ", best " + bestmove.to_s() + + ", children " + to_str(children.num()); +} - INCR(player->runs); - iterate(); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(player->gcbarrier.wait()){ - Time starttime; - logerr("Starting player GC with limit " + to_str(player->gclimit) + " ... "); - uint64_t nodesbefore = player->nodes; - Board copy = player->rootboard; - player->garbage_collect(copy, & player->root); - Time gctime; - player->ctmem.compact(1.0, 0.75); - Time compacttime; - logerr(to_str(100.0*player->nodes/nodesbefore, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(player->ctmem.meminuse() >= player->maxmem/2) - player->gclimit = (int)(player->gclimit*1.3); - else if(player->gclimit > 5) - player->gclimit = (int)(player->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(player->threadstate, Thread_GC, Thread_Running); - CAS(player->threadstate, Thread_GC_End, Thread_Wait_End); - } - player->gcbarrier.wait(); - break; - } +bool AgentMCTS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 8){ + move = Move(dict["move"]); + exp = ExpPair(dict["exp"]); + know = from_str(dict["know"]); + outcome = Outcome(from_str(dict["outcome"])); + proofdepth = from_str(dict["depth"]); + bestmove = Move(dict["best"]); + // ignore children + return true; } + return false; } void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ - time_used = 0; - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) + if(rootboard.won() >= Outcome::DRAW || (time <= 0 && max_runs == 0)) return; Time starttime; - stop_threads(); + pool.pause(); if(runs) logerr("Pondered " + to_str(runs) + " runs\n"); runs = 0; maxruns = max_runs; - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); - + pool.reset(); //let them run! - start_threads(); - - Alarm timer; - if(time > 0) - timer(time - (Time() - starttime), std::bind(&AgentMCTS::timedout, this)); + pool.resume(); - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); + pool.wait_pause(time); - time_used = Time() - starttime; + double time_used = Time() - starttime; if(verbose){ DepthStats gamelen, treelen; double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < threads.size(); i++){ - gamelen += threads[i]->gamelen; - treelen += threads[i]->treelen; + for(auto & t : pool){ + gamelen += t->gamelen; + treelen += t->treelen; + for(int a = 0; a < 4; a++) - times[a] += threads[i]->times[a]; + times[a] += t->times[a]; } logerr("Finished: " + to_str(runs) + " runs in " + to_str(time_used*1000, 0) + " msec: " + to_str(runs/time_used, 0) + " Games/s\n"); @@ -125,45 +87,36 @@ void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ logerr("Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"); } - if(root.outcome != -3){ - logerr("Solved as a "); - if( root.outcome == 0) logerr("draw\n"); - else if(root.outcome == 3) logerr("draw by simultaneous win\n"); - else if(root.outcome == toplay) logerr("win\n"); - else if(root.outcome == 3-toplay) logerr("loss\n"); - else if(root.outcome == -toplay) logerr("win or draw\n"); - else if(root.outcome == toplay-3) logerr("loss or draw\n"); - } + if(root.outcome != Outcome::UNKNOWN) + logerr("Solved as a " + root.outcome.to_s_rel(toplay) + "\n"); - vector pv = get_pv(); - string pvstr; - for(vector::iterator m = pv.begin(); m != pv.end(); ++m) - pvstr += " " + m->to_s(); + std::string pvstr; + for(auto m : get_pv()) + pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); + pool.reset(); runs = 0; - if(ponder && root.outcome < 0) - start_threads(); + if(ponder && root.outcome < Outcome::DRAW) + pool.resume(); } -AgentMCTS::AgentMCTS() { +AgentMCTS::AgentMCTS() : pool(this) { nodes = 0; runs = 0; gclimit = 5; - time_used = 0; profile = false; ponder = false; //#ifdef SINGLE_THREAD ... make sure only 1 thread numthreads = 1; + pool.set_num_threads(numthreads); maxmem = 1000*1024*1024; explore = 1; @@ -183,86 +136,27 @@ AgentMCTS::AgentMCTS() { instantwin = 0; - //no threads started until a board is set - threadstate = Thread_Wait_Start; } AgentMCTS::~AgentMCTS(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly + pool.pause(); + pool.set_num_threads(0); root.dealloc(ctmem); ctmem.compact(); } -void AgentMCTS::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); -} - -string AgentMCTS::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void AgentMCTS::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void AgentMCTS::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void AgentMCTS::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new MCTSThread(this)); -} void AgentMCTS::set_ponder(bool p){ if(ponder != p){ ponder = p; - stop_threads(); + pool.pause(); if(ponder) - start_threads(); + pool.resume(); } } void AgentMCTS::set_board(const Board & board, bool clear){ - stop_threads(); + pool.pause(); nodes -= root.dealloc(ctmem); root = Node(); @@ -270,13 +164,11 @@ void AgentMCTS::set_board(const Board & board, bool clear){ rootboard = board; - reset_threads(); //needed since the threads aren't started before a board it set - if(ponder) - start_threads(); + pool.resume(); } void AgentMCTS::move(const Move & m){ - stop_threads(); + pool.pause(); uword nodesbefore = nodes; @@ -296,7 +188,7 @@ void AgentMCTS::move(const Move & m){ root.swap_tree(child); if(nodesbefore > 0) - logerr("Nodes: before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr("Nodes before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); }else{ nodes -= root.dealloc(ctmem); root = Node(); @@ -307,30 +199,30 @@ void AgentMCTS::move(const Move & m){ rootboard.move(m); root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; + if(rootboard.won() < Outcome::DRAW) + root.outcome = Outcome::UNKNOWN; if(ponder) - start_threads(); + pool.resume(); } double AgentMCTS::gamelen() const { DepthStats len; - for(unsigned int i = 0; i < threads.size(); i++) - len += threads[i]->gamelen; + for(auto & t : pool) + len += t->gamelen; return len.avg(); } -vector AgentMCTS::get_pv() const { - vector pv; +std::vector AgentMCTS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -339,25 +231,30 @@ vector AgentMCTS::get_pv() const { return pv; } -string AgentMCTS::move_stats(vector moves) const { - string s = ""; +std::string AgentMCTS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; - for(vector::iterator m = moves.begin(); node && m != moves.end(); ++m) - node = find_child(node, *m); + if(moves.size()){ + s += "path:\n"; + for(auto m : moves){ + if(node){ + node = find_child(node, m); + s += node->to_s() + "\n"; + } + } + } if(node){ - Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++) - if(child->move != M_NONE) - s += child->to_s() + "\n"; + s += "children:\n"; + for(auto & n : node->children) + s += n.to_s() + "\n"; } return s; } -Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { - if(node->outcome >= 0) +Move AgentMCTS::return_move(const Node * node, Side toplay, int verbose) const { + if(node->outcome >= Outcome::DRAW) return node->bestmove; double val, maxval = -1000000000000.0; //1 trillion @@ -367,10 +264,10 @@ Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss + if(child->outcome >= Outcome::DRAW){ + if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win + else if(child->outcome == Outcome::DRAW) val = -400000000000.0 + child->exp.num(); //longest tie + else val = -800000000000.0 + child->exp.num(); //longest loss }else{ //not proven // val = child->exp.num(); //num simulations val = child->exp.sum(); //num wins @@ -394,13 +291,13 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ Node * child = node->children.begin(), * end = node->children.end(); - int toplay = board.toplay(); + Side toplay = board.toplay(); for( ; child != end; child++){ if(child->children.num() == 0) continue; - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved + if( (node->outcome >= Outcome::DRAW && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == Outcome::DRAW)) || //parent is solved, only keep the proof tree, plus heavy draws + (node->outcome < Outcome::DRAW && child->exp.num() > (child->outcome >= Outcome::DRAW ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved board.move(child->move); garbage_collect(board, child); board.undo(child->move); @@ -417,3 +314,60 @@ AgentMCTS::Node * AgentMCTS::find_child(const Node * node, const Move & move) co return NULL; } + +void AgentMCTS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.exp.num() >= limit && (side != node.outcome || child.outcome == node.outcome)){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentMCTS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + + node->children.alloc(board.moves_avail(), ctmem); + + Node * child = node->children.begin(), + * end = node->children.end(); + MoveIterator moveit(board, prunesymmetry); + int nummoves = 0; + for(; !moveit.done() && child != end; ++moveit, ++child){ + *child = Node(*moveit); + nummoves++; + } + + if(prunesymmetry) + node->children.shrink(nummoves); //shrink the node to ignore the extra moves + else //both end conditions should happen in parallel + assert(moveit.done() && child == end); + + PLUS(nodes, node->children.num()); +} + +void AgentMCTS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, & node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentmcts.h b/pentago/agentmcts.h index 42896a2..8bee2d2 100644 --- a/pentago/agentmcts.h +++ b/pentago/agentmcts.h @@ -6,6 +6,7 @@ #include #include +#include "../lib/agentpool.h" #include "../lib/compacttree.h" #include "../lib/depthstats.h" #include "../lib/exppair.h" @@ -19,6 +20,10 @@ #include "board.h" #include "move.h" + +namespace Morat { +namespace Pentago { + class AgentMCTS : public Agent{ public: @@ -26,7 +31,7 @@ class AgentMCTS : public Agent{ public: ExpPair exp; int16_t know; - int8_t outcome; + Outcome outcome; uint8_t proofdepth; Move move; Move bestmove; //if outcome is set, then bestmove is the way to get there @@ -35,8 +40,8 @@ class AgentMCTS : public Agent{ //seems to need padding to multiples of 8 bytes or it segfaults? //don't forget to update the copy constructor/operator - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } + Node() : know(0), outcome(Outcome::UNKNOWN), proofdepth(0), move(M_NONE) { } + Node(const Move & m, Outcome o = Outcome::UNKNOWN) : know(0), outcome(o), proofdepth(0), move(m) { } Node(const Node & n) { *this = n; } Node & operator = (const Node & n){ if(this != & n){ //don't copy to self @@ -58,17 +63,8 @@ class AgentMCTS : public Agent{ children.swap(n.children); } - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", know " + to_str(know) + - ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); unsigned int size() const { unsigned int num = children.num(); @@ -122,38 +118,36 @@ class AgentMCTS : public Agent{ MoveList() { } - void addtree(const Move & move, char turn){ + void addtree(const Move & move, Side turn){ } - void addrollout(const Move & move, char turn){ + void addrollout(const Move & move, Side turn){ } void reset(Board * b){ exp[0].clear(); exp[1].clear(); } - void finishrollout(int won){ + void finishrollout(Outcome won){ exp[0].addloss(); exp[1].addloss(); - if(won == 0){ + if(won == Outcome::DRAW){ exp[0].addtie(); exp[1].addtie(); }else{ - exp[won-1].addwin(); + exp[won.to_i() - 1].addwin(); } } void subvlosses(int n){ exp[0].addlosses(-n); exp[1].addlosses(-n); } - const ExpPair & getexp(int turn) const { - return exp[turn-1]; + const ExpPair & getexp(Side turn) const { + return exp[turn.to_i() - 1]; } }; - class MCTSThread { + class AgentThread : public AgentThreadBase { mutable XORShift_uint64 rand64; mutable XORShift_float unitrand; - Thread thread; - AgentMCTS * player; bool use_explore; //whether to use exploration for this simulation MoveList movelist; int stage; //which of the four MCTS stages is it on @@ -163,11 +157,8 @@ class AgentMCTS : public Agent{ double times[4]; //time spent in each of the stages Time timestamps[4]; //timestamps for the beginning, before child creation, before rollout, after rollout - MCTSThread(AgentMCTS * p) : rand64(std::rand()), unitrand(std::rand()), player(p) { - reset(); - thread(bind(&MCTSThread::run, this)); - } - ~MCTSThread() { } + AgentThread(AgentThreadPool * p, AgentMCTS * a) : AgentThreadBase(p, a) { } + void reset(){ treelen.reset(); @@ -177,18 +168,15 @@ class AgentMCTS : public Agent{ times[a] = 0; } - int join(){ return thread.join(); } private: - void run(); //thread runner, calls iterate on each iteration void iterate(); //handles each iteration void walk_tree(Board & board, Node * node, int depth); - bool create_children(const Board & board, Node * node, int toplay); + bool create_children(const Board & board, Node * node); void add_knowledge(const Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay) const; + Node * choose_move(const Node * node, Side toplay) const; - int rollout(Board & board, Move move, int depth); -// PairMove rollout_choose_move(Board & board, const Move & prev, int & doinstwin, bool checkrings); + Outcome rollout(Board & board, Move move, int depth); }; @@ -231,30 +219,11 @@ class AgentMCTS : public Agent{ CompactTree ctmem; - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - double time_used; + AgentThreadPool pool; AgentMCTS(); ~AgentMCTS(); - string statestring(); - - void stop_threads(); - void start_threads(); - void reset_threads(); - void set_memlimit(uint64_t lim) { }; // in bytes void clear_mem() { }; @@ -267,14 +236,59 @@ class AgentMCTS : public Agent{ Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + bool done() { + //solved or finished runs + return (rootboard.won() >= Outcome::DRAW || root.outcome >= Outcome::DRAW || (maxruns > 0 && runs >= maxruns)); + } + + bool need_gc() { + //out of memory, start garbage collection + return (ctmem.memalloced() >= maxmem); + } + + void start_gc() { + Time starttime; + logerr("Starting player GC with limit " + to_str(gclimit) + " ... "); + uint64_t nodesbefore = nodes; + Board copy = rootboard; + garbage_collect(copy, & root); + Time gctime; + ctmem.compact(1.0, 0.75); + Time compacttime; + logerr(to_str(100.0*nodes/nodesbefore, 1) + " % of tree remains - " + + to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); + + if(ctmem.meminuse() >= maxmem/2) + gclimit = (int)(gclimit*1.3); + else if(gclimit > rollouts*5) + gclimit = (int)(gclimit*0.9); //slowly decay to a minimum of 5 + } + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0) + limit = root.exp.num()/1000; + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } - void timedout(); protected: void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - bool do_backup(Node * node, Node * backup, int toplay); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + bool do_backup(Node * node, Node * backup, Side toplay); + Move return_move(const Node * node, Side toplay, int verbose = 0) const; + Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const ; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentmcts_test.cpp b/pentago/agentmcts_test.cpp new file mode 100644 index 0000000..734aa13 --- /dev/null +++ b/pentago/agentmcts_test.cpp @@ -0,0 +1,15 @@ + +#include "../lib/catch.hpp" + +#include "agentmcts.h" + +using namespace Morat; +using namespace Pentago; + +TEST_CASE("Pentago::AgentMCTS::Node::to_s/from_s", "[pentago][agentmcts]") { + AgentMCTS::Node n(Move("a1")); + auto s = n.to_s(); + AgentMCTS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/pentago/agentmctsthread.cpp b/pentago/agentmctsthread.cpp index 07762a0..425bc6d 100644 --- a/pentago/agentmctsthread.cpp +++ b/pentago/agentmctsthread.cpp @@ -7,19 +7,24 @@ #include "agentmcts.h" #include "moveiterator.h" -void AgentMCTS::MCTSThread::iterate(){ - if(player->profile){ + +namespace Morat { +namespace Pentago { + +void AgentMCTS::AgentThread::iterate(){ + INCR(agent->runs); + if(agent->profile){ timestamps[0] = Time(); stage = 0; } - movelist.reset(&(player->rootboard)); - player->root.exp.addvloss(); - Board copy = player->rootboard; - walk_tree(copy, & player->root, 0); - player->root.exp.addv(movelist.getexp(3-player->rootboard.toplay())); + movelist.reset(&(agent->rootboard)); + agent->root.exp.addvloss(); + Board copy = agent->rootboard; + walk_tree(copy, & agent->root, 0); + agent->root.exp.addv(movelist.getexp(~agent->rootboard.toplay())); - if(player->profile){ + if(agent->profile){ times[0] += timestamps[1] - timestamps[0]; times[1] += timestamps[2] - timestamps[1]; times[2] += timestamps[3] - timestamps[2]; @@ -27,20 +32,20 @@ void AgentMCTS::MCTSThread::iterate(){ } } -void AgentMCTS::MCTSThread::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); +void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ + Side toplay = board.toplay(); - if(!node->children.empty() && node->outcome < 0){ + if(!node->children.empty() && node->outcome < Outcome::DRAW){ //choose a child and recurse Node * child; do{ child = choose_move(node, toplay); - if(child->outcome < 0){ + if(child->outcome < Outcome::DRAW){ movelist.addtree(child->move, toplay); if(!board.move(child->move)){ - logerr("move failed: " + child->move.to_s(true) + "\n" + board.to_s(true)); + logerr("move failed: " + child->move.to_s() + "\n" + board.to_s(true)); assert(false && "move failed"); } @@ -49,36 +54,36 @@ void AgentMCTS::MCTSThread::walk_tree(Board & board, Node * node, int depth){ walk_tree(board, child, depth+1); child->exp.addv(movelist.getexp(toplay)); - player->do_backup(node, child, toplay); + agent->do_backup(node, child, toplay); return; } - }while(!player->do_backup(node, child, toplay)); + }while(!agent->do_backup(node, child, toplay)); return; } - if(player->profile && stage == 0){ + if(agent->profile && stage == 0){ stage = 1; timestamps[1] = Time(); } - int won = (player->minimax ? node->outcome : board.won()); + Outcome won = (agent->minimax ? node->outcome : board.won()); //if it's not already decided - if(won < 0){ + if(won < Outcome::DRAW){ //create children if valid - if(node->exp.num() >= player->visitexpand+1 && create_children(board, node, toplay)){ + if(node->exp.num() >= agent->visitexpand+1 && create_children(board, node)){ walk_tree(board, node, depth); return; } - if(player->profile){ + if(agent->profile){ stage = 2; timestamps[2] = Time(); } //do random game on this node - for(int i = 0; i < player->rollouts; i++){ + for(int i = 0; i < agent->rollouts; i++){ Board copy = board; rollout(copy, node->move, depth); } @@ -90,7 +95,7 @@ void AgentMCTS::MCTSThread::walk_tree(Board & board, Node * node, int depth){ movelist.subvlosses(1); - if(player->profile){ + if(agent->profile){ timestamps[3] = Time(); if(stage == 1) timestamps[2] = timestamps[3]; @@ -104,40 +109,40 @@ bool sort_node_know(const AgentMCTS::Node & a, const AgentMCTS::Node & b){ return (a.know > b.know); } -bool AgentMCTS::MCTSThread::create_children(const Board & board, Node * node, int toplay){ +bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ if(!node->children.lock()) return false; CompactTree::Children temp; - temp.alloc(board.moves_avail(), player->ctmem); + temp.alloc(board.moves_avail(), agent->ctmem); Node * child = temp.begin(), * end = temp.end(); - MoveIterator move(board, player->prunesymmetry); + MoveIterator move(board, agent->prunesymmetry); int nummoves = 0; for(; !move.done() && child != end; ++move, ++child){ *child = Node(*move); const Board & after = move.board(); - if(player->minimax){ + if(agent->minimax){ child->outcome = after.won(); - if(child->outcome == toplay){ //proven win from here, don't need children + if(child->outcome == board.toplay()){ //proven win from here, don't need children node->outcome = child->outcome; node->proofdepth = 1; node->bestmove = *move; node->children.unlock(); - temp.dealloc(player->ctmem); + temp.dealloc(agent->ctmem); return true; } } - if(player->knowledge) + if(agent->knowledge) add_knowledge(after, node, child); nummoves++; } - if(player->prunesymmetry) + if(agent->prunesymmetry) temp.shrink(nummoves); //shrink the node to ignore the extra moves else{ //both end conditions should happen in parallel assert(move.done() && child == end); @@ -146,33 +151,33 @@ bool AgentMCTS::MCTSThread::create_children(const Board & board, Node * node, in //sort in decreasing order by knowledge // sort(temp.begin(), temp.end(), sort_node_know); - PLUS(player->nodes, temp.num()); + PLUS(agent->nodes, temp.num()); node->children.swap(temp); assert(temp.unlock()); return true; } -AgentMCTS::Node * AgentMCTS::MCTSThread::choose_move(const Node * node, int toplay) const { +AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, Side toplay) const { float val, maxval = -1000000000; float logvisits = log(node->exp.num()); - float explore = player->explore; - if(player->parentexplore) + float explore = agent->explore; + if(agent->parentexplore) explore *= node->exp.avg(); - Node * ret = NULL, + Node * ret = NULL, * child = node->children.begin(), * end = node->children.end(); for(; child != end; child++){ - if(child->outcome >= 0){ + if(child->outcome >= Outcome::DRAW){ if(child->outcome == toplay) //return a win immediately return child; - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse + val = (child->outcome == Outcome::DRAW ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse }else{ - val = child->value(player->knowledge, player->fpurgency); + val = child->value(agent->knowledge, agent->fpurgency); if(explore > 0) val += explore*sqrt(logvisits/(child->exp.num() + 1)); } @@ -198,132 +203,96 @@ backup in this order: 0 lose return true if fully solved, false if it's unknown or partially unknown */ -bool AgentMCTS::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread +bool AgentMCTS::do_backup(Node * node, Node * backup, Side toplay){ + Outcome node_outcome = node->outcome; + if(node_outcome >= Outcome::DRAW) //already proven, probably by a different thread return true; - if(backup->outcome == -3) //nothing proven by this child, so no chance + if(backup->outcome == Outcome::UNKNOWN) //nothing proven by this child, so no chance return false; uint8_t proofdepth = backup->proofdepth; if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; + uint64_t sims = 0, bestsims = 0, outcome = 0, best_outcome = 0; backup = NULL; Node * child = node->children.begin(), * end = node->children.end(); for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions + Outcome child_outcome = child->outcome; //save a copy to avoid race conditions if(proofdepth < child->proofdepth+1) proofdepth = child->proofdepth+1; //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss + if(child_outcome == Outcome::UNKNOWN){ // win/draw/loss outcome = 3; - }else if(childoutcome == toplay){ //win + }else if(child_outcome == toplay){ //win backup = child; outcome = 6; proofdepth = child->proofdepth+1; break; - }else if(childoutcome == 3-toplay){ //loss + }else if(child_outcome == ~toplay){ //loss outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss + }else if(child_outcome == Outcome::DRAW){ //draw + if(node_outcome == -toplay) //draw/loss, ie I can't win outcome = 4; else outcome = 2; - }else if(childoutcome == -toplay){ //win/draw + }else if(child_outcome == -~toplay){ //win/draw, ie opponent can't win outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss + }else if(child_outcome == -toplay){ //draw/loss, ie I can't win outcome = 1; }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); + logerr("child_outcome == " + child_outcome.to_s() + "\n"); assert(false && "How'd I get here? All outcomes should be tested above"); } sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; + if(best_outcome < outcome){ //better outcome is always preferable + best_outcome = outcome; bestsims = sims; backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ + }else if(best_outcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ //find long losses or easy wins/draws bestsims = sims; backup = child; } } - if(bestoutcome == 3) //no win, but found an unknown + if(best_outcome == 3) //no win, but found an unknown return false; } - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ + if(node->outcome.cas(node_outcome, backup->outcome)){ node->bestmove = backup->move; node->proofdepth = proofdepth; }else //if it was in a race, try again, might promote a partial solve to full solve return do_backup(node, backup, toplay); - return (node->outcome >= 0); + return (node->outcome >= Outcome::DRAW); } -void AgentMCTS::MCTSThread::add_knowledge(const Board & board, Node * node, Node * child){ - if(player->win_score > 0) - child->know = player->win_score * board.score_calc(); +void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Node * child){ + if(agent->win_score > 0) + child->know = agent->win_score * board.score_calc(); } /////////////////////////////////////////// //play a random game starting from a board state, and return the results of who won -int AgentMCTS::MCTSThread::rollout(Board & board, Move move, int depth){ - int won; - while((won = board.won()) < 0) { +Outcome AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ + Outcome won; + while((won = board.won()) < Outcome::DRAW) { board.move_rand(rand64); } gamelen.add(board.num_moves()); movelist.finishrollout(won); return won; } -/* -//play a random game starting from a board state, and return the results of who won -int AgentMCTS::MCTSThread::rollout(Board & board, Move move, int depth){ - int won; - - Move forced = M_UNKNOWN; - while((won = board.won()) < 0){ - int turn = board.toplay(); - - if(forced == M_UNKNOWN){ - //do a complex choice - PairMove pair = rollout_choose_move(board, move); - move = pair.a; - forced = pair.b; - }else{ - move = forced; - forced = M_UNKNOWN; - } - - if(move == M_UNKNOWN) - move = board.move_rand(rand64); - else - board.move(move); - movelist.addrollout(move, turn); - board.won_calc(); - depth++; - } - - gamelen.add(depth); - - movelist.finishrollout(won); - return won; -} - -PairMove AgentMCTS::MCTSThread::rollout_choose_move(Board & board, Move move, int depth){ - //look for possible win -} -*/ +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentpns.cpp b/pentago/agentpns.cpp index 348dccb..9d83b25 100644 --- a/pentago/agentpns.cpp +++ b/pentago/agentpns.cpp @@ -6,149 +6,97 @@ #include "agentpns.h" #include "moveiterator.h" -void AgentPNS::search(double time, uint64_t maxiters, int verbose){ - if(rootboard.won() >= 0) - return; - - start_threads(); - timeout = false; - Alarm timer(time, std::bind(&AgentPNS::timedout, this)); +namespace Morat { +namespace Pentago { - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); +std::string AgentPNS::Node::to_s() const { + return "AgentPNS::Node" + ", move " + move.to_s() + + ", phi " + to_str(phi) + + ", delta " + to_str(delta) + + ", work " + to_str(work) + + ", children " + to_str(children.num()); } -void AgentPNS::PNSThread::run(){ - while(true){ - switch(agent->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - agent->runbarrier.wait(); - CAS(agent->threadstate, Thread_Wait_Start, Thread_Running); - CAS(agent->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - agent->runbarrier.wait(); - CAS(agent->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; +bool AgentPNS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); - case Thread_Running: //threads are running - if(agent->root.terminal()){ //solved - CAS(agent->threadstate, Thread_Running, Thread_Wait_End); - break; - } - if(agent->ctmem.memalloced() >= agent->memlimit){ //out of memory, start garbage collection - CAS(agent->threadstate, Thread_Running, Thread_GC); - break; - } - - pns(agent->rootboard, &agent->root, 0, INF32/2, INF32/2); - break; + if(dict.size() == 6){ + move = Move(dict["move"]); + phi = from_str(dict["phi"]); + delta = from_str(dict["delta"]); + work = from_str(dict["work"]); + // ignore children + return true; + } + return false; +} - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(agent->gcbarrier.wait()){ - logerr("Starting solver GC with limit " + to_str(agent->gclimit) + " ... "); +void AgentPNS::test() { + Node n(Move("a1")); + auto s = n.to_s(); + Node k; + assert(k.from_s(s)); +} - Time starttime; - agent->garbage_collect(& agent->root); +void AgentPNS::search(double time, uint64_t maxiters, int verbose){ + max_nodes_seen = maxiters; - Time gctime; - agent->ctmem.compact(1.0, 0.75); + if(rootboard.won() >= 0) + return; - Time compacttime; - logerr(to_str(100.0*agent->ctmem.meminuse()/agent->memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); + Time starttime; - if(agent->ctmem.meminuse() >= agent->memlimit/2) - agent->gclimit = (unsigned int)(agent->gclimit*1.3); - else if(agent->gclimit > 5) - agent->gclimit = (unsigned int)(agent->gclimit*0.9); //slowly decay to a minimum of 5 + pool.reset(); + pool.resume(); - CAS(agent->threadstate, Thread_GC, Thread_Running); - CAS(agent->threadstate, Thread_GC_End, Thread_Wait_End); - } - agent->gcbarrier.wait(); - break; - } - } -} + pool.wait_pause(time); -void AgentPNS::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); - timeout = true; -} -string AgentPNS::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} + double time_used = Time() - starttime; -void AgentPNS::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} -void AgentPNS::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} + if(verbose){ + DepthStats treelen; + for(auto & t : pool) + treelen += t->treelen; -void AgentPNS::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); + logerr("Finished: " + to_str(nodes_seen) + " nodes created in " + to_str(time_used*1000, 0) + " msec: " + to_str(nodes_seen/time_used, 0) + " Nodes/s\n"); + if(nodes_seen > 0){ + logerr("Tree depth: " + treelen.to_s() + "\n"); + } - threads.clear(); + Side toplay = rootboard.toplay(); - threadstate = Thread_Wait_Start; + logerr("Root: " + root.to_s() + "\n"); + Outcome outcome = root.to_outcome(~toplay); + if(outcome != Outcome::UNKNOWN) + logerr("Solved as a " + outcome.to_s_rel(toplay) + "\n"); - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); + std::string pvstr; + for(auto m : get_pv()) + pvstr += " " + m.to_s(); + logerr("PV: " + pvstr + "\n"); -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new PNSThread(this)); + if(verbose >= 3 && !root.children.empty()) + logerr("Move stats:\n" + move_stats(vecmove())); + } } +void AgentPNS::AgentThread::iterate(){ + pns(agent->rootboard, &agent->root, 0, INF32/2, INF32/2); +} -bool AgentPNS::PNSThread::pns(const Board & board, Node * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(agent->maxdepth < depth) - agent->maxdepth = depth; - +bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uint32_t tp, uint32_t td){ + // no children, create them if(node->children.empty()){ + treelen.add(depth); + if(node->terminal()) return true; - if(agent->ctmem.memalloced() >= agent->memlimit) + if(agent->need_gc()) return false; if(!node->children.lock()) @@ -159,60 +107,70 @@ bool AgentPNS::PNSThread::pns(const Board & board, Node * node, int depth, uint3 temp.alloc(numnodes, agent->ctmem); unsigned int i = 0; - unsigned int seen = 0; for(MoveIterator move(board); !move.done(); ++move){ - int outcome = solve1ply(move.board(), seen); - unsigned int pd = 1; + unsigned int pd; + Outcome outcome; + + if(agent->ab){ + Board next = board; + next.move(*move); + + pd = 0; + outcome = solve1ply(move.board(), pd); + }else{ + pd = 1; + outcome = move.board().won(); + } + temp[i] = Node(*move).outcome(outcome, board.toplay(), agent->ties, pd); i++; } + nodes_seen += i; + PLUS(agent->nodes_seen, i); PLUS(agent->nodes, i); temp.shrink(i); //if symmetry, there may be extra moves to ignore node->children.swap(temp); assert(temp.unlock()); - PLUS(agent->nodes_seen, seen); - updatePDnum(node); - return true; + return (agent->nodes_seen >= agent->max_nodes_seen); } bool mem; do{ - Node * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); + Node * child = node->children.begin(), // the best move to explore + * child2 = node->children.begin();// second best for thresholds - uint32_t tpc, tdc; + uint32_t tpc, tdc; // the thresholds if(agent->df){ - for(Node * i = node->children.begin(); i != childend; i++){ - if(i->refdelta() <= child->refdelta()){ + for(auto & i : node->children){ + if(i.refdelta() <= child->refdelta()){ child2 = child; - child = i; - }else if(i->refdelta() < child2->refdelta()){ - child2 = i; + child = & i; + }else if(i.refdelta() < child2->refdelta()){ + child2 = & i; } } - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); + tpc = std::min(INF32/2, (td + child->phi - node->delta)); + tdc = std::min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); }else{ tpc = tdc = 0; - for(Node * i = node->children.begin(); i != childend; i++) - if(child->refdelta() > i->refdelta()) - child = i; + for(auto & i : node->children) + if(child->refdelta() > i.refdelta()) + child = & i; } Board next = board; next.move(child->move); child->ref(); - uint64_t itersbefore = iters; + uint64_t seen_before = nodes_seen; mem = pns(next, child, depth + 1, tpc, tdc); child->deref(); - PLUS(child->work, iters - itersbefore); + PLUS(child->work, nodes_seen - seen_before); if(updatePDnum(node) && !agent->df) break; @@ -222,7 +180,7 @@ bool AgentPNS::PNSThread::pns(const Board & board, Node * node, int depth, uint3 return mem; } -bool AgentPNS::PNSThread::updatePDnum(Node * node){ +bool AgentPNS::AgentThread::updatePDnum(Node * node){ Node * i = node->children.begin(); Node * end = node->children.end(); @@ -262,16 +220,16 @@ double AgentPNS::gamelen() const { return rootboard.moves_remain(); } -vector AgentPNS::get_pv() const { - vector pv; +std::vector AgentPNS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -280,24 +238,29 @@ vector AgentPNS::get_pv() const { return pv; } -string AgentPNS::move_stats(vector moves) const { - string s = ""; +std::string AgentPNS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; - for(vector::iterator m = moves.begin(); node && m != moves.end(); ++m) - node = find_child(node, *m); + if(moves.size()){ + s += "path:\n"; + for(auto m : moves){ + if(node){ + node = find_child(node, m); + s += node->to_s() + "\n"; + } + } + } if(node){ - Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++) - if(child->move != M_NONE) - s += child->to_s() + "\n"; + s += "children:\n"; + for(auto & n : node->children) + s += n.to_s() + "\n"; } return s; } -Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { +Move AgentPNS::return_move(const Node * node, Side toplay, int verbose) const { double val, maxval = -1000000000000.0; //1 trillion Node * ret = NULL, @@ -305,11 +268,11 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - int outcome = child->to_outcome(toplay); - if(outcome >= 0){ - if(outcome == toplay) val = 800000000000.0 - (double)child->work; //shortest win - else if(outcome == 0) val = -400000000000.0 + (double)child->work; //longest tie - else val = -800000000000.0 + (double)child->work; //longest loss + Outcome outcome = child->to_outcome(toplay); + if(outcome >= Outcome::DRAW){ + if( outcome == +toplay) val = 800000000000.0 - (double)child->work; //shortest win + else if(outcome == Outcome::DRAW) val = -400000000000.0 + (double)child->work; //longest tie + else val = -800000000000.0 + (double)child->work; //longest loss }else{ //not proven val = child->work; } @@ -329,9 +292,9 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { } AgentPNS::Node * AgentPNS::find_child(const Node * node, const Move & move) const { - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; + for(auto & n : node->children) + if(n.move == move) + return &n; return NULL; } @@ -349,3 +312,51 @@ void AgentPNS::garbage_collect(Node * node){ } } } + +void AgentPNS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + node->children.alloc(board.moves_avail(), ctmem); + unsigned int i = 0; + for(MoveIterator move(board); !move.done(); ++move){ + Outcome outcome = move.board().won(); + node->children[i] = Node(*move).outcome(outcome, board.toplay(), ties, 1); + i++; + } + PLUS(nodes, i); + node->children.shrink(i); //if symmetry, there may be extra moves to ignore +} + +void AgentPNS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.work >= limit && (side != node.to_outcome(~side) || child.to_outcome(side) == node.to_outcome(~side))){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentPNS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, &node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentpns.h b/pentago/agentpns.h index dfd9ac1..f8e13d7 100644 --- a/pentago/agentpns.h +++ b/pentago/agentpns.h @@ -3,12 +3,20 @@ //A multi-threaded, tree based, proof number search solver. +#include + +#include "../lib/agentpool.h" #include "../lib/compacttree.h" +#include "../lib/depthstats.h" #include "../lib/log.h" +#include "../lib/string.h" #include "agent.h" +namespace Morat { +namespace Pentago { + class AgentPNS : public Agent { static const uint32_t LOSS = (1<<30)-1; static const uint32_t DRAW = (1<<30)-2; @@ -48,33 +56,33 @@ class AgentPNS : public Agent { assert(children.empty()); } - Node & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); + Node & abval(int ab_outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && (ab_outcome == 1 || ab_outcome == -1)) + ab_outcome = (toplay == assign ? 2 : -2); - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } + if( ab_outcome == 0) { phi = value; delta = value; } + else if(ab_outcome == 2) { phi = LOSS; delta = 0; } + else if(ab_outcome == -2) { phi = 0; delta = LOSS; } + else /*(ab_outcome 1||-1)*/ { phi = 0; delta = DRAW; } return *this; } - Node & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; + Node & outcome(Outcome outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && outcome == Outcome::DRAW) + outcome = +assign; - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } + if( outcome == Outcome::UNKNOWN) { phi = value; delta = value; } + else if(outcome == +toplay) { phi = LOSS; delta = 0; } + else if(outcome == +~toplay) { phi = 0; delta = LOSS; } + else /*(outcome == Outcome::DRAW)*/ { phi = 0; delta = DRAW; } return *this; } - int to_outcome(int toplay) const { - if(phi == LOSS) return toplay; - if(delta == LOSS) return 3 - toplay; - if(delta == DRAW) return 0; - return -3; + Outcome to_outcome(Side toplay) const { + if(phi == LOSS) return +toplay; + if(delta == LOSS) return +~toplay; + if(delta == DRAW) return Outcome::DRAW; + return Outcome::UNKNOWN; } bool terminal(){ return (phi == 0 || delta == 0); } @@ -95,15 +103,8 @@ class AgentPNS : public Agent { return num; } - string to_s() const { - return "Node: move " + move.to_s() + - ", phi " + to_str(phi) + - ", delta " + to_str(delta) + - ", work " + to_str(work) + -// ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + -// ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); void swap_tree(Node & n){ children.swap(n.children); @@ -123,26 +124,23 @@ class AgentPNS : public Agent { } }; - class PNSThread { - Thread thread; - AgentPNS * agent; + class AgentThread : public AgentThreadBase { public: - uint64_t iters; + DepthStats treelen; + uint64_t nodes_seen; + + AgentThread(AgentThreadPool * p, AgentPNS * a) : AgentThreadBase(p, a) { } - PNSThread(AgentPNS * a) : agent(a), iters(0) { - thread(bind(&PNSThread::run, this)); - } - virtual ~PNSThread() { } void reset(){ - iters = 0; + nodes_seen = 0; } - int join(){ return thread.join(); } - void run(); //thread runner - //basic proof number search building a tree + void iterate(); //handles each iteration + + //basic proof number search building a tree bool pns(const Board & board, Node * node, int depth, uint32_t tp, uint32_t td); - //update the phi and delta for the node + //update the phi and delta for the node bool updatePDnum(Node * node); }; @@ -152,79 +150,57 @@ class AgentPNS : public Agent { unsigned int gclimit; CompactTree ctmem; - int maxdepth; - uint64_t nodes_seen; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; + AgentThreadPool pool; + + uint64_t nodes_seen, max_nodes_seen; + + int ab; // how deep of an alpha-beta search to run at each leaf node bool df; // go depth first? float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 + Side ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 int numthreads; Node root; - AgentPNS() { + AgentPNS() : pool(this) { + ab = 1; df = true; epsilon = 0.25; - ties = 0; + ties = Side::NONE; numthreads = 1; + pool.set_num_threads(numthreads); gclimit = 5; nodes = 0; reset(); set_memlimit(1000*1024*1024); - - //no threads started until a board is set - threadstate = Thread_Wait_Start; } ~AgentPNS(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly + pool.pause(); + pool.set_num_threads(0); root.dealloc(ctmem); ctmem.compact(); } void reset(){ - maxdepth = 0; nodes_seen = 0; timeout = false; } - string statestring(); - void stop_threads(); - void start_threads(); - void reset_threads(); - void timedout(); - void set_board(const Board & board, bool clear = true){ rootboard = board; reset(); if(clear) clear_mem(); - - reset_threads(); //needed since the threads aren't started before a board it set } void move(const Move & m){ - stop_threads(); + pool.pause(); rootboard.move(m); reset(); @@ -247,7 +223,7 @@ class AgentPNS : public Agent { root.swap_tree(child); if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr(std::string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); assert(nodes == root.size()); @@ -267,15 +243,68 @@ class AgentPNS : public Agent { nodes = 0; } + bool done() { + //solved or finished runs + return root.terminal(); + } + + bool need_gc() { + //out of memory, start garbage collection + return (ctmem.memalloced() >= memlimit); + } + + void start_gc() { + Time starttime; + logerr("Starting GC with limit " + to_str(gclimit) + " ... "); + + garbage_collect(& root); + + Time gctime; + ctmem.compact(1.0, 0.75); + + Time compacttime; + logerr(to_str(100.0*ctmem.meminuse()/memlimit, 1) + " % of tree remains - " + + to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); + + if(ctmem.meminuse() >= memlimit/2) + gclimit = (unsigned int)(gclimit*1.3); + else if(gclimit > 5) + gclimit = (unsigned int)(gclimit*0.9); //slowly decay to a minimum of 5 + } + void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0){ + limit = 0; + //TODO: Set the root.work properly + for(auto & child : root.children) + limit += child.work; + limit /= 1000; + } + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } + + static void test(); private: //remove all the nodes with little work to free up some memory void garbage_collect(Node * node); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/agentpns_test.cpp b/pentago/agentpns_test.cpp new file mode 100644 index 0000000..89f869e --- /dev/null +++ b/pentago/agentpns_test.cpp @@ -0,0 +1,15 @@ + +#include "../lib/catch.hpp" + +#include "agentpns.h" + +using namespace Morat; +using namespace Pentago; + +TEST_CASE("Pentago::AgentPNS::Node::to_s/from_s", "[pentago][agentpns]") { + AgentPNS::Node n(Move("a1")); + auto s = n.to_s(); + AgentPNS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/pentago/board.cpp b/pentago/board.cpp index cc3a4ac..4dccb65 100644 --- a/pentago/board.cpp +++ b/pentago/board.cpp @@ -3,6 +3,10 @@ #include "board.h" + +namespace Morat { +namespace Pentago { + const int Board::xytobit[36] = { 0, 1, 2, 15, 16, 9, 7, 8, 3, 14, 17, 10, @@ -26,7 +30,7 @@ const uint64_t Board::xybits[36] = { const int16_t Board::scoremap[6] = { 0, 1, 3, 9, 27, 127 }; -Board::Board(string str) { +Board::Board(std::string str) { sides[1] = 0; sides[2] = 0; nummoves = 0; @@ -51,15 +55,16 @@ Board::Board(string str) { to_play = (nummoves % 2) + 1; } -string Board::state() const { - string s; +std::string Board::state() const { + std::string s; for(int y = 0; y < 6; y++) for(int x = 0; x < 6; x++) s += to_str((int)get(x, y)); return s; } -string Board::to_s(bool color) const { +std::string Board::to_s(bool color) const { + using std::string; string white = "O", black = "@", empty = ".", @@ -99,10 +104,10 @@ string Board::to_s(bool color) const { for(int y = 0; y < 6; y++){ s += left[y] + " " + string(1, 'a' + y) + " "; for(int x = 0; x < 6; x++){ - int p = get(x, y); - if(p == 0) s += empty; - if(p == 1) s += white; - if(p == 2) s += black; + Side p = get(x, y); + if(p == Side::NONE) s += empty; + if(p == Side::P1) s += white; + if(p == Side::P2) s += black; s += " "; } s += coord + right[y] + "\n"; @@ -114,19 +119,6 @@ string Board::to_s(bool color) const { return s; } -string Board::won_str() const { - switch(won()){ - case -3: return "none"; - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: - case 3: return "draw"; - case 1: return "white"; - case 2: return "black"; - } - return "unknown"; -} - #define winpattern(a,b,c,d,e) ((1ULL<<(a)) | (1ULL<<(b)) | (1ULL<<(c)) | (1ULL<<(d)) | (1ULL<<(e))) @@ -235,104 +227,7 @@ uint16_t * gen_lookup3to2(unsigned int inbits, unsigned int outbits){ const uint16_t * Board::lookup3to2 = gen_lookup3to2(9, 15); -void check(uint64_t h, uint8_t o, Board &b){ - if(h != b.hash()) - printf("expected hash: %lu, got: %lu\n", h, b.hash()); - if(o != b.orient()) - printf("expected orient: %i, got: %i\n", o, b.orient()); - if(h != b.hash() || o != b.orient()) - printf("%s", b.to_s().c_str()); - assert(h == b.hash()); - assert(o == b.orient()); -} -void check(uint64_t h, uint8_t o, string m){ - Board b; - b.move(m); - check(h, o, b); -} -void check(uint64_t h, uint8_t o, std::initializer_list moves){ - Board b; - for(string m : moves){ - b.move(m); - } - check(h, o, b); -} +void Board::test() { } -void Board::test() { -// printf("board tests\n"); - - //a single non-rotated piece leads to known board orientations - check(6, 0, "a2z"); - check(6, 1, "b6t"); - check(6, 2, "f5v"); - check(6, 3, "e1x"); - check(6, 4, "b1u"); - check(6, 5, "f2s"); - check(6, 6, "e6y"); - check(6, 7, "a5w"); - - check(2, 0, "a1z"); - check(2, 1, "a6t"); - check(2, 2, "f6v"); - check(2, 3, "f1x"); - - //a pair of non-rotated pieces lead to known board orientations - check(15, 0, {"a2z", "a3z"}); - check(15, 1, {"b6t", "c6t"}); - check(15, 2, {"f5v", "f4v"}); - check(15, 3, {"e1x", "d1x"}); - check(15, 4, {"b1u", "c1u"}); - check(15, 5, {"f2s", "f3s"}); - check(15, 6, {"e6y", "d6y"}); - check(15, 7, {"a5w", "a4w"}); - - //a single oriented piece leads to a known board orientation - check(6, 0, "a2z:0"); - check(6, 3, "a2z:1"); - check(6, 2, "a2z:2"); - check(6, 1, "a2z:3"); - check(6, 4, "a2z:4"); - check(6, 5, "a2z:5"); - check(6, 6, "a2z:6"); - check(6, 7, "a2z:7"); - - //a single oriented piece leads to a known board orientation - check(6, 1, "b6t:0"); - check(6, 0, "b6t:1"); - check(6, 3, "b6t:2"); - check(6, 2, "b6t:3"); - check(6, 5, "b6t:4"); - check(6, 6, "b6t:5"); - check(6, 7, "b6t:6"); - check(6, 4, "b6t:7"); - - //a single oriented piece leads to a known board orientation - check(6, 2, "f5v:0"); - check(6, 1, "f5v:1"); - check(6, 0, "f5v:2"); - check(6, 3, "f5v:3"); - check(6, 6, "f5v:4"); - check(6, 7, "f5v:5"); - check(6, 4, "f5v:6"); - check(6, 5, "f5v:7"); - - //a single oriented piece leads to a known board orientation - check(6, 4, "b1u:0"); - check(6, 5, "b1u:1"); - check(6, 6, "b1u:2"); - check(6, 7, "b1u:3"); - check(6, 0, "b1u:4"); - check(6, 3, "b1u:5"); - check(6, 2, "b1u:6"); - check(6, 1, "b1u:7"); - - //a single oriented piece leads to a known board orientation - check(6, 7, "a5w:0"); - check(6, 4, "a5w:1"); - check(6, 5, "a5w:2"); - check(6, 6, "a5w:3"); - check(6, 3, "a5w:4"); - check(6, 2, "a5w:5"); - check(6, 1, "a5w:6"); - check(6, 0, "a5w:7"); -} +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/board.h b/pentago/board.h index cc11d54..7e7d5b4 100644 --- a/pentago/board.h +++ b/pentago/board.h @@ -4,14 +4,17 @@ #include #include #include +#include #include -#include +#include "../lib/outcome.h" #include "../lib/xorshift.h" #include "move.h" -using namespace std; + +namespace Morat { +namespace Pentago { //#include //#define bitcount(x) std::bitset(x).count() @@ -40,14 +43,14 @@ class Board{ uint64_t sides[3]; // sides[0] = sides[1] | sides[2]; bitmap of position for each side uint8_t nummoves; // how many moves have been made so far - uint8_t to_play; // who's turn is it next, 1|2 - mutable int8_t outcome; //-3 = unknown, 0 = tie, 1,2 = player win - mutable uint8_t orientation; + Side to_play; // who's turn is it next, 1|2 + mutable Outcome outcome; //-3 = unknown, 0 = tie, 1,2 = player win mutable int16_t cached_score; mutable uint64_t cached_hash; static const int16_t default_score = 0xDEAD; public: + static constexpr const char * const name = "pentago"; static const int default_size = 6; static const int min_size = 6; static const int max_size = 6; @@ -60,20 +63,19 @@ class Board{ sides[1] = 0; sides[2] = 0; nummoves = 0; - to_play = 1; - outcome = -4; - orientation = 8; + to_play = Side::P1; + outcome = Outcome::UNDEF; cached_score = default_score; cached_hash = 0; } //take a position as 01012200 ... of length 36, left to right, top to bottom, all [012] - Board(string str); + Board(std::string str); static void test(); int num_moves() const { return nummoves; } - int moves_remain() const { return (won() >= 0 ? 0 : 36 - nummoves); } + int moves_remain() const { return (won() >= Outcome::DRAW ? 0 : 36 - nummoves); } int moves_avail() const { return moves_remain()*8; } //upper bound int get_size() const { @@ -90,41 +92,41 @@ class Board{ return 0; } - string to_s(bool color = true) const ; - string state() const ; + std::string to_s(bool color = true) const ; + std::string state() const ; + friend std::ostream& operator<< (std::ostream &out, const Board & b) { return out << b.to_s(true); } void print(bool color = true) const { printf("%s", to_s(color).c_str()); } - string won_str() const ; + std::string won_str() const ; - uint8_t toplay() const { + Side toplay() const { return to_play; } - int8_t won() const { - if(outcome == -4) + Outcome won() const { + if(outcome == Outcome::UNDEF) outcome = won_calc(); return outcome; } - int8_t won_calc() const { - int8_t wonside = 0; + Outcome won_calc() const { + Outcome wonside = Outcome::DRAW; uint64_t ws = sides[1]; uint64_t bs = sides[2]; for(int i = 0; i < 32; i++){ uint64_t wm = winmaps[i]; - if ((ws & wm) == wm) wonside |= 1; - else if((bs & wm) == wm) wonside |= 2; + if ((ws & wm) == wm) wonside |= Outcome::P1; + else if((bs & wm) == wm) wonside |= Outcome::P2; } - switch(wonside){ - case 1: - case 2: return wonside; - case 3: return 0; //wonside == 3 when both sides win simultaneously - default: return (nummoves >= 36 ? 0 : -3); - } + if(wonside == Outcome::P1 || wonside == Outcome::P2) + return wonside; + if(wonside == Outcome::DRAW2) // both sides win simultaneously + return Outcome::DRAW; + return (nummoves >= 36 ? Outcome::DRAW : Outcome::UNKNOWN); } int16_t score() const { @@ -148,38 +150,23 @@ class Board{ } //return the score from the perspective of the player that just played //ie not the player whose turn it is now - return (to_play == 1 ? -s : s); - } - - unsigned int orient() const { - if(!cached_hash) - hash(); - return orientation; + return (to_play == Side::P1 ? -s : s); } - uint64_t hash() const { - if(!cached_hash) - cached_hash = (nummoves < fullhash_depth ? full_hash() : simple_hash()); - return cached_hash; - } - - bool move(Move mo){ + bool move(Move m){ assert(outcome < 0); - orient(); - Move m = mo.rotate(orientation); - //TODO: only call valid_move if the move didn't come from an iterator? if(!valid_move(m)) return false; if(m == M_SWAP){ - swap(sides[1], sides[2]); + std::swap(sides[1], sides[2]); to_play = 1; return true; } - sides[to_play] |= xybits[m.l]; + sides[to_play.to_i()] |= xybits[m.l]; if (m.direction() == 0) { sides[1] = rotate_quad_ccw(sides[1], m.quadrant()); @@ -191,15 +178,10 @@ class Board{ sides[0] = sides[1] | sides[2]; nummoves++; - to_play = 3 - to_play; - outcome = -4; + to_play = ~to_play; + outcome = Outcome::UNDEF; cached_score = default_score; cached_hash = 0; - orientation = 8; //start with an unoriented board - - -// if(m != mo) -// logerr(mo.to_s(true) + " -> " + m.to_s(true) + " -> " + to_str(orient()) + "\n"); return true; } @@ -215,7 +197,7 @@ class Board{ } while(move & (move-1)); // } while(bitcount(move) > 1); // if there's only one bit left, that's our move - sides[to_play] |= move; + sides[to_play.to_i()] |= move; uint64_t rotation = (mask >> 36); //mask is already a random number, so just re-use the unused high bits uint64_t direction = rotation & 0x4; @@ -231,11 +213,10 @@ class Board{ sides[0] = sides[1] | sides[2]; nummoves++; - to_play = 3 - to_play; - outcome = -4; + to_play = ~to_play; + outcome = Outcome::UNDEF; cached_score = default_score; cached_hash = 0; - orientation = 8; return true; } @@ -245,12 +226,12 @@ class Board{ return false; if(m == M_SWAP){ - swap(sides[1], sides[2]); + std::swap(sides[1], sides[2]); to_play = 1; return true; } - to_play = 3 - to_play; + to_play = ~to_play; nummoves--; if (m.direction() == 0) { @@ -261,19 +242,17 @@ class Board{ sides[2] = rotate_quad_ccw(sides[2], m.quadrant()); } - sides[to_play] &= ~xybits[m.l]; + sides[to_play.to_i()] &= ~xybits[m.l]; sides[0] = sides[1] | sides[2]; - outcome = -4; + outcome = Outcome::UNDEF; cached_score = default_score; cached_hash = 0; - orientation = 8; return true; } -private: uint64_t simple_hash() const { //Take 9 bits at a time from each player, merge them, convert to base 2 @@ -285,37 +264,41 @@ class Board{ h |= ((uint64_t)(lookup3to2[((w & (0x1FFull << 18)) | (b & (0x1FFull << 9))) >> 9])) << 15; h |= ((uint64_t)(lookup3to2[((w & (0x1FFull << 27)) | (b & (0x1FFull << 18))) >> 18])) << 30; h |= ((uint64_t)(lookup3to2[((w & (0x1FFull << 36)) | (b & (0x1FFull << 27))) >> 27])) << 45; - orientation = 8; return h; } - static inline void choose(uint64_t & m, uint64_t h, uint8_t & o, uint8_t no){ - if(m > h){ - m = h; - o = no; - } - } - uint64_t full_hash() const { - //make sure this matches Move::rotate + if(nummoves >= fullhash_depth) + return simple_hash(); + + if(cached_hash) + return cached_hash; + Board b(*this); - uint64_t h, m = ~0; - uint8_t o = 0; - choose(m, (h = b.simple_hash()), o, 0); - choose(m, (h = rotate_hash(h) ), o, 1); - choose(m, (h = rotate_hash(h) ), o, 2); - choose(m, ( rotate_hash(h) ), o, 3); + uint64_t h, m = ~0ull; + choose(m, (h = b.simple_hash())); + choose(m, (h = rotate_hash(h) )); + choose(m, (h = rotate_hash(h) )); + choose(m, ( rotate_hash(h) )); b.flip_board(); - choose(m, (h = b.simple_hash()), o, 4); - choose(m, (h = rotate_hash(h) ), o, 5); - choose(m, (h = rotate_hash(h) ), o, 6); - choose(m, ( rotate_hash(h) ), o, 7); + choose(m, (h = b.simple_hash())); + choose(m, (h = rotate_hash(h) )); + choose(m, (h = rotate_hash(h) )); + choose(m, ( rotate_hash(h) )); - orientation = o; + cached_hash = m; return m; } +private: + + static inline void choose(uint64_t & m, uint64_t h){ + if(m > h){ + m = h; + } + } + static uint64_t rotate_hash(uint64_t h){ // rotate ccw return ((h & 0xFFFFFFFFFFF8000ull) >> 15) | ((h & 0x7FFFull) << 45); } @@ -367,3 +350,6 @@ class Board{ return (b & ~m) | (((b & m) >> 6) & m) | (((b & m) << 2) & m); } }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/gtp.h b/pentago/gtp.h index f0e692d..e8ad456 100644 --- a/pentago/gtp.h +++ b/pentago/gtp.h @@ -12,6 +12,10 @@ #include "history.h" #include "move.h" + +namespace Morat { +namespace Pentago { + class GTP : public GTPCommon { History hist; @@ -33,39 +37,42 @@ class GTP : public GTPCommon { set_board(); - newcallback("name", bind(>P::gtp_name, this, _1), "Name of the program"); - newcallback("version", bind(>P::gtp_version, this, _1), "Version of the program"); - newcallback("verbose", bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); - newcallback("colorboard", bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); - newcallback("showboard", bind(>P::gtp_print, this, _1), "Show the board"); - newcallback("print", bind(>P::gtp_print, this, _1), "Alias for showboard"); - newcallback("state", bind(>P::gtp_state, this, _1), "Output the board state in a simpler form than print, or set the board state"); - newcallback("hash", bind(>P::gtp_hash, this, _1), "Output the hash for the current position"); - newcallback("clear_board", bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); - newcallback("clear", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); - newcallback("boardsize", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board, board is fixed size"); - newcallback("size", bind(>P::gtp_clearboard, this, _1), "Alias for board_size"); - newcallback("play", bind(>P::gtp_play, this, _1), "Place a stone: play "); - newcallback("white", bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); - newcallback("black", bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); - newcallback("undo", bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); - newcallback("time", bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); - newcallback("genmove", bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); - newcallback("solve", bind(>P::gtp_solve, this, _1), "Try to solve this position"); - - newcallback("mcts", bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); - newcallback("pns", bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); - newcallback("ab", bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); - - newcallback("all_legal", bind(>P::gtp_all_legal, this, _1), "List all legal moves"); - newcallback("history", bind(>P::gtp_history, this, _1), "List of played moves"); - newcallback("playgame", bind(>P::gtp_playgame, this, _1), "Play a list of moves"); - newcallback("winner", bind(>P::gtp_winner, this, _1), "Check the winner of the game"); - - newcallback("pv", bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); - newcallback("move_stats", bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); - - newcallback("params", bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + newcallback("name", std::bind(>P::gtp_name, this, _1), "Name of the program"); + newcallback("version", std::bind(>P::gtp_version, this, _1), "Version of the program"); + newcallback("verbose", std::bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); + newcallback("colorboard", std::bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); + newcallback("showboard", std::bind(>P::gtp_print, this, _1), "Show the board"); + newcallback("print", std::bind(>P::gtp_print, this, _1), "Alias for showboard"); + newcallback("state", std::bind(>P::gtp_state, this, _1), "Output the board state in a simpler form than print, or set the board state"); + newcallback("hash", std::bind(>P::gtp_hash, this, _1), "Output the hash for the current position"); + newcallback("clear_board", std::bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); + newcallback("clear", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("boardsize", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board, board is fixed size"); + newcallback("size", std::bind(>P::gtp_clearboard, this, _1), "Alias for board_size"); + newcallback("play", std::bind(>P::gtp_play, this, _1), "Place a stone: play "); + newcallback("white", std::bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); + newcallback("black", std::bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); + newcallback("undo", std::bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); + newcallback("time", std::bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); + newcallback("genmove", std::bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); + newcallback("solve", std::bind(>P::gtp_solve, this, _1), "Try to solve this position"); + + newcallback("ab", std::bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); + newcallback("mcts", std::bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); + newcallback("pns", std::bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); + + newcallback("all_legal", std::bind(>P::gtp_all_legal, this, _1), "List all legal moves"); + newcallback("history", std::bind(>P::gtp_history, this, _1), "List of played moves"); + newcallback("playgame", std::bind(>P::gtp_playgame, this, _1), "Play a list of moves"); + newcallback("winner", std::bind(>P::gtp_winner, this, _1), "Check the winner of the game"); + + newcallback("pv", std::bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); + newcallback("move_stats", std::bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); + + newcallback("params", std::bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + + newcallback("save_sgf", std::bind(>P::gtp_save_sgf, this, _1), "Output an sgf of the current tree"); + newcallback("load_sgf", std::bind(>P::gtp_load_sgf, this, _1), "Load an sgf generated by save_sgf"); } void set_board(bool clear = true){ @@ -80,13 +87,12 @@ class GTP : public GTPCommon { GTPResponse gtp_state(vecstr args); GTPResponse gtp_print(vecstr args); GTPResponse gtp_hash(vecstr args); - string won_str(int outcome) const; GTPResponse gtp_boardsize(vecstr args); GTPResponse gtp_clearboard(vecstr args); GTPResponse gtp_undo(vecstr args); GTPResponse gtp_all_legal(vecstr args); GTPResponse gtp_history(vecstr args); - GTPResponse play(const string & pos, int toplay); + GTPResponse play(const std::string & pos, Side toplay); GTPResponse gtp_playgame(vecstr args); GTPResponse gtp_play(vecstr args); GTPResponse gtp_playwhite(vecstr args); @@ -96,16 +102,24 @@ class GTP : public GTPCommon { GTPResponse gtp_version(vecstr args); GTPResponse gtp_verbose(vecstr args); GTPResponse gtp_colorboard(vecstr args); - GTPResponse gtp_mcts(vecstr args); - GTPResponse gtp_pns(vecstr args); - GTPResponse gtp_ab(vecstr args); GTPResponse gtp_move_stats(vecstr args); GTPResponse gtp_pv(vecstr args); GTPResponse gtp_genmove(vecstr args); GTPResponse gtp_solve(vecstr args); + GTPResponse gtp_params(vecstr args); + + GTPResponse gtp_ab(vecstr args); GTPResponse gtp_ab_params(vecstr args); + GTPResponse gtp_mcts(vecstr args); GTPResponse gtp_mcts_params(vecstr args); + GTPResponse gtp_pns(vecstr args); GTPResponse gtp_pns_params(vecstr args); + + GTPResponse gtp_save_sgf(vecstr args); + GTPResponse gtp_load_sgf(vecstr args); }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/gtpagent.cpp b/pentago/gtpagent.cpp index 9340fde..4f58b22 100644 --- a/pentago/gtpagent.cpp +++ b/pentago/gtpagent.cpp @@ -1,15 +1,17 @@ -#include +#include "gtp.h" -#include "../lib/fileio.h" -#include "gtp.h" +namespace Morat { +namespace Pentago { using namespace std; - GTPResponse GTP::gtp_move_stats(vecstr args){ - return GTPResponse(true, agent->move_stats()); + vector moves; + for(auto s : args) + moves.push_back(Move(s)); + return GTPResponse(true, agent->move_stats(moves)); } GTPResponse GTP::gtp_solve(vecstr args){ @@ -142,11 +144,11 @@ GTPResponse GTP::gtp_mcts_params(vecstr args){ string arg = args[i]; if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ + mcts->pool.pause(); mcts->numthreads = from_str(args[++i]); - bool p = mcts->ponder; - mcts->set_ponder(false); //stop the threads while resetting them - mcts->reset_threads(); - mcts->set_ponder(p); + mcts->pool.set_num_threads(mcts->numthreads); + if(mcts->ponder) + mcts->pool.resume(); }else if((arg == "-o" || arg == "--ponder") && i+1 < args.size()){ mcts->set_ponder(from_str(args[++i])); }else if((arg == "--profile") && i+1 < args.size()){ @@ -190,7 +192,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" " -m --memory Memory limit in Mb [" + to_str(pns->memlimit/(1024*1024)) + "]\n" " -t --threads How many threads to run [" + to_str(pns->numthreads) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties) + "]\n" + " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties.to_i()) + "]\n" " -d --df Use depth-first thresholds [" + to_str(pns->df) + "]\n" " -e --epsilon How big should the threshold be [" + to_str(pns->epsilon) + "]\n" // " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(pns->ab) + "]\n" @@ -202,13 +204,13 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ pns->numthreads = from_str(args[++i]); - pns->reset_threads(); + pns->pool.set_num_threads(pns->numthreads); }else if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ uint64_t mem = from_str(args[++i]); if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); pns->set_memlimit(mem*1024*1024); }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - pns->ties = from_str(args[++i]); + pns->ties = Side(from_str(args[++i])); pns->clear_mem(); }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ pns->df = from_str(args[++i]); @@ -223,3 +225,6 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ return GTPResponse(true, errs); } + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/gtpgeneral.cpp b/pentago/gtpgeneral.cpp index 538e145..c3a8bca 100644 --- a/pentago/gtpgeneral.cpp +++ b/pentago/gtpgeneral.cpp @@ -1,7 +1,15 @@ +#include + +#include "../lib/sgf.h" + #include "gtp.h" #include "moveiterator.h" + +namespace Morat { +namespace Pentago { + GTPResponse GTP::gtp_mcts(vecstr args){ delete agent; agent = new AgentMCTS(); @@ -55,24 +63,24 @@ GTPResponse GTP::gtp_undo(vecstr args){ } GTPResponse GTP::gtp_all_legal(vecstr args){ - string ret; + std::string ret; for(MoveIterator move(*hist); !move.done(); ++move) ret += move->to_s() + " "; return GTPResponse(true, ret); } GTPResponse GTP::gtp_history(vecstr args){ - string ret; + std::string ret; for(auto m : hist) ret += m.to_s() + " "; return GTPResponse(true, ret); } -GTPResponse GTP::play(const string & pos, int toplay){ +GTPResponse GTP::play(const std::string & pos, Side toplay){ if(toplay != hist->toplay()) return GTPResponse(false, "It is the other player's turn!"); - if(hist->won() >= 0) + if(hist->won() >= Outcome::DRAW) return GTPResponse(false, "The game is already over."); Move m(pos); @@ -83,7 +91,7 @@ GTPResponse GTP::play(const string & pos, int toplay){ move(m); if(verbose >= 2) - logerr("Placement: " + m.to_s() + ", outcome: " + hist->won_str() + "\n" + hist->to_s(colorboard)); + logerr("Placement: " + m.to_s() + ", outcome: " + hist->won().to_s() + "\n" + hist->to_s(colorboard)); return GTPResponse(true); } @@ -101,41 +109,37 @@ GTPResponse GTP::gtp_play(vecstr args){ if(args.size() != 2) return GTPResponse(false, "Wrong number of arguments"); - char toplay = 0; switch(tolower(args[0][0])){ - case 'w': toplay = 1; break; - case 'b': toplay = 2; break; - default: - return GTPResponse(false, "Invalid player selection"); + case 'w': return play(args[1], Side::P1); + case 'b': return play(args[1], Side::P2); + default: return GTPResponse(false, "Invalid player selection"); } - - return play(args[1], toplay); } GTPResponse GTP::gtp_playwhite(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 1); + return play(args[0], Side::P1); } GTPResponse GTP::gtp_playblack(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 2); + return play(args[0], Side::P2); } GTPResponse GTP::gtp_winner(vecstr args){ - return GTPResponse(true, hist->won_str()); + return GTPResponse(true, hist->won().to_s()); } GTPResponse GTP::gtp_name(vecstr args){ - return GTPResponse(true, "Pentagod"); + return GTPResponse(true, std::string("morat-") + Board::name); } GTPResponse GTP::gtp_version(vecstr args){ - return GTPResponse(true, "1.5"); + return GTPResponse(true, "0.1"); } GTPResponse GTP::gtp_verbose(vecstr args){ @@ -155,5 +159,93 @@ GTPResponse GTP::gtp_colorboard(vecstr args){ } GTPResponse GTP::gtp_hash(vecstr args){ - return GTPResponse(true, to_str(hist->hash())); + return GTPResponse(true, to_str(hist->simple_hash())); } + +GTPResponse GTP::gtp_save_sgf(vecstr args){ + int limit = -1; + if(args.size() == 0) + return GTPResponse(true, "save_sgf [work limit]"); + + std::ifstream infile(args[0].c_str()); + + if(infile) { + infile.close(); + return GTPResponse(false, "File " + args[0] + " already exists"); + } + + std::ofstream outfile(args[0].c_str()); + + if(!outfile) + return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); + + if(args.size() > 1) + limit = from_str(args[1]); + + SGFPrinter sgf(outfile); + sgf.game(Board::name); + sgf.program(gtp_name(vecstr()).response, gtp_version(vecstr()).response); + sgf.size(hist->get_size()); + + sgf.end_root(); + + Side s = Side::P1; + for(auto m : hist){ + sgf.move(s, m); + s = ~s; + } + + agent->gen_sgf(sgf, limit); + + sgf.end(); + outfile.close(); + return true; +} + + +GTPResponse GTP::gtp_load_sgf(vecstr args){ + if(args.size() == 0) + return GTPResponse(true, "load_sgf "); + + std::ifstream infile(args[0].c_str()); + + if(!infile) { + return GTPResponse(false, "Error opening file " + args[0] + " for reading"); + } + + SGFParser sgf(infile); + if(sgf.game() != Board::name){ + infile.close(); + return GTPResponse(false, "File is for the wrong game: " + sgf.game()); + } + + int size = sgf.size(); + if(size != hist->get_size()){ + if(hist.len() == 0){ + hist = History(size); + set_board(); + time_control.new_game(); + }else{ + infile.close(); + return GTPResponse(false, "File has the wrong boardsize to match the existing game"); + } + } + + Side s = Side::P1; + + while(sgf.next_node()){ + Move m = sgf.move(); + move(m); // push the game forward + s = ~s; + } + + if(sgf.has_children()) + agent->load_sgf(sgf); + + assert(sgf.done_child()); + infile.close(); + return true; +} + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/history.h b/pentago/history.h index 00ccd06..5cd1001 100644 --- a/pentago/history.h +++ b/pentago/history.h @@ -8,6 +8,10 @@ #include "board.h" #include "move.h" + +namespace Morat { +namespace Pentago { + class History { std::vector hist; Board board; @@ -68,3 +72,6 @@ class History { return false; } }; + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/pentagod.cpp b/pentago/main.cpp similarity index 94% rename from pentago/pentagod.cpp rename to pentago/main.cpp index 95559db..427be60 100644 --- a/pentago/pentagod.cpp +++ b/pentago/main.cpp @@ -6,6 +6,10 @@ #include "gtp.h" + +using namespace Morat; +using namespace Pentago; + using namespace std; void die(int code, const string & str){ @@ -15,11 +19,10 @@ void die(int code, const string & str){ int main(int argc, char **argv){ Board::test(); - Move::test(); MoveIterator::test(); RandomMoveIteratorTest(); - srand(Time().in_msec()); + srand(Time().in_usec()); GTP gtp; gtp.colorboard = isatty(fileno(stdout)); diff --git a/pentago/move.cpp b/pentago/move.cpp deleted file mode 100644 index 17d2ccb..0000000 --- a/pentago/move.cpp +++ /dev/null @@ -1,52 +0,0 @@ - -#include "../lib/string.h" - -#include "board.h" -#include "move.h" - - -void check(const std::string a, int ao, const std::string b, int bo) { - Move am = Move(b, bo).rotate(ao); - Move bm = Move(a, ao).rotate(bo); - printf("%s %2i <=> %s %2i ", a.c_str(), ao, b.c_str(), bo); - printf(" : %s %2i <=> %s %2i\n", bm.to_s().c_str(), bm.o, am.to_s().c_str(), am.o); - assert(bm == Move(b, bo)); - assert(am == Move(a, ao)); -} - -void check(const std::string a, const std::string b) { - Move ma(a); - Move mb(b); - Move ar = ma.rotate(mb.orientation()); - Move br = mb.rotate(ma.orientation()); - if(ma != br || mb != ar){ - printf("%s => %s %s => %s\n", a.c_str(), ar.to_s(true).c_str(), b.c_str(), br.to_s(true).c_str()); - printf("%s\n", Board().to_s(true).c_str()); - } - assert(ma == br); - assert(mb == ar); -} - - -void check_many(std::string in, std::initializer_list moves){ - assert(moves.size() == 8); - int r = 0; - for(std::string out : moves){ - check(in, out + ":" + to_str(r)); - r++; - } -} - -void Move::test() { -// printf("move tests\n"); - - // 0 1 2 3 4 5 6 7 - check_many("a2z:0", {"a2z", "b6t", "f5v", "e1x", "b1u", "f2s", "e6y", "a5w"}); - check_many("a2z:1", {"e1x", "a2z", "b6t", "f5v", "a5w", "b1u", "f2s", "e6y"}); - check_many("a2z:2", {"f5v", "e1x", "a2z", "b6t", "e6y", "a5w", "b1u", "f2s"}); - check_many("a2z:3", {"b6t", "f5v", "e1x", "a2z", "f2s", "e6y", "a5w", "b1u"}); - check_many("a2z:4", {"b1u", "a5w", "e6y", "f2s", "a2z", "e1x", "f5v" ,"b6t"}); - check_many("a2z:5", {"f2s", "b1u", "a5w", "e6y", "b6t", "a2z", "e1x", "f5v"}); - check_many("a2z:6", {"e6y", "f2s", "b1u", "a5w", "f5v", "b6t", "a2z", "e1x"}); - check_many("a2z:7", {"a5w", "e6y", "f2s", "b1u", "e1x", "f5v", "b6t", "a2z"}); -} diff --git a/pentago/move.h b/pentago/move.h index 6fdf521..5c13eea 100644 --- a/pentago/move.h +++ b/pentago/move.h @@ -1,13 +1,16 @@ #pragma once -#include #include +#include #include -#include "../lib/log.h" #include "../lib/string.h" + +namespace Morat { +namespace Pentago { + enum MoveSpecial { M_SWAP = -1, //-1 so that adding 1 makes it into a valid move M_RESIGN = -2, @@ -16,10 +19,8 @@ enum MoveSpecial { }; struct Move { - int l : 8; //location = MoveSpecial | y*6+x - unsigned r : 4; //rotation = 0-7 - unsigned o : 4; //orientation = 0-7 | 8=unoriented -// int8_t l, r, o; //location, rotation, orientation + int8_t l; //location = MoveSpecial | y*6+x + int8_t r; //rotation = 0-7 /* location = 0 1 2 3 4 5 @@ -41,25 +42,19 @@ struct Move { quadrant = rotation >> 1 direction = rotation & 1 - -orientation = orientation of the board when this move was created - alternatively, this is the location once the board has been rotated to this orientation - This is NOT where the piece would be on a 0-orientation board (unless it is a 0-orientation move). - a->b != b->a - */ - Move(MoveSpecial a = M_UNKNOWN, unsigned int O = 8) : l(a), r(14), o(O) { } //big r so it will always wrap to l=0 with swap - Move(int L, unsigned int R, unsigned int O = 8) : l(L), r(R), o(O) { } - Move(unsigned int X, unsigned int Y, unsigned int R, unsigned int O) : l(Y*6 + X), r(R), o(O) { } + Move(MoveSpecial a = M_UNKNOWN) : l(a), r(14) { } //big r so it will always wrap to l=0 with swap + Move(int L, unsigned int R) : l(L), r(R) { } + Move(unsigned int X, unsigned int Y, unsigned int R) : l(Y*6 + X), r(R) { } - Move(const std::string & str, unsigned int O = 8){ - if( str == "swap" ){ l = M_SWAP; r = 14; o = O; } - else if(str == "resign" ){ l = M_RESIGN; r = 14; o = O; } - else if(str == "none" ){ l = M_NONE; r = 14; o = O; } - else if(str == "unknown" ){ l = M_UNKNOWN; r = 14; o = O; } - else if(str.length() <= 2){ l = M_NONE; r = 14; o = O; } + Move(const std::string & str){ + if( str == "swap" ){ l = M_SWAP; r = 14; } + else if(str == "resign" ){ l = M_RESIGN; r = 14; } + else if(str == "none" ){ l = M_NONE; r = 14; } + else if(str == "unknown" ){ l = M_UNKNOWN; r = 14; } + else if(str.length() <= 2){ l = M_NONE; r = 14; } else{ unsigned int y = tolower(str[0]) - 'a'; //[abcdef] unsigned int x = str[1] - '1'; //[123456] @@ -70,13 +65,6 @@ orientation = orientation of the board when this move was created else if(c >= 's' && c <= 'z') r = c - 's'; //[stuvwxyz] else if(c >= 'a' && c <= 'h') r = c - 'a'; //[abcdefgh] else r = 0; //unknown, but do something - - if(str.length() == 5) { - assert(O == 8 && str[3] == ':'); - o = str[4] - '0'; - } else { - o = O; - } } } @@ -85,78 +73,27 @@ orientation = orientation of the board when this move was created unsigned int y() const { return l / 6; } unsigned int quadrant() const { return r >> 1; } unsigned int direction() const { return r & 1; } - unsigned int orientation() const { return o; } - - Move rotate(unsigned int other) const { - if(l < 0) //special - return *this; - - assert(o <= 8 && other <= 8); - - if(other == o) //already the correct orientation - return *this; - - if(o == 8 && other < 8) //unoriented move, default to board - return Move(l, r, other); - - if(o == 8 || other == 8){ // oriented moves are incompatible with unoriented boards - printf("rotate %s to %i\n", to_s(true).c_str(), other); - assert(o <= 7 && other <= 7); - } - -/* -flip move: Move(y, x, (9-r)&7). Works because (9-1)&7=8&7=0, (9-1)&7=8&7=0 -rotate cw: Move(5-y, x, (r+2)&7) -rotate 180: Move(5-x, 5-y, (r+4)&7) -rotate ccw: Move( y, 5-x, (r+6)&7) -*/ - - unsigned int c = o*8 + other; - switch(c){ - case 0: case 011: case 022: case 033: case 044: case 055: case 066: case 077: return *this; - //Move( x(), y(), r, other); - case 01: case 012: case 023: case 030: case 047: case 054: case 065: case 076: return Move(5-y(), x(), (r+2)&7, other); - case 02: case 013: case 020: case 031: case 046: case 057: case 064: case 075: return Move(5-x(), 5-y(), (r+4)&7, other); - case 03: case 010: case 021: case 032: case 045: case 056: case 067: case 074: return Move( y(), 5-x(), (r+6)&7, other); - - case 04: case 015: case 026: case 037: case 040: case 051: case 062: case 073: return Move( y(), x(), ( 9-r)&7, other); - case 07: case 014: case 025: case 036: case 041: case 052: case 063: case 070: return Move(5-x(), y(), (11-r)&7, other); - case 06: case 017: case 024: case 035: case 042: case 053: case 060: case 071: return Move(5-y(), 5-x(), (13-r)&7, other); - case 05: case 016: case 027: case 034: case 043: case 050: case 061: case 072: return Move( x(), 5-y(), (15-r)&7, other); - default: - printf("o: %i, other: %i, c: %#4x", o, other, c); - assert(false && "Bad orientation?!?"); - } - } - - static void test(); - std::string to_s(bool orient = false) const { + std::string to_s() const { if(l == M_UNKNOWN) return "unknown"; if(l == M_NONE) return "none"; if(l == M_SWAP) return "swap"; if(l == M_RESIGN) return "resign"; - std::string s = std::string() + char(y() + 'a') + to_str(x() + 1) + char(r + 's'); - if(orient) - s += ":" + to_str(o); - return s; + return std::string() + char(y() + 'a') + to_str(x() + 1) + char(r + 's'); } - //TODO: handle orientation? + friend std::ostream& operator<< (std::ostream &out, const Move & m) { return out << m.to_s(); } + bool operator< (const Move & b) const { return (l == b.l ? r < b.r : l < b.l); } bool operator<=(const Move & b) const { return (l == b.l ? r <= b.r : l <= b.l); } bool operator> (const Move & b) const { return (l == b.l ? r > b.r : l > b.l); } bool operator>=(const Move & b) const { return (l == b.l ? r >= b.r : l >= b.l); } bool operator==(const MoveSpecial & b) const { return (l == b); } - bool operator==(const Move & b) const { return (l == b.l && r == b.r && o == b.o); } - bool operator!=(const Move & b) const { return (l != b.l || r != b.r || o != b.o); } + bool operator==(const Move & b) const { return (l == b.l && r == b.r); } + bool operator!=(const Move & b) const { return (l != b.l || r != b.r); } bool operator!=(const MoveSpecial & b) const { return (l != b); } }; - -struct PairMove { - Move a, b; - PairMove(Move A = M_UNKNOWN, Move B = M_UNKNOWN) : a(A), b(B) { } - PairMove(MoveSpecial A) : a(Move(A)), b(M_UNKNOWN) { } -}; +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/moveiterator.cpp b/pentago/moveiterator.cpp index 690bb07..b2a1cc4 100644 --- a/pentago/moveiterator.cpp +++ b/pentago/moveiterator.cpp @@ -8,6 +8,9 @@ #include "moveiterator.h" +namespace Morat { +namespace Pentago { + void MoveIterator::test() { // printf("MoveIterator tests\n"); @@ -57,7 +60,7 @@ void RandomMoveIteratorTest() { i++; assert(i == 6); - set moves; + std::set moves; for(MoveIterator move(board, 0); !move.done(); ++move) moves.insert(move->to_s()); @@ -74,7 +77,7 @@ void RandomMoveIteratorTest() { assert(moves.size() == 0); - set boards; + std::set boards; for(MoveIterator move(board, 0); !move.done(); ++move) boards.insert(move.board().to_s()); @@ -90,3 +93,6 @@ void RandomMoveIteratorTest() { assert(boards.size() == 0); } + +}; // namespace Pentago +}; // namespace Morat diff --git a/pentago/moveiterator.h b/pentago/moveiterator.h index 0c0b93e..6dd37e5 100644 --- a/pentago/moveiterator.h +++ b/pentago/moveiterator.h @@ -6,6 +6,10 @@ #include "board.h" #include "move.h" + +namespace Morat { +namespace Pentago { + class MoveIterator { //only returns valid moves... const Board & base_board; //base board Board after; // board after making the move @@ -13,10 +17,10 @@ class MoveIterator { //only returns valid moves... bool unique; HashSet hashes; public: - MoveIterator(const Board & b, int Unique = -1) : base_board(b), move(M_SWAP, b.orient()) { + MoveIterator(const Board & b, int Unique = -1) : base_board(b), move(M_SWAP) { unique = (Unique == -1 ? base_board.num_moves() <= Board::unique_depth : Unique); - if(base_board.won() >= 0){ + if(base_board.won() >= Outcome::DRAW){ move = Move(36, 8); //already done } else { if(unique) @@ -44,7 +48,7 @@ class MoveIterator { //only returns valid moves... bool move_success = after.move(move); assert(move_success); if(unique){ - uint64_t h = after.hash(); + uint64_t h = after.full_hash(); if(!hashes.add(h)) continue; } @@ -107,3 +111,6 @@ class RandomMoveIterator { //only returns valid moves... }; void RandomMoveIteratorTest(); + +}; // namespace Pentago +}; // namespace Morat diff --git a/rex/agent.h b/rex/agent.h index 6adecd2..ce3ab64 100644 --- a/rex/agent.h +++ b/rex/agent.h @@ -3,11 +3,19 @@ //Interface for the various agents: players and solvers +#include "../lib/outcome.h" +#include "../lib/sgf.h" #include "../lib/types.h" #include "board.h" + +namespace Morat { +namespace Rex { + class Agent { +protected: + typedef std::vector vecmove; public: Agent() { } virtual ~Agent() { } @@ -19,51 +27,57 @@ class Agent { virtual void set_memlimit(uint64_t lim) = 0; // in bytes virtual void clear_mem() = 0; - virtual vector get_pv() const = 0; - string move_stats() const { return move_stats(vector()); } - virtual string move_stats(const vector moves) const = 0; + virtual vecmove get_pv() const = 0; + std::string move_stats() const { return move_stats(vecmove()); } + virtual std::string move_stats(const vecmove moves) const = 0; virtual double gamelen() const = 0; virtual void timedout(){ timeout = true; } + virtual void gen_sgf(SGFPrinter & sgf, int limit) const = 0; + virtual void load_sgf(SGFParser & sgf) = 0; + protected: volatile bool timeout; Board rootboard; - static int solve1ply(const Board & board, unsigned int & nodes) { - int outcome = -3; - int turn = board.toplay(); + static Outcome solve1ply(const Board & board, unsigned int & nodes) { + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; } return outcome; } - static int solve2ply(const Board & board, unsigned int & nodes) { + static Outcome solve2ply(const Board & board, unsigned int & nodes) { int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); + Side op = ~turn; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; - if(board.test_win(*move, opponent) > 0) + if(board.test_outcome(*move, op) == +op) losses++; } if(losses >= 2) - return opponent; + return (Outcome)op; return outcome; } - }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentab.cpp b/rex/agentab.cpp index 2c66bce..95e22cb 100644 --- a/rex/agentab.cpp +++ b/rex/agentab.cpp @@ -6,6 +6,10 @@ #include "agentab.h" + +namespace Morat { +namespace Rex { + void AgentAB::search(double time, uint64_t maxiters, int verbose) { reset(); if(rootboard.won() >= 0) @@ -41,8 +45,8 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { if(verbose){ logerr("Finished: " + to_str(nodes_seen) + " nodes in " + to_str(time_used*1000, 0) + " msec: " + to_str((uint64_t)((double)nodes_seen/time_used)) + " Nodes/s\n"); - vector pv = get_pv(); - string pvstr; + vecmove pv = get_pv(); + std::string pvstr; for(auto m : pv) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); @@ -56,11 +60,11 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int depth) { nodes_seen++; - int won = board.won(); - if(won >= 0){ - if(won == 0) + Outcome won = board.won(); + if(won >= Outcome::DRAW){ + if(won == Outcome::DRAW) return SCORE_DRAW; - if(won == board.toplay()) + if(won == +board.toplay()) return SCORE_WIN; return SCORE_LOSS; } @@ -81,8 +85,8 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if(TT && (node = tt_get(board)) && node->depth >= depth){ switch(node->flag){ case VALID: return node->score; - case LBOUND: alpha = max(alpha, node->score); break; - case UBOUND: beta = min(beta, node->score); break; + case LBOUND: alpha = std::max(alpha, node->score); break; + case UBOUND: beta = std::min(beta, node->score); break; default: assert(false && "Unknown flag!"); } if(alpha >= beta) @@ -125,11 +129,11 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d return score; } -string AgentAB::move_stats(vector moves) const { - string s = ""; +std::string AgentAB::move_stats(vecmove moves) const { + std::string s = ""; Board b = rootboard; - for(vector::iterator m = moves.begin(); m != moves.end(); ++m) + for(vecmove::iterator m = moves.begin(); m != moves.end(); ++m) b.move(*m); for(MoveIterator move(b); !move.done(); ++move){ @@ -162,8 +166,8 @@ Move AgentAB::return_move(const Board & board, int verbose) const { return best; } -vector AgentAB::get_pv() const { - vector pv; +std::vector AgentAB::get_pv() const { + vecmove pv; Board b = rootboard; int i = 20; @@ -197,3 +201,6 @@ AgentAB::Node * AgentAB::tt_get(uint64_t h) const { void AgentAB::tt_set(const Node & n) { *(tt(n.hash)) = n; } + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentab.h b/rex/agentab.h index 646043f..405da12 100644 --- a/rex/agentab.h +++ b/rex/agentab.h @@ -7,6 +7,10 @@ #include "agent.h" + +namespace Morat { +namespace Rex { + class AgentAB : public Agent { static const int16_t SCORE_WIN = 32767; static const int16_t SCORE_LOSS = -32767; @@ -30,7 +34,7 @@ class AgentAB : public Agent { Node(uint64_t h = ~0ull, int16_t s = 0, Move b = M_UNKNOWN, int8_t d = 0, int8_t f = 0) : //. int8_t o = -3 hash(h), score(s), bestmove(b), depth(d), flag(f), padding(0xDEAD) { } //, outcome(o) - string to_s() const { + std::string to_s() const { return "score " + to_str(score) + ", depth " + to_str((int)depth) + ", flag " + to_str((int)flag) + @@ -93,8 +97,16 @@ class AgentAB : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(rootboard, verbose); } double gamelen() const { return rootboard.movesremain(); } - vector get_pv() const; - string move_stats(vector moves) const; + vecmove get_pv() const; + std::string move_stats(vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + log("gen_sgf not supported in the ab agent."); + } + + void load_sgf(SGFParser & sgf) { + log("load_sgf not supported in the ab agent."); + } private: int16_t negamax(const Board & board, int16_t alpha, int16_t beta, int depth); @@ -105,3 +117,6 @@ class AgentAB : public Agent { Node * tt_get(const Board & b) const ; void tt_set(const Node & n) ; }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentmcts.cpp b/rex/agentmcts.cpp index 8f8ba60..cde7744 100644 --- a/rex/agentmcts.cpp +++ b/rex/agentmcts.cpp @@ -10,12 +10,45 @@ #include "agentmcts.h" #include "board.h" + +namespace Morat { +namespace Rex { + const float AgentMCTS::min_rave = 0.1; +std::string AgentMCTS::Node::to_s() const { + return "AgentMCTS::Node" + ", move " + move.to_s() + + ", exp " + exp.to_s() + + ", rave " + rave.to_s() + + ", know " + to_str(know) + + ", outcome " + to_str((int)outcome.to_i()) + + ", depth " + to_str((int)proofdepth) + + ", best " + bestmove.to_s() + + ", children " + to_str(children.num()); +} + +bool AgentMCTS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 9){ + move = Move(dict["move"]); + exp = ExpPair(dict["exp"]); + rave = ExpPair(dict["rave"]); + know = from_str(dict["know"]); + outcome = Outcome(from_str(dict["outcome"])); + proofdepth = from_str(dict["depth"]); + bestmove = Move(dict["best"]); + // ignore children + return true; + } + return false; +} + void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) + if(rootboard.won() >= Outcome::DRAW || (time <= 0 && max_runs == 0)) return; Time starttime; @@ -56,30 +89,23 @@ void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ logerr("Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"); } - if(root.outcome != -3){ - logerr("Solved as a "); - if( root.outcome == 0) logerr("draw\n"); - else if(root.outcome == 3) logerr("draw by simultaneous win\n"); - else if(root.outcome == toplay) logerr("win\n"); - else if(root.outcome == 3-toplay) logerr("loss\n"); - else if(root.outcome == -toplay) logerr("win or draw\n"); - else if(root.outcome == toplay-3) logerr("loss or draw\n"); - } + if(root.outcome != Outcome::UNKNOWN) + logerr("Solved as a " + root.outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } pool.reset(); runs = 0; - if(ponder && root.outcome < 0) + if(ponder && root.outcome < Outcome::DRAW) pool.resume(); } @@ -194,8 +220,8 @@ void AgentMCTS::move(const Move & m){ rootboard.move(m); root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; + if(rootboard.won() < Outcome::DRAW) + root.outcome = Outcome::UNKNOWN; if(ponder) pool.resume(); @@ -208,16 +234,16 @@ double AgentMCTS::gamelen() const { return len.avg(); } -vector AgentMCTS::get_pv() const { - vector pv; +std::vector AgentMCTS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -226,8 +252,8 @@ vector AgentMCTS::get_pv() const { return pv; } -string AgentMCTS::move_stats(vector moves) const { - string s = ""; +std::string AgentMCTS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -248,8 +274,8 @@ string AgentMCTS::move_stats(vector moves) const { return s; } -Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { - if(node->outcome >= 0) +Move AgentMCTS::return_move(const Node * node, Side toplay, int verbose) const { + if(node->outcome >= Outcome::DRAW) return node->bestmove; double val, maxval = -1000000000000.0; //1 trillion @@ -259,10 +285,10 @@ Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss + if(child->outcome >= Outcome::DRAW){ + if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win + else if(child->outcome == Outcome::DRAW) val = -400000000000.0 + child->exp.num(); //longest tie + else val = -800000000000.0 + child->exp.num(); //longest loss }else{ //not proven if(msrave == -1) //num simulations val = child->exp.num(); @@ -290,13 +316,13 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ Node * child = node->children.begin(), * end = node->children.end(); - int toplay = board.toplay(); + Side toplay = board.toplay(); for( ; child != end; child++){ if(child->children.num() == 0) continue; - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved + if( (node->outcome >= Outcome::DRAW && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == Outcome::DRAW)) || //parent is solved, only keep the proof tree, plus heavy draws + (node->outcome < Outcome::DRAW && child->exp.num() > (child->outcome >= Outcome::DRAW ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved board.set(child->move); garbage_collect(board, child); board.unset(child->move); @@ -307,36 +333,22 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ } AgentMCTS::Node * AgentMCTS::find_child(const Node * node, const Move & move) const { - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - + for(auto & c : node->children) + if(c.move == move) + return &c; return NULL; } -void AgentMCTS::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; +void AgentMCTS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.exp.num() >= limit && (side != node.outcome || child.outcome == node.outcome)){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); } } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); } void AgentMCTS::create_children_simple(const Board & board, Node * node){ @@ -361,64 +373,25 @@ void AgentMCTS::create_children_simple(const Board & board, Node * node){ PLUS(nodes, node->children.num()); } -//reads the format from gen_hgf. -void AgentMCTS::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); +void AgentMCTS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, & node); - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); } } - - eat_char(fd, ')'); - - return; } + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentmcts.h b/rex/agentmcts.h index 2da03fc..f665ab6 100644 --- a/rex/agentmcts.h +++ b/rex/agentmcts.h @@ -11,6 +11,12 @@ #include "../lib/depthstats.h" #include "../lib/exppair.h" #include "../lib/log.h" +#include "../lib/move.h" +#include "../lib/movelist.h" +#include "../lib/policy_bridge.h" +#include "../lib/policy_instantwin.h" +#include "../lib/policy_lastgoodreply.h" +#include "../lib/policy_random.h" #include "../lib/thread.h" #include "../lib/time.h" #include "../lib/types.h" @@ -19,14 +25,11 @@ #include "agent.h" #include "board.h" #include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" +namespace Morat { +namespace Rex { + class AgentMCTS : public Agent{ public: @@ -35,7 +38,7 @@ class AgentMCTS : public Agent{ ExpPair rave; ExpPair exp; int16_t know; - int8_t outcome; + Outcome outcome; uint8_t proofdepth; Move move; Move bestmove; //if outcome is set, then bestmove is the way to get there @@ -44,8 +47,8 @@ class AgentMCTS : public Agent{ //seems to need padding to multiples of 8 bytes or it segfaults? //don't forget to update the copy constructor/operator - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } + Node() : know(0), outcome(Outcome::UNKNOWN), proofdepth(0), move(M_NONE) { } + Node(const Move & m, Outcome o = Outcome::UNKNOWN) : know(0), outcome(o), proofdepth(0), move(m) { } Node(const Node & n) { *this = n; } Node & operator = (const Node & n){ if(this != & n){ //don't copy to self @@ -68,18 +71,8 @@ class AgentMCTS : public Agent{ children.swap(n.children); } - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); unsigned int size() const { unsigned int num = children.num(); @@ -142,16 +135,16 @@ class AgentMCTS : public Agent{ class AgentThread : public AgentThreadBase { mutable XORShift_float unitrand; - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; + LastGoodReply last_good_reply; + RandomPolicy random_policy; + ProtectBridge protect_bridge; + InstantWin instant_wins; bool use_rave; //whether to use rave for this simulation bool use_explore; //whether to use exploration for this simulation LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; + MoveList movelist; int stage; //which of the four MCTS stages is it on public: @@ -179,11 +172,11 @@ class AgentMCTS : public Agent{ void walk_tree(Board & board, Node * node, int depth); bool create_children(const Board & board, Node * node); void add_knowledge(const Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); + Node * choose_move(const Node * node, Side toplay, int remain) const; + void update_rave(const Node * node, Side toplay); bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - int rollout(Board & board, Move move, int depth); + Outcome rollout(Board & board, Move move, int depth); Move rollout_choose_move(Board & board, const Move & prev); Move rollout_pattern(const Board & board, const Move & move); }; @@ -261,12 +254,12 @@ class AgentMCTS : public Agent{ Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; bool done() { //solved or finished runs - return (rootboard.won() >= 0 || root.outcome >= 0 || (maxruns > 0 && runs >= maxruns)); + return (rootboard.won() >= Outcome::DRAW || root.outcome >= Outcome::DRAW || (maxruns > 0 && runs >= maxruns)); } bool need_gc() { @@ -292,16 +285,28 @@ class AgentMCTS : public Agent{ gclimit = (int)(gclimit*0.9); //slowly decay to a minimum of 5 } + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0) + limit = root.exp.num()/1000; + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } protected: void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - bool do_backup(Node * node, Node * backup, int toplay); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + bool do_backup(Node * node, Node * backup, Side toplay); + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const ; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentmcts_test.cpp b/rex/agentmcts_test.cpp new file mode 100644 index 0000000..8075e46 --- /dev/null +++ b/rex/agentmcts_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentmcts.h" + + +using namespace Morat; +using namespace Rex; + +TEST_CASE("Rex::AgentMCTS::Node::to_s/from_s", "[rex][agentmcts]") { + AgentMCTS::Node n(Move("a1")); + auto s = n.to_s(); + AgentMCTS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/rex/agentmctsthread.cpp b/rex/agentmctsthread.cpp index 8231d5f..5e9defd 100644 --- a/rex/agentmctsthread.cpp +++ b/rex/agentmctsthread.cpp @@ -6,6 +6,10 @@ #include "agentmcts.h" + +namespace Morat { +namespace Rex { + void AgentMCTS::AgentThread::iterate(){ INCR(agent->runs); if(agent->profile){ @@ -19,7 +23,7 @@ void AgentMCTS::AgentThread::iterate(){ use_rave = (unitrand() < agent->userave); use_explore = (unitrand() < agent->useexplore); walk_tree(copy, & agent->root, 0); - agent->root.exp.addv(movelist.getexp(3-agent->rootboard.toplay())); + agent->root.exp.addv(movelist.getexp(~agent->rootboard.toplay())); if(agent->profile){ times[0] += timestamps[1] - timestamps[0]; @@ -30,16 +34,16 @@ void AgentMCTS::AgentThread::iterate(){ } void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); + Side toplay = board.toplay(); - if(!node->children.empty() && node->outcome < 0){ + if(!node->children.empty() && node->outcome < Outcome::DRAW){ //choose a child and recurse Node * child; do{ int remain = board.movesremain(); child = choose_move(node, toplay, remain); - if(child->outcome < 0){ + if(child->outcome < Outcome::DRAW){ movelist.addtree(child->move, toplay); if(!board.move(child->move)){ @@ -71,10 +75,10 @@ void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ timestamps[1] = Time(); } - int won = (agent->minimax ? node->outcome : board.won()); + Outcome won = (agent->minimax ? node->outcome : board.won()); //if it's not already decided - if(won < 0){ + if(won < Outcome::DRAW){ //create children if valid if(node->exp.num() >= agent->visitexpand+1 && create_children(board, node)){ walk_tree(board, node, depth); @@ -125,6 +129,8 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ CompactTree::Children temp; temp.alloc(board.movesremain(), agent->ctmem); + Side toplay = board.toplay(); + Side opponent = ~toplay; int losses = 0; Node * child = temp.begin(), @@ -136,14 +142,14 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ *child = Node(*move); if(agent->minimax){ - child->outcome = board.test_win(*move); + child->outcome = board.test_outcome(*move); - if(agent->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ + if(agent->minimax >= 2 && board.test_outcome(*move, opponent) == +opponent){ losses++; loss = child; } - if(child->outcome == board.toplay()){ //proven win from here, don't need children + if(child->outcome == +toplay){ //proven win from here, don't need children node->outcome = child->outcome; node->proofdepth = 1; node->bestmove = *move; @@ -171,7 +177,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ macro.exp.addwins(agent->visitexpand); *(temp.begin()) = macro; }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - board.toplay(); + node->outcome = +opponent; node->proofdepth = 2; node->bestmove = loss->move; node->children.unlock(); @@ -180,7 +186,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ } if(agent->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); + std::sort(temp.begin(), temp.end(), sort_node_know); PLUS(agent->nodes, temp.num()); node->children.swap(temp); @@ -189,7 +195,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ return true; } -AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int toplay, int remain) const { +AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, Side toplay, int remain) const { float val, maxval = -1000000000; float logvisits = log(node->exp.num()); int dynwidenlim = (agent->dynwiden > 0 ? (int)(logvisits/agent->logdynwiden)+2 : Board::max_vecsize); @@ -204,11 +210,11 @@ AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int top * end = node->children.end(); for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ + if(child->outcome >= Outcome::DRAW){ if(child->outcome == toplay) //return a win immediately return child; - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse + val = (child->outcome == Outcome::DRAW ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse }else{ val = child->value(raveval, agent->knowledge, agent->fpurgency); if(explore > 0) @@ -237,80 +243,80 @@ backup in this order: 0 lose return true if fully solved, false if it's unknown or partially unknown */ -bool AgentMCTS::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread +bool AgentMCTS::do_backup(Node * node, Node * backup, Side toplay){ + Outcome node_outcome = node->outcome; + if(node_outcome >= Outcome::DRAW) //already proven, probably by a different thread return true; - if(backup->outcome == -3) //nothing proven by this child, so no chance + if(backup->outcome == Outcome::UNKNOWN) //nothing proven by this child, so no chance return false; uint8_t proofdepth = backup->proofdepth; if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; + uint64_t sims = 0, bestsims = 0, outcome = 0, best_outcome = 0; backup = NULL; Node * child = node->children.begin(), * end = node->children.end(); for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions + Outcome child_outcome = child->outcome; //save a copy to avoid race conditions if(proofdepth < child->proofdepth+1) proofdepth = child->proofdepth+1; //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss + if(child_outcome == Outcome::UNKNOWN){ // win/draw/loss outcome = 3; - }else if(childoutcome == toplay){ //win + }else if(child_outcome == toplay){ //win backup = child; outcome = 6; proofdepth = child->proofdepth+1; break; - }else if(childoutcome == 3-toplay){ //loss + }else if(child_outcome == ~toplay){ //loss outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss + }else if(child_outcome == Outcome::DRAW){ //draw + if(node_outcome == -toplay) //draw/loss, ie I can't win outcome = 4; else outcome = 2; - }else if(childoutcome == -toplay){ //win/draw + }else if(child_outcome == -~toplay){ //win/draw, ie opponent can't win outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss + }else if(child_outcome == -toplay){ //draw/loss, ie I can't win outcome = 1; }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); + logerr("child_outcome == " + child_outcome.to_s() + "\n"); assert(false && "How'd I get here? All outcomes should be tested above"); } sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; + if(best_outcome < outcome){ //better outcome is always preferable + best_outcome = outcome; bestsims = sims; backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ + }else if(best_outcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ //find long losses or easy wins/draws bestsims = sims; backup = child; } } - if(bestoutcome == 3) //no win, but found an unknown + if(best_outcome == 3) //no win, but found an unknown return false; } - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ + if(node->outcome.cas(node_outcome, backup->outcome)){ node->bestmove = backup->move; node->proofdepth = proofdepth; }else //if it was in a race, try again, might promote a partial solve to full solve return do_backup(node, backup, toplay); - return (node->outcome >= 0); + return (node->outcome >= Outcome::DRAW); } //update the rave score of all children that were played -void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ +void AgentMCTS::AgentThread::update_rave(const Node * node, Side toplay){ Node * child = node->children.begin(), * childend = node->children.end(); @@ -321,7 +327,7 @@ void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Node * child){ if(agent->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); + int dist = board.dist(node->move, child->move); if(dist < 4) child->know += agent->localreply * (4 - dist); } @@ -343,24 +349,24 @@ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Nod child->know += agent->bridge; if(agent->dists) - child->know += abs(agent->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); + child->know += abs(agent->dists) * std::max(0, board.get_size() - dists.get(child->move, board.toplay())); } //test whether this move is a forced reply to the opponent probing your virtual connections bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) + if(board.dist(move, test) != 1) return false; bool equals = false; int state = 0; - int piece = 3 - board.get(move); + Side piece = ~board.get(move); for(int i = 0; i < 8; i++){ Move cur = move + neighbours[i % 6]; bool on = board.onboard(cur); - int v = 0; + Side v = Side::NONE; if(on) v = board.get(cur); @@ -371,7 +377,7 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //else state = 0; }else if(state == 1){ if(on){ - if(v == 0){ + if(v == Side::NONE){ state = 2; equals = (test == cur); }else if(v != piece) @@ -396,16 +402,16 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //play a random game starting from a board state, and return the results of who won -int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ - int won; +Outcome AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ + Outcome won; if(agent->instantwin) instant_wins.rollout_start(board, agent->instantwin); random_policy.rollout_start(board); - while((won = board.won()) < 0){ - int turn = board.toplay(); + while((won = board.won()) < Outcome::DRAW){ + Side turn = board.toplay(); move = rollout_choose_move(board, move); @@ -449,3 +455,6 @@ Move AgentMCTS::AgentThread::rollout_choose_move(Board & board, const Move & pre return random_policy.choose_move(board, prev); } + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentpns.cpp b/rex/agentpns.cpp index ec270ff..1fdcb23 100644 --- a/rex/agentpns.cpp +++ b/rex/agentpns.cpp @@ -5,6 +5,40 @@ #include "agentpns.h" + +namespace Morat { +namespace Rex { + +std::string AgentPNS::Node::to_s() const { + return "AgentPNS::Node" + ", move " + move.to_s() + + ", phi " + to_str(phi) + + ", delta " + to_str(delta) + + ", work " + to_str(work) + + ", children " + to_str(children.num()); +} + +bool AgentPNS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 6){ + move = Move(dict["move"]); + phi = from_str(dict["phi"]); + delta = from_str(dict["delta"]); + work = from_str(dict["work"]); + // ignore children + return true; + } + return false; +} + +void AgentPNS::test() { + Node n(Move("a1")); + auto s = n.to_s(); + Node k; + assert(k.from_s(s)); +} + void AgentPNS::search(double time, uint64_t maxiters, int verbose){ max_nodes_seen = maxiters; @@ -32,27 +66,20 @@ void AgentPNS::search(double time, uint64_t maxiters, int verbose){ logerr("Tree depth: " + treelen.to_s() + "\n"); } - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); logerr("Root: " + root.to_s() + "\n"); - int outcome = root.to_outcome(3-toplay); - if(outcome != -3){ - logerr("Solved as a "); - if( outcome == 0) logerr("draw\n"); - else if(outcome == 3) logerr("draw by simultaneous win\n"); - else if(outcome == toplay) logerr("win\n"); - else if(outcome == 3-toplay) logerr("loss\n"); - else if(outcome == -toplay) logerr("win or draw\n"); - else if(outcome == toplay-3) logerr("loss or draw\n"); - } + Outcome outcome = root.to_outcome(~toplay); + if(outcome != Outcome::UNKNOWN) + logerr("Solved as a " + outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } } @@ -83,8 +110,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin unsigned int i = 0; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - unsigned int pd = 1; - int outcome; + unsigned int pd; + Outcome outcome; if(agent->ab){ Board next = board; @@ -94,10 +121,10 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin outcome = (agent->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); }else{ pd = 1; - outcome = board.test_win(*move); + outcome = board.test_outcome(*move); } - if(agent->lbdist && outcome < 0) + if(agent->lbdist && outcome != Outcome::UNKNOWN) pd = dists.get(*move); temp[i] = Node(*move).outcome(outcome, board.toplay(), agent->ties, pd); @@ -132,8 +159,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin } } - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); + tpc = std::min(INF32/2, (td + child->phi - node->delta)); + tdc = std::min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); }else{ tpc = tdc = 0; for(auto & i : node->children) @@ -198,16 +225,16 @@ double AgentPNS::gamelen() const { return rootboard.movesremain(); } -vector AgentPNS::get_pv() const { - vector pv; +std::vector AgentPNS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -216,8 +243,8 @@ vector AgentPNS::get_pv() const { return pv; } -string AgentPNS::move_stats(vector moves) const { - string s = ""; +std::string AgentPNS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -238,7 +265,7 @@ string AgentPNS::move_stats(vector moves) const { return s; } -Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { +Move AgentPNS::return_move(const Node * node, Side toplay, int verbose) const { double val, maxval = -1000000000000.0; //1 trillion Node * ret = NULL, @@ -246,11 +273,11 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - int outcome = child->to_outcome(toplay); - if(outcome >= 0){ - if(outcome == toplay) val = 800000000000.0 - (double)child->work; //shortest win - else if(outcome == 0) val = -400000000000.0 + (double)child->work; //longest tie - else val = -800000000000.0 + (double)child->work; //longest loss + Outcome outcome = child->to_outcome(toplay); + if(outcome >= Outcome::DRAW){ + if( outcome == +toplay) val = 800000000000.0 - (double)child->work; //shortest win + else if(outcome == Outcome::DRAW) val = -400000000000.0 + (double)child->work; //longest tie + else val = -800000000000.0 + (double)child->work; //longest loss }else{ //not proven val = child->work; } @@ -290,3 +317,51 @@ void AgentPNS::garbage_collect(Node * node){ } } } + +void AgentPNS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + node->children.alloc(board.movesremain(), ctmem); + unsigned int i = 0; + for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ + Outcome outcome = board.test_outcome(*move); + node->children[i] = Node(*move).outcome(outcome, board.toplay(), ties, 1); + i++; + } + PLUS(nodes, i); + node->children.shrink(i); //if symmetry, there may be extra moves to ignore +} + +void AgentPNS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.work >= limit && (side != node.to_outcome(~side) || child.to_outcome(side) == node.to_outcome(~side))){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentPNS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, &node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentpns.h b/rex/agentpns.h index ad33042..48e38cc 100644 --- a/rex/agentpns.h +++ b/rex/agentpns.h @@ -3,15 +3,21 @@ //A multi-threaded, tree based, proof number search solver. +#include + #include "../lib/agentpool.h" #include "../lib/compacttree.h" #include "../lib/depthstats.h" #include "../lib/log.h" +#include "../lib/string.h" #include "agent.h" #include "lbdist.h" +namespace Morat { +namespace Rex { + class AgentPNS : public Agent { static const uint32_t LOSS = (1<<30)-1; static const uint32_t DRAW = (1<<30)-2; @@ -51,33 +57,33 @@ class AgentPNS : public Agent { assert(children.empty()); } - Node & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); + Node & abval(int ab_outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && (ab_outcome == 1 || ab_outcome == -1)) + ab_outcome = (toplay == assign ? 2 : -2); - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } + if( ab_outcome == 0) { phi = value; delta = value; } + else if(ab_outcome == 2) { phi = LOSS; delta = 0; } + else if(ab_outcome == -2) { phi = 0; delta = LOSS; } + else /*(ab_outcome 1||-1)*/ { phi = 0; delta = DRAW; } return *this; } - Node & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; + Node & outcome(Outcome outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && outcome == Outcome::DRAW) + outcome = +assign; - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } + if( outcome == Outcome::UNKNOWN) { phi = value; delta = value; } + else if(outcome == +toplay) { phi = LOSS; delta = 0; } + else if(outcome == +~toplay) { phi = 0; delta = LOSS; } + else /*(outcome == Outcome::DRAW)*/ { phi = 0; delta = DRAW; } return *this; } - int to_outcome(int toplay) const { - if(phi == LOSS) return toplay; - if(delta == LOSS) return 3 - toplay; - if(delta == DRAW) return 0; - return -3; + Outcome to_outcome(Side toplay) const { + if(phi == LOSS) return +toplay; + if(delta == LOSS) return +~toplay; + if(delta == DRAW) return Outcome::DRAW; + return Outcome::UNKNOWN; } bool terminal(){ return (phi == 0 || delta == 0); } @@ -98,15 +104,8 @@ class AgentPNS : public Agent { return num; } - string to_s() const { - return "Node: move " + move.to_s() + - ", phi " + to_str(phi) + - ", delta " + to_str(delta) + - ", work " + to_str(work) + -// ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + -// ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); void swap_tree(Node & n){ children.swap(n.children); @@ -162,7 +161,7 @@ class AgentPNS : public Agent { int ab; // how deep of an alpha-beta search to run at each leaf node bool df; // go depth first? float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 + Side ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 bool lbdist; int numthreads; @@ -172,7 +171,7 @@ class AgentPNS : public Agent { ab = 2; df = true; epsilon = 0.25; - ties = 0; + ties = Side::NONE; lbdist = false; numthreads = 1; pool.set_num_threads(numthreads); @@ -228,7 +227,7 @@ class AgentPNS : public Agent { root.swap_tree(child); if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr(std::string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); assert(nodes == root.size()); @@ -280,12 +279,36 @@ class AgentPNS : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0){ + limit = 0; + //TODO: Set the root.work properly + for(auto & child : root.children) + limit += child.work; + limit /= 1000; + } + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } + + static void test(); private: //remove all the nodes with little work to free up some memory void garbage_collect(Node * node); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/agentpns_test.cpp b/rex/agentpns_test.cpp new file mode 100644 index 0000000..6b245f1 --- /dev/null +++ b/rex/agentpns_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentpns.h" + + +using namespace Morat; +using namespace Rex; + +TEST_CASE("Rex::AgentPNS::Node::to_s/from_s", "[rex][agentpns]") { + AgentPNS::Node n(Move("a1")); + auto s = n.to_s(); + AgentPNS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/rex/board.cpp b/rex/board.cpp new file mode 100644 index 0000000..e49615c --- /dev/null +++ b/rex/board.cpp @@ -0,0 +1,73 @@ + +#include "board.h" + +namespace Morat { +namespace Rex { + +std::string Board::Cell::to_s(int i) const { + return "Cell " + to_str(i) +": " + "piece: " + to_str(piece.to_i())+ + ", size: " + to_str((int)size) + + ", parent: " + to_str((int)parent) + + ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + + ", perm: " + to_str((int)perm) + + ", pattern: " + to_str((int)pattern); +} + +std::string Board::to_s(bool color) const { + using std::string; + string white = "O", + black = "@", + empty = ".", + coord = "", + reset = ""; + if(color){ + string esc = "\033"; + reset = esc + "[0m"; + coord = esc + "[1;37m"; + empty = reset + "."; + white = esc + "[1;33m" + "@"; //yellow + black = esc + "[1;34m" + "@"; //blue + } + + string s; + for(int i = 0; i < size; i++) + s += " " + coord + to_str(i+1); + s += "\n"; + + for(int y = 0; y < size; y++){ + s += string(y, ' '); + s += coord + char('A' + y); + int end = lineend(y); + for(int x = 0; x < end; x++){ + s += (last == Move(x, y) ? coord + "[" : + last == Move(x-1, y) ? coord + "]" : " "); + Side p = get(x, y); + if( p == Side::NONE) s += empty; + else if(p == Side::P1) s += white; + else if(p == Side::P2) s += black; + else s += "?"; + } + s += (last == Move(end-1, y) ? coord + "]" : " "); + s += white + reset; + s += '\n'; + } + s += string(size + 2, ' '); + for(int i = 0; i < size; i++) + s += black + " "; + s += "\n"; + + s += reset; + return s; +} + +int Board::edges(int x, int y) const { + return (x == 0 ? 1 : 0) | + (x == sizem1 ? 2 : 0) | + (y == 0 ? 4 : 0) | + (y == sizem1 ? 8 : 0); +} + + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/board.h b/rex/board.h index ce5aa6c..636f75b 100644 --- a/rex/board.h +++ b/rex/board.h @@ -4,17 +4,21 @@ #include #include #include +#include #include #include +#include "../lib/bitcount.h" #include "../lib/hashset.h" +#include "../lib/move.h" +#include "../lib/outcome.h" #include "../lib/string.h" #include "../lib/types.h" #include "../lib/zobrist.h" -#include "move.h" -using namespace std; +namespace Morat { +namespace Rex { /* * the board is represented as a flattened 2d array of the form: @@ -48,35 +52,31 @@ static MoveValid * staticneighbourlist[17] = { class Board{ public: + static constexpr const char * const name = "rex"; static const int default_size = 8; static const int min_size = 3; static const int max_size = 16; static const int max_vecsize = max_size * max_size; + static const int num_win_types = 1; static const int pattern_cells = 18; typedef uint64_t Pattern; struct Cell { - uint16_t piece; //who controls this cell, 0 for none, 1,2 for players + Side piece; //who controls this cell, 0 for none, 1,2 for players uint16_t size; //size of this group of cells -mutable uint16_t parent; //parent for this group of cells. 8 bits limits board size to 16 until it's no longer stored as a square +mutable uint16_t parent; //parent for this group of cells uint8_t edge; //which edges are this group connected to uint8_t perm; //is this a permanent piece or a randomly placed piece? Pattern pattern; //the pattern of pieces for neighbours, but from their perspective. Rotate 180 for my perpective - Cell() : piece(73), size(0), parent(0), edge(0), perm(0), pattern(0) { } - Cell(unsigned int p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : + Cell() : piece(Side::NONE), size(0), parent(0), edge(0), perm(0), pattern(0) { } + Cell(Side p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : piece(p), size(s), parent(a), edge(e), perm(0), pattern(t) { } - string to_s(int i) const { - return "Cell " + to_str((int)i) +": " - "piece: " + to_str((int)piece)+ - ", size: " + to_str((int)size) + - ", parent: " + to_str((int)parent) + - ", edge: " + to_str((int)edge) + - ", perm: " + to_str((int)perm) + - ", pattern: " + to_str((int)pattern); - } + int numedges() const { return BitsSetTable256[edge]; } + + std::string to_s(int i) const; }; class MoveIterator { //only returns valid moves... @@ -87,7 +87,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board HashSet hashes; public: MoveIterator(const Board & b, bool Unique) : board(b), lineend(0), move(Move(M_SWAP), -1), unique(Unique) { - if(board.outcome >= 0){ + if(board.outcome >= Outcome::DRAW){ move = MoveValid(0, board.size, -1); //already done } else { if(unique) @@ -113,9 +113,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board move.xy = -1; return *this; } - - move.x = 0; - move.xy = move.y * board.size; + move.x = board.linestart(move.y); + move.xy = board.xy(move.x, move.y); lineend = board.lineend(move.y); } }while(!board.valid_move_fast(move)); @@ -139,10 +138,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board short nummoves; short unique_depth; //update and test rotations/symmetry with less than this many pieces on the board Move last; - char toPlay; - char outcome; //-3 = unknown, 0 = tie, 1,2 = player win + Side toPlay; + Outcome outcome; - vector cells; + std::vector cells; Zobrist<6> hash; const MoveValid * neighbourlist; @@ -157,15 +156,15 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board last = M_NONE; nummoves = 0; unique_depth = 5; - toPlay = 1; - outcome = -3; + toPlay = Side::P1; + outcome = Outcome::UNKNOWN; neighbourlist = get_neighbour_list(); num_cells = vecsize(); cells.resize(vecsize()); for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ + for(int x = 0; x < size; x++){ int posxy = xy(x, y); Pattern p = 0, j = 3; for(const MoveValid * i = nb_begin(posxy), *e = nb_end_big_hood(i); i < e; i++){ @@ -173,7 +172,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board p |= j; j <<= 2; } - cells[posxy] = Cell(0, posxy, 1, edges(x, y), pattern_reverse(p)); + Side s = (onboard(x, y) ? Side::NONE : Side::UNDEF); + cells[posxy] = Cell(s, posxy, 1, edges(x, y), pattern_reverse(p)); } } } @@ -190,7 +190,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board int numcells() const { return num_cells; } int num_moves() const { return nummoves; } - int movesremain() const { return (won() >= 0 ? 0 : num_cells - nummoves); } + int movesremain() const { return (won() >= Outcome::DRAW ? 0 : num_cells - nummoves); } int xy(int x, int y) const { return y*size + x; } int xy(const Move & m) const { return m.y*size + m.x; } @@ -198,6 +198,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board MoveValid yx(int i) const { return MoveValid(i % size, i / size, i); } + int dist(const Move & a, const Move & b) const { + return (abs(a.x - b.x) + abs(a.y - b.y) + abs((a.x + a.y) - (b.x + b.y)) )/2; + } + const Cell * cell(int i) const { return & cells[i]; } const Cell * cell(int x, int y) const { return cell(xy(x,y)); } const Cell * cell(const Move & m) const { return cell(xy(m)); } @@ -205,40 +209,41 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board //assumes valid x,y - int get(int i) const { return cells[i].piece; } - int get(int x, int y) const { return get(xy(x, y)); } - int get(const Move & m) const { return get(xy(m)); } - int get(const MoveValid & m) const { return get(m.xy); } + Side get(int i) const { return cells[i].piece; } + Side get(int x, int y) const { return get(xy(x, y)); } + Side get(const Move & m) const { return get(xy(m)); } + Side get(const MoveValid & m) const { return get(m.xy); } - int geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : 0); } + Side geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : Side::UNDEF); } - int local(const Move & m, char turn) const { return local(xy(m), turn); } - int local(int i, char turn) const { + int local(const Move & m, Side turn) const { return local(xy(m), turn); } + int local(int i, Side turn) const { Pattern p = pattern(i); Pattern x = ((p & 0xAAAAAAAAAull) >> 1) ^ (p & 0x555555555ull); // p1 is now when p1 or p2 but not both (ie off the board) - p = x & (turn == 1 ? p : p >> 1); // now just the selected player + p = x & (turn == Side::P1 ? p : p >> 1); // now just the selected player return (p & 0x000000FFF ? 3 : 0) | (p & 0x000FFF000 ? 2 : 0) | (p & 0xFFF000000 ? 1 : 0); } - //assumes x, y are in array bounds - bool onboard_fast(int x, int y) const { return ( y < size && x < size); } - bool onboard_fast(const Move & m) const { return (m.y < size && m.x < size); } + //assumes x, y are in array bounds, and all moves within array bounds are valid + bool onboard_fast(int x, int y) const { return true; } + bool onboard_fast(const Move & m) const { return true; } //checks array bounds too - bool onboard(int x, int y) const { return ( x >= 0 && y >= 0 && onboard_fast(x, y) ); } - bool onboard(const Move & m)const { return (m.x >= 0 && m.y >= 0 && onboard_fast(m) ); } + bool onboard(int x, int y) const { return ( x >= 0 && y >= 0 && x < size && y < size && onboard_fast(x, y) ); } + bool onboard(const Move & m)const { return (m.x >= 0 && m.y >= 0 && m.x < size && m.y < size && onboard_fast(m) ); } bool onboard(const MoveValid & m) const { return m.onboard(); } //assumes x, y are in bounds and the game isn't already finished - bool valid_move_fast(int x, int y) const { return !get(x,y); } - bool valid_move_fast(const Move & m) const { return !get(m); } - bool valid_move_fast(const MoveValid & m) const { return !get(m.xy); } + bool valid_move_fast(int i) const { return get(i) == Side::NONE; } + bool valid_move_fast(int x, int y) const { return valid_move_fast(xy(x, y)); } + bool valid_move_fast(const Move & m) const { return valid_move_fast(xy(m)); } + bool valid_move_fast(const MoveValid & m) const { return valid_move_fast(m.xy); } //checks array bounds too - bool valid_move(int x, int y) const { return (outcome == -3 && onboard(x, y) && !get(x, y)); } - bool valid_move(const Move & m) const { return (outcome == -3 && onboard(m) && !get(m)); } - bool valid_move(const MoveValid & m) const { return (outcome == -3 && m.onboard() && !get(m)); } + bool valid_move(int x, int y) const { return (outcome < Outcome::DRAW && onboard(x, y) && valid_move_fast(x, y)); } + bool valid_move(const Move & m) const { return (outcome < Outcome::DRAW && onboard(m) && valid_move_fast(m)); } + bool valid_move(const MoveValid & m) const { return (outcome < Outcome::DRAW && m.onboard() && valid_move_fast(m)); } //iterator through neighbours of a position const MoveValid * nb_begin(int x, int y) const { return nb_begin(xy(x, y)); } @@ -252,12 +257,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board const MoveValid * nb_end_small_hood(const MoveValid * m) const { return m + 12; } const MoveValid * nb_end_big_hood(const MoveValid * m) const { return m + 18; } - int edges(int x, int y) const { - return (x == 0 ? 1 : 0) | - (x == sizem1 ? 2 : 0) | - (y == 0 ? 4 : 0) | - (y == sizem1 ? 8 : 0); - } + int edges(int x, int y) const; MoveValid * get_neighbour_list(){ if(!staticneighbourlist[(int)size]){ @@ -281,94 +281,24 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return staticneighbourlist[(int)size]; } - + int linestart(int y) const { return 0; } int lineend(int y) const { return size; } + int linelen(int y) const { return lineend(y) - linestart(y); } - string to_s(bool color) const { - string white = "O", - black = "@", - empty = ".", - coord = "", - reset = ""; - if(color){ - string esc = "\033"; - reset = esc + "[0m"; - coord = esc + "[1;37m"; - empty = reset + "."; - white = esc + "[1;33m" + "@"; //yellow - black = esc + "[1;34m" + "@"; //blue - } - - string s; - for(int i = 0; i < size; i++) - s += " " + coord + to_str(i+1); - s += "\n"; - - for(int y = 0; y < size; y++){ - s += string(y, ' '); - s += coord + char('A' + y); - int end = lineend(y); - for(int x = 0; x < size; x++){ - s += (last == Move(x, y) ? coord + "[" : - last == Move(x-1, y) ? coord + "]" : " "); - int p = get(x, y); - if(p == 0) s += empty; - if(p == 1) s += white; - if(p == 2) s += black; - if(p >= 3) s += "?"; - } - s += (last == Move(end-1, y) ? coord + "]" : " "); - s += white + reset; - s += '\n'; - } - s += string(size + 2, ' '); - for(int i = 0; i < size; i++) - s += black + " "; - s += "\n"; - - s += reset; - return s; - } + std::string to_s(bool color) const; + friend std::ostream& operator<< (std::ostream &out, const Board & b) { return out << b.to_s(true); } void print(bool color = true) const { printf("%s", to_s(color).c_str()); } - string boardstr() const { - string white, black; - for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ - int p = get(x, y); - if(p == 1) white += Move(x, y).to_s(); - if(p == 2) black += Move(x, y).to_s(); - } - } - return white + ";" + black; - } - - string won_str() const { - switch(outcome){ - case -3: return "none"; - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - } - return "unknown"; - } - - char won() const { + Outcome won() const { return outcome; } - int win() const{ // 0 for draw or unknown, 1 for win, -1 for loss - if(outcome <= 0) - return 0; - return (outcome == toplay() ? 1 : -1); - } + char getwintype() const { return outcome > Outcome::DRAW; } - char toplay() const { + Side toplay() const { return toPlay; } @@ -376,22 +306,22 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return MoveIterator(*this, (unique ? nummoves <= unique_depth : false)); } - void set(const Move & m, bool perm = true){ + void set(const Move & m, bool perm = true) { last = m; Cell * cell = & cells[xy(m)]; cell->piece = toPlay; cell->perm = perm; nummoves++; update_hash(m, toPlay); //depends on nummoves - toPlay = 3 - toPlay; + toPlay = ~toPlay; } - void unset(const Move & m){ //break win checks, but is a poor mans undo if all you care about is the hash - toPlay = 3 - toPlay; + void unset(const Move & m) { //break win checks, but is a poor mans undo if all you care about is the hash + toPlay = ~toPlay; update_hash(m, toPlay); nummoves--; Cell * cell = & cells[xy(m)]; - cell->piece = 0; + cell->piece = Side::NONE; cell->perm = 0; } @@ -421,7 +351,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return true; if(cells[i].size < cells[j].size) //force i's subtree to be bigger - swap(i, j); + std::swap(i, j); cells[j].parent = i; cells[i].size += cells[j].size; @@ -431,7 +361,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } Cell test_cell(const Move & pos) const { - char turn = toplay(); + Side turn = toplay(); int posxy = xy(pos); Cell testcell = cells[find_group(pos)]; @@ -463,7 +393,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (nummoves > unique_depth ? hash.get(0) : hash.get()); } - string hashstr() const { + std::string hashstr() const { static const char hexlookup[] = "0123456789abcdef"; char buf[19] = "0x"; hash_t val = gethash(); @@ -475,7 +405,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (char *)buf; } - void update_hash(const Move & pos, int turn){ + void update_hash(const Move & pos, Side side) { + int turn = side.to_i(); if(nummoves > unique_depth){ //simple update, no rotations/symmetry hash.update(0, 3*xy(pos) + turn); return; @@ -498,7 +429,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return test_hash(pos, toplay()); } - hash_t test_hash(const Move & pos, int turn) const { + hash_t test_hash(const Move & pos, Side side) const { + int turn = side.to_i(); if(nummoves >= unique_depth) //simple test, no rotations/symmetry return hash.test(0, 3*xy(pos) + turn); @@ -507,11 +439,11 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board z = sizem1 - x - y; hash_t m = hash.test(0, 3*xy(x, y) + turn); - m = min(m, hash.test(1, 3*xy(z, y) + turn)); - m = min(m, hash.test(2, 3*xy(z, x) + turn)); - m = min(m, hash.test(3, 3*xy(x, z) + turn)); - m = min(m, hash.test(4, 3*xy(y, z) + turn)); - m = min(m, hash.test(5, 3*xy(y, x) + turn)); + m = std::min(m, hash.test(1, 3*xy(z, y) + turn)); + m = std::min(m, hash.test(2, 3*xy(z, x) + turn)); + m = std::min(m, hash.test(3, 3*xy(x, z) + turn)); + m = std::min(m, hash.test(4, 3*xy(y, z) + turn)); + m = std::min(m, hash.test(5, 3*xy(y, x) + turn)); return m; } @@ -543,13 +475,13 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (((p & 0x03F03F03Full) << 6) | ((p & 0xFC0FC0FC0ull) >> 6)); } - static Pattern pattern_invert(Pattern p){ //switch players + static Pattern pattern_invert(Pattern p) { //switch players return ((p & 0xAAAAAAAAAull) >> 1) | ((p & 0x555555555ull) << 1); } - static Pattern pattern_rotate(Pattern p){ + static Pattern pattern_rotate(Pattern p) { return (((p & 0x003003003ull) << 10) | ((p & 0xFFCFFCFFCull) >> 2)); } - static Pattern pattern_mirror(Pattern p){ + static Pattern pattern_mirror(Pattern p) { // HGFEDC BA9876 543210 -> DEFGHC 6789AB 123450 return ((p & (3ull << 6)) ) | ((p & (3ull << 0)) ) | // 0,3 stay in place ((p & (3ull << 10)) >> 8) | ((p & (3ull << 2)) << 8) | // 1,5 swap @@ -561,36 +493,36 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board ((p & (3ull << 34)) >> 8) | ((p & (3ull << 26)) << 8) | // H,D swap ((p & (3ull << 32)) >> 4) | ((p & (3ull << 28)) << 4); // G,E swap } - static Pattern pattern_symmetry(Pattern p){ //takes a pattern and returns the representative version + static Pattern pattern_symmetry(Pattern p) { //takes a pattern and returns the representative version Pattern m = p; //012345 - m = min(m, (p = pattern_rotate(p)));//501234 - m = min(m, (p = pattern_rotate(p)));//450123 - m = min(m, (p = pattern_rotate(p)));//345012 - m = min(m, (p = pattern_rotate(p)));//234501 - m = min(m, (p = pattern_rotate(p)));//123450 - m = min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 - m = min(m, (p = pattern_rotate(p)));//105432 - m = min(m, (p = pattern_rotate(p)));//210543 - m = min(m, (p = pattern_rotate(p)));//321054 - m = min(m, (p = pattern_rotate(p)));//432105 - m = min(m, (p = pattern_rotate(p)));//543210 + m = std::min(m, (p = pattern_rotate(p)));//501234 + m = std::min(m, (p = pattern_rotate(p)));//450123 + m = std::min(m, (p = pattern_rotate(p)));//345012 + m = std::min(m, (p = pattern_rotate(p)));//234501 + m = std::min(m, (p = pattern_rotate(p)));//123450 + m = std::min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 + m = std::min(m, (p = pattern_rotate(p)));//105432 + m = std::min(m, (p = pattern_rotate(p)));//210543 + m = std::min(m, (p = pattern_rotate(p)));//321054 + m = std::min(m, (p = pattern_rotate(p)));//432105 + m = std::min(m, (p = pattern_rotate(p)));//543210 return m; } - bool move(const Move & pos, bool checkwin = true, bool permanent = true){ + bool move(const Move & pos, bool checkwin = true, bool permanent = true) { return move(MoveValid(pos, xy(pos)), checkwin, permanent); } - bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true){ - assert(outcome < 0); + bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true) { + assert(outcome < Outcome::DRAW); if(!valid_move(pos)) return false; - char turn = toplay(); + Side turn = toplay(); set(pos, permanent); // update the nearby patterns - Pattern p = turn; + Pattern p = turn.to_i(); for(const MoveValid * i = nb_begin(pos.xy), *e = nb_end_big_hood(i); i < e; i++){ if(i->onboard()){ cells[i->xy].pattern |= p; @@ -609,27 +541,27 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board // did I win? Cell * g = & cells[find_group(pos.xy)]; - uint8_t winmask = (turn == 1 ? 3 : 0xC); + uint8_t winmask = (turn == Side::P1 ? 3 : 0xC); if((g->edge & winmask) == winmask){ - outcome = 3 - turn; + outcome = ~turn; } return true; } - bool test_local(const Move & pos, char turn) const { + bool test_local(const Move & pos, Side turn) const { return test_local(MoveValid(pos, xy(pos)), turn); } + bool test_local(const MoveValid & pos, Side turn) const { return (local(pos, turn) == 3); } //test if making this move would win, but don't actually make the move - int test_win(const Move & pos, char turn = 0) const { - if(turn == 0) - turn = toplay(); - + Outcome test_outcome(const Move & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const Move & pos, Side turn) const { return test_outcome(MoveValid(pos, xy(pos)), turn); } + Outcome test_outcome(const MoveValid & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const MoveValid & pos, Side turn) const { if(test_local(pos, turn)){ - int posxy = xy(pos); - Cell testcell = cells[find_group(posxy)]; + Cell testcell = cells[find_group(pos.xy)]; int numgroups = 0; - for(const MoveValid * i = nb_begin(posxy), *e = nb_end(i); i < e; i++){ + for(const MoveValid * i = nb_begin(pos), *e = nb_end(i); i < e; i++){ if(i->onboard() && turn == get(i->xy)){ const Cell * g = & cells[find_group(i->xy)]; testcell.edge |= g->edge; @@ -639,11 +571,14 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } } - int winmask = (turn == 1 ? 3 : 0xC); + int winmask = (turn == Side::P1 ? 3 : 0xC); if((testcell.edge & winmask) == winmask) - return 3 - turn; + return ~turn; } - return -3; + return Outcome::UNKNOWN; } }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/board_test.cpp b/rex/board_test.cpp new file mode 100644 index 0000000..c6dd471 --- /dev/null +++ b/rex/board_test.cpp @@ -0,0 +1,143 @@ + +#include "../lib/catch.hpp" + +#include "board.h" + + +using namespace Morat; +using namespace Rex; + +void test_game(Board b, std::vector moves, Outcome outcome) { + REQUIRE(b.num_moves() == 0); + Side side = Side::P1; + for(auto s : moves) { + Outcome expected = (s == moves.back() ? outcome : Outcome::UNKNOWN); + Move move(s); + CAPTURE(move); + CAPTURE(b); + REQUIRE(b.valid_move(move)); + REQUIRE(b.toplay() == side); + REQUIRE(b.test_outcome(move) == expected); + REQUIRE(b.move(move)); + REQUIRE(b.won() == expected); + side = ~side; + } +} + +TEST_CASE("Rex::Board", "[rex][board]") { + Board b(7); + + SECTION("Basics") { + REQUIRE(b.get_size() == 7); + REQUIRE(b.movesremain() == 49); + } + + SECTION("valid moves") { + std::string valid[] = {"A1", "D4", + "a1", "a2", "a3", "a4", "a5", "a6", "a7", + "b1", "b2", "b3", "b4", "b5", "b6", "b7", + "c1", "c2", "c3", "c4", "c5", "c6", "c7", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e1", "e2", "e3", "e4", "e5", "e6", "e7", + "f1", "f2", "f3", "f4", "f5", "f6", "f7", + "g1", "g2", "g3", "g4", "g5", "g6", "g7", + }; + for(auto m : valid){ + REQUIRE(b.onboard(m)); + REQUIRE(b.valid_move(m)); + } + } + + SECTION("invalid moves") { + std::string invalid[] = {"a0", "a8", "a10", "b8", "c8", "e0", "e8", "f8", "f0", "h1", "f0"}; + for(auto m : invalid){ + REQUIRE_FALSE(b.onboard(m)); + REQUIRE_FALSE(b.valid_move(m)); + } + } + + SECTION("duplicate moves") { + Move m("a1"); + REQUIRE(b.valid_move(m)); + REQUIRE(b.move(m)); + REQUIRE_FALSE(b.valid_move(m)); + REQUIRE_FALSE(b.move(m)); + } + + SECTION("move distance") { + SECTION("x") { + REQUIRE(b.dist(Move("b2"), Move("b1")) == 1); + REQUIRE(b.dist(Move("b2"), Move("b3")) == 1); + } + SECTION("y") { + REQUIRE(b.dist(Move("b2"), Move("a2")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c2")) == 1); + } + SECTION("z") { + REQUIRE(b.dist(Move("b2"), Move("a3")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c1")) == 1); + } + SECTION("farther") { + REQUIRE(b.dist(Move("b2"), Move("a1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("c3")) == 2); + REQUIRE(b.dist(Move("b2"), Move("d4")) == 4); + REQUIRE(b.dist(Move("b2"), Move("d3")) == 3); + REQUIRE(b.dist(Move("b2"), Move("d1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("e3")) == 4); + } + } + + SECTION("Unknown_1") { + test_game(b, { "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + test_game(b, {"d4", "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_2") { + test_game(b, { "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + test_game(b, {"d4", "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_3") { + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::UNKNOWN); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "d4"}, Outcome::UNKNOWN); + + test_game(b, { "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::UNKNOWN); + test_game(b, {"d7", "b2", "f3", "b3", "f4", "c2", "f5", "c4", "f6", "d3", "f7", "c3", "e6", "d4"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_4") { + test_game(b, { + "a1", "a2", "a3", "a4", + "b1", "b2", "b3", "b4", "b5", + "c1", "c2", "c3", "c4", "c5", "c6", + "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "e2", "e3", "e4", "e5", "e6", "e7", + "f3", "f4", "f5", "f6", "f7", + "g4", "g5", "g6", "g7", + }, Outcome::UNKNOWN); + } + + SECTION("White Connects") { + test_game(b, + {"a1", "b1", "a2", "b2", "a3", + "b3", "a4", "b4", "a5", "b5", + "a6", "b6","a7"}, + Outcome::P2); + } + + SECTION("Black Connects") { + test_game(b, + {"a2", "a1", "b2", "b1", "c2", + "c1", "d2", "d1", "e2", "e1", + "f2", "f1","g2", "g1"}, + Outcome::P1); + } + + SECTION("Black Connects") { + test_game(b, + {"a2", "a1", "b2", "b1", "c2", + "c1", "d2", "d1", "e2", "e1", + "f2", "f1","g2", "g1"}, + Outcome::P1); + } +} diff --git a/rex/gtp.h b/rex/gtp.h index f53c9e9..fb53b91 100644 --- a/rex/gtp.h +++ b/rex/gtp.h @@ -2,6 +2,8 @@ #pragma once #include "../lib/gtpcommon.h" +#include "../lib/history.h" +#include "../lib/move.h" #include "../lib/string.h" #include "agent.h" @@ -9,11 +11,13 @@ #include "agentmcts.h" #include "agentpns.h" #include "board.h" -#include "history.h" -#include "move.h" + + +namespace Morat { +namespace Rex { class GTP : public GTPCommon { - History hist; + History hist; public: int verbose; @@ -35,46 +39,46 @@ class GTP : public GTPCommon { set_board(); - newcallback("name", bind(>P::gtp_name, this, _1), "Name of the program"); - newcallback("version", bind(>P::gtp_version, this, _1), "Version of the program"); - newcallback("verbose", bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); - newcallback("extended", bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); - newcallback("debug", bind(>P::gtp_debug, this, _1), "Enable debug mode"); - newcallback("colorboard", bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); - newcallback("showboard", bind(>P::gtp_print, this, _1), "Show the board"); - newcallback("print", bind(>P::gtp_print, this, _1), "Alias for showboard"); - newcallback("dists", bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); -// newcallback("zobrist", bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); - newcallback("clear_board", bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); - newcallback("clear", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); - newcallback("boardsize", bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); - newcallback("size", bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); - newcallback("play", bind(>P::gtp_play, this, _1), "Place a stone: play "); - newcallback("white", bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); - newcallback("black", bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); - newcallback("undo", bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); - newcallback("time", bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); - newcallback("genmove", bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); - newcallback("solve", bind(>P::gtp_solve, this, _1), "Try to solve this position"); - -// newcallback("ab", bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); - newcallback("mcts", bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); - newcallback("pns", bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); - - newcallback("all_legal", bind(>P::gtp_all_legal, this, _1), "List all legal moves"); - newcallback("history", bind(>P::gtp_history, this, _1), "List of played moves"); - newcallback("playgame", bind(>P::gtp_playgame, this, _1), "Play a list of moves"); - newcallback("winner", bind(>P::gtp_winner, this, _1), "Check the winner of the game"); - newcallback("patterns", bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); - - newcallback("pv", bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); - newcallback("move_stats", bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); - - newcallback("params", bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); - -// newcallback("player_hgf", bind(>P::gtp_player_hgf, this, _1), "Output an hgf of the current tree"); -// newcallback("player_load_hgf", bind(>P::gtp_player_load_hgf,this, _1), "Load an hgf generated by player_hgf"); -// newcallback("player_gammas", bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); + newcallback("name", std::bind(>P::gtp_name, this, _1), "Name of the program"); + newcallback("version", std::bind(>P::gtp_version, this, _1), "Version of the program"); + newcallback("verbose", std::bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); + newcallback("extended", std::bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); + newcallback("debug", std::bind(>P::gtp_debug, this, _1), "Enable debug mode"); + newcallback("colorboard", std::bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); + newcallback("showboard", std::bind(>P::gtp_print, this, _1), "Show the board"); + newcallback("print", std::bind(>P::gtp_print, this, _1), "Alias for showboard"); + newcallback("dists", std::bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); + newcallback("zobrist", std::bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); + newcallback("clear_board", std::bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); + newcallback("clear", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("boardsize", std::bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); + newcallback("size", std::bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); + newcallback("play", std::bind(>P::gtp_play, this, _1), "Place a stone: play "); + newcallback("white", std::bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); + newcallback("black", std::bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); + newcallback("undo", std::bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); + newcallback("time", std::bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); + newcallback("genmove", std::bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); + newcallback("solve", std::bind(>P::gtp_solve, this, _1), "Try to solve this position"); + +// newcallback("ab", std::bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); + newcallback("mcts", std::bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); + newcallback("pns", std::bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); + + newcallback("all_legal", std::bind(>P::gtp_all_legal, this, _1), "List all legal moves"); + newcallback("history", std::bind(>P::gtp_history, this, _1), "List of played moves"); + newcallback("playgame", std::bind(>P::gtp_playgame, this, _1), "Play a list of moves"); + newcallback("winner", std::bind(>P::gtp_winner, this, _1), "Check the winner of the game"); + newcallback("patterns", std::bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); + + newcallback("pv", std::bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); + newcallback("move_stats", std::bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); + + newcallback("params", std::bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + + newcallback("save_sgf", std::bind(>P::gtp_save_sgf, this, _1), "Output an sgf of the current tree"); + newcallback("load_sgf", std::bind(>P::gtp_load_sgf, this, _1), "Load an sgf generated by save_sgf"); +// newcallback("player_gammas", std::bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); } void set_board(bool clear = true){ @@ -94,7 +98,7 @@ class GTP : public GTPCommon { GTPResponse gtp_all_legal(vecstr args); GTPResponse gtp_history(vecstr args); GTPResponse gtp_patterns(vecstr args); - GTPResponse play(const string & pos, int toplay); + GTPResponse play(const std::string & pos, Side toplay); GTPResponse gtp_playgame(vecstr args); GTPResponse gtp_play(vecstr args); GTPResponse gtp_playwhite(vecstr args); @@ -124,8 +128,11 @@ class GTP : public GTPCommon { GTPResponse gtp_pns_params(vecstr args); // GTPResponse gtp_player_gammas(vecstr args); -// GTPResponse gtp_player_hgf(vecstr args); -// GTPResponse gtp_player_load_hgf(vecstr args); + GTPResponse gtp_save_sgf(vecstr args); + GTPResponse gtp_load_sgf(vecstr args); - string solve_str(int outcome) const; + std::string solve_str(int outcome) const; }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/gtpagent.cpp b/rex/gtpagent.cpp index d32178a..8943d38 100644 --- a/rex/gtpagent.cpp +++ b/rex/gtpagent.cpp @@ -1,13 +1,12 @@ -#include +#include "gtp.h" -#include "../lib/fileio.h" -#include "gtp.h" +namespace Morat { +namespace Rex { using namespace std; - GTPResponse GTP::gtp_move_stats(vecstr args){ vector moves; for(auto s : args) @@ -249,7 +248,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" " -m --memory Memory limit in Mb [" + to_str(pns->memlimit/(1024*1024)) + "]\n" " -t --threads How many threads to run [" + to_str(pns->numthreads) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties) + "]\n" + " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties.to_i()) + "]\n" " -d --df Use depth-first thresholds [" + to_str(pns->df) + "]\n" " -e --epsilon How big should the threshold be [" + to_str(pns->epsilon) + "]\n" " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(pns->ab) + "]\n" @@ -267,7 +266,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); pns->set_memlimit(mem*1024*1024); }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - pns->ties = from_str(args[++i]); + pns->ties = Side(from_str(args[++i])); pns->clear_mem(); }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ pns->df = from_str(args[++i]); @@ -282,3 +281,6 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ return GTPResponse(true, errs); } + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/gtpgeneral.cpp b/rex/gtpgeneral.cpp index 60f0f73..3c63b58 100644 --- a/rex/gtpgeneral.cpp +++ b/rex/gtpgeneral.cpp @@ -1,7 +1,15 @@ +#include + +#include "../lib/sgf.h" + #include "gtp.h" #include "lbdist.h" + +namespace Morat { +namespace Rex { + GTPResponse GTP::gtp_mcts(vecstr args){ delete agent; agent = new AgentMCTS(); @@ -39,7 +47,7 @@ GTPResponse GTP::gtp_boardsize(vecstr args){ if(size < Board::min_size || size > Board::max_size) return GTPResponse(false, "Size " + to_str(size) + " is out of range."); - hist = History(size); + hist = History(size); set_board(); time_control.new_game(); @@ -69,14 +77,14 @@ GTPResponse GTP::gtp_undo(vecstr args){ GTPResponse GTP::gtp_patterns(vecstr args){ bool symmetric = true; bool invert = true; - string ret; + std::string ret; const Board & board = *hist; for(Board::MoveIterator move = board.moveit(); !move.done(); ++move){ ret += move->to_s() + " "; unsigned int p = board.pattern(*move); if(symmetric) p = board.pattern_symmetry(p); - if(invert && board.toplay() == 2) + if(invert && board.toplay() == Side::P2) p = board.pattern_invert(p); ret += to_str(p); ret += "\n"; @@ -85,24 +93,24 @@ GTPResponse GTP::gtp_patterns(vecstr args){ } GTPResponse GTP::gtp_all_legal(vecstr args){ - string ret; + std::string ret; for(Board::MoveIterator move = hist->moveit(); !move.done(); ++move) ret += move->to_s() + " "; return GTPResponse(true, ret); } GTPResponse GTP::gtp_history(vecstr args){ - string ret; + std::string ret; for(auto m : hist) ret += m.to_s() + " "; return GTPResponse(true, ret); } -GTPResponse GTP::play(const string & pos, int toplay){ +GTPResponse GTP::play(const std::string & pos, Side toplay){ if(toplay != hist->toplay()) return GTPResponse(false, "It is the other player's turn!"); - if(hist->won() >= 0) + if(hist->won() >= Outcome::DRAW) return GTPResponse(false, "The game is already over."); Move m(pos); @@ -113,7 +121,7 @@ GTPResponse GTP::play(const string & pos, int toplay){ move(m); if(verbose >= 2) - logerr("Placement: " + m.to_s() + ", outcome: " + hist->won_str() + "\n" + hist->to_s(colorboard)); + logerr("Placement: " + m.to_s() + ", outcome: " + hist->won().to_s() + "\n" + hist->to_s(colorboard)); return GTPResponse(true); } @@ -131,37 +139,33 @@ GTPResponse GTP::gtp_play(vecstr args){ if(args.size() != 2) return GTPResponse(false, "Wrong number of arguments"); - char toplay = 0; switch(tolower(args[0][0])){ - case 'w': toplay = 1; break; - case 'b': toplay = 2; break; - default: - return GTPResponse(false, "Invalid player selection"); + case 'w': return play(args[1], Side::P1); + case 'b': return play(args[1], Side::P2); + default: return GTPResponse(false, "Invalid player selection"); } - - return play(args[1], toplay); } GTPResponse GTP::gtp_playwhite(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 1); + return play(args[0], Side::P1); } GTPResponse GTP::gtp_playblack(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 2); + return play(args[0], Side::P2); } GTPResponse GTP::gtp_winner(vecstr args){ - return GTPResponse(true, hist->won_str()); + return GTPResponse(true, hist->won().to_s()); } GTPResponse GTP::gtp_name(vecstr args){ - return GTPResponse(true, "Castro"); + return GTPResponse(true, std::string("morat-") + Board::name); } GTPResponse GTP::gtp_version(vecstr args){ @@ -193,7 +197,7 @@ GTPResponse GTP::gtp_extended(vecstr args){ } GTPResponse GTP::gtp_debug(vecstr args){ - string str = "\n"; + std::string str = "\n"; str += "Board size: " + to_str(hist->get_size()) + "\n"; str += "Board cells: " + to_str(hist->numcells()) + "\n"; str += "Board vec: " + to_str(hist->vecsize()) + "\n"; @@ -203,14 +207,15 @@ GTPResponse GTP::gtp_debug(vecstr args){ } GTPResponse GTP::gtp_dists(vecstr args){ + using std::string; Board board = *hist; LBDists dists(&board); - int side = 0; + Side side = Side::NONE; if(args.size() >= 1){ switch(tolower(args[0][0])){ - case 'w': side = 1; break; - case 'b': side = 2; break; + case 'w': side = Side::P1; break; + case 'b': side = Side::P2; break; default: return GTPResponse(false, "Invalid player selection"); } @@ -243,17 +248,17 @@ GTPResponse GTP::gtp_dists(vecstr args){ s += coord + char('A' + y); int end = board.lineend(y); for(int x = 0; x < end; x++){ - int p = board.get(x, y); + Side p = board.get(x, y); s += ' '; - if(p == 0){ - int d = (side ? dists.get(Move(x, y), side) : dists.get(Move(x, y))); - if(d < 30) + if(p == Side::NONE){ + int d = (side == Side::NONE ? dists.get(Move(x, y)) : dists.get(Move(x, y), side)); + if(d < 10) s += reset + to_str(d); else s += empty; - }else if(p == 1){ + }else if(p == Side::P1){ s += white; - }else if(p == 2){ + }else if(p == Side::P2){ s += black; } } @@ -265,3 +270,91 @@ GTPResponse GTP::gtp_dists(vecstr args){ GTPResponse GTP::gtp_zobrist(vecstr args){ return GTPResponse(true, hist->hashstr()); } + +GTPResponse GTP::gtp_save_sgf(vecstr args){ + int limit = -1; + if(args.size() == 0) + return GTPResponse(true, "save_sgf [work limit]"); + + std::ifstream infile(args[0].c_str()); + + if(infile) { + infile.close(); + return GTPResponse(false, "File " + args[0] + " already exists"); + } + + std::ofstream outfile(args[0].c_str()); + + if(!outfile) + return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); + + if(args.size() > 1) + limit = from_str(args[1]); + + SGFPrinter sgf(outfile); + sgf.game(Board::name); + sgf.program(gtp_name(vecstr()).response, gtp_version(vecstr()).response); + sgf.size(hist->get_size()); + + sgf.end_root(); + + Side s = Side::P1; + for(auto m : hist){ + sgf.move(s, m); + s = ~s; + } + + agent->gen_sgf(sgf, limit); + + sgf.end(); + outfile.close(); + return true; +} + + +GTPResponse GTP::gtp_load_sgf(vecstr args){ + if(args.size() == 0) + return GTPResponse(true, "load_sgf "); + + std::ifstream infile(args[0].c_str()); + + if(!infile) { + return GTPResponse(false, "Error opening file " + args[0] + " for reading"); + } + + SGFParser sgf(infile); + if(sgf.game() != Board::name){ + infile.close(); + return GTPResponse(false, "File is for the wrong game: " + sgf.game()); + } + + int size = sgf.size(); + if(size != hist->get_size()){ + if(hist.len() == 0){ + hist = History(size); + set_board(); + time_control.new_game(); + }else{ + infile.close(); + return GTPResponse(false, "File has the wrong boardsize to match the existing game"); + } + } + + Side s = Side::P1; + + while(sgf.next_node()){ + Move m = sgf.move(); + move(m); // push the game forward + s = ~s; + } + + if(sgf.has_children()) + agent->load_sgf(sgf); + + assert(sgf.done_child()); + infile.close(); + return true; +} + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/gtpplayer.cpp b/rex/gtpplayer.cpp deleted file mode 100644 index 1d9f89b..0000000 --- a/rex/gtpplayer.cpp +++ /dev/null @@ -1,547 +0,0 @@ - - -#include - -#include "../lib/fileio.h" - -#include "gtp.h" - -using namespace std; - - -GTPResponse GTP::gtp_move_stats(vecstr args){ - string s = ""; - - Player::Node * node = &(player.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - Player::Node * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - Player::Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s(); - s += "," + to_str((child->exp.num() ? child->exp.avg() : 0.0), 4) + "," + to_str(child->exp.num()); - s += "," + to_str((child->rave.num() ? child->rave.avg() : 0.0), 4) + "," + to_str(child->rave.num()); - s += "," + to_str(child->know); - if(child->outcome >= 0) - s += "," + won_str(child->outcome); - s += "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_player_solve(vecstr args){ - double use_time = (args.size() >= 1 ? - from_str(args[0]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = M_RESIGN; - if(ret) - best = ret->move; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret){ - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(ret->outcome >= 0){ - stats += "Solved as a "; - if(ret->outcome == toplay) stats += "win"; - else if(ret->outcome == 0) stats += "draw"; - else stats += "loss"; - stats += "\n"; - } - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - if(verbose) - logerr(stats); - - Solver s; - if(ret){ - s.outcome = (ret->outcome >= 0 ? ret->outcome : -3); - s.bestmove = ret->move; - s.maxdepth = gamelen.maxdepth; - s.nodes_seen = runs; - }else{ - s.outcome = 3-toplay; - s.bestmove = M_RESIGN; - s.maxdepth = 0; - s.nodes_seen = 0; - } - - return GTPResponse(true, solve_str(s)); -} - - -GTPResponse GTP::gtp_player_solved(vecstr args){ - string s = ""; - Player::Node * child = player.root.children.begin(), - * childend = player.root.children.end(); - int toplay = player.rootboard.toplay(); - int best = 0; - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - if(child->outcome == toplay) - return GTPResponse(true, won_str(toplay)); - else if(child->outcome < 0) - best = 2; - else if(child->outcome == 0) - best = 1; - } - if(best == 2) return GTPResponse(true, won_str(-3)); - if(best == 1) return GTPResponse(true, won_str(0)); - return GTPResponse(true, won_str(3 - toplay)); -} - -GTPResponse GTP::gtp_pv(vecstr args){ - string pvstr = ""; - vector pv = player.get_pv(); - for(unsigned int i = 0; i < pv.size(); i++) - pvstr += pv[i].to_s() + " "; - return GTPResponse(true, pvstr); -} - -GTPResponse GTP::gtp_player_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_hgf [sims limit]"); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(fd){ - fclose(fd); - return GTPResponse(false, "File " + args[0] + " already exists"); - } - - fd = fopen(args[0].c_str(), "w"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); - - unsigned int limit = 10000; - if(args.size() > 1) - limit = from_str(args[1]); - - Board board = *hist; - - - fprintf(fd, "(;FF[4]SZ[%i]\n", board.get_size()); - int p = 1; - for(auto m : hist){ - fprintf(fd, ";%c[%s]", (p == 1 ? 'W' : 'B'), m.to_s().c_str()); - p = 3-p; - } - - - Player::Node * child = player.root.children.begin(), - * end = player.root.children.end(); - - for( ; child != end; child++){ - if(child->exp.num() >= limit){ - board.set(child->move); - player.gen_hgf(board, child, limit, 1, fd); - board.unset(child->move); - } - } - - fprintf(fd, ")\n"); - - fclose(fd); - - return true; -} - -GTPResponse GTP::gtp_player_load_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_load_hgf "); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for reading failed"); - - int size; - assert(fscanf(fd, "(;FF[4]SZ[%i]", & size) > 0); - if(size != hist->get_size()){ - if(hist.len() == 0){ - hist = History(Board(size)); - set_board(); - }else{ - fclose(fd); - return GTPResponse(false, "File has the wrong boardsize to match the existing game"); - } - } - - eat_whitespace(fd); - - Board board(size); - Player::Node * node = & player.root; - vector prefix; - - char side, movestr[5]; - while(fscanf(fd, ";%c[%5[^]]]", &side, movestr) > 0){ - Move move(movestr); - - if(board.num_moves() >= (int)hist.len()){ - if(node->children.empty()) - player.create_children_simple(board, node); - - prefix.push_back(node); - node = player.find_child(node, move); - }else if(hist[board.num_moves()] != move){ - fclose(fd); - return GTPResponse(false, "The current game is deeper than this file"); - } - board.move(move); - - eat_whitespace(fd); - } - prefix.push_back(node); - - - if(fpeek(fd) != ')'){ - if(node->children.empty()) - player.create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Player::Node child; - player.load_hgf(board, & child, fd); - - Player::Node * i = player.find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_whitespace(fd); - assert(fgetc(fd) == ')'); - fclose(fd); - - while(!prefix.empty()){ - Player::Node * node = prefix.back(); - prefix.pop_back(); - - Player::Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = hist->toplay(); - if(prefix.size() % 2 == 1) - toplay = 3 - toplay; - - Player::Node * backup = child; - - node->exp.clear(); - for( ; child != end; child++){ - node->exp += child->exp.invert(); - if(child->outcome == toplay || child->exp.num() > backup->exp.num()) - backup = child; - } - player.do_backup(node, backup, toplay); - } - - return true; -} - - -GTPResponse GTP::gtp_genmove(vecstr args){ - if(player.rootboard.won() >= 0) - return GTPResponse(true, "resign"); - - double use_time = (args.size() >= 2 ? - from_str(args[1]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(args.size() >= 2) - use_time = from_str(args[1]); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - uword nodesbefore = player.nodes; - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = player.root.bestmove; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret) - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(player.root.outcome != -3){ - stats += "Solved as a "; - if(player.root.outcome == 0) stats += "draw"; - else if(player.root.outcome == toplay) stats += "win"; - else if(player.root.outcome == 3-toplay) stats += "loss"; - else if(player.root.outcome == -toplay) stats += "win or draw"; - else if(player.root.outcome == toplay-3) stats += "loss or draw"; - stats += "\n"; - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - string extended; - if(genmoveextended){ - //move score - if(ret) extended += " " + to_str(ret->exp.avg()); - else extended += " 0"; - //outcome - extended += " " + won_str(player.root.outcome); - //work - extended += " " + to_str(runs); - //nodes - extended += " " + to_str(player.nodes - nodesbefore); - } - - move(best); - - if(verbose >= 2){ - stats += "history: "; - for(auto m : hist) - stats += m.to_s() + " "; - stats += "\n"; - stats += hist->to_s(colorboard) + "\n"; - } - - if(verbose) - logerr(stats); - - return GTPResponse(true, best.to_s() + extended); -} - -GTPResponse GTP::gtp_player_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Set player parameters, eg: player_params -e 1 -f 0 -t 2 -o 1 -p 0\n" + - "Processing:\n" + -#ifndef SINGLE_THREAD - " -t --threads Number of MCTS threads [" + to_str(player.numthreads) + "]\n" + -#endif - " -o --ponder Continue to ponder during the opponents time [" + to_str(player.ponder) + "]\n" + - " -M --maxmem Max memory in Mb to use for the tree [" + to_str(player.maxmem/(1024*1024)) + "]\n" + - " --profile Output the time used by each phase of MCTS [" + to_str(player.profile) + "]\n" + - "Final move selection:\n" + - " -E --msexplore Lower bound constant in final move selection [" + to_str(player.msexplore) + "]\n" + - " -F --msrave Rave factor, 0 for pure exp, -1 # sims, -2 # wins [" + to_str(player.msrave) + "]\n" + - "Tree traversal:\n" + - " -e --explore Exploration rate for UCT [" + to_str(player.explore) + "]\n" + - " -A --parexplore Multiply the explore rate by parents experience [" + to_str(player.parentexplore) + "]\n" + - " -f --ravefactor The rave factor: alpha = rf/(rf + visits) [" + to_str(player.ravefactor) + "]\n" + - " -d --decrrave Decrease the rave factor over time: rf += d*empty [" + to_str(player.decrrave) + "]\n" + - " -a --knowledge Use knowledge: 0.01*know/sqrt(visits+1) [" + to_str(player.knowledge) + "]\n" + - " -r --userave Use rave with this probability [0-1] [" + to_str(player.userave) + "]\n" + - " -X --useexplore Use exploration with this probability [0-1] [" + to_str(player.useexplore) + "]\n" + - " -u --fpurgency Value to assign to an unplayed move [" + to_str(player.fpurgency) + "]\n" + - " -O --rollouts Number of rollouts to run per simulation [" + to_str(player.rollouts) + "]\n" + - " -I --dynwiden Dynamic widening, consider log_wid(exp) children [" + to_str(player.dynwiden) + "]\n" + - "Tree building:\n" + - " -s --shortrave Only use moves from short rollouts for rave [" + to_str(player.shortrave) + "]\n" + - " -k --keeptree Keep the tree from the previous move [" + to_str(player.keeptree) + "]\n" + - " -m --minimax Backup the minimax proof in the UCT tree [" + to_str(player.minimax) + "]\n" + - " -x --visitexpand Number of visits before expanding a node [" + to_str(player.visitexpand) + "]\n" + - " -P --symmetry Prune symmetric moves, good for proof, not play [" + to_str(player.prunesymmetry) + "]\n" + - " --gcsolved Garbage collect solved nodes with fewer sims than [" + to_str(player.gcsolved) + "]\n" + - "Node initialization knowledge, Give a bonus:\n" + - " -l --localreply based on the distance to the previous move [" + to_str(player.localreply) + "]\n" + - " -y --locality to stones near other stones of the same color [" + to_str(player.locality) + "]\n" + - " -c --connect to stones connected to edges [" + to_str(player.connect) + "]\n" + - " -S --size based on the size of the group [" + to_str(player.size) + "]\n" + - " -b --bridge to maintaining a 2-bridge after the op probes [" + to_str(player.bridge) + "]\n" + - " -D --distance to low minimum distance to win (<0 avoid VCs) [" + to_str(player.dists) + "]\n" + - "Rollout policy:\n" + - " -h --weightrand Weight the moves according to computed gammas [" + to_str(player.weightedrandom) + "]\n" + - " -p --pattern Maintain the virtual connection pattern [" + to_str(player.rolloutpattern) + "]\n" + - " -g --goodreply Reuse the last good reply (1), remove losses (2) [" + to_str(player.lastgoodreply) + "]\n" + - " -w --instantwin Look for instant wins to this depth [" + to_str(player.instantwin) + "]\n" - ); - - string errs; - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - player.numthreads = from_str(args[++i]); - bool p = player.ponder; - player.set_ponder(false); //stop the threads while resetting them - player.reset_threads(); - player.set_ponder(p); - }else if((arg == "-o" || arg == "--ponder") && i+1 < args.size()){ - player.set_ponder(from_str(args[++i])); - }else if((arg == "--profile") && i+1 < args.size()){ - player.profile = from_str(args[++i]); - }else if((arg == "-M" || arg == "--maxmem") && i+1 < args.size()){ - player.maxmem = from_str(args[++i])*1024*1024; - }else if((arg == "-E" || arg == "--msexplore") && i+1 < args.size()){ - player.msexplore = from_str(args[++i]); - }else if((arg == "-F" || arg == "--msrave") && i+1 < args.size()){ - player.msrave = from_str(args[++i]); - }else if((arg == "-e" || arg == "--explore") && i+1 < args.size()){ - player.explore = from_str(args[++i]); - }else if((arg == "-A" || arg == "--parexplore") && i+1 < args.size()){ - player.parentexplore = from_str(args[++i]); - }else if((arg == "-f" || arg == "--ravefactor") && i+1 < args.size()){ - player.ravefactor = from_str(args[++i]); - }else if((arg == "-d" || arg == "--decrrave") && i+1 < args.size()){ - player.decrrave = from_str(args[++i]); - }else if((arg == "-a" || arg == "--knowledge") && i+1 < args.size()){ - player.knowledge = from_str(args[++i]); - }else if((arg == "-s" || arg == "--shortrave") && i+1 < args.size()){ - player.shortrave = from_str(args[++i]); - }else if((arg == "-k" || arg == "--keeptree") && i+1 < args.size()){ - player.keeptree = from_str(args[++i]); - }else if((arg == "-m" || arg == "--minimax") && i+1 < args.size()){ - player.minimax = from_str(args[++i]); - }else if((arg == "-P" || arg == "--symmetry") && i+1 < args.size()){ - player.prunesymmetry = from_str(args[++i]); - }else if(( arg == "--gcsolved") && i+1 < args.size()){ - player.gcsolved = from_str(args[++i]); - }else if((arg == "-r" || arg == "--userave") && i+1 < args.size()){ - player.userave = from_str(args[++i]); - }else if((arg == "-X" || arg == "--useexplore") && i+1 < args.size()){ - player.useexplore = from_str(args[++i]); - }else if((arg == "-u" || arg == "--fpurgency") && i+1 < args.size()){ - player.fpurgency = from_str(args[++i]); - }else if((arg == "-O" || arg == "--rollouts") && i+1 < args.size()){ - player.rollouts = from_str(args[++i]); - if(player.gclimit < player.rollouts*5) - player.gclimit = player.rollouts*5; - }else if((arg == "-I" || arg == "--dynwiden") && i+1 < args.size()){ - player.dynwiden = from_str(args[++i]); - player.logdynwiden = std::log(player.dynwiden); - }else if((arg == "-x" || arg == "--visitexpand") && i+1 < args.size()){ - player.visitexpand = from_str(args[++i]); - }else if((arg == "-l" || arg == "--localreply") && i+1 < args.size()){ - player.localreply = from_str(args[++i]); - }else if((arg == "-y" || arg == "--locality") && i+1 < args.size()){ - player.locality = from_str(args[++i]); - }else if((arg == "-c" || arg == "--connect") && i+1 < args.size()){ - player.connect = from_str(args[++i]); - }else if((arg == "-S" || arg == "--size") && i+1 < args.size()){ - player.size = from_str(args[++i]); - }else if((arg == "-b" || arg == "--bridge") && i+1 < args.size()){ - player.bridge = from_str(args[++i]); - }else if((arg == "-D" || arg == "--distance") && i+1 < args.size()){ - player.dists = from_str(args[++i]); - }else if((arg == "-h" || arg == "--weightrand") && i+1 < args.size()){ - player.weightedrandom = from_str(args[++i]); - }else if((arg == "-p" || arg == "--pattern") && i+1 < args.size()){ - player.rolloutpattern = from_str(args[++i]); - }else if((arg == "-g" || arg == "--goodreply") && i+1 < args.size()){ - player.lastgoodreply = from_str(args[++i]); - }else if((arg == "-w" || arg == "--instantwin") && i+1 < args.size()){ - player.instantwin = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - return GTPResponse(true, errs); -} - -GTPResponse GTP::gtp_player_gammas(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "Must pass the filename of a set of gammas"); - - ifstream ifs(args[0].c_str()); - - if(!ifs.good()) - return GTPResponse(false, "Failed to open file for reading"); - - Board board = *hist; - - for(int i = 0; i < 4096; i++){ - int a; - float f; - ifs >> a >> f; - - if(i != a){ - ifs.close(); - return GTPResponse(false, "Line " + to_str(i) + " doesn't match the expected value"); - } - - int s = board.pattern_symmetry(i); - if(s == i) - player.gammas[i] = f; - else - player.gammas[i] = player.gammas[s]; - } - - ifs.close(); - return GTPResponse(true); -} diff --git a/rex/gtpsolver.cpp b/rex/gtpsolver.cpp deleted file mode 100644 index 1df5ea1..0000000 --- a/rex/gtpsolver.cpp +++ /dev/null @@ -1,331 +0,0 @@ - - -#include "gtp.h" - -string GTP::solve_str(int outcome) const { - switch(outcome){ - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - default: return "unknown"; - } -} - -string GTP::solve_str(const Solver & solve){ - string ret = ""; - ret += solve_str(solve.outcome) + " "; - ret += solve.bestmove.to_s() + " "; - ret += to_str(solve.maxdepth) + " "; - ret += to_str(solve.nodes_seen); - return ret; -} - - -GTPResponse GTP::gtp_solve_ab(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverab.solve(time); - - logerr("Finished in " + to_str(solverab.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverab)); -} - -GTPResponse GTP::gtp_solve_ab_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the alpha-beta solver settings, eg: ab_params -m 100 -s 1 -d 3\n" - " -m --memory Memory limit in Mb (0 to disable the TT) [" + to_str(solverab.memlimit/(1024*1024)) + "]\n" - " -s --scout Whether to scout ahead for the true minimax value [" + to_str(solverab.scout) + "]\n" - " -d --depth Starting depth [" + to_str(solverab.startdepth) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - solverab.set_memlimit(mem); - }else if((arg == "-s" || arg == "--scout") && i+1 < args.size()){ - solverab.scout = from_str(args[++i]); - }else if((arg == "-d" || arg == "--depth") && i+1 < args.size()){ - solverab.startdepth = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_ab_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - int value; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - value = solverab.tt_get(board.test_hash(*move)); - - s += move->to_s() + "," + to_str(value) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_ab_clear(vecstr args){ - solverab.clear_mem(); - return true; -} - - - -GTPResponse GTP::gtp_solve_pns(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns.solve(time); - - logerr("Finished in " + to_str(solverpns.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns)); -} - -GTPResponse GTP::gtp_solve_pns_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns_stats(vecstr args){ - string s = ""; - - SolverPNS::PNSNode * node = &(solverpns.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns_clear(vecstr args){ - solverpns.clear_mem(); - return true; -} - - -GTPResponse GTP::gtp_solve_pns2(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns2.solve(time); - - logerr("Finished in " + to_str(solverpns2.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns2)); -} - -GTPResponse GTP::gtp_solve_pns2_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns2.memlimit/(1024*1024)) + "]\n" - " -t --threads How many threads to run [" + to_str(solverpns2.numthreads) + "]\n" -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns2.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns2.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns2.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns2.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - solverpns2.numthreads = from_str(args[++i]); - solverpns2.reset_threads(); - }else if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns2.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns2.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns2.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns2.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns2.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns2_stats(vecstr args){ - string s = ""; - - SolverPNS2::PNSNode * node = &(solverpns2.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS2::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS2::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "," + to_str(child->work) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns2_clear(vecstr args){ - solverpns2.clear_mem(); - return true; -} - - - - -GTPResponse GTP::gtp_solve_pnstt(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpnstt.solve(time); - - logerr("Finished in " + to_str(solverpnstt.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpnstt)); -} - -GTPResponse GTP::gtp_solve_pnstt_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pnstt solver settings, eg: pnstt_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpnstt.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpnstt.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpnstt.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpnstt.ab) + "]\n" - " -c --copy Try to copy a proof to this many siblings, <0 quit early [" + to_str(solverpnstt.copyproof) + "]\n" -// " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpnstt.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpnstt.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpnstt.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpnstt.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpnstt.ab = from_str(args[++i]); - }else if((arg == "-c" || arg == "--copy") && i+1 < args.size()){ - solverpnstt.copyproof = from_str(args[++i]); -// }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ -// solverpnstt.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pnstt_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - SolverPNSTT::PNSNode * child = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - child = solverpnstt.tt(board, *move); - - s += move->to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pnstt_clear(vecstr args){ - solverpnstt.clear_mem(); - return true; -} diff --git a/rex/lbdist.h b/rex/lbdist.h index 99ccf30..2020796 100644 --- a/rex/lbdist.h +++ b/rex/lbdist.h @@ -12,8 +12,13 @@ Decrease distance when crossing your own virtual connection? //TODO: Needs to be fixed for only one direction per player +#include "../lib/move.h" + #include "board.h" -#include "move.h" + + +namespace Morat { +namespace Rex { class LBDists { struct MoveDist { @@ -71,15 +76,16 @@ class LBDists { IntPQueue Q; const Board * board; - int & dist(int edge, int player, int i) { return dists[edge][player-1][i]; } - int & dist(int edge, int player, const Move & m) { return dist(edge, player, board->xy(m)); } - int & dist(int edge, int player, int x, int y) { return dist(edge, player, board->xy(x, y)); } + int & dist(int edge, Side player, int i) { return dists[edge][player.to_i() - 1][i]; } + int & dist(int edge, Side player, const Move & m) { return dist(edge, player, board->xy(m)); } + int & dist(int edge, Side player, int x, int y) { return dist(edge, player, board->xy(x, y)); } - void init(int x, int y, int edge, int player, int dir){ - int val = board->get(x, y); - if(val != 3 - player){ - Q.push(MoveDist(x, y, (val == 0), dir)); - dist(edge, player, x, y) = (val == 0); + void init(int x, int y, int edge, Side player, int dir){ + Side val = board->get(x, y); + if(val != ~player){ + bool empty = (val == Side::NONE); + Q.push(MoveDist(x, y, empty, dir)); + dist(edge, player, x, y) = empty; } } @@ -88,7 +94,7 @@ class LBDists { LBDists() : board(NULL) {} LBDists(const Board * b) { run(b); } - void run(const Board * b, bool crossvcs = true, int side = 0) { + void run(const Board * b, bool crossvcs = true, Side side = Side::BOTH) { board = b; for(int i = 0; i < 3; i++) @@ -96,22 +102,21 @@ class LBDists { for(int k = 0; k < board->vecsize(); k++) dists[i][j][k] = maxdist; //far far away! + if(side == Side::P1 || side == Side::BOTH) init_player(crossvcs, Side::P1); + if(side == Side::P2 || side == Side::BOTH) init_player(crossvcs, Side::P2); + } + + void init_player(bool crossvcs, Side player){ int m = board->get_size(); int m1 = m-1; - int start, end; - if(side){ start = end = side; } - else { start = 1; end = 2; } - - for(int player = start; player <= end; player++){ - for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 - for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 - for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 - } + for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 + for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 + for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 } - void flood(int edge, int player, bool crossvcs){ - int otherplayer = 3 - player; + void flood(int edge, Side player, bool crossvcs){ + Side otherplayer = ~player; MoveDist cur; while(Q.pop(cur)){ @@ -121,12 +126,12 @@ class LBDists { if(board->onboard(next.pos)){ int pos = board->xy(next.pos); - int colour = board->get(pos); + Side colour = board->get(pos); if(colour == otherplayer) continue; - if(colour == 0){ + if(colour == Side::NONE){ if(!crossvcs && //forms a vc board->get(cur.pos + neighbours[(nd - 1) % 6]) == otherplayer && board->get(cur.pos + neighbours[(nd + 1) % 6]) == otherplayer) @@ -145,12 +150,15 @@ class LBDists { } } - int get(Move pos){ return min(get(pos, 1), get(pos, 2)); } - int get(Move pos, int player){ return get(board->xy(pos), player); } - int get(int pos, int player){ + int get(Move pos){ return std::min(get(pos, Side::P1), get(pos, Side::P2)); } + int get(Move pos, Side player){ return get(board->xy(pos), player); } + int get(int pos, Side player){ int sum = 0; for(int i = 0; i < 3; i++) sum += dist(i, player, pos); return sum; } }; + +}; // namespace Rex +}; // namespace Morat diff --git a/rex/trex.cpp b/rex/main.cpp similarity index 96% rename from rex/trex.cpp rename to rex/main.cpp index 6868140..679f45c 100644 --- a/rex/trex.cpp +++ b/rex/main.cpp @@ -1,5 +1,4 @@ - #include #include @@ -7,6 +6,10 @@ #include "gtp.h" + +using namespace Morat; +using namespace Rex; + using namespace std; void die(int code, const string & str){ @@ -15,6 +18,7 @@ void die(int code, const string & str){ } int main(int argc, char **argv){ + srand(Time().in_usec()); GTP gtp; @@ -51,7 +55,7 @@ int main(int argc, char **argv){ die(255, "Unknown argument: " + arg + ", try --help"); } } - + gtp.setinfile(stdin); gtp.setoutfile(stdout); gtp.run(); diff --git a/rex/move.h b/rex/move.h deleted file mode 100644 index 84cf035..0000000 --- a/rex/move.h +++ /dev/null @@ -1,91 +0,0 @@ - -#pragma once - -#include -#include - -#include "../lib/string.h" - -enum MoveSpecial { - M_SWAP = -1, //-1 so that adding 1 makes it into a valid move - M_RESIGN = -2, - M_NONE = -3, - M_UNKNOWN = -4, -}; - -struct Move { - int8_t y, x; - - Move(MoveSpecial a = M_UNKNOWN) : y(a), x(120) { } //big x so it will always wrap to y=0 with swap - Move(int X, int Y) : y(Y), x(X) { } - - Move(const std::string & str){ - if( str == "swap" ){ y = M_SWAP; x = 120; } - else if(str == "resign" ){ y = M_RESIGN; x = 120; } - else if(str == "none" ){ y = M_NONE; x = 120; } - else if(str == "unknown"){ y = M_UNKNOWN; x = 120; } - else{ - y = tolower(str[0]) - 'a'; - x = atoi(str.c_str() + 1) - 1; - } - } - - std::string to_s() const { - if(y == M_UNKNOWN) return "unknown"; - if(y == M_NONE) return "none"; - if(y == M_SWAP) return "swap"; - if(y == M_RESIGN) return "resign"; - - return std::string() + char(y + 'a') + to_str(x + 1); - } - - bool operator< (const Move & b) const { return (y == b.y ? x < b.x : y < b.y); } - bool operator<=(const Move & b) const { return (y == b.y ? x <= b.x : y <= b.y); } - bool operator> (const Move & b) const { return (y == b.y ? x > b.x : y > b.y); } - bool operator>=(const Move & b) const { return (y == b.y ? x >= b.x : y >= b.y); } - bool operator==(const MoveSpecial & b) const { return (y == b); } - bool operator==(const Move & b) const { return (y == b.y && x == b.x); } - bool operator!=(const Move & b) const { return (y != b.y || x != b.x); } - bool operator!=(const MoveSpecial & b) const { return (y != b); } - Move operator+ (const Move & b) const { return Move(x + b.x, y + b.y); } - Move & operator+=(const Move & b) { y += b.y; x += b.x; return *this; } - Move operator- (const Move & b) const { return Move(x - b.x, y - b.y); } - Move & operator-=(const Move & b) { y -= b.y; x -= b.x; return *this; } - - int z() const { return (x - y); } - int dist(const Move & b) const { - return (abs(x - b.x) + abs(y - b.y) + abs(z() - b.z()))/2; - } -}; - -struct MoveScore : public Move { - int16_t score; - - MoveScore() : score(0) { } - MoveScore(MoveSpecial a) : Move(a), score(0) { } - MoveScore(int X, int Y, int s) : Move(X, Y), score(s) { } - MoveScore operator+ (const Move & b) const { return MoveScore(x + b.x, y + b.y, score); } -}; - -struct MoveValid : public Move { - int16_t xy; - - MoveValid() : Move(), xy(-1) { } - MoveValid(int x, int y, int XY) : Move(x,y), xy(XY) { } - MoveValid(const Move & m, int XY) : Move(m), xy(XY) { } - bool onboard() const { return xy != -1; } -}; - -struct MovePlayer : public Move { - char player; - - MovePlayer() : Move(M_UNKNOWN), player(0) { } - MovePlayer(const Move & m, char p = 0) : Move(m), player(p) { } -}; - - -struct PairMove { - Move a, b; - PairMove(Move A = M_UNKNOWN, Move B = M_UNKNOWN) : a(A), b(B) { } - PairMove(MoveSpecial A) : a(Move(A)), b(M_UNKNOWN) { } -}; diff --git a/rex/player.cpp b/rex/player.cpp deleted file mode 100644 index b517471..0000000 --- a/rex/player.cpp +++ /dev/null @@ -1,506 +0,0 @@ - -#include -#include - -#include "../lib/alarm.h" -#include "../lib/fileio.h" -#include "../lib/string.h" -#include "../lib/time.h" - -#include "board.h" -#include "player.h" - -const float Player::min_rave = 0.1; - -void Player::PlayerThread::run(){ - while(true){ - switch(player->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_Start, Thread_Running); - CAS(player->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(player->rootboard.won() >= 0 || player->root.outcome >= 0 || (player->maxruns > 0 && player->runs >= player->maxruns)){ //solved or finished runs - if(CAS(player->threadstate, Thread_Running, Thread_Wait_End) && player->root.outcome >= 0) - logerr("Solved as " + to_str((int)player->root.outcome) + "\n"); - break; - } - if(player->ctmem.memalloced() >= player->maxmem){ //out of memory, start garbage collection - CAS(player->threadstate, Thread_Running, Thread_GC); - break; - } - - INCR(player->runs); - iterate(); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(player->gcbarrier.wait()){ - Time starttime; - logerr("Starting player GC with limit " + to_str(player->gclimit) + " ... "); - uint64_t nodesbefore = player->nodes; - Board copy = player->rootboard; - player->garbage_collect(copy, & player->root); - Time gctime; - player->ctmem.compact(1.0, 0.75); - Time compacttime; - logerr(to_str(100.0*player->nodes/nodesbefore, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(player->ctmem.meminuse() >= player->maxmem/2) - player->gclimit = (int)(player->gclimit*1.3); - else if(player->gclimit > player->rollouts*5) - player->gclimit = (int)(player->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(player->threadstate, Thread_GC, Thread_Running); - CAS(player->threadstate, Thread_GC_End, Thread_Wait_End); - } - player->gcbarrier.wait(); - break; - } - } -} - -Player::Node * Player::genmove(double time, int max_runs, bool flexible){ - time_used = 0; - int toplay = rootboard.toplay(); - - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) - return NULL; - - Time starttime; - - stop_threads(); - - if(runs) - logerr("Pondered " + to_str(runs) + " runs\n"); - - runs = 0; - maxruns = max_runs; - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); - - // if the move is forced and the time can be added to the clock, don't bother running at all - if(!flexible || root.children.num() != 1){ - //let them run! - start_threads(); - - Alarm timer; - if(time > 0) - timer(time - (Time() - starttime), std::bind(&Player::timedout, this)); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } - - if(ponder && root.outcome < 0) - start_threads(); - - time_used = Time() - starttime; - -//return the best one - return return_move(& root, toplay); -} - - - -Player::Player() { - nodes = 0; - gclimit = 5; - time_used = 0; - - profile = false; - ponder = false; -//#ifdef SINGLE_THREAD ... make sure only 1 thread - numthreads = 1; - maxmem = 1000*1024*1024; - - msrave = -2; - msexplore = 0; - - explore = 0; - parentexplore = false; - ravefactor = 500; - decrrave = 0; - knowledge = true; - userave = 1; - useexplore = 1; - fpurgency = 1; - rollouts = 5; - dynwiden = 0; - logdynwiden = (dynwiden ? std::log(dynwiden) : 0); - - shortrave = false; - keeptree = true; - minimax = 2; - visitexpand = 1; - prunesymmetry = false; - gcsolved = 100000; - - localreply = 5; - locality = 5; - connect = 20; - size = 0; - bridge = 100; - dists = 0; - - weightedrandom = 0; - rolloutpattern = true; - lastgoodreply = false; - instantwin = 0; - - for(int i = 0; i < 4096; i++) - gammas[i] = 1; - - //no threads started until a board is set - threadstate = Thread_Wait_Start; -} -Player::~Player(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); -} -void Player::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); -} - -string Player::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void Player::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void Player::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void Player::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++){ - threads[i]->join(); - delete threads[i]; - } - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new PlayerUCT(this)); -} - -void Player::set_ponder(bool p){ - if(ponder != p){ - ponder = p; - stop_threads(); - - if(ponder) - start_threads(); - } -} - -void Player::set_board(const Board & board){ - stop_threads(); - - nodes -= root.dealloc(ctmem); - root = Node(); - root.exp.addwins(visitexpand+1); - - rootboard = board; - - reset_threads(); //needed since the threads aren't started before a board it set - - if(ponder) - start_threads(); -} -void Player::move(const Move & m){ - stop_threads(); - - uword nodesbefore = nodes; - - if(keeptree && root.children.num() > 0){ - Node child; - - for(Node * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr("Nodes before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - }else{ - nodes -= root.dealloc(ctmem); - root = Node(); - root.move = m; - } - assert(nodes == root.size()); - - rootboard.move(m); - - root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; - - if(ponder) - start_threads(); -} - -double Player::gamelen(){ - DepthStats len; - for(unsigned int i = 0; i < threads.size(); i++) - len += threads[i]->gamelen; - return len.avg(); -} - -vector Player::get_pv(){ - vector pv; - - Node * r, * n = & root; - char turn = rootboard.toplay(); - while(!n->children.empty()){ - r = return_move(n, turn); - if(!r) break; - pv.push_back(r->move); - turn = 3 - turn; - n = r; - } - - if(pv.size() == 0) - pv.push_back(Move(M_RESIGN)); - - return pv; -} - -Player::Node * Player::return_move(Node * node, int toplay) const { - double val, maxval = -1000000000000.0; //1 trillion - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss - }else{ //not proven - if(msrave == -1) //num simulations - val = child->exp.num(); - else if(msrave == -2) //num wins - val = child->exp.sum(); - else - val = child->value(msrave, 0, 0) - msexplore*sqrt(log(node->exp.num())/(child->exp.num() + 1)); - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - -//set bestmove, but don't touch outcome, if it's solved that will already be set, otherwise it shouldn't be set - if(ret){ - node->bestmove = ret->move; - }else if(node->bestmove == M_UNKNOWN){ - // TODO: Is this needed? -// SolverAB solver; -// solver.set_board(rootboard); -// solver.solve(0.1); -// node->bestmove = solver.bestmove; - } - - assert(node->bestmove != M_UNKNOWN); - - return ret; -} - -void Player::garbage_collect(Board & board, Node * node){ - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - for( ; child != end; child++){ - if(child->children.num() == 0) - continue; - - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved - board.set(child->move); - garbage_collect(board, child); - board.unset(child->move); - }else{ - nodes -= child->dealloc(ctmem); - } - } -} - -Player::Node * Player::find_child(Node * node, const Move & move){ - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - - return NULL; -} - -void Player::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; - } - } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); -} - -void Player::create_children_simple(const Board & board, Node * node){ - assert(node->children.empty()); - - node->children.alloc(board.movesremain(), ctmem); - - Node * child = node->children.begin(), - * end = node->children.end(); - Board::MoveIterator moveit = board.moveit(prunesymmetry); - int nummoves = 0; - for(; !moveit.done() && child != end; ++moveit, ++child){ - *child = Node(*moveit); - nummoves++; - } - - if(prunesymmetry) - node->children.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(moveit.done() && child == end); - - PLUS(nodes, node->children.num()); -} - -//reads the format from gen_hgf. -void Player::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); - - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_char(fd, ')'); - - return; -} diff --git a/rex/player.h b/rex/player.h deleted file mode 100644 index 9741a1a..0000000 --- a/rex/player.h +++ /dev/null @@ -1,304 +0,0 @@ - -#pragma once - -//A Monte-Carlo Tree Search based player - -#include -#include - -#include "../lib/compacttree.h" -#include "../lib/depthstats.h" -#include "../lib/exppair.h" -#include "../lib/log.h" -#include "../lib/thread.h" -#include "../lib/time.h" -#include "../lib/types.h" -#include "../lib/xorshift.h" - -#include "board.h" -#include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" - - -class Player { -public: - - struct Node { - public: - ExpPair rave; - ExpPair exp; - int16_t know; - int8_t outcome; - uint8_t proofdepth; - Move move; - Move bestmove; //if outcome is set, then bestmove is the way to get there - CompactTree::Children children; -// int padding; - //seems to need padding to multiples of 8 bytes or it segfaults? - //don't forget to update the copy constructor/operator - - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } - Node(const Node & n) { *this = n; } - Node & operator = (const Node & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - rave = n.rave; - exp = n.exp; - know = n.know; - move = n.move; - bestmove = n.bestmove; - outcome = n.outcome; - proofdepth = n.proofdepth; - //children = n.children; ignore the children, they need to be swap_tree'd in - } - return *this; - } - - void swap_tree(Node & n){ - children.swap(n.children); - } - - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str(outcome) + "/" + to_str(proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } - - unsigned int size() const { - unsigned int num = children.num(); - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - ~Node(){ - assert(children.empty()); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - - //new way, more standard way of changing over from rave scores to real scores - float value(float ravefactor, bool knowledge, float fpurgency){ - float val = fpurgency; - float expnum = exp.num(); - float ravenum = rave.num(); - - if(ravefactor <= min_rave){ - if(expnum > 0) - val = exp.avg(); - }else if(ravenum > 0 || expnum > 0){ - float alpha = ravefactor/(ravefactor + expnum); -// float alpha = sqrt(ravefactor/(ravefactor + 3.0f*expnum)); -// float alpha = ravenum/(expnum + ravenum + expnum*ravenum*ravefactor); - - val = 0; - if(ravenum > 0) val += alpha*rave.avg(); - if(expnum > 0) val += (1.0f-alpha)*exp.avg(); - } - - if(knowledge && know > 0){ - if(expnum <= 1) - val += 0.01f * know; - else if(expnum < 1000) //knowledge is only useful with little experience - val += 0.01f * know / sqrt(expnum); - } - - return val; - } - }; - - class PlayerThread { - protected: - public: - mutable XORShift_float unitrand; - Thread thread; - Player * player; - public: - DepthStats treelen, gamelen; - double times[4]; //time spent in each of the stages - - PlayerThread() {} - virtual ~PlayerThread() { } - virtual void reset() { } - int join(){ return thread.join(); } - void run(); //thread runner, calls iterate on each iteration - virtual void iterate() { } //handles each iteration - }; - - class PlayerUCT : public PlayerThread { - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; - - bool use_rave; //whether to use rave for this simulation - bool use_explore; //whether to use exploration for this simulation - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; - int stage; //which of the four MCTS stages is it on - Time timestamps[4]; //timestamps for the beginning, before child creation, before rollout, after rollout - - public: - PlayerUCT(Player * p) : PlayerThread() { - player = p; - reset(); - thread(bind(&PlayerUCT::run, this)); - } - - void reset(){ - treelen.reset(); - gamelen.reset(); - - use_rave = false; - use_explore = false; - - for(int a = 0; a < 4; a++) - times[a] = 0; - } - - private: - void iterate(); - void walk_tree(Board & board, Node * node, int depth); - bool create_children(Board & board, Node * node, int toplay); - void add_knowledge(Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); - bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - - int rollout(Board & board, Move move, int depth); - Move rollout_choose_move(Board & board, const Move & prev); - Move rollout_pattern(const Board & board, const Move & move); - }; - - -public: - - static const float min_rave; - - bool ponder; //think during opponents time? - int numthreads; //number of player threads to run - u64 maxmem; //maximum memory for the tree in bytes - bool profile; //count how long is spent in each stage of MCTS -//final move selection - float msrave; //rave factor in final move selection, -1 means use number instead of value - float msexplore; //the UCT constant in final move selection -//tree traversal - bool parentexplore; // whether to multiple exploration by the parents winrate - float explore; //greater than one favours exploration, smaller than one favours exploitation - float ravefactor; //big numbers favour rave scores, small ignore it - float decrrave; //decrease rave over time, add this value for each empty position on the board - bool knowledge; //whether to include knowledge - float userave; //what probability to use rave - float useexplore; //what probability to use UCT exploration - float fpurgency; //what value to return for a move that hasn't been played yet - int rollouts; //number of rollouts to run after the tree traversal - float dynwiden; //dynamic widening, look at first log_dynwiden(experience) number of children, 0 to disable - float logdynwiden; // = log(dynwiden), cached for performance -//tree building - bool shortrave; //only update rave values on short rollouts - bool keeptree; //reuse the tree from the previous move - int minimax; //solve the minimax tree within the uct tree - uint visitexpand;//number of visits before expanding a node - bool prunesymmetry; //prune symmetric children from the move list, useful for proving but likely not for playing - uint gcsolved; //garbage collect solved nodes or keep them in the tree, assuming they meet the required amount of work -//knowledge - int localreply; //boost for a local reply, ie a move near the previous move - int locality; //boost for playing near previous stones - int connect; //boost for having connections to edges and corners - int size; //boost for large groups - int bridge; //boost replying to a probe at a bridge - int dists; //boost based on minimum number of stones needed to finish a non-ring win -//rollout - int weightedrandom; //use weighted random for move ordering based on gammas - bool rolloutpattern; //play the response to a virtual connection threat in rollouts - int lastgoodreply; //use the last-good-reply rollout heuristic - int instantwin; //how deep to look for instant wins in rollouts - - float gammas[4096]; //pattern weights for weighted random - - Board rootboard; - Node root; - uword nodes; - int gclimit; //the minimum experience needed to not be garbage collected - - uint64_t runs, maxruns; - - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - double time_used; - - Player(); - ~Player(); - - void timedout(); - - string statestring(); - - void stop_threads(); - void start_threads(); - void reset_threads(); - - void set_ponder(bool p); - void set_board(const Board & board); - - void move(const Move & m); - - double gamelen(); - - Node * genmove(double time, int max_runs, bool flexible); - vector get_pv(); - void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - - bool do_backup(Node * node, Node * backup, int toplay); - - Node * find_child(Node * node, const Move & move); - void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); - -protected: - Node * return_move(Node * node, int toplay) const; -}; diff --git a/rex/playeruct.cpp b/rex/playeruct.cpp deleted file mode 100644 index 55bc5e2..0000000 --- a/rex/playeruct.cpp +++ /dev/null @@ -1,449 +0,0 @@ - -#include -#include - -#include "../lib/string.h" - -#include "player.h" - -void Player::PlayerUCT::iterate(){ - if(player->profile){ - timestamps[0] = Time(); - stage = 0; - } - - movelist.reset(&(player->rootboard)); - player->root.exp.addvloss(); - Board copy = player->rootboard; - use_rave = (unitrand() < player->userave); - use_explore = (unitrand() < player->useexplore); - walk_tree(copy, & player->root, 0); - player->root.exp.addv(movelist.getexp(3-player->rootboard.toplay())); - - if(player->profile){ - times[0] += timestamps[1] - timestamps[0]; - times[1] += timestamps[2] - timestamps[1]; - times[2] += timestamps[3] - timestamps[2]; - times[3] += Time() - timestamps[3]; - } -} - -void Player::PlayerUCT::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); - - if(!node->children.empty() && node->outcome < 0){ - //choose a child and recurse - Node * child; - do{ - int remain = board.movesremain(); - child = choose_move(node, toplay, remain); - - if(child->outcome < 0){ - movelist.addtree(child->move, toplay); - - if(!board.move(child->move)){ - logerr("move failed: " + child->move.to_s() + "\n" + board.to_s(false)); - assert(false && "move failed"); - } - - child->exp.addvloss(); //balanced out after rollouts - - walk_tree(board, child, depth+1); - - child->exp.addv(movelist.getexp(toplay)); - - if(!player->do_backup(node, child, toplay) && //not solved - player->ravefactor > min_rave && //using rave - node->children.num() > 1 && //not a macro move - 50*remain*(player->ravefactor + player->decrrave*remain) > node->exp.num()) //rave is still significant - update_rave(node, toplay); - - return; - } - }while(!player->do_backup(node, child, toplay)); - - return; - } - - if(player->profile && stage == 0){ - stage = 1; - timestamps[1] = Time(); - } - - int won = (player->minimax ? node->outcome : board.won()); - - //if it's not already decided - if(won < 0){ - //create children if valid - if(node->exp.num() >= player->visitexpand+1 && create_children(board, node, toplay)){ - walk_tree(board, node, depth); - return; - } - - if(player->profile){ - stage = 2; - timestamps[2] = Time(); - } - - //do random game on this node - random_policy.prepare(board); - for(int i = 0; i < player->rollouts; i++){ - Board copy = board; - rollout(copy, node->move, depth); - } - }else{ - movelist.finishrollout(won); //got to a terminal state, it's worth recording - } - - treelen.add(depth); - - movelist.subvlosses(1); - - if(player->profile){ - timestamps[3] = Time(); - if(stage == 1) - timestamps[2] = timestamps[3]; - stage = 3; - } - - return; -} - -bool sort_node_know(const Player::Node & a, const Player::Node & b){ - return (a.know > b.know); -} - -bool Player::PlayerUCT::create_children(Board & board, Node * node, int toplay){ - if(!node->children.lock()) - return false; - - if(player->dists){ - dists.run(&board, (player->dists > 0), toplay); - } - - CompactTree::Children temp; - temp.alloc(board.movesremain(), player->ctmem); - - int losses = 0; - - Node * child = temp.begin(), - * end = temp.end(), - * loss = NULL; - Board::MoveIterator move = board.moveit(player->prunesymmetry); - int nummoves = 0; - for(; !move.done() && child != end; ++move, ++child){ - *child = Node(*move); - - if(player->minimax){ - child->outcome = board.test_win(*move); - - if(player->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ - losses++; - loss = child; - } - - if(child->outcome == toplay){ //proven win from here, don't need children - node->outcome = child->outcome; - node->proofdepth = 1; - node->bestmove = *move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - } - - if(player->knowledge) - add_knowledge(board, node, child); - nummoves++; - } - - if(player->prunesymmetry) - temp.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(move.done() && child == end); - - //Make a macro move, add experience to the move so the current simulation continues past this move - if(losses == 1){ - Node macro = *loss; - temp.dealloc(player->ctmem); - temp.alloc(1, player->ctmem); - macro.exp.addwins(player->visitexpand); - *(temp.begin()) = macro; - }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - toplay; - node->proofdepth = 2; - node->bestmove = loss->move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - - if(player->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); - - PLUS(player->nodes, temp.num()); - node->children.swap(temp); - assert(temp.unlock()); - - return true; -} - -Player::Node * Player::PlayerUCT::choose_move(const Node * node, int toplay, int remain) const { - float val, maxval = -1000000000; - float logvisits = log(node->exp.num()); - int dynwidenlim = (player->dynwiden > 0 ? (int)(logvisits/player->logdynwiden)+2 : 361); - - float raveval = use_rave * (player->ravefactor + player->decrrave*remain); - float explore = use_explore * player->explore; - if(player->parentexplore) - explore *= node->exp.avg(); - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) //return a win immediately - return child; - - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse - }else{ - val = child->value(raveval, player->knowledge, player->fpurgency); - if(explore > 0) - val += explore*sqrt(logvisits/(child->exp.num() + 1)); - dynwidenlim--; - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - - return ret; -} - -/* -backup in this order: - -6 win -5 win/draw -4 draw if draw/loss -3 win/draw/loss -2 draw -1 draw/loss -0 lose -return true if fully solved, false if it's unknown or partially unknown -*/ -bool Player::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread - return true; - - if(backup->outcome == -3) //nothing proven by this child, so no chance - return false; - - - uint8_t proofdepth = backup->proofdepth; - if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; - backup = NULL; - - Node * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions - - if(proofdepth < child->proofdepth+1) - proofdepth = child->proofdepth+1; - - //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss - outcome = 3; - }else if(childoutcome == toplay){ //win - backup = child; - outcome = 6; - proofdepth = child->proofdepth+1; - break; - }else if(childoutcome == 3-toplay){ //loss - outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss - outcome = 4; - else - outcome = 2; - }else if(childoutcome == -toplay){ //win/draw - outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss - outcome = 1; - }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); - assert(false && "How'd I get here? All outcomes should be tested above"); - } - - sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; - bestsims = sims; - backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ - //find long losses or easy wins/draws - bestsims = sims; - backup = child; - } - } - - if(bestoutcome == 3) //no win, but found an unknown - return false; - } - - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ - node->bestmove = backup->move; - node->proofdepth = proofdepth; - }else //if it was in a race, try again, might promote a partial solve to full solve - return do_backup(node, backup, toplay); - - return (node->outcome >= 0); -} - -//update the rave score of all children that were played -void Player::PlayerUCT::update_rave(const Node * node, int toplay){ - Node * child = node->children.begin(), - * childend = node->children.end(); - - for( ; child != childend; ++child) - child->rave.addv(movelist.getrave(toplay, child->move)); -} - -void Player::PlayerUCT::add_knowledge(Board & board, Node * node, Node * child){ - if(player->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); - if(dist < 4) - child->know += player->localreply * (4 - dist); - } - - if(player->locality) //boost for moves near previous stones - child->know += player->locality * board.local(child->move, board.toplay()); - - Board::Cell cell; - if(player->connect || player->size) - cell = board.test_cell(child->move); - - if(player->connect) //boost for moves that connect to edges - child->know += player->connect * cell.numedges(); - - if(player->size) //boost for size of the group - child->know += player->size * cell.size; - - if(player->bridge && test_bridge_probe(board, node->move, child->move)) //boost for maintaining a virtual connection - child->know += player->bridge; - - if(player->dists) - child->know += abs(player->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); -} - -//test whether this move is a forced reply to the opponent probing your virtual connections -bool Player::PlayerUCT::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { - //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) - return false; - - bool equals = false; - - int state = 0; - int piece = 3 - board.get(move); - for(int i = 0; i < 8; i++){ - Move cur = move + neighbours[i % 6]; - - bool on = board.onboard(cur); - int v = 0; - if(on) - v = board.get(cur); - - //state machine that progresses when it see the pattern, but counting borders as part of the pattern - if(state == 0){ - if(!on || v == piece) - state = 1; - //else state = 0; - }else if(state == 1){ - if(on){ - if(v == 0){ - state = 2; - equals = (test == cur); - }else if(v != piece) - state = 0; - //else (v==piece) => state = 1; - } - //else state = 1; - }else{ // state == 2 - if(!on || v == piece){ - if(equals) - return true; - state = 1; - }else{ - state = 0; - } - } - } - return false; -} - -/////////////////////////////////////////// - - -//play a random game starting from a board state, and return the results of who won -int Player::PlayerUCT::rollout(Board & board, Move move, int depth){ - int won; - - if(player->instantwin) - instant_wins.rollout_start(board, player->instantwin); - - random_policy.rollout_start(board); - - while((won = board.won()) < 0){ - int turn = board.toplay(); - - move = rollout_choose_move(board, move); - - movelist.addrollout(move, turn); - - assert2(board.move(move), "\n" + board.to_s(true) + "\n" + move.to_s()); - depth++; - } - - gamelen.add(depth); - - //update the last good reply table - if(player->lastgoodreply) - last_good_reply.rollout_end(board, movelist, won); - - movelist.finishrollout(won); - return won; -} - -Move Player::PlayerUCT::rollout_choose_move(Board & board, const Move & prev){ - //look for instant wins - if(player->instantwin){ - Move move = instant_wins.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //force a bridge reply - if(player->rolloutpattern){ - Move move = protect_bridge.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //reuse the last good reply - if(player->lastgoodreply){ - Move move = last_good_reply.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - return random_policy.choose_move(board, prev); -} diff --git a/rex/policy.h b/rex/policy.h deleted file mode 100644 index 01309d8..0000000 --- a/rex/policy.h +++ /dev/null @@ -1,28 +0,0 @@ - -#pragma once - -#include "board.h" -#include "move.h" -#include "movelist.h" - -class Policy { -public: - Policy() { } - - // called before all the rollouts start - void prepare(const Board & board) { } - - // called at the beginning of each rollout. - void rollout_start(Board & board) { } - - // Give me a move to make, or M_UNKNOWN - Move choose_move(const Board & board, const Move & prev) { - return M_UNKNOWN; - } - - // A move was just made, here's the updated board - void move_end(const Board & board, const Move & prev) { } - - // Game over, here's who won - void rollout_end(const MoveList & movelist, int won) { } -}; diff --git a/rex/policy_instantwin.h b/rex/policy_instantwin.h deleted file mode 100644 index c3c1dfa..0000000 --- a/rex/policy_instantwin.h +++ /dev/null @@ -1,95 +0,0 @@ - -#pragma once - -#include "../lib/assert2.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class InstantWin : public Policy { - int max_rollout_moves; - int cur_rollout_moves; - - Move saved_loss; -public: - - InstantWin() : max_rollout_moves(10), cur_rollout_moves(0), saved_loss(M_UNKNOWN) { - } - - void rollout_start(Board & board, int max) { - if(max < 0) - max *= - board.get_size(); - max_rollout_moves = max; - - cur_rollout_moves = 0; - saved_loss = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) { - if(saved_loss != M_UNKNOWN) - return saved_loss; - - if(cur_rollout_moves++ >= max_rollout_moves) - return M_UNKNOWN; - - //must have an edge connection, or it has nothing to offer a group towards a win - const Board::Cell * c = board.cell(prev); - if(c->edge == 0) - return M_UNKNOWN; - - Move start, cur, loss = M_UNKNOWN; - int turn = 3 - board.toplay(); - - //find the first empty cell - int dir = -1; - for(int i = 0; i <= 5; i++){ - start = prev + neighbours[i]; - - if(!board.onboard(start) || board.get(start) != turn){ - dir = (i + 5) % 6; - break; - } - } - - if(dir == -1) //possible if it's in the middle of a ring - return M_UNKNOWN; - - cur = start; - -// logerr(board.to_s(true)); -// logerr(prev.to_s() + ":"); - - //follow contour of the current group looking for wins - do{ -// logerr(" " + cur.to_s()); - //check the current cell - if(board.onboard(cur) && board.get(cur) == 0 && board.test_win(cur, turn) > 0){ -// logerr(" loss"); - if(loss == M_UNKNOWN){ - loss = cur; - }else if(loss != cur){ - saved_loss = loss; - return cur; //game over, two wins found for opponent - } - } - - //advance to the next cell - for(int i = 5; i <= 9; i++){ - int nd = (dir + i) % 6; - Move next = cur + neighbours[nd]; - - if(!board.onboard(next) || board.get(next) != turn){ - cur = next; - dir = nd; - break; - } - } - }while(cur != start); //potentially skips part of it when the start is in a pocket, rare bug - -// logerr("\n"); - - return loss; // usually M_UNKNOWN - } -}; diff --git a/rex/solver.h b/rex/solver.h deleted file mode 100644 index d6e6240..0000000 --- a/rex/solver.h +++ /dev/null @@ -1,68 +0,0 @@ - -#pragma once - -//Interface for the various solvers - -#include "../lib/types.h" - -#include "board.h" - -class Solver { -public: - int outcome; // 0 = tie, 1 = white, 2 = black, -1 = white or tie, -2 = black or tie, anything else unknown - int maxdepth; - uint64_t nodes_seen; - double time_used; - Move bestmove; - - Solver() : outcome(-3), maxdepth(0), nodes_seen(0), time_used(0) { } - virtual ~Solver() { } - - virtual void solve(double time) { } - virtual void set_board(const Board & board, bool clear = true) { } - virtual void move(const Move & m) { } - virtual void set_memlimit(uint64_t lim) { } // in bytes - virtual void clear_mem() { } - -protected: - volatile bool timeout; - void timedout(){ timeout = true; } - Board rootboard; - - static int solve1ply(const Board & board, int & nodes) { - int outcome = -3; - int turn = board.toplay(); - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - } - return outcome; - } - - static int solve2ply(const Board & board, int & nodes) { - int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - - if(board.test_win(*move, opponent) > 0) - losses++; - } - if(losses >= 2) - return opponent; - return outcome; - } - -}; diff --git a/rex/solverab.cpp b/rex/solverab.cpp deleted file mode 100644 index 1abdf47..0000000 --- a/rex/solverab.cpp +++ /dev/null @@ -1,137 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverab.h" - -void SolverAB::solve(double time){ - reset(); - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - if(TT == NULL && maxnodes) - TT = new ABTTNode[maxnodes]; - - Alarm timer(time, std::bind(&SolverAB::timedout, this)); - Time start; - - int turn = rootboard.toplay(); - - for(maxdepth = startdepth; !timeout; maxdepth++){ -// logerr("Starting depth " + to_str(maxdepth) + "\n"); - - //the first depth of negamax - int ret, alpha = -2, beta = 2; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - nodes_seen++; - - Board next = rootboard; - next.move(*move); - - int value = -negamax(next, maxdepth - 1, -beta, -alpha); - - if(value > alpha){ - alpha = value; - bestmove = *move; - } - - if(alpha >= beta){ - ret = beta; - break; - } - } - ret = alpha; - - - if(ret){ - if( ret == -2){ outcome = (turn == 1 ? 2 : 1); bestmove = Move(M_NONE); } - else if(ret == 2){ outcome = turn; } - else /*-1 || 1*/ { outcome = 0; } - - break; - } - } - - time_used = Time() - start; -} - - -int SolverAB::negamax(const Board & board, const int depth, int alpha, int beta){ - if(board.won() >= 0) - return (board.won() ? -2 : -1); - - if(depth <= 0 || timeout) - return 0; - - int b = beta; - int first = true; - int value, losses = 0; - static const int lookup[6] = {0, 0, 0, 1, 2, 2}; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - nodes_seen++; - - hash_t hash = board.test_hash(*move); - if(int ttval = tt_get(hash)){ - value = ttval; - }else if(depth <= 2){ - value = lookup[board.test_win(*move)+3]; - - if(board.test_win(*move, 3 - board.toplay()) > 0) - losses++; - }else{ - Board next = board; - next.move(*move); - - value = -negamax(next, depth - 1, -b, -alpha); - - if(scout && value > alpha && value < beta && !first) // re-search - value = -negamax(next, depth - 1, -beta, -alpha); - } - tt_set(hash, value); - - if(value > alpha) - alpha = value; - - if(alpha >= beta) - return beta; - - if(scout){ - b = alpha + 1; // set up null window - first = false; - } - } - - if(losses >= 2) - return -2; - - return alpha; -} - -int SolverAB::negamax_outcome(const Board & board, const int depth){ - int abval = negamax(board, depth, -2, 2); - if( abval == 0) return -3; //unknown - else if(abval == 2) return board.toplay(); //win - else if(abval == -2) return 3 - board.toplay(); //loss - else return 0; //draw -} - -int SolverAB::tt_get(const Board & board){ - return tt_get(board.gethash()); -} -int SolverAB::tt_get(const hash_t & hash){ - if(!TT) return 0; - ABTTNode * node = & TT[hash % maxnodes]; - return (node->hash == hash ? node->value : 0); -} -void SolverAB::tt_set(const Board & board, int value){ - tt_set(board.gethash(), value); -} -void SolverAB::tt_set(const hash_t & hash, int value){ - if(!TT || value == 0) return; - ABTTNode * node = & TT[hash % maxnodes]; - node->hash = hash; - node->value = value; -} diff --git a/rex/solverab.h b/rex/solverab.h deleted file mode 100644 index 35ca7b9..0000000 --- a/rex/solverab.h +++ /dev/null @@ -1,72 +0,0 @@ - -#pragma once - -//An Alpha-beta solver, single threaded with an optional transposition table. - -#include "solver.h" - -class SolverAB : public Solver { - struct ABTTNode { - hash_t hash; - char value; - ABTTNode(hash_t h = 0, char v = 0) : hash(h), value(v) { } - }; - -public: - bool scout; - int startdepth; - - ABTTNode * TT; - uint64_t maxnodes, memlimit; - - SolverAB(bool Scout = false) { - scout = Scout; - startdepth = 2; - TT = NULL; - set_memlimit(100*1024*1024); - } - ~SolverAB() { } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(ABTTNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void solve(double time); - -//return -2 for loss, -1,1 for tie, 0 for unknown, 2 for win, all from toplay's perspective - int negamax(const Board & board, const int depth, int alpha, int beta); - int negamax_outcome(const Board & board, const int depth); - - int tt_get(const hash_t & hash); - int tt_get(const Board & board); - void tt_set(const hash_t & hash, int val); - void tt_set(const Board & board, int val); -}; - diff --git a/rex/solverpns.cpp b/rex/solverpns.cpp deleted file mode 100644 index 7f11a1a..0000000 --- a/rex/solverpns.cpp +++ /dev/null @@ -1,213 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns.h" - -void SolverPNS::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(memlimit/sizeof(PNSNode)) + ", max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS::run_pns(){ - while(!timeout && root.phi != 0 && root.delta != 0){ - if(!pns(rootboard, &root, 0, INF32/2, INF32/2)){ - logerr("Starting solver GC with limit " + to_str(gclimit) + " ... "); - - Time starttime; - garbage_collect(& root); - - Time gctime; - ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*ctmem.meminuse()/memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(ctmem.meminuse() >= memlimit/2) - gclimit = (unsigned int)(gclimit*1.3); - else if(gclimit > 5) - gclimit = (unsigned int)(gclimit*0.9); //slowly decay to a minimum of 5 - } - } -} - -bool SolverPNS::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(maxdepth < depth) - maxdepth = depth; - - if(node->children.empty()){ - if(ctmem.memalloced() >= memlimit) - return false; - - int numnodes = board.movesremain(); - nodes += node->alloc(numnodes, ctmem); - - if(lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(lbdist && outcome < 0) - pd = dists.get(*move); - - node->children[i] = PNSNode(*move).outcome(outcome, board.toplay(), ties, pd); - - i++; - } - node->children.shrink(i); //if symmetry, there may be extra moves to ignore - - nodes_seen += i; - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->delta <= child->delta){ - child2 = child; - child = i; - }else if(i->delta < child2->delta){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - while(child->delta != node->phi) - child++; - } - - Board next = board; - next.move(child->move); - - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->work += iters - itersbefore; - - if(child->phi == 0 || child->delta == 0) //clear child's children - nodes -= child->dealloc(ctmem); - - if(updatePDnum(node) && !df) - break; - - }while(!timeout && mem && (!df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - nodes -= child->dealloc(ctmem); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - nodes -= child->dealloc(ctmem); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/rex/solverpns.h b/rex/solverpns.h deleted file mode 100644 index b040d82..0000000 --- a/rex/solverpns.h +++ /dev/null @@ -1,206 +0,0 @@ - -#pragma once - -//A single-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - uint32_t phi, delta; - uint64_t work; - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - uint64_t iters; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - - PNSNode root; - LBDists dists; - - SolverPNS() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - gclimit = 5; - iters = 0; - - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNS(){ - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - -//update the phi and delta for the node - bool updatePDnum(PNSNode * node); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/rex/solverpns2.cpp b/rex/solverpns2.cpp deleted file mode 100644 index 4995fc5..0000000 --- a/rex/solverpns2.cpp +++ /dev/null @@ -1,323 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns2.h" - -void SolverPNS2::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - start_threads(); - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS2::timedout, this)); - Time start; - -// logerr("max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS2::SolverThread::run(){ - while(true){ - switch(solver->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_Start, Thread_Running); - CAS(solver->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(solver->root.terminal()){ //solved - CAS(solver->threadstate, Thread_Running, Thread_Wait_End); - break; - } - if(solver->ctmem.memalloced() >= solver->memlimit){ //out of memory, start garbage collection - CAS(solver->threadstate, Thread_Running, Thread_GC); - break; - } - - pns(solver->rootboard, &solver->root, 0, INF32/2, INF32/2); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(solver->gcbarrier.wait()){ - logerr("Starting solver GC with limit " + to_str(solver->gclimit) + " ... "); - - Time starttime; - solver->garbage_collect(& solver->root); - - Time gctime; - solver->ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*solver->ctmem.meminuse()/solver->memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(solver->ctmem.meminuse() >= solver->memlimit/2) - solver->gclimit = (unsigned int)(solver->gclimit*1.3); - else if(solver->gclimit > 5) - solver->gclimit = (unsigned int)(solver->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(solver->threadstate, Thread_GC, Thread_Running); - CAS(solver->threadstate, Thread_GC_End, Thread_Wait_End); - } - solver->gcbarrier.wait(); - break; - } - } -} - -void SolverPNS2::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); - timeout = true; -} - -string SolverPNS2::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void SolverPNS2::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void SolverPNS2::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void SolverPNS2::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new SolverThread(this)); -} - - -bool SolverPNS2::SolverThread::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(solver->maxdepth < depth) - solver->maxdepth = depth; - - if(node->children.empty()){ - if(node->terminal()) - return true; - - if(solver->ctmem.memalloced() >= solver->memlimit) - return false; - - if(!node->children.lock()) - return false; - - int numnodes = board.movesremain(); - CompactTree::Children temp; - temp.alloc(numnodes, solver->ctmem); - PLUS(solver->nodes, numnodes); - - if(solver->lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(solver->ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (solver->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - PLUS(solver->nodes_seen, pd); - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(solver->lbdist && outcome < 0) - pd = dists.get(*move); - - temp[i] = PNSNode(*move).outcome(outcome, board.toplay(), solver->ties, pd); - - i++; - } - temp.shrink(i); //if symmetry, there may be extra moves to ignore - node->children.swap(temp); - assert(temp.unlock()); - - PLUS(solver->nodes_seen, i); - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(solver->df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->refdelta() <= child->refdelta()){ - child2 = child; - child = i; - }else if(i->refdelta() < child2->refdelta()){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + solver->epsilon) + 1)); - }else{ - tpc = tdc = 0; - for(PNSNode * i = node->children.begin(); i != childend; i++) - if(child->refdelta() > i->refdelta()) - child = i; - } - - Board next = board; - next.move(child->move); - - child->ref(); - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->deref(); - PLUS(child->work, iters - itersbefore); - - if(updatePDnum(node) && !solver->df) - break; - - }while(!solver->timeout && mem && (!solver->df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS2::SolverThread::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS2::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/rex/solverpns2.h b/rex/solverpns2.h deleted file mode 100644 index 5af5d1d..0000000 --- a/rex/solverpns2.h +++ /dev/null @@ -1,265 +0,0 @@ - -#pragma once - -//A multi-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS2 : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - static const uint16_t reflock = 1<<15; - uint32_t phi, delta; - uint64_t work; - uint16_t refcount; //how many threads are down this node - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - uint32_t refdelta() const { - return delta + refcount; - } - - void ref() { PLUS(refcount, 1); } - void deref(){ PLUS(refcount, -1); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - class SolverThread { - protected: - public: - Thread thread; - SolverPNS2 * solver; - public: - uint64_t iters; - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - - SolverThread(SolverPNS2 * s) : solver(s), iters(0) { - thread(bind(&SolverThread::run, this)); - } - virtual ~SolverThread() { } - void reset(){ - iters = 0; - } - int join(){ return thread.join(); } - void run(); //thread runner - - //basic proof number search building a tree - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - //update the phi and delta for the node - bool updatePDnum(PNSNode * node); - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - int numthreads; - - PNSNode root; - LBDists dists; - - SolverPNS2() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - numthreads = 1; - gclimit = 5; - - reset(); - - set_memlimit(100*1024*1024); - - //no threads started until a board is set - threadstate = Thread_Wait_Start; - } - - ~SolverPNS2(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - string statestring(); - void stop_threads(); - void start_threads(); - void reset_threads(); - void timedout(); - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - - reset_threads(); //needed since the threads aren't started before a board it set - } - void move(const Move & m){ - stop_threads(); - - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/rex/solverpns_tt.cpp b/rex/solverpns_tt.cpp deleted file mode 100644 index 0818e8c..0000000 --- a/rex/solverpns_tt.cpp +++ /dev/null @@ -1,282 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns_tt.h" - -void SolverPNSTT::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNSTT::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(maxnodes) + ", max memory: " + to_str(memlimit) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == 0){ - bestmove = *move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == DRAW){ - bestmove = *move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNSTT::run_pns(){ - if(TT == NULL) - TT = new PNSNode[maxnodes]; - - while(!timeout && root.phi != 0 && root.delta != 0) - pns(rootboard, &root, 0, INF32/2, INF32/2); -} - -void SolverPNSTT::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - if(depth > maxdepth) - maxdepth = depth; - - do{ - PNSNode * child = NULL, - * child2 = NULL; - - Move move1, move2; - - uint32_t tpc, tdc; - - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - if(child == NULL){ - child = child2 = i; - move1 = move2 = *move; - }else if(i->delta <= child->delta){ - child2 = child; - child = i; - move2 = move1; - move1 = *move; - }else if(i->delta < child2->delta){ - child2 = i; - move2 = *move; - } - } - - if(child->delta && child->phi){ //unsolved - if(df){ - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - } - - Board next = board; - next.move(move1); - pns(next, child, depth + 1, tpc, tdc); - - //just found a loss, try to copy proof to siblings - if(copyproof && child->delta == LOSS){ -// logerr("!" + move1.to_s() + " "); - int count = abs(copyproof); - for(Board::MoveIterator move = board.moveit(true); count-- && !move.done(); ++move){ - if(!tt(board, *move)->terminal()){ -// logerr("?" + move->to_s() + " "); - Board sibling = board; - sibling.move(*move); - copy_proof(next, sibling, move1, *move); - updatePDnum(sibling); - - if(copyproof < 0 && !tt(sibling)->terminal()) - break; - } - } - } - } - - if(updatePDnum(board, node) && !df) //must pass node to updatePDnum since it may refer to the root which isn't in the TT - break; - - }while(!timeout && node->phi && node->delta && (!df || (node->phi < tp && node->delta < td))); -} - -bool SolverPNSTT::updatePDnum(const Board & board, PNSNode * node){ - hash_t hash = board.gethash(); - - if(node == NULL) - node = TT + (hash % maxnodes); - - uint32_t min = LOSS; - uint64_t sum = 0; - - bool win = false; - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(hash == node->hash && min == node->phi && sum == node->delta){ - return false; - }else{ - node->hash = hash; //just in case it was overwritten by something else - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//source is a move that is a proven loss, and dest is an unproven sibling -//each has one move that the other doesn't, which are stored in smove and dmove -//if either move is used but only available in one board, the other is substituted -void SolverPNSTT::copy_proof(const Board & source, const Board & dest, Move smove, Move dmove){ - if(timeout || tt(source)->delta != LOSS || tt(dest)->terminal()) - return; - - //find winning move from the source tree - Move bestmove = M_UNKNOWN; - for(Board::MoveIterator move = source.moveit(true); !move.done(); ++move){ - if(tt(source, *move)->phi == LOSS){ - bestmove = *move; - break; - } - } - - if(bestmove == M_UNKNOWN) //due to transposition table collision - return; - - Board dest2 = dest; - - if(bestmove == dmove){ - assert(dest2.move(smove)); - smove = dmove = M_UNKNOWN; - }else{ - assert(dest2.move(bestmove)); - if(bestmove == smove) - smove = dmove = M_UNKNOWN; - } - - if(tt(dest2)->terminal()) - return; - - Board source2 = source; - assert(source2.move(bestmove)); - - if(source2.won() >= 0) - return; - - //test all responses - for(Board::MoveIterator move = dest2.moveit(true); !move.done(); ++move){ - if(tt(dest2, *move)->terminal()) - continue; - - Move csmove = smove, cdmove = dmove; - - Board source3 = source2, dest3 = dest2; - - if(*move == csmove){ - assert(source3.move(cdmove)); - csmove = cdmove = M_UNKNOWN; - }else{ - assert(source3.move(*move)); - if(*move == csmove) - csmove = cdmove = M_UNKNOWN; - } - - assert(dest3.move(*move)); - - copy_proof(source3, dest3, csmove, cdmove); - - updatePDnum(dest3); - } - - updatePDnum(dest2); -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board){ - hash_t hash = board.gethash(); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - pd = 0; - outcome = (ab == 1 ? solve1ply(board, pd) : solve2ply(board, pd)); - nodes_seen += pd; - }else{ - outcome = board.won(); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board, Move move){ - hash_t hash = board.test_hash(move, board.toplay()); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(move); - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(move); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} diff --git a/rex/solverpns_tt.h b/rex/solverpns_tt.h deleted file mode 100644 index 95d344e..0000000 --- a/rex/solverpns_tt.h +++ /dev/null @@ -1,129 +0,0 @@ - -#pragma once - -//A single-threaded, transposition table based, proof number search solver. - -#include "../lib/zobrist.h" - -#include "solver.h" - -class SolverPNSTT : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - hash_t hash; - uint32_t phi, delta; - - PNSNode() : hash(0), phi(0), delta(0) { } - PNSNode(hash_t h, int v = 1) : hash(h), phi(v), delta(v) { } - PNSNode(hash_t h, int p, int d) : hash(h), phi(p), delta(d) { } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } //unknown - else if(outcome == 2) { phi = LOSS; delta = 0; } //win - else if(outcome == -2) { phi = 0; delta = LOSS; } //loss - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } //draw - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - }; - - PNSNode root; - PNSNode * TT; - uint64_t maxnodes, memlimit; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - int copyproof; //how many siblings to try to copy a proof to - - - SolverPNSTT() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - copyproof = 0; - - TT = NULL; - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNSTT(){ - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - - root = PNSNode(rootboard.gethash(), 1); - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(PNSNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - void pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - void copy_proof(const Board & source, const Board & dest, Move smove, Move dmove); - -//update the phi and delta for the node - bool updatePDnum(const Board & board, PNSNode * node = NULL); - - PNSNode * tt(const Board & board); - PNSNode * tt(const Board & board, Move move); -}; diff --git a/y/agent.h b/y/agent.h index 6adecd2..f050dc0 100644 --- a/y/agent.h +++ b/y/agent.h @@ -3,11 +3,19 @@ //Interface for the various agents: players and solvers +#include "../lib/outcome.h" +#include "../lib/sgf.h" #include "../lib/types.h" #include "board.h" + +namespace Morat { +namespace Y { + class Agent { +protected: + typedef std::vector vecmove; public: Agent() { } virtual ~Agent() { } @@ -19,51 +27,57 @@ class Agent { virtual void set_memlimit(uint64_t lim) = 0; // in bytes virtual void clear_mem() = 0; - virtual vector get_pv() const = 0; - string move_stats() const { return move_stats(vector()); } - virtual string move_stats(const vector moves) const = 0; + virtual vecmove get_pv() const = 0; + std::string move_stats() const { return move_stats(vecmove()); } + virtual std::string move_stats(const vecmove moves) const = 0; virtual double gamelen() const = 0; virtual void timedout(){ timeout = true; } + virtual void gen_sgf(SGFPrinter & sgf, int limit) const = 0; + virtual void load_sgf(SGFParser & sgf) = 0; + protected: volatile bool timeout; Board rootboard; - static int solve1ply(const Board & board, unsigned int & nodes) { - int outcome = -3; - int turn = board.toplay(); + static Outcome solve1ply(const Board & board, unsigned int & nodes) { + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; } return outcome; } - static int solve2ply(const Board & board, unsigned int & nodes) { + static Outcome solve2ply(const Board & board, unsigned int & nodes) { int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; + Outcome outcome = Outcome::UNKNOWN; + Side turn = board.toplay(); + Side op = ~turn; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ ++nodes; - int won = board.test_win(*move, turn); + Outcome won = board.test_outcome(*move, turn); - if(won == turn) + if(won == +turn) return won; - if(won == 0) - outcome = 0; + if(won == Outcome::DRAW) + outcome = Outcome::DRAW; - if(board.test_win(*move, opponent) > 0) + if(board.test_outcome(*move, op) == +op) losses++; } if(losses >= 2) - return opponent; + return (Outcome)op; return outcome; } - }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentab.cpp b/y/agentab.cpp index 2c66bce..1965915 100644 --- a/y/agentab.cpp +++ b/y/agentab.cpp @@ -6,6 +6,10 @@ #include "agentab.h" + +namespace Morat { +namespace Y { + void AgentAB::search(double time, uint64_t maxiters, int verbose) { reset(); if(rootboard.won() >= 0) @@ -41,8 +45,8 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { if(verbose){ logerr("Finished: " + to_str(nodes_seen) + " nodes in " + to_str(time_used*1000, 0) + " msec: " + to_str((uint64_t)((double)nodes_seen/time_used)) + " Nodes/s\n"); - vector pv = get_pv(); - string pvstr; + vecmove pv = get_pv(); + std::string pvstr; for(auto m : pv) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); @@ -56,11 +60,11 @@ void AgentAB::search(double time, uint64_t maxiters, int verbose) { int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int depth) { nodes_seen++; - int won = board.won(); - if(won >= 0){ - if(won == 0) + Outcome won = board.won(); + if(won >= Outcome::DRAW){ + if(won == Outcome::DRAW) return SCORE_DRAW; - if(won == board.toplay()) + if(won == +board.toplay()) return SCORE_WIN; return SCORE_LOSS; } @@ -81,8 +85,8 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d if(TT && (node = tt_get(board)) && node->depth >= depth){ switch(node->flag){ case VALID: return node->score; - case LBOUND: alpha = max(alpha, node->score); break; - case UBOUND: beta = min(beta, node->score); break; + case LBOUND: alpha = std::max(alpha, node->score); break; + case UBOUND: beta = std::min(beta, node->score); break; default: assert(false && "Unknown flag!"); } if(alpha >= beta) @@ -125,11 +129,11 @@ int16_t AgentAB::negamax(const Board & board, int16_t alpha, int16_t beta, int d return score; } -string AgentAB::move_stats(vector moves) const { - string s = ""; +std::string AgentAB::move_stats(vecmove moves) const { + std::string s = ""; Board b = rootboard; - for(vector::iterator m = moves.begin(); m != moves.end(); ++m) + for(vecmove::iterator m = moves.begin(); m != moves.end(); ++m) b.move(*m); for(MoveIterator move(b); !move.done(); ++move){ @@ -162,8 +166,8 @@ Move AgentAB::return_move(const Board & board, int verbose) const { return best; } -vector AgentAB::get_pv() const { - vector pv; +std::vector AgentAB::get_pv() const { + vecmove pv; Board b = rootboard; int i = 20; @@ -197,3 +201,6 @@ AgentAB::Node * AgentAB::tt_get(uint64_t h) const { void AgentAB::tt_set(const Node & n) { *(tt(n.hash)) = n; } + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentab.h b/y/agentab.h index 646043f..ee6d57a 100644 --- a/y/agentab.h +++ b/y/agentab.h @@ -7,6 +7,10 @@ #include "agent.h" + +namespace Morat { +namespace Y { + class AgentAB : public Agent { static const int16_t SCORE_WIN = 32767; static const int16_t SCORE_LOSS = -32767; @@ -30,7 +34,7 @@ class AgentAB : public Agent { Node(uint64_t h = ~0ull, int16_t s = 0, Move b = M_UNKNOWN, int8_t d = 0, int8_t f = 0) : //. int8_t o = -3 hash(h), score(s), bestmove(b), depth(d), flag(f), padding(0xDEAD) { } //, outcome(o) - string to_s() const { + std::string to_s() const { return "score " + to_str(score) + ", depth " + to_str((int)depth) + ", flag " + to_str((int)flag) + @@ -93,8 +97,16 @@ class AgentAB : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(rootboard, verbose); } double gamelen() const { return rootboard.movesremain(); } - vector get_pv() const; - string move_stats(vector moves) const; + vecmove get_pv() const; + std::string move_stats(vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + log("gen_sgf not supported in the ab agent."); + } + + void load_sgf(SGFParser & sgf) { + log("load_sgf not supported in the ab agent."); + } private: int16_t negamax(const Board & board, int16_t alpha, int16_t beta, int depth); @@ -105,3 +117,6 @@ class AgentAB : public Agent { Node * tt_get(const Board & b) const ; void tt_set(const Node & n) ; }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentmcts.cpp b/y/agentmcts.cpp index 6f4822c..0e38ca4 100644 --- a/y/agentmcts.cpp +++ b/y/agentmcts.cpp @@ -10,12 +10,45 @@ #include "agentmcts.h" #include "board.h" + +namespace Morat { +namespace Y { + const float AgentMCTS::min_rave = 0.1; +std::string AgentMCTS::Node::to_s() const { + return "AgentMCTS::Node" + ", move " + move.to_s() + + ", exp " + exp.to_s() + + ", rave " + rave.to_s() + + ", know " + to_str(know) + + ", outcome " + to_str((int)outcome.to_i()) + + ", depth " + to_str((int)proofdepth) + + ", best " + bestmove.to_s() + + ", children " + to_str(children.num()); +} + +bool AgentMCTS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 9){ + move = Move(dict["move"]); + exp = ExpPair(dict["exp"]); + rave = ExpPair(dict["rave"]); + know = from_str(dict["know"]); + outcome = Outcome(from_str(dict["outcome"])); + proofdepth = from_str(dict["depth"]); + bestmove = Move(dict["best"]); + // ignore children + return true; + } + return false; +} + void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) + if(rootboard.won() >= Outcome::DRAW || (time <= 0 && max_runs == 0)) return; Time starttime; @@ -56,30 +89,23 @@ void AgentMCTS::search(double time, uint64_t max_runs, int verbose){ logerr("Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"); } - if(root.outcome != -3){ - logerr("Solved as a "); - if( root.outcome == 0) logerr("draw\n"); - else if(root.outcome == 3) logerr("draw by simultaneous win\n"); - else if(root.outcome == toplay) logerr("win\n"); - else if(root.outcome == 3-toplay) logerr("loss\n"); - else if(root.outcome == -toplay) logerr("win or draw\n"); - else if(root.outcome == toplay-3) logerr("loss or draw\n"); - } + if(root.outcome != Outcome::UNKNOWN) + logerr("Solved as a " + root.outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } pool.reset(); runs = 0; - if(ponder && root.outcome < 0) + if(ponder && root.outcome < Outcome::DRAW) pool.resume(); } @@ -194,8 +220,8 @@ void AgentMCTS::move(const Move & m){ rootboard.move(m); root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; + if(rootboard.won() < Outcome::DRAW) + root.outcome = Outcome::UNKNOWN; if(ponder) pool.resume(); @@ -208,16 +234,16 @@ double AgentMCTS::gamelen() const { return len.avg(); } -vector AgentMCTS::get_pv() const { - vector pv; +std::vector AgentMCTS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -226,8 +252,8 @@ vector AgentMCTS::get_pv() const { return pv; } -string AgentMCTS::move_stats(vector moves) const { - string s = ""; +std::string AgentMCTS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -248,8 +274,8 @@ string AgentMCTS::move_stats(vector moves) const { return s; } -Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { - if(node->outcome >= 0) +Move AgentMCTS::return_move(const Node * node, Side toplay, int verbose) const { + if(node->outcome >= Outcome::DRAW) return node->bestmove; double val, maxval = -1000000000000.0; //1 trillion @@ -259,10 +285,10 @@ Move AgentMCTS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss + if(child->outcome >= Outcome::DRAW){ + if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win + else if(child->outcome == Outcome::DRAW) val = -400000000000.0 + child->exp.num(); //longest tie + else val = -800000000000.0 + child->exp.num(); //longest loss }else{ //not proven if(msrave == -1) //num simulations val = child->exp.num(); @@ -290,13 +316,13 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ Node * child = node->children.begin(), * end = node->children.end(); - int toplay = board.toplay(); + Side toplay = board.toplay(); for( ; child != end; child++){ if(child->children.num() == 0) continue; - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved + if( (node->outcome >= Outcome::DRAW && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == Outcome::DRAW)) || //parent is solved, only keep the proof tree, plus heavy draws + (node->outcome < Outcome::DRAW && child->exp.num() > (child->outcome >= Outcome::DRAW ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved board.set(child->move); garbage_collect(board, child); board.unset(child->move); @@ -307,36 +333,22 @@ void AgentMCTS::garbage_collect(Board & board, Node * node){ } AgentMCTS::Node * AgentMCTS::find_child(const Node * node, const Move & move) const { - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - + for(auto & c : node->children) + if(c.move == move) + return &c; return NULL; } -void AgentMCTS::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; +void AgentMCTS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.exp.num() >= limit && (side != node.outcome || child.outcome == node.outcome)){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); } } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); } void AgentMCTS::create_children_simple(const Board & board, Node * node){ @@ -361,64 +373,25 @@ void AgentMCTS::create_children_simple(const Board & board, Node * node){ PLUS(nodes, node->children.num()); } -//reads the format from gen_hgf. -void AgentMCTS::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); +void AgentMCTS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, & node); - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); } } - - eat_char(fd, ')'); - - return; } + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentmcts.h b/y/agentmcts.h index 2da03fc..885259c 100644 --- a/y/agentmcts.h +++ b/y/agentmcts.h @@ -11,6 +11,12 @@ #include "../lib/depthstats.h" #include "../lib/exppair.h" #include "../lib/log.h" +#include "../lib/move.h" +#include "../lib/movelist.h" +#include "../lib/policy_bridge.h" +#include "../lib/policy_instantwin.h" +#include "../lib/policy_lastgoodreply.h" +#include "../lib/policy_random.h" #include "../lib/thread.h" #include "../lib/time.h" #include "../lib/types.h" @@ -19,14 +25,11 @@ #include "agent.h" #include "board.h" #include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" +namespace Morat { +namespace Y { + class AgentMCTS : public Agent{ public: @@ -35,7 +38,7 @@ class AgentMCTS : public Agent{ ExpPair rave; ExpPair exp; int16_t know; - int8_t outcome; + Outcome outcome; uint8_t proofdepth; Move move; Move bestmove; //if outcome is set, then bestmove is the way to get there @@ -44,8 +47,8 @@ class AgentMCTS : public Agent{ //seems to need padding to multiples of 8 bytes or it segfaults? //don't forget to update the copy constructor/operator - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } + Node() : know(0), outcome(Outcome::UNKNOWN), proofdepth(0), move(M_NONE) { } + Node(const Move & m, Outcome o = Outcome::UNKNOWN) : know(0), outcome(o), proofdepth(0), move(m) { } Node(const Node & n) { *this = n; } Node & operator = (const Node & n){ if(this != & n){ //don't copy to self @@ -68,18 +71,8 @@ class AgentMCTS : public Agent{ children.swap(n.children); } - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); unsigned int size() const { unsigned int num = children.num(); @@ -142,16 +135,16 @@ class AgentMCTS : public Agent{ class AgentThread : public AgentThreadBase { mutable XORShift_float unitrand; - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; + LastGoodReply last_good_reply; + RandomPolicy random_policy; + ProtectBridge protect_bridge; + InstantWin instant_wins; bool use_rave; //whether to use rave for this simulation bool use_explore; //whether to use exploration for this simulation LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; + MoveList movelist; int stage; //which of the four MCTS stages is it on public: @@ -179,11 +172,11 @@ class AgentMCTS : public Agent{ void walk_tree(Board & board, Node * node, int depth); bool create_children(const Board & board, Node * node); void add_knowledge(const Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); + Node * choose_move(const Node * node, Side toplay, int remain) const; + void update_rave(const Node * node, Side toplay); bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - int rollout(Board & board, Move move, int depth); + Outcome rollout(Board & board, Move move, int depth); Move rollout_choose_move(Board & board, const Move & prev); Move rollout_pattern(const Board & board, const Move & move); }; @@ -261,12 +254,12 @@ class AgentMCTS : public Agent{ Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; bool done() { //solved or finished runs - return (rootboard.won() >= 0 || root.outcome >= 0 || (maxruns > 0 && runs >= maxruns)); + return (rootboard.won() >= Outcome::DRAW || root.outcome >= Outcome::DRAW || (maxruns > 0 && runs >= maxruns)); } bool need_gc() { @@ -292,16 +285,28 @@ class AgentMCTS : public Agent{ gclimit = (int)(gclimit*0.9); //slowly decay to a minimum of 5 } + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0) + limit = root.exp.num()/1000; + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } protected: void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - bool do_backup(Node * node, Node * backup, int toplay); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + bool do_backup(Node * node, Node * backup, Side toplay); + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const ; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentmcts_test.cpp b/y/agentmcts_test.cpp new file mode 100644 index 0000000..ddff131 --- /dev/null +++ b/y/agentmcts_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentmcts.h" + + +using namespace Morat; +using namespace Y; + +TEST_CASE("Y::AgentMCTS::Node::to_s/from_s", "[y][agentmcts]") { + AgentMCTS::Node n(Move("a1")); + auto s = n.to_s(); + AgentMCTS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/y/agentmctsthread.cpp b/y/agentmctsthread.cpp index ca972bc..542206e 100644 --- a/y/agentmctsthread.cpp +++ b/y/agentmctsthread.cpp @@ -6,6 +6,10 @@ #include "agentmcts.h" + +namespace Morat { +namespace Y { + void AgentMCTS::AgentThread::iterate(){ INCR(agent->runs); if(agent->profile){ @@ -19,7 +23,7 @@ void AgentMCTS::AgentThread::iterate(){ use_rave = (unitrand() < agent->userave); use_explore = (unitrand() < agent->useexplore); walk_tree(copy, & agent->root, 0); - agent->root.exp.addv(movelist.getexp(3-agent->rootboard.toplay())); + agent->root.exp.addv(movelist.getexp(~agent->rootboard.toplay())); if(agent->profile){ times[0] += timestamps[1] - timestamps[0]; @@ -30,16 +34,16 @@ void AgentMCTS::AgentThread::iterate(){ } void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); + Side toplay = board.toplay(); - if(!node->children.empty() && node->outcome < 0){ + if(!node->children.empty() && node->outcome < Outcome::DRAW){ //choose a child and recurse Node * child; do{ int remain = board.movesremain(); child = choose_move(node, toplay, remain); - if(child->outcome < 0){ + if(child->outcome < Outcome::DRAW){ movelist.addtree(child->move, toplay); if(!board.move(child->move)){ @@ -71,10 +75,10 @@ void AgentMCTS::AgentThread::walk_tree(Board & board, Node * node, int depth){ timestamps[1] = Time(); } - int won = (agent->minimax ? node->outcome : board.won()); + Outcome won = (agent->minimax ? node->outcome : board.won()); //if it's not already decided - if(won < 0){ + if(won < Outcome::DRAW){ //create children if valid if(node->exp.num() >= agent->visitexpand+1 && create_children(board, node)){ walk_tree(board, node, depth); @@ -125,6 +129,8 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ CompactTree::Children temp; temp.alloc(board.movesremain(), agent->ctmem); + Side toplay = board.toplay(); + Side opponent = ~toplay; int losses = 0; Node * child = temp.begin(), @@ -136,14 +142,14 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ *child = Node(*move); if(agent->minimax){ - child->outcome = board.test_win(*move); + child->outcome = board.test_outcome(*move); - if(agent->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ + if(agent->minimax >= 2 && board.test_outcome(*move, opponent) == +opponent){ losses++; loss = child; } - if(child->outcome == board.toplay()){ //proven win from here, don't need children + if(child->outcome == +toplay){ //proven win from here, don't need children node->outcome = child->outcome; node->proofdepth = 1; node->bestmove = *move; @@ -171,7 +177,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ macro.exp.addwins(agent->visitexpand); *(temp.begin()) = macro; }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - board.toplay(); + node->outcome = +opponent; node->proofdepth = 2; node->bestmove = loss->move; node->children.unlock(); @@ -180,7 +186,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ } if(agent->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); + std::sort(temp.begin(), temp.end(), sort_node_know); PLUS(agent->nodes, temp.num()); node->children.swap(temp); @@ -189,7 +195,7 @@ bool AgentMCTS::AgentThread::create_children(const Board & board, Node * node){ return true; } -AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int toplay, int remain) const { +AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, Side toplay, int remain) const { float val, maxval = -1000000000; float logvisits = log(node->exp.num()); int dynwidenlim = (agent->dynwiden > 0 ? (int)(logvisits/agent->logdynwiden)+2 : Board::max_vecsize); @@ -204,11 +210,11 @@ AgentMCTS::Node * AgentMCTS::AgentThread::choose_move(const Node * node, int top * end = node->children.end(); for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ + if(child->outcome >= Outcome::DRAW){ if(child->outcome == toplay) //return a win immediately return child; - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse + val = (child->outcome == Outcome::DRAW ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse }else{ val = child->value(raveval, agent->knowledge, agent->fpurgency); if(explore > 0) @@ -237,80 +243,80 @@ backup in this order: 0 lose return true if fully solved, false if it's unknown or partially unknown */ -bool AgentMCTS::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread +bool AgentMCTS::do_backup(Node * node, Node * backup, Side toplay){ + Outcome node_outcome = node->outcome; + if(node_outcome >= Outcome::DRAW) //already proven, probably by a different thread return true; - if(backup->outcome == -3) //nothing proven by this child, so no chance + if(backup->outcome == Outcome::UNKNOWN) //nothing proven by this child, so no chance return false; uint8_t proofdepth = backup->proofdepth; if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; + uint64_t sims = 0, bestsims = 0, outcome = 0, best_outcome = 0; backup = NULL; Node * child = node->children.begin(), * end = node->children.end(); for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions + Outcome child_outcome = child->outcome; //save a copy to avoid race conditions if(proofdepth < child->proofdepth+1) proofdepth = child->proofdepth+1; //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss + if(child_outcome == Outcome::UNKNOWN){ // win/draw/loss outcome = 3; - }else if(childoutcome == toplay){ //win + }else if(child_outcome == toplay){ //win backup = child; outcome = 6; proofdepth = child->proofdepth+1; break; - }else if(childoutcome == 3-toplay){ //loss + }else if(child_outcome == ~toplay){ //loss outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss + }else if(child_outcome == Outcome::DRAW){ //draw + if(node_outcome == -toplay) //draw/loss, ie I can't win outcome = 4; else outcome = 2; - }else if(childoutcome == -toplay){ //win/draw + }else if(child_outcome == -~toplay){ //win/draw, ie opponent can't win outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss + }else if(child_outcome == -toplay){ //draw/loss, ie I can't win outcome = 1; }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); + logerr("child_outcome == " + child_outcome.to_s() + "\n"); assert(false && "How'd I get here? All outcomes should be tested above"); } sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; + if(best_outcome < outcome){ //better outcome is always preferable + best_outcome = outcome; bestsims = sims; backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ + }else if(best_outcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ //find long losses or easy wins/draws bestsims = sims; backup = child; } } - if(bestoutcome == 3) //no win, but found an unknown + if(best_outcome == 3) //no win, but found an unknown return false; } - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ + if(node->outcome.cas(node_outcome, backup->outcome)){ node->bestmove = backup->move; node->proofdepth = proofdepth; }else //if it was in a race, try again, might promote a partial solve to full solve return do_backup(node, backup, toplay); - return (node->outcome >= 0); + return (node->outcome >= Outcome::DRAW); } //update the rave score of all children that were played -void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ +void AgentMCTS::AgentThread::update_rave(const Node * node, Side toplay){ Node * child = node->children.begin(), * childend = node->children.end(); @@ -321,7 +327,7 @@ void AgentMCTS::AgentThread::update_rave(const Node * node, int toplay){ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Node * child){ if(agent->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); + int dist = board.dist(node->move, child->move); if(dist < 4) child->know += agent->localreply * (4 - dist); } @@ -343,24 +349,24 @@ void AgentMCTS::AgentThread::add_knowledge(const Board & board, Node * node, Nod child->know += agent->bridge; if(agent->dists) - child->know += abs(agent->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); + child->know += abs(agent->dists) * std::max(0, board.get_size() - dists.get(child->move, board.toplay())); } //test whether this move is a forced reply to the opponent probing your virtual connections bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) + if(board.dist(move, test) != 1) return false; bool equals = false; int state = 0; - int piece = 3 - board.get(move); + Side piece = ~board.get(move); for(int i = 0; i < 8; i++){ Move cur = move + neighbours[i % 6]; bool on = board.onboard(cur); - int v = 0; + Side v = Side::NONE; if(on) v = board.get(cur); @@ -371,7 +377,7 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //else state = 0; }else if(state == 1){ if(on){ - if(v == 0){ + if(v == Side::NONE){ state = 2; equals = (test == cur); }else if(v != piece) @@ -396,16 +402,16 @@ bool AgentMCTS::AgentThread::test_bridge_probe(const Board & board, const Move & //play a random game starting from a board state, and return the results of who won -int AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ - int won; +Outcome AgentMCTS::AgentThread::rollout(Board & board, Move move, int depth){ + Outcome won; if(agent->instantwin) instant_wins.rollout_start(board, agent->instantwin); random_policy.rollout_start(board); - while((won = board.won()) < 0){ - int turn = board.toplay(); + while((won = board.won()) < Outcome::DRAW){ + Side turn = board.toplay(); move = rollout_choose_move(board, move); @@ -449,3 +455,6 @@ Move AgentMCTS::AgentThread::rollout_choose_move(Board & board, const Move & pre return random_policy.choose_move(board, prev); } + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentpns.cpp b/y/agentpns.cpp index ec270ff..6887c70 100644 --- a/y/agentpns.cpp +++ b/y/agentpns.cpp @@ -5,6 +5,40 @@ #include "agentpns.h" + +namespace Morat { +namespace Y { + +std::string AgentPNS::Node::to_s() const { + return "AgentPNS::Node" + ", move " + move.to_s() + + ", phi " + to_str(phi) + + ", delta " + to_str(delta) + + ", work " + to_str(work) + + ", children " + to_str(children.num()); +} + +bool AgentPNS::Node::from_s(std::string s) { + auto dict = parse_dict(s, ", ", " "); + + if(dict.size() == 6){ + move = Move(dict["move"]); + phi = from_str(dict["phi"]); + delta = from_str(dict["delta"]); + work = from_str(dict["work"]); + // ignore children + return true; + } + return false; +} + +void AgentPNS::test() { + Node n(Move("a1")); + auto s = n.to_s(); + Node k; + assert(k.from_s(s)); +} + void AgentPNS::search(double time, uint64_t maxiters, int verbose){ max_nodes_seen = maxiters; @@ -32,27 +66,20 @@ void AgentPNS::search(double time, uint64_t maxiters, int verbose){ logerr("Tree depth: " + treelen.to_s() + "\n"); } - int toplay = rootboard.toplay(); + Side toplay = rootboard.toplay(); logerr("Root: " + root.to_s() + "\n"); - int outcome = root.to_outcome(3-toplay); - if(outcome != -3){ - logerr("Solved as a "); - if( outcome == 0) logerr("draw\n"); - else if(outcome == 3) logerr("draw by simultaneous win\n"); - else if(outcome == toplay) logerr("win\n"); - else if(outcome == 3-toplay) logerr("loss\n"); - else if(outcome == -toplay) logerr("win or draw\n"); - else if(outcome == toplay-3) logerr("loss or draw\n"); - } + Outcome outcome = root.to_outcome(~toplay); + if(outcome != Outcome::UNKNOWN) + logerr("Solved as a " + outcome.to_s_rel(toplay) + "\n"); - string pvstr; + std::string pvstr; for(auto m : get_pv()) pvstr += " " + m.to_s(); logerr("PV: " + pvstr + "\n"); if(verbose >= 3 && !root.children.empty()) - logerr("Move stats:\n" + move_stats(vector())); + logerr("Move stats:\n" + move_stats(vecmove())); } } @@ -83,8 +110,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin unsigned int i = 0; for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - unsigned int pd = 1; - int outcome; + unsigned int pd; + Outcome outcome; if(agent->ab){ Board next = board; @@ -94,10 +121,10 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin outcome = (agent->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); }else{ pd = 1; - outcome = board.test_win(*move); + outcome = board.test_outcome(*move); } - if(agent->lbdist && outcome < 0) + if(agent->lbdist && outcome != Outcome::UNKNOWN) pd = dists.get(*move); temp[i] = Node(*move).outcome(outcome, board.toplay(), agent->ties, pd); @@ -132,8 +159,8 @@ bool AgentPNS::AgentThread::pns(const Board & board, Node * node, int depth, uin } } - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); + tpc = std::min(INF32/2, (td + child->phi - node->delta)); + tdc = std::min(tp, (uint32_t)(child2->delta*(1.0 + agent->epsilon) + 1)); }else{ tpc = tdc = 0; for(auto & i : node->children) @@ -198,16 +225,16 @@ double AgentPNS::gamelen() const { return rootboard.movesremain(); } -vector AgentPNS::get_pv() const { - vector pv; +std::vector AgentPNS::get_pv() const { + vecmove pv; const Node * n = & root; - char turn = rootboard.toplay(); + Side turn = rootboard.toplay(); while(n && !n->children.empty()){ Move m = return_move(n, turn); pv.push_back(m); n = find_child(n, m); - turn = 3 - turn; + turn = ~turn; } if(pv.size() == 0) @@ -216,8 +243,8 @@ vector AgentPNS::get_pv() const { return pv; } -string AgentPNS::move_stats(vector moves) const { - string s = ""; +std::string AgentPNS::move_stats(vecmove moves) const { + std::string s = ""; const Node * node = & root; if(moves.size()){ @@ -238,7 +265,7 @@ string AgentPNS::move_stats(vector moves) const { return s; } -Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { +Move AgentPNS::return_move(const Node * node, Side toplay, int verbose) const { double val, maxval = -1000000000000.0; //1 trillion Node * ret = NULL, @@ -246,11 +273,11 @@ Move AgentPNS::return_move(const Node * node, int toplay, int verbose) const { * end = node->children.end(); for( ; child != end; child++){ - int outcome = child->to_outcome(toplay); - if(outcome >= 0){ - if(outcome == toplay) val = 800000000000.0 - (double)child->work; //shortest win - else if(outcome == 0) val = -400000000000.0 + (double)child->work; //longest tie - else val = -800000000000.0 + (double)child->work; //longest loss + Outcome outcome = child->to_outcome(toplay); + if(outcome >= Outcome::DRAW){ + if( outcome == +toplay) val = 800000000000.0 - (double)child->work; //shortest win + else if(outcome == Outcome::DRAW) val = -400000000000.0 + (double)child->work; //longest tie + else val = -800000000000.0 + (double)child->work; //longest loss }else{ //not proven val = child->work; } @@ -290,3 +317,51 @@ void AgentPNS::garbage_collect(Node * node){ } } } + +void AgentPNS::create_children_simple(const Board & board, Node * node){ + assert(node->children.empty()); + node->children.alloc(board.movesremain(), ctmem); + unsigned int i = 0; + for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ + Outcome outcome = board.test_outcome(*move); + node->children[i] = Node(*move).outcome(outcome, board.toplay(), ties, 1); + i++; + } + PLUS(nodes, i); + node->children.shrink(i); //if symmetry, there may be extra moves to ignore +} + +void AgentPNS::gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const { + for(auto & child : node.children){ + if(child.work >= limit && (side != node.to_outcome(~side) || child.to_outcome(side) == node.to_outcome(~side))){ + sgf.child_start(); + sgf.move(side, child.move); + sgf.comment(child.to_s()); + gen_sgf(sgf, limit, child, ~side); + sgf.child_end(); + } + } +} + +void AgentPNS::load_sgf(SGFParser & sgf, const Board & board, Node & node) { + assert(sgf.has_children()); + create_children_simple(board, &node); + + while(sgf.next_child()){ + Move m = sgf.move(); + Node & child = *find_child(&node, m); + child.from_s(sgf.comment()); + if(sgf.done_child()){ + continue; + }else{ + // has children! + Board b = board; + b.move(m); + load_sgf(sgf, b, child); + assert(sgf.done_child()); + } + } +} + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentpns.h b/y/agentpns.h index ad33042..1ad4160 100644 --- a/y/agentpns.h +++ b/y/agentpns.h @@ -3,15 +3,21 @@ //A multi-threaded, tree based, proof number search solver. +#include + #include "../lib/agentpool.h" #include "../lib/compacttree.h" #include "../lib/depthstats.h" #include "../lib/log.h" +#include "../lib/string.h" #include "agent.h" #include "lbdist.h" +namespace Morat { +namespace Y { + class AgentPNS : public Agent { static const uint32_t LOSS = (1<<30)-1; static const uint32_t DRAW = (1<<30)-2; @@ -51,33 +57,33 @@ class AgentPNS : public Agent { assert(children.empty()); } - Node & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); + Node & abval(int ab_outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && (ab_outcome == 1 || ab_outcome == -1)) + ab_outcome = (toplay == assign ? 2 : -2); - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } + if( ab_outcome == 0) { phi = value; delta = value; } + else if(ab_outcome == 2) { phi = LOSS; delta = 0; } + else if(ab_outcome == -2) { phi = 0; delta = LOSS; } + else /*(ab_outcome 1||-1)*/ { phi = 0; delta = DRAW; } return *this; } - Node & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; + Node & outcome(Outcome outcome, Side toplay, Side assign, int value = 1){ + if(assign != Side::NONE && outcome == Outcome::DRAW) + outcome = +assign; - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } + if( outcome == Outcome::UNKNOWN) { phi = value; delta = value; } + else if(outcome == +toplay) { phi = LOSS; delta = 0; } + else if(outcome == +~toplay) { phi = 0; delta = LOSS; } + else /*(outcome == Outcome::DRAW)*/ { phi = 0; delta = DRAW; } return *this; } - int to_outcome(int toplay) const { - if(phi == LOSS) return toplay; - if(delta == LOSS) return 3 - toplay; - if(delta == DRAW) return 0; - return -3; + Outcome to_outcome(Side toplay) const { + if(phi == LOSS) return +toplay; + if(delta == LOSS) return +~toplay; + if(delta == DRAW) return Outcome::DRAW; + return Outcome::UNKNOWN; } bool terminal(){ return (phi == 0 || delta == 0); } @@ -98,15 +104,8 @@ class AgentPNS : public Agent { return num; } - string to_s() const { - return "Node: move " + move.to_s() + - ", phi " + to_str(phi) + - ", delta " + to_str(delta) + - ", work " + to_str(work) + -// ", outcome " + to_str((int)outcome) + "/" + to_str((int)proofdepth) + -// ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } + std::string to_s() const ; + bool from_s(std::string s); void swap_tree(Node & n){ children.swap(n.children); @@ -162,7 +161,7 @@ class AgentPNS : public Agent { int ab; // how deep of an alpha-beta search to run at each leaf node bool df; // go depth first? float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 + Side ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 bool lbdist; int numthreads; @@ -172,7 +171,7 @@ class AgentPNS : public Agent { ab = 2; df = true; epsilon = 0.25; - ties = 0; + ties = Side::NONE; lbdist = false; numthreads = 1; pool.set_num_threads(numthreads); @@ -228,7 +227,7 @@ class AgentPNS : public Agent { root.swap_tree(child); if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); + logerr(std::string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); assert(nodes == root.size()); @@ -280,12 +279,36 @@ class AgentPNS : public Agent { void search(double time, uint64_t maxiters, int verbose); Move return_move(int verbose) const { return return_move(& root, rootboard.toplay(), verbose); } double gamelen() const; - vector get_pv() const; - string move_stats(const vector moves) const; + vecmove get_pv() const; + std::string move_stats(const vecmove moves) const; + + void gen_sgf(SGFPrinter & sgf, int limit) const { + if(limit < 0){ + limit = 0; + //TODO: Set the root.work properly + for(auto & child : root.children) + limit += child.work; + limit /= 1000; + } + gen_sgf(sgf, limit, root, rootboard.toplay()); + } + + void load_sgf(SGFParser & sgf) { + load_sgf(sgf, rootboard, root); + } + + static void test(); private: //remove all the nodes with little work to free up some memory void garbage_collect(Node * node); - Move return_move(const Node * node, int toplay, int verbose = 0) const; + Move return_move(const Node * node, Side toplay, int verbose = 0) const; Node * find_child(const Node * node, const Move & move) const ; + void create_children_simple(const Board & board, Node * node); + + void gen_sgf(SGFPrinter & sgf, unsigned int limit, const Node & node, Side side) const; + void load_sgf(SGFParser & sgf, const Board & board, Node & node); }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/agentpns_test.cpp b/y/agentpns_test.cpp new file mode 100644 index 0000000..4bc2f25 --- /dev/null +++ b/y/agentpns_test.cpp @@ -0,0 +1,16 @@ + +#include "../lib/catch.hpp" + +#include "agentpns.h" + + +using namespace Morat; +using namespace Y; + +TEST_CASE("Y::AgentPNS::Node::to_s/from_s", "[y][agentpns]") { + AgentPNS::Node n(Move("a1")); + auto s = n.to_s(); + AgentPNS::Node k; + REQUIRE(k.from_s(s)); + REQUIRE(n.to_s() == k.to_s()); +} diff --git a/y/board.cpp b/y/board.cpp new file mode 100644 index 0000000..f7e4f8f --- /dev/null +++ b/y/board.cpp @@ -0,0 +1,67 @@ + +#include "board.h" + +namespace Morat { +namespace Y { + +std::string Board::Cell::to_s(int i) const { + return "Cell " + to_str(i) +": " + "piece: " + to_str(piece.to_i())+ + ", size: " + to_str((int)size) + + ", parent: " + to_str((int)parent) + + ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + + ", perm: " + to_str((int)perm) + + ", pattern: " + to_str((int)pattern); +} + +std::string Board::to_s(bool color) const { + using std::string; + string white = "O", + black = "@", + empty = ".", + coord = "", + reset = ""; + if(color){ + string esc = "\033"; + reset = esc + "[0m"; + coord = esc + "[1;37m"; + empty = reset + "."; + white = esc + "[1;33m" + "@"; //yellow + black = esc + "[1;34m" + "@"; //blue + } + + string s; + for(int i = 0; i < size; i++) + s += " " + coord + to_str(i+1); + s += "\n"; + + for(int y = 0; y < size; y++){ + s += string(y, ' '); + s += coord + char('A' + y); + int end = lineend(y); + for(int x = 0; x < end; x++){ + s += (last == Move(x, y) ? coord + "[" : + last == Move(x-1, y) ? coord + "]" : " "); + Side p = get(x, y); + if( p == Side::NONE) s += empty; + else if(p == Side::P1) s += white; + else if(p == Side::P2) s += black; + else s += "?"; + } + s += (last == Move(end-1, y) ? coord + "]" : " "); + s += '\n'; + } + + s += reset; + return s; +} + +int Board::edges(int x, int y) const { + return (x == 0 ? 1 : 0) | + (y == 0 ? 2 : 0) | + (x + y == sizem1 ? 4 : 0); +} + + +}; // namespace Y +}; // namespace Morat diff --git a/y/board.h b/y/board.h index 67d7abd..a54e204 100644 --- a/y/board.h +++ b/y/board.h @@ -4,18 +4,21 @@ #include #include #include +#include #include #include #include "../lib/bitcount.h" #include "../lib/hashset.h" +#include "../lib/move.h" +#include "../lib/outcome.h" #include "../lib/string.h" #include "../lib/types.h" #include "../lib/zobrist.h" -#include "move.h" -using namespace std; +namespace Morat { +namespace Y { /* * the board is represented as a flattened 2d array of the form: @@ -49,37 +52,31 @@ static MoveValid * staticneighbourlist[17] = { class Board{ public: + static constexpr const char * const name = "y"; static const int default_size = 10; static const int min_size = 5; static const int max_size = 16; static const int max_vecsize = max_size * max_size; + static const int num_win_types = 1; static const int pattern_cells = 18; typedef uint64_t Pattern; struct Cell { - uint16_t piece; //who controls this cell, 0 for none, 1,2 for players + Side piece; //who controls this cell, 0 for none, 1,2 for players uint16_t size; //size of this group of cells -mutable uint16_t parent; //parent for this group of cells. 8 bits limits board size to 16 until it's no longer stored as a square +mutable uint16_t parent; //parent for this group of cells uint8_t edge; //which edges are this group connected to uint8_t perm; //is this a permanent piece or a randomly placed piece? Pattern pattern; //the pattern of pieces for neighbours, but from their perspective. Rotate 180 for my perpective - Cell() : piece(73), size(0), parent(0), edge(0), perm(0), pattern(0) { } - Cell(unsigned int p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : + Cell() : piece(Side::NONE), size(0), parent(0), edge(0), perm(0), pattern(0) { } + Cell(Side p, unsigned int a, unsigned int s, unsigned int e, Pattern t) : piece(p), size(s), parent(a), edge(e), perm(0), pattern(t) { } int numedges() const { return BitsSetTable256[edge]; } - string to_s(int i) const { - return "Cell " + to_str(i) +": " - "piece: " + to_str((int)piece)+ - ", size: " + to_str((int)size) + - ", parent: " + to_str((int)parent) + - ", edge: " + to_str((int)edge) + "/" + to_str(numedges()) + - ", perm: " + to_str((int)perm) + - ", pattern: " + to_str((int)pattern); - } + std::string to_s(int i) const; }; class MoveIterator { //only returns valid moves... @@ -90,7 +87,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board HashSet hashes; public: MoveIterator(const Board & b, bool Unique) : board(b), lineend(0), move(Move(M_SWAP), -1), unique(Unique) { - if(board.outcome >= 0){ + if(board.outcome >= Outcome::DRAW){ move = MoveValid(0, board.size, -1); //already done } else { if(unique) @@ -116,9 +113,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board move.xy = -1; return *this; } - - move.x = 0; - move.xy = move.y * board.size; + move.x = board.linestart(move.y); + move.xy = board.xy(move.x, move.y); lineend = board.lineend(move.y); } }while(!board.valid_move_fast(move)); @@ -142,10 +138,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board short nummoves; short unique_depth; //update and test rotations/symmetry with less than this many pieces on the board Move last; - char toPlay; - char outcome; //-3 = unknown, 0 = tie, 1,2 = player win + Side toPlay; + Outcome outcome; - vector cells; + std::vector cells; Zobrist<6> hash; const MoveValid * neighbourlist; @@ -160,15 +156,15 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board last = M_NONE; nummoves = 0; unique_depth = 5; - toPlay = 1; - outcome = -3; + toPlay = Side::P1; + outcome = Outcome::UNKNOWN; neighbourlist = get_neighbour_list(); num_cells = vecsize() - (size*sizem1/2); cells.resize(vecsize()); for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ + for(int x = 0; x < size; x++){ int posxy = xy(x, y); Pattern p = 0, j = 3; for(const MoveValid * i = nb_begin(posxy), *e = nb_end_big_hood(i); i < e; i++){ @@ -176,7 +172,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board p |= j; j <<= 2; } - cells[posxy] = Cell(0, posxy, 1, edges(x, y), pattern_reverse(p)); + Side s = (onboard(x, y) ? Side::NONE : Side::UNDEF); + cells[posxy] = Cell(s, posxy, 1, edges(x, y), pattern_reverse(p)); } } } @@ -193,7 +190,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board int numcells() const { return num_cells; } int num_moves() const { return nummoves; } - int movesremain() const { return (won() >= 0 ? 0 : num_cells - nummoves); } + int movesremain() const { return (won() >= Outcome::DRAW ? 0 : num_cells - nummoves); } int xy(int x, int y) const { return y*size + x; } int xy(const Move & m) const { return m.y*size + m.x; } @@ -201,6 +198,10 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board MoveValid yx(int i) const { return MoveValid(i % size, i / size, i); } + int dist(const Move & a, const Move & b) const { + return (abs(a.x - b.x) + abs(a.y - b.y) + abs((a.x + a.y) - (b.x + b.y)) )/2; + } + const Cell * cell(int i) const { return & cells[i]; } const Cell * cell(int x, int y) const { return cell(xy(x,y)); } const Cell * cell(const Move & m) const { return cell(xy(m)); } @@ -208,18 +209,18 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board //assumes valid x,y - int get(int i) const { return cells[i].piece; } - int get(int x, int y) const { return get(xy(x, y)); } - int get(const Move & m) const { return get(xy(m)); } - int get(const MoveValid & m) const { return get(m.xy); } + Side get(int i) const { return cells[i].piece; } + Side get(int x, int y) const { return get(xy(x, y)); } + Side get(const Move & m) const { return get(xy(m)); } + Side get(const MoveValid & m) const { return get(m.xy); } - int geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : 0); } + Side geton(const MoveValid & m) const { return (m.onboard() ? get(m.xy) : Side::UNDEF); } - int local(const Move & m, char turn) const { return local(xy(m), turn); } - int local(int i, char turn) const { + int local(const Move & m, Side turn) const { return local(xy(m), turn); } + int local(int i, Side turn) const { Pattern p = pattern(i); Pattern x = ((p & 0xAAAAAAAAAull) >> 1) ^ (p & 0x555555555ull); // p1 is now when p1 or p2 but not both (ie off the board) - p = x & (turn == 1 ? p : p >> 1); // now just the selected player + p = x & (turn == Side::P1 ? p : p >> 1); // now just the selected player return (p & 0x000000FFF ? 3 : 0) | (p & 0x000FFF000 ? 2 : 0) | (p & 0xFFF000000 ? 1 : 0); @@ -235,13 +236,14 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board bool onboard(const MoveValid & m) const { return m.onboard(); } //assumes x, y are in bounds and the game isn't already finished - bool valid_move_fast(int x, int y) const { return !get(x,y); } - bool valid_move_fast(const Move & m) const { return !get(m); } - bool valid_move_fast(const MoveValid & m) const { return !get(m.xy); } + bool valid_move_fast(int i) const { return get(i) == Side::NONE; } + bool valid_move_fast(int x, int y) const { return valid_move_fast(xy(x, y)); } + bool valid_move_fast(const Move & m) const { return valid_move_fast(xy(m)); } + bool valid_move_fast(const MoveValid & m) const { return valid_move_fast(m.xy); } //checks array bounds too - bool valid_move(int x, int y) const { return (outcome == -3 && onboard(x, y) && !get(x, y)); } - bool valid_move(const Move & m) const { return (outcome == -3 && onboard(m) && !get(m)); } - bool valid_move(const MoveValid & m) const { return (outcome == -3 && m.onboard() && !get(m)); } + bool valid_move(int x, int y) const { return (outcome < Outcome::DRAW && onboard(x, y) && valid_move_fast(x, y)); } + bool valid_move(const Move & m) const { return (outcome < Outcome::DRAW && onboard(m) && valid_move_fast(m)); } + bool valid_move(const MoveValid & m) const { return (outcome < Outcome::DRAW && m.onboard() && valid_move_fast(m)); } //iterator through neighbours of a position const MoveValid * nb_begin(int x, int y) const { return nb_begin(xy(x, y)); } @@ -255,11 +257,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board const MoveValid * nb_end_small_hood(const MoveValid * m) const { return m + 12; } const MoveValid * nb_end_big_hood(const MoveValid * m) const { return m + 18; } - int edges(int x, int y) const { - return (x == 0 ? 1 : 0) | - (y == 0 ? 2 : 0) | - (x + y == sizem1 ? 4 : 0); - } + int edges(int x, int y) const; MoveValid * get_neighbour_list(){ if(!staticneighbourlist[(int)size]){ @@ -283,89 +281,24 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return staticneighbourlist[(int)size]; } - + int linestart(int y) const { return 0; } int lineend(int y) const { return (size - y); } + int linelen(int y) const { return lineend(y) - linestart(y); } - string to_s(bool color) const { - string white = "O", - black = "@", - empty = ".", - coord = "", - reset = ""; - if(color){ - string esc = "\033"; - reset = esc + "[0m"; - coord = esc + "[1;37m"; - empty = reset + "."; - white = esc + "[1;33m" + "@"; //yellow - black = esc + "[1;34m" + "@"; //blue - } - - string s; - for(int i = 0; i < size; i++) - s += " " + coord + to_str(i+1); - s += "\n"; - - for(int y = 0; y < size; y++){ - s += string(y, ' '); - s += coord + char('A' + y); - int end = lineend(y); - for(int x = 0; x < end; x++){ - s += (last == Move(x, y) ? coord + "[" : - last == Move(x-1, y) ? coord + "]" : " "); - int p = get(x, y); - if(p == 0) s += empty; - if(p == 1) s += white; - if(p == 2) s += black; - if(p >= 3) s += "?"; - } - s += (last == Move(end-1, y) ? coord + "]" : " "); - s += '\n'; - } - - s += reset; - return s; - } + std::string to_s(bool color) const; + friend std::ostream& operator<< (std::ostream &out, const Board & b) { return out << b.to_s(true); } void print(bool color = true) const { printf("%s", to_s(color).c_str()); } - string boardstr() const { - string white, black; - for(int y = 0; y < size; y++){ - for(int x = 0; x < lineend(y); x++){ - int p = get(x, y); - if(p == 1) white += Move(x, y).to_s(); - if(p == 2) black += Move(x, y).to_s(); - } - } - return white + ";" + black; - } - - string won_str() const { - switch(outcome){ - case -3: return "none"; - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - } - return "unknown"; - } - - char won() const { + Outcome won() const { return outcome; } - int win() const{ // 0 for draw or unknown, 1 for win, -1 for loss - if(outcome <= 0) - return 0; - return (outcome == toplay() ? 1 : -1); - } + char getwintype() const { return outcome > Outcome::DRAW; } - char toplay() const { + Side toplay() const { return toPlay; } @@ -373,22 +306,22 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return MoveIterator(*this, (unique ? nummoves <= unique_depth : false)); } - void set(const Move & m, bool perm = true){ + void set(const Move & m, bool perm = true) { last = m; Cell * cell = & cells[xy(m)]; cell->piece = toPlay; cell->perm = perm; nummoves++; update_hash(m, toPlay); //depends on nummoves - toPlay = 3 - toPlay; + toPlay = ~toPlay; } - void unset(const Move & m){ //break win checks, but is a poor mans undo if all you care about is the hash - toPlay = 3 - toPlay; + void unset(const Move & m) { //break win checks, but is a poor mans undo if all you care about is the hash + toPlay = ~toPlay; update_hash(m, toPlay); nummoves--; Cell * cell = & cells[xy(m)]; - cell->piece = 0; + cell->piece = Side::NONE; cell->perm = 0; } @@ -418,7 +351,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return true; if(cells[i].size < cells[j].size) //force i's subtree to be bigger - swap(i, j); + std::swap(i, j); cells[j].parent = i; cells[i].size += cells[j].size; @@ -428,7 +361,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } Cell test_cell(const Move & pos) const { - char turn = toplay(); + Side turn = toplay(); int posxy = xy(pos); Cell testcell = cells[find_group(pos)]; @@ -458,7 +391,7 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (nummoves > unique_depth ? hash.get(0) : hash.get()); } - string hashstr() const { + std::string hashstr() const { static const char hexlookup[] = "0123456789abcdef"; char buf[19] = "0x"; hash_t val = gethash(); @@ -470,7 +403,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (char *)buf; } - void update_hash(const Move & pos, int turn){ + void update_hash(const Move & pos, Side side) { + int turn = side.to_i(); if(nummoves > unique_depth){ //simple update, no rotations/symmetry hash.update(0, 3*xy(pos) + turn); return; @@ -493,7 +427,8 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return test_hash(pos, toplay()); } - hash_t test_hash(const Move & pos, int turn) const { + hash_t test_hash(const Move & pos, Side side) const { + int turn = side.to_i(); if(nummoves >= unique_depth) //simple test, no rotations/symmetry return hash.test(0, 3*xy(pos) + turn); @@ -502,11 +437,11 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board z = sizem1 - x - y; hash_t m = hash.test(0, 3*xy(x, y) + turn); - m = min(m, hash.test(1, 3*xy(z, y) + turn)); - m = min(m, hash.test(2, 3*xy(z, x) + turn)); - m = min(m, hash.test(3, 3*xy(x, z) + turn)); - m = min(m, hash.test(4, 3*xy(y, z) + turn)); - m = min(m, hash.test(5, 3*xy(y, x) + turn)); + m = std::min(m, hash.test(1, 3*xy(z, y) + turn)); + m = std::min(m, hash.test(2, 3*xy(z, x) + turn)); + m = std::min(m, hash.test(3, 3*xy(x, z) + turn)); + m = std::min(m, hash.test(4, 3*xy(y, z) + turn)); + m = std::min(m, hash.test(5, 3*xy(y, x) + turn)); return m; } @@ -538,13 +473,13 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board return (((p & 0x03F03F03Full) << 6) | ((p & 0xFC0FC0FC0ull) >> 6)); } - static Pattern pattern_invert(Pattern p){ //switch players + static Pattern pattern_invert(Pattern p) { //switch players return ((p & 0xAAAAAAAAAull) >> 1) | ((p & 0x555555555ull) << 1); } - static Pattern pattern_rotate(Pattern p){ + static Pattern pattern_rotate(Pattern p) { return (((p & 0x003003003ull) << 10) | ((p & 0xFFCFFCFFCull) >> 2)); } - static Pattern pattern_mirror(Pattern p){ + static Pattern pattern_mirror(Pattern p) { // HGFEDC BA9876 543210 -> DEFGHC 6789AB 123450 return ((p & (3ull << 6)) ) | ((p & (3ull << 0)) ) | // 0,3 stay in place ((p & (3ull << 10)) >> 8) | ((p & (3ull << 2)) << 8) | // 1,5 swap @@ -556,36 +491,36 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board ((p & (3ull << 34)) >> 8) | ((p & (3ull << 26)) << 8) | // H,D swap ((p & (3ull << 32)) >> 4) | ((p & (3ull << 28)) << 4); // G,E swap } - static Pattern pattern_symmetry(Pattern p){ //takes a pattern and returns the representative version + static Pattern pattern_symmetry(Pattern p) { //takes a pattern and returns the representative version Pattern m = p; //012345 - m = min(m, (p = pattern_rotate(p)));//501234 - m = min(m, (p = pattern_rotate(p)));//450123 - m = min(m, (p = pattern_rotate(p)));//345012 - m = min(m, (p = pattern_rotate(p)));//234501 - m = min(m, (p = pattern_rotate(p)));//123450 - m = min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 - m = min(m, (p = pattern_rotate(p)));//105432 - m = min(m, (p = pattern_rotate(p)));//210543 - m = min(m, (p = pattern_rotate(p)));//321054 - m = min(m, (p = pattern_rotate(p)));//432105 - m = min(m, (p = pattern_rotate(p)));//543210 + m = std::min(m, (p = pattern_rotate(p)));//501234 + m = std::min(m, (p = pattern_rotate(p)));//450123 + m = std::min(m, (p = pattern_rotate(p)));//345012 + m = std::min(m, (p = pattern_rotate(p)));//234501 + m = std::min(m, (p = pattern_rotate(p)));//123450 + m = std::min(m, (p = pattern_mirror(pattern_rotate(p))));//012345 -> 054321 + m = std::min(m, (p = pattern_rotate(p)));//105432 + m = std::min(m, (p = pattern_rotate(p)));//210543 + m = std::min(m, (p = pattern_rotate(p)));//321054 + m = std::min(m, (p = pattern_rotate(p)));//432105 + m = std::min(m, (p = pattern_rotate(p)));//543210 return m; } - bool move(const Move & pos, bool checkwin = true, bool permanent = true){ + bool move(const Move & pos, bool checkwin = true, bool permanent = true) { return move(MoveValid(pos, xy(pos)), checkwin, permanent); } - bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true){ - assert(outcome < 0); + bool move(const MoveValid & pos, bool checkwin = true, bool permanent = true) { + assert(outcome < Outcome::DRAW); if(!valid_move(pos)) return false; - char turn = toplay(); + Side turn = toplay(); set(pos, permanent); // update the nearby patterns - Pattern p = turn; + Pattern p = turn.to_i(); for(const MoveValid * i = nb_begin(pos.xy), *e = nb_end_big_hood(i); i < e; i++){ if(i->onboard()){ cells[i->xy].pattern |= p; @@ -605,25 +540,25 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board // did I win? Cell * g = & cells[find_group(pos.xy)]; if(g->numedges() == 3){ - outcome = turn; + outcome = +turn; } return true; } - bool test_local(const Move & pos, char turn) const { + bool test_local(const Move & pos, Side turn) const { return test_local(MoveValid(pos, xy(pos)), turn); } + bool test_local(const MoveValid & pos, Side turn) const { return (local(pos, turn) == 3); } //test if making this move would win, but don't actually make the move - int test_win(const Move & pos, char turn = 0) const { - if(turn == 0) - turn = toplay(); - + Outcome test_outcome(const Move & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const Move & pos, Side turn) const { return test_outcome(MoveValid(pos, xy(pos)), turn); } + Outcome test_outcome(const MoveValid & pos) const { return test_outcome(pos, toplay()); } + Outcome test_outcome(const MoveValid & pos, Side turn) const { if(test_local(pos, turn)){ - int posxy = xy(pos); - Cell testcell = cells[find_group(posxy)]; + Cell testcell = cells[find_group(pos.xy)]; int numgroups = 0; - for(const MoveValid * i = nb_begin(posxy), *e = nb_end(i); i < e; i++){ + for(const MoveValid * i = nb_begin(pos), *e = nb_end(i); i < e; i++){ if(i->onboard() && turn == get(i->xy)){ const Cell * g = & cells[find_group(i->xy)]; testcell.edge |= g->edge; @@ -634,9 +569,12 @@ mutable uint16_t parent; //parent for this group of cells. 8 bits limits board } if(testcell.numedges() == 3) - return turn; + return +turn; } - return -3; + return Outcome::UNKNOWN; } }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/board_test.cpp b/y/board_test.cpp new file mode 100644 index 0000000..647bd4b --- /dev/null +++ b/y/board_test.cpp @@ -0,0 +1,115 @@ + +#include "../lib/catch.hpp" +#include "../lib/string.h" + +#include "board.h" + + +using namespace Morat; +using namespace Y; + +void test_game(Board b, std::vector moves, Outcome outcome) { + REQUIRE(b.num_moves() == 0); + Side side = Side::P1; + for(auto s : moves) { + Outcome expected = (s == moves.back() ? outcome : Outcome::UNKNOWN); + Move move(s); + CAPTURE(move); + CAPTURE(b); + REQUIRE(b.valid_move(move)); + REQUIRE(b.toplay() == side); + REQUIRE(b.test_outcome(move) == expected); + REQUIRE(b.move(move)); + REQUIRE(b.won() == expected); + side = ~side; + } +} +void test_game(Board b, std::string moves, Outcome outcome) { + test_game(b, explode(moves, " "), outcome); +} + +TEST_CASE("Y::Board [y][board]") { + Board b(7); + + SECTION("Basics") { + REQUIRE(b.get_size() == 7); + REQUIRE(b.movesremain() == 28); + } + + SECTION("valid moves") { + std::string valid[] = {"A1", "D4", + "a1", "a2", "a3", "a4", "a5", "a6", "a7", + "b1", "b2", "b3", "b4", "b5", "b6", + "c1", "c2", "c3", "c4", "c5", + "d1", "d2", "d3", "d4", + "e1", "e2", "e3", + "f1", "f2", + "g1", + }; + for(auto m : valid){ + REQUIRE(b.onboard(m)); + REQUIRE(b.valid_move(m)); + } + } + + SECTION("invalid moves") { + std::string invalid[] = {"a0", "a8", "a10", "b7", "c6", "g2", "e8", "f8", "f0", "h1", "f0"}; + for(auto m : invalid){ + REQUIRE_FALSE(b.onboard(m)); + REQUIRE_FALSE(b.valid_move(m)); + } + } + + SECTION("duplicate moves") { + Move m("a1"); + REQUIRE(b.valid_move(m)); + REQUIRE(b.move(m)); + REQUIRE_FALSE(b.valid_move(m)); + REQUIRE_FALSE(b.move(m)); + } + + SECTION("move distance") { + SECTION("x") { + REQUIRE(b.dist(Move("b2"), Move("b1")) == 1); + REQUIRE(b.dist(Move("b2"), Move("b3")) == 1); + } + SECTION("y") { + REQUIRE(b.dist(Move("b2"), Move("a2")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c2")) == 1); + } + SECTION("z") { + REQUIRE(b.dist(Move("b2"), Move("a3")) == 1); + REQUIRE(b.dist(Move("b2"), Move("c1")) == 1); + } + SECTION("farther") { + REQUIRE(b.dist(Move("b2"), Move("a1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("c3")) == 2); + REQUIRE(b.dist(Move("b2"), Move("d4")) == 4); + REQUIRE(b.dist(Move("b2"), Move("d3")) == 3); + REQUIRE(b.dist(Move("b2"), Move("d1")) == 2); + REQUIRE(b.dist(Move("b2"), Move("e3")) == 4); + } + } + + SECTION("Unknown_1") { + test_game(b, { "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + test_game(b, {"d4", "a1", "b1", "a2", "b2", "a3", "b3", "a4"}, Outcome::UNKNOWN); + } + + SECTION("Unknown_2") { + test_game(b, { "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + test_game(b, {"d4", "b1", "c1", "b2", "c2", "b3", "c3", "b4", "c4", "b5", "c5", "a2"}, Outcome::UNKNOWN); + } + + SECTION("White Connects") { + test_game(b, + "c3 e2 c4 c5 e3 d3 d4 b1 c2 g1 b5 d2 d1 a6 a5", + Outcome::P1); + } + + SECTION("Black Connects") { + test_game(b, + "a1 b2 c3 e2 c1 b5 c5 a4 e3 c2 b6 c4 g1 d4 f2 d3 a3 a5 e1 f1", + Outcome::P2); + } +} diff --git a/y/gtp.h b/y/gtp.h index f53c9e9..f06cdf7 100644 --- a/y/gtp.h +++ b/y/gtp.h @@ -2,6 +2,8 @@ #pragma once #include "../lib/gtpcommon.h" +#include "../lib/history.h" +#include "../lib/move.h" #include "../lib/string.h" #include "agent.h" @@ -9,11 +11,13 @@ #include "agentmcts.h" #include "agentpns.h" #include "board.h" -#include "history.h" -#include "move.h" + + +namespace Morat { +namespace Y { class GTP : public GTPCommon { - History hist; + History hist; public: int verbose; @@ -35,46 +39,46 @@ class GTP : public GTPCommon { set_board(); - newcallback("name", bind(>P::gtp_name, this, _1), "Name of the program"); - newcallback("version", bind(>P::gtp_version, this, _1), "Version of the program"); - newcallback("verbose", bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); - newcallback("extended", bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); - newcallback("debug", bind(>P::gtp_debug, this, _1), "Enable debug mode"); - newcallback("colorboard", bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); - newcallback("showboard", bind(>P::gtp_print, this, _1), "Show the board"); - newcallback("print", bind(>P::gtp_print, this, _1), "Alias for showboard"); - newcallback("dists", bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); -// newcallback("zobrist", bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); - newcallback("clear_board", bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); - newcallback("clear", bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); - newcallback("boardsize", bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); - newcallback("size", bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); - newcallback("play", bind(>P::gtp_play, this, _1), "Place a stone: play "); - newcallback("white", bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); - newcallback("black", bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); - newcallback("undo", bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); - newcallback("time", bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); - newcallback("genmove", bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); - newcallback("solve", bind(>P::gtp_solve, this, _1), "Try to solve this position"); - -// newcallback("ab", bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); - newcallback("mcts", bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); - newcallback("pns", bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); - - newcallback("all_legal", bind(>P::gtp_all_legal, this, _1), "List all legal moves"); - newcallback("history", bind(>P::gtp_history, this, _1), "List of played moves"); - newcallback("playgame", bind(>P::gtp_playgame, this, _1), "Play a list of moves"); - newcallback("winner", bind(>P::gtp_winner, this, _1), "Check the winner of the game"); - newcallback("patterns", bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); - - newcallback("pv", bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); - newcallback("move_stats", bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); - - newcallback("params", bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); - -// newcallback("player_hgf", bind(>P::gtp_player_hgf, this, _1), "Output an hgf of the current tree"); -// newcallback("player_load_hgf", bind(>P::gtp_player_load_hgf,this, _1), "Load an hgf generated by player_hgf"); -// newcallback("player_gammas", bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); + newcallback("name", std::bind(>P::gtp_name, this, _1), "Name of the program"); + newcallback("version", std::bind(>P::gtp_version, this, _1), "Version of the program"); + newcallback("verbose", std::bind(>P::gtp_verbose, this, _1), "Set verbosity, 0 for quiet, 1 for normal, 2+ for more output"); + newcallback("extended", std::bind(>P::gtp_extended, this, _1), "Output extra stats from genmove in the response"); + newcallback("debug", std::bind(>P::gtp_debug, this, _1), "Enable debug mode"); + newcallback("colorboard", std::bind(>P::gtp_colorboard, this, _1), "Turn on or off the colored board"); + newcallback("showboard", std::bind(>P::gtp_print, this, _1), "Show the board"); + newcallback("print", std::bind(>P::gtp_print, this, _1), "Alias for showboard"); + newcallback("dists", std::bind(>P::gtp_dists, this, _1), "Similar to print, but shows minimum win distances"); + newcallback("zobrist", std::bind(>P::gtp_zobrist, this, _1), "Output the zobrist hash for the current move"); + newcallback("clear_board", std::bind(>P::gtp_clearboard, this, _1), "Clear the board, but keep the size"); + newcallback("clear", std::bind(>P::gtp_clearboard, this, _1), "Alias for clear_board"); + newcallback("boardsize", std::bind(>P::gtp_boardsize, this, _1), "Clear the board, set the board size"); + newcallback("size", std::bind(>P::gtp_boardsize, this, _1), "Alias for board_size"); + newcallback("play", std::bind(>P::gtp_play, this, _1), "Place a stone: play "); + newcallback("white", std::bind(>P::gtp_playwhite, this, _1), "Place a white stone: white "); + newcallback("black", std::bind(>P::gtp_playblack, this, _1), "Place a black stone: black "); + newcallback("undo", std::bind(>P::gtp_undo, this, _1), "Undo one or more moves: undo [amount to undo]"); + newcallback("time", std::bind(>P::gtp_time, this, _1), "Set the time limits and the algorithm for per game time"); + newcallback("genmove", std::bind(>P::gtp_genmove, this, _1), "Generate a move: genmove [color] [time]"); + newcallback("solve", std::bind(>P::gtp_solve, this, _1), "Try to solve this position"); + +// newcallback("ab", std::bind(>P::gtp_ab, this, _1), "Switch to use the Alpha/Beta agent to play/solve"); + newcallback("mcts", std::bind(>P::gtp_mcts, this, _1), "Switch to use the Monte Carlo Tree Search agent to play/solve"); + newcallback("pns", std::bind(>P::gtp_pns, this, _1), "Switch to use the Proof Number Search agent to play/solve"); + + newcallback("all_legal", std::bind(>P::gtp_all_legal, this, _1), "List all legal moves"); + newcallback("history", std::bind(>P::gtp_history, this, _1), "List of played moves"); + newcallback("playgame", std::bind(>P::gtp_playgame, this, _1), "Play a list of moves"); + newcallback("winner", std::bind(>P::gtp_winner, this, _1), "Check the winner of the game"); + newcallback("patterns", std::bind(>P::gtp_patterns, this, _1), "List all legal moves plus their local pattern"); + + newcallback("pv", std::bind(>P::gtp_pv, this, _1), "Output the principle variation for the player tree as it stands now"); + newcallback("move_stats", std::bind(>P::gtp_move_stats, this, _1), "Output the move stats for the player tree as it stands now"); + + newcallback("params", std::bind(>P::gtp_params, this, _1), "Set the options for the player, no args gives options"); + + newcallback("save_sgf", std::bind(>P::gtp_save_sgf, this, _1), "Output an sgf of the current tree"); + newcallback("load_sgf", std::bind(>P::gtp_load_sgf, this, _1), "Load an sgf generated by save_sgf"); +// newcallback("player_gammas", std::bind(>P::gtp_player_gammas, this, _1), "Load the gammas for weighted random from a file"); } void set_board(bool clear = true){ @@ -94,7 +98,7 @@ class GTP : public GTPCommon { GTPResponse gtp_all_legal(vecstr args); GTPResponse gtp_history(vecstr args); GTPResponse gtp_patterns(vecstr args); - GTPResponse play(const string & pos, int toplay); + GTPResponse play(const std::string & pos, Side toplay); GTPResponse gtp_playgame(vecstr args); GTPResponse gtp_play(vecstr args); GTPResponse gtp_playwhite(vecstr args); @@ -124,8 +128,11 @@ class GTP : public GTPCommon { GTPResponse gtp_pns_params(vecstr args); // GTPResponse gtp_player_gammas(vecstr args); -// GTPResponse gtp_player_hgf(vecstr args); -// GTPResponse gtp_player_load_hgf(vecstr args); + GTPResponse gtp_save_sgf(vecstr args); + GTPResponse gtp_load_sgf(vecstr args); - string solve_str(int outcome) const; + std::string solve_str(int outcome) const; }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/gtpagent.cpp b/y/gtpagent.cpp index d32178a..c0e6489 100644 --- a/y/gtpagent.cpp +++ b/y/gtpagent.cpp @@ -1,13 +1,12 @@ -#include +#include "gtp.h" -#include "../lib/fileio.h" -#include "gtp.h" +namespace Morat { +namespace Y { using namespace std; - GTPResponse GTP::gtp_move_stats(vecstr args){ vector moves; for(auto s : args) @@ -249,7 +248,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" " -m --memory Memory limit in Mb [" + to_str(pns->memlimit/(1024*1024)) + "]\n" " -t --threads How many threads to run [" + to_str(pns->numthreads) + "]\n" - " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties) + "]\n" + " -s --ties Which side to assign ties to, 0 = handle, 1 = p1, 2 = p2 [" + to_str(pns->ties.to_i()) + "]\n" " -d --df Use depth-first thresholds [" + to_str(pns->df) + "]\n" " -e --epsilon How big should the threshold be [" + to_str(pns->epsilon) + "]\n" " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(pns->ab) + "]\n" @@ -267,7 +266,7 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); pns->set_memlimit(mem*1024*1024); }else if((arg == "-s" || arg == "--ties") && i+1 < args.size()){ - pns->ties = from_str(args[++i]); + pns->ties = Side(from_str(args[++i])); pns->clear_mem(); }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ pns->df = from_str(args[++i]); @@ -282,3 +281,6 @@ GTPResponse GTP::gtp_pns_params(vecstr args){ return GTPResponse(true, errs); } + +}; // namespace Y +}; // namespace Morat diff --git a/y/gtpgeneral.cpp b/y/gtpgeneral.cpp index 60f0f73..e4910ec 100644 --- a/y/gtpgeneral.cpp +++ b/y/gtpgeneral.cpp @@ -1,7 +1,15 @@ +#include + +#include "../lib/sgf.h" + #include "gtp.h" #include "lbdist.h" + +namespace Morat { +namespace Y { + GTPResponse GTP::gtp_mcts(vecstr args){ delete agent; agent = new AgentMCTS(); @@ -39,7 +47,7 @@ GTPResponse GTP::gtp_boardsize(vecstr args){ if(size < Board::min_size || size > Board::max_size) return GTPResponse(false, "Size " + to_str(size) + " is out of range."); - hist = History(size); + hist = History(size); set_board(); time_control.new_game(); @@ -69,14 +77,14 @@ GTPResponse GTP::gtp_undo(vecstr args){ GTPResponse GTP::gtp_patterns(vecstr args){ bool symmetric = true; bool invert = true; - string ret; + std::string ret; const Board & board = *hist; for(Board::MoveIterator move = board.moveit(); !move.done(); ++move){ ret += move->to_s() + " "; unsigned int p = board.pattern(*move); if(symmetric) p = board.pattern_symmetry(p); - if(invert && board.toplay() == 2) + if(invert && board.toplay() == Side::P2) p = board.pattern_invert(p); ret += to_str(p); ret += "\n"; @@ -85,24 +93,24 @@ GTPResponse GTP::gtp_patterns(vecstr args){ } GTPResponse GTP::gtp_all_legal(vecstr args){ - string ret; + std::string ret; for(Board::MoveIterator move = hist->moveit(); !move.done(); ++move) ret += move->to_s() + " "; return GTPResponse(true, ret); } GTPResponse GTP::gtp_history(vecstr args){ - string ret; + std::string ret; for(auto m : hist) ret += m.to_s() + " "; return GTPResponse(true, ret); } -GTPResponse GTP::play(const string & pos, int toplay){ +GTPResponse GTP::play(const std::string & pos, Side toplay){ if(toplay != hist->toplay()) return GTPResponse(false, "It is the other player's turn!"); - if(hist->won() >= 0) + if(hist->won() >= Outcome::DRAW) return GTPResponse(false, "The game is already over."); Move m(pos); @@ -113,7 +121,7 @@ GTPResponse GTP::play(const string & pos, int toplay){ move(m); if(verbose >= 2) - logerr("Placement: " + m.to_s() + ", outcome: " + hist->won_str() + "\n" + hist->to_s(colorboard)); + logerr("Placement: " + m.to_s() + ", outcome: " + hist->won().to_s() + "\n" + hist->to_s(colorboard)); return GTPResponse(true); } @@ -131,37 +139,33 @@ GTPResponse GTP::gtp_play(vecstr args){ if(args.size() != 2) return GTPResponse(false, "Wrong number of arguments"); - char toplay = 0; switch(tolower(args[0][0])){ - case 'w': toplay = 1; break; - case 'b': toplay = 2; break; - default: - return GTPResponse(false, "Invalid player selection"); + case 'w': return play(args[1], Side::P1); + case 'b': return play(args[1], Side::P2); + default: return GTPResponse(false, "Invalid player selection"); } - - return play(args[1], toplay); } GTPResponse GTP::gtp_playwhite(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 1); + return play(args[0], Side::P1); } GTPResponse GTP::gtp_playblack(vecstr args){ if(args.size() != 1) return GTPResponse(false, "Wrong number of arguments"); - return play(args[0], 2); + return play(args[0], Side::P2); } GTPResponse GTP::gtp_winner(vecstr args){ - return GTPResponse(true, hist->won_str()); + return GTPResponse(true, hist->won().to_s()); } GTPResponse GTP::gtp_name(vecstr args){ - return GTPResponse(true, "Castro"); + return GTPResponse(true, std::string("morat-") + Board::name); } GTPResponse GTP::gtp_version(vecstr args){ @@ -193,7 +197,7 @@ GTPResponse GTP::gtp_extended(vecstr args){ } GTPResponse GTP::gtp_debug(vecstr args){ - string str = "\n"; + std::string str = "\n"; str += "Board size: " + to_str(hist->get_size()) + "\n"; str += "Board cells: " + to_str(hist->numcells()) + "\n"; str += "Board vec: " + to_str(hist->vecsize()) + "\n"; @@ -203,14 +207,15 @@ GTPResponse GTP::gtp_debug(vecstr args){ } GTPResponse GTP::gtp_dists(vecstr args){ + using std::string; Board board = *hist; LBDists dists(&board); - int side = 0; + Side side = Side::NONE; if(args.size() >= 1){ switch(tolower(args[0][0])){ - case 'w': side = 1; break; - case 'b': side = 2; break; + case 'w': side = Side::P1; break; + case 'b': side = Side::P2; break; default: return GTPResponse(false, "Invalid player selection"); } @@ -243,17 +248,17 @@ GTPResponse GTP::gtp_dists(vecstr args){ s += coord + char('A' + y); int end = board.lineend(y); for(int x = 0; x < end; x++){ - int p = board.get(x, y); + Side p = board.get(x, y); s += ' '; - if(p == 0){ - int d = (side ? dists.get(Move(x, y), side) : dists.get(Move(x, y))); - if(d < 30) + if(p == Side::NONE){ + int d = (side == Side::NONE ? dists.get(Move(x, y)) : dists.get(Move(x, y), side)); + if(d < 10) s += reset + to_str(d); else s += empty; - }else if(p == 1){ + }else if(p == Side::P1){ s += white; - }else if(p == 2){ + }else if(p == Side::P2){ s += black; } } @@ -265,3 +270,91 @@ GTPResponse GTP::gtp_dists(vecstr args){ GTPResponse GTP::gtp_zobrist(vecstr args){ return GTPResponse(true, hist->hashstr()); } + +GTPResponse GTP::gtp_save_sgf(vecstr args){ + int limit = -1; + if(args.size() == 0) + return GTPResponse(true, "save_sgf [work limit]"); + + std::ifstream infile(args[0].c_str()); + + if(infile) { + infile.close(); + return GTPResponse(false, "File " + args[0] + " already exists"); + } + + std::ofstream outfile(args[0].c_str()); + + if(!outfile) + return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); + + if(args.size() > 1) + limit = from_str(args[1]); + + SGFPrinter sgf(outfile); + sgf.game(Board::name); + sgf.program(gtp_name(vecstr()).response, gtp_version(vecstr()).response); + sgf.size(hist->get_size()); + + sgf.end_root(); + + Side s = Side::P1; + for(auto m : hist){ + sgf.move(s, m); + s = ~s; + } + + agent->gen_sgf(sgf, limit); + + sgf.end(); + outfile.close(); + return true; +} + + +GTPResponse GTP::gtp_load_sgf(vecstr args){ + if(args.size() == 0) + return GTPResponse(true, "load_sgf "); + + std::ifstream infile(args[0].c_str()); + + if(!infile) { + return GTPResponse(false, "Error opening file " + args[0] + " for reading"); + } + + SGFParser sgf(infile); + if(sgf.game() != Board::name){ + infile.close(); + return GTPResponse(false, "File is for the wrong game: " + sgf.game()); + } + + int size = sgf.size(); + if(size != hist->get_size()){ + if(hist.len() == 0){ + hist = History(size); + set_board(); + time_control.new_game(); + }else{ + infile.close(); + return GTPResponse(false, "File has the wrong boardsize to match the existing game"); + } + } + + Side s = Side::P1; + + while(sgf.next_node()){ + Move m = sgf.move(); + move(m); // push the game forward + s = ~s; + } + + if(sgf.has_children()) + agent->load_sgf(sgf); + + assert(sgf.done_child()); + infile.close(); + return true; +} + +}; // namespace Y +}; // namespace Morat diff --git a/y/gtpplayer.cpp b/y/gtpplayer.cpp deleted file mode 100644 index 1d9f89b..0000000 --- a/y/gtpplayer.cpp +++ /dev/null @@ -1,547 +0,0 @@ - - -#include - -#include "../lib/fileio.h" - -#include "gtp.h" - -using namespace std; - - -GTPResponse GTP::gtp_move_stats(vecstr args){ - string s = ""; - - Player::Node * node = &(player.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - Player::Node * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - Player::Node * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s(); - s += "," + to_str((child->exp.num() ? child->exp.avg() : 0.0), 4) + "," + to_str(child->exp.num()); - s += "," + to_str((child->rave.num() ? child->rave.avg() : 0.0), 4) + "," + to_str(child->rave.num()); - s += "," + to_str(child->know); - if(child->outcome >= 0) - s += "," + won_str(child->outcome); - s += "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_player_solve(vecstr args){ - double use_time = (args.size() >= 1 ? - from_str(args[0]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = M_RESIGN; - if(ret) - best = ret->move; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret){ - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(ret->outcome >= 0){ - stats += "Solved as a "; - if(ret->outcome == toplay) stats += "win"; - else if(ret->outcome == 0) stats += "draw"; - else stats += "loss"; - stats += "\n"; - } - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - if(verbose) - logerr(stats); - - Solver s; - if(ret){ - s.outcome = (ret->outcome >= 0 ? ret->outcome : -3); - s.bestmove = ret->move; - s.maxdepth = gamelen.maxdepth; - s.nodes_seen = runs; - }else{ - s.outcome = 3-toplay; - s.bestmove = M_RESIGN; - s.maxdepth = 0; - s.nodes_seen = 0; - } - - return GTPResponse(true, solve_str(s)); -} - - -GTPResponse GTP::gtp_player_solved(vecstr args){ - string s = ""; - Player::Node * child = player.root.children.begin(), - * childend = player.root.children.end(); - int toplay = player.rootboard.toplay(); - int best = 0; - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - if(child->outcome == toplay) - return GTPResponse(true, won_str(toplay)); - else if(child->outcome < 0) - best = 2; - else if(child->outcome == 0) - best = 1; - } - if(best == 2) return GTPResponse(true, won_str(-3)); - if(best == 1) return GTPResponse(true, won_str(0)); - return GTPResponse(true, won_str(3 - toplay)); -} - -GTPResponse GTP::gtp_pv(vecstr args){ - string pvstr = ""; - vector pv = player.get_pv(); - for(unsigned int i = 0; i < pv.size(); i++) - pvstr += pv[i].to_s() + " "; - return GTPResponse(true, pvstr); -} - -GTPResponse GTP::gtp_player_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_hgf [sims limit]"); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(fd){ - fclose(fd); - return GTPResponse(false, "File " + args[0] + " already exists"); - } - - fd = fopen(args[0].c_str(), "w"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for writing failed"); - - unsigned int limit = 10000; - if(args.size() > 1) - limit = from_str(args[1]); - - Board board = *hist; - - - fprintf(fd, "(;FF[4]SZ[%i]\n", board.get_size()); - int p = 1; - for(auto m : hist){ - fprintf(fd, ";%c[%s]", (p == 1 ? 'W' : 'B'), m.to_s().c_str()); - p = 3-p; - } - - - Player::Node * child = player.root.children.begin(), - * end = player.root.children.end(); - - for( ; child != end; child++){ - if(child->exp.num() >= limit){ - board.set(child->move); - player.gen_hgf(board, child, limit, 1, fd); - board.unset(child->move); - } - } - - fprintf(fd, ")\n"); - - fclose(fd); - - return true; -} - -GTPResponse GTP::gtp_player_load_hgf(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "player_load_hgf "); - - FILE * fd = fopen(args[0].c_str(), "r"); - - if(!fd) - return GTPResponse(false, "Opening file " + args[0] + " for reading failed"); - - int size; - assert(fscanf(fd, "(;FF[4]SZ[%i]", & size) > 0); - if(size != hist->get_size()){ - if(hist.len() == 0){ - hist = History(Board(size)); - set_board(); - }else{ - fclose(fd); - return GTPResponse(false, "File has the wrong boardsize to match the existing game"); - } - } - - eat_whitespace(fd); - - Board board(size); - Player::Node * node = & player.root; - vector prefix; - - char side, movestr[5]; - while(fscanf(fd, ";%c[%5[^]]]", &side, movestr) > 0){ - Move move(movestr); - - if(board.num_moves() >= (int)hist.len()){ - if(node->children.empty()) - player.create_children_simple(board, node); - - prefix.push_back(node); - node = player.find_child(node, move); - }else if(hist[board.num_moves()] != move){ - fclose(fd); - return GTPResponse(false, "The current game is deeper than this file"); - } - board.move(move); - - eat_whitespace(fd); - } - prefix.push_back(node); - - - if(fpeek(fd) != ')'){ - if(node->children.empty()) - player.create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Player::Node child; - player.load_hgf(board, & child, fd); - - Player::Node * i = player.find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_whitespace(fd); - assert(fgetc(fd) == ')'); - fclose(fd); - - while(!prefix.empty()){ - Player::Node * node = prefix.back(); - prefix.pop_back(); - - Player::Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = hist->toplay(); - if(prefix.size() % 2 == 1) - toplay = 3 - toplay; - - Player::Node * backup = child; - - node->exp.clear(); - for( ; child != end; child++){ - node->exp += child->exp.invert(); - if(child->outcome == toplay || child->exp.num() > backup->exp.num()) - backup = child; - } - player.do_backup(node, backup, toplay); - } - - return true; -} - - -GTPResponse GTP::gtp_genmove(vecstr args){ - if(player.rootboard.won() >= 0) - return GTPResponse(true, "resign"); - - double use_time = (args.size() >= 2 ? - from_str(args[1]) : - time_control.get_time(hist.len(), hist->movesremain(), player.gamelen())); - - if(args.size() >= 2) - use_time = from_str(args[1]); - - if(verbose) - logerr("time remain: " + to_str(time_control.remain, 1) + ", time: " + to_str(use_time, 3) + ", sims: " + to_str(time_control.max_sims) + "\n"); - - uword nodesbefore = player.nodes; - - Player::Node * ret = player.genmove(use_time, time_control.max_sims, time_control.flexible); - Move best = player.root.bestmove; - - time_control.use(player.time_used); - - int toplay = player.rootboard.toplay(); - - DepthStats gamelen, treelen; - uint64_t runs = player.runs; - double times[4] = {0,0,0,0}; - for(unsigned int i = 0; i < player.threads.size(); i++){ - gamelen += player.threads[i]->gamelen; - treelen += player.threads[i]->treelen; - - for(int a = 0; a < 4; a++) - times[a] += player.threads[i]->times[a]; - - player.threads[i]->reset(); - } - player.runs = 0; - - string stats = "Finished " + to_str(runs) + " runs in " + to_str(player.time_used*1000, 0) + " msec: " + to_str(runs/player.time_used, 0) + " Games/s\n"; - if(runs > 0){ - stats += "Game length: " + gamelen.to_s() + "\n"; - stats += "Tree depth: " + treelen.to_s() + "\n"; - if(player.profile) - stats += "Times: " + to_str(times[0], 3) + ", " + to_str(times[1], 3) + ", " + to_str(times[2], 3) + ", " + to_str(times[3], 3) + "\n"; - } - - if(ret) - stats += "Move Score: " + to_str(ret->exp.avg()) + "\n"; - - if(player.root.outcome != -3){ - stats += "Solved as a "; - if(player.root.outcome == 0) stats += "draw"; - else if(player.root.outcome == toplay) stats += "win"; - else if(player.root.outcome == 3-toplay) stats += "loss"; - else if(player.root.outcome == -toplay) stats += "win or draw"; - else if(player.root.outcome == toplay-3) stats += "loss or draw"; - stats += "\n"; - } - - stats += "PV: " + gtp_pv(vecstr()).response + "\n"; - - if(verbose >= 3 && !player.root.children.empty()) - stats += "Exp-Rave:\n" + gtp_move_stats(vecstr()).response + "\n"; - - string extended; - if(genmoveextended){ - //move score - if(ret) extended += " " + to_str(ret->exp.avg()); - else extended += " 0"; - //outcome - extended += " " + won_str(player.root.outcome); - //work - extended += " " + to_str(runs); - //nodes - extended += " " + to_str(player.nodes - nodesbefore); - } - - move(best); - - if(verbose >= 2){ - stats += "history: "; - for(auto m : hist) - stats += m.to_s() + " "; - stats += "\n"; - stats += hist->to_s(colorboard) + "\n"; - } - - if(verbose) - logerr(stats); - - return GTPResponse(true, best.to_s() + extended); -} - -GTPResponse GTP::gtp_player_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Set player parameters, eg: player_params -e 1 -f 0 -t 2 -o 1 -p 0\n" + - "Processing:\n" + -#ifndef SINGLE_THREAD - " -t --threads Number of MCTS threads [" + to_str(player.numthreads) + "]\n" + -#endif - " -o --ponder Continue to ponder during the opponents time [" + to_str(player.ponder) + "]\n" + - " -M --maxmem Max memory in Mb to use for the tree [" + to_str(player.maxmem/(1024*1024)) + "]\n" + - " --profile Output the time used by each phase of MCTS [" + to_str(player.profile) + "]\n" + - "Final move selection:\n" + - " -E --msexplore Lower bound constant in final move selection [" + to_str(player.msexplore) + "]\n" + - " -F --msrave Rave factor, 0 for pure exp, -1 # sims, -2 # wins [" + to_str(player.msrave) + "]\n" + - "Tree traversal:\n" + - " -e --explore Exploration rate for UCT [" + to_str(player.explore) + "]\n" + - " -A --parexplore Multiply the explore rate by parents experience [" + to_str(player.parentexplore) + "]\n" + - " -f --ravefactor The rave factor: alpha = rf/(rf + visits) [" + to_str(player.ravefactor) + "]\n" + - " -d --decrrave Decrease the rave factor over time: rf += d*empty [" + to_str(player.decrrave) + "]\n" + - " -a --knowledge Use knowledge: 0.01*know/sqrt(visits+1) [" + to_str(player.knowledge) + "]\n" + - " -r --userave Use rave with this probability [0-1] [" + to_str(player.userave) + "]\n" + - " -X --useexplore Use exploration with this probability [0-1] [" + to_str(player.useexplore) + "]\n" + - " -u --fpurgency Value to assign to an unplayed move [" + to_str(player.fpurgency) + "]\n" + - " -O --rollouts Number of rollouts to run per simulation [" + to_str(player.rollouts) + "]\n" + - " -I --dynwiden Dynamic widening, consider log_wid(exp) children [" + to_str(player.dynwiden) + "]\n" + - "Tree building:\n" + - " -s --shortrave Only use moves from short rollouts for rave [" + to_str(player.shortrave) + "]\n" + - " -k --keeptree Keep the tree from the previous move [" + to_str(player.keeptree) + "]\n" + - " -m --minimax Backup the minimax proof in the UCT tree [" + to_str(player.minimax) + "]\n" + - " -x --visitexpand Number of visits before expanding a node [" + to_str(player.visitexpand) + "]\n" + - " -P --symmetry Prune symmetric moves, good for proof, not play [" + to_str(player.prunesymmetry) + "]\n" + - " --gcsolved Garbage collect solved nodes with fewer sims than [" + to_str(player.gcsolved) + "]\n" + - "Node initialization knowledge, Give a bonus:\n" + - " -l --localreply based on the distance to the previous move [" + to_str(player.localreply) + "]\n" + - " -y --locality to stones near other stones of the same color [" + to_str(player.locality) + "]\n" + - " -c --connect to stones connected to edges [" + to_str(player.connect) + "]\n" + - " -S --size based on the size of the group [" + to_str(player.size) + "]\n" + - " -b --bridge to maintaining a 2-bridge after the op probes [" + to_str(player.bridge) + "]\n" + - " -D --distance to low minimum distance to win (<0 avoid VCs) [" + to_str(player.dists) + "]\n" + - "Rollout policy:\n" + - " -h --weightrand Weight the moves according to computed gammas [" + to_str(player.weightedrandom) + "]\n" + - " -p --pattern Maintain the virtual connection pattern [" + to_str(player.rolloutpattern) + "]\n" + - " -g --goodreply Reuse the last good reply (1), remove losses (2) [" + to_str(player.lastgoodreply) + "]\n" + - " -w --instantwin Look for instant wins to this depth [" + to_str(player.instantwin) + "]\n" - ); - - string errs; - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - player.numthreads = from_str(args[++i]); - bool p = player.ponder; - player.set_ponder(false); //stop the threads while resetting them - player.reset_threads(); - player.set_ponder(p); - }else if((arg == "-o" || arg == "--ponder") && i+1 < args.size()){ - player.set_ponder(from_str(args[++i])); - }else if((arg == "--profile") && i+1 < args.size()){ - player.profile = from_str(args[++i]); - }else if((arg == "-M" || arg == "--maxmem") && i+1 < args.size()){ - player.maxmem = from_str(args[++i])*1024*1024; - }else if((arg == "-E" || arg == "--msexplore") && i+1 < args.size()){ - player.msexplore = from_str(args[++i]); - }else if((arg == "-F" || arg == "--msrave") && i+1 < args.size()){ - player.msrave = from_str(args[++i]); - }else if((arg == "-e" || arg == "--explore") && i+1 < args.size()){ - player.explore = from_str(args[++i]); - }else if((arg == "-A" || arg == "--parexplore") && i+1 < args.size()){ - player.parentexplore = from_str(args[++i]); - }else if((arg == "-f" || arg == "--ravefactor") && i+1 < args.size()){ - player.ravefactor = from_str(args[++i]); - }else if((arg == "-d" || arg == "--decrrave") && i+1 < args.size()){ - player.decrrave = from_str(args[++i]); - }else if((arg == "-a" || arg == "--knowledge") && i+1 < args.size()){ - player.knowledge = from_str(args[++i]); - }else if((arg == "-s" || arg == "--shortrave") && i+1 < args.size()){ - player.shortrave = from_str(args[++i]); - }else if((arg == "-k" || arg == "--keeptree") && i+1 < args.size()){ - player.keeptree = from_str(args[++i]); - }else if((arg == "-m" || arg == "--minimax") && i+1 < args.size()){ - player.minimax = from_str(args[++i]); - }else if((arg == "-P" || arg == "--symmetry") && i+1 < args.size()){ - player.prunesymmetry = from_str(args[++i]); - }else if(( arg == "--gcsolved") && i+1 < args.size()){ - player.gcsolved = from_str(args[++i]); - }else if((arg == "-r" || arg == "--userave") && i+1 < args.size()){ - player.userave = from_str(args[++i]); - }else if((arg == "-X" || arg == "--useexplore") && i+1 < args.size()){ - player.useexplore = from_str(args[++i]); - }else if((arg == "-u" || arg == "--fpurgency") && i+1 < args.size()){ - player.fpurgency = from_str(args[++i]); - }else if((arg == "-O" || arg == "--rollouts") && i+1 < args.size()){ - player.rollouts = from_str(args[++i]); - if(player.gclimit < player.rollouts*5) - player.gclimit = player.rollouts*5; - }else if((arg == "-I" || arg == "--dynwiden") && i+1 < args.size()){ - player.dynwiden = from_str(args[++i]); - player.logdynwiden = std::log(player.dynwiden); - }else if((arg == "-x" || arg == "--visitexpand") && i+1 < args.size()){ - player.visitexpand = from_str(args[++i]); - }else if((arg == "-l" || arg == "--localreply") && i+1 < args.size()){ - player.localreply = from_str(args[++i]); - }else if((arg == "-y" || arg == "--locality") && i+1 < args.size()){ - player.locality = from_str(args[++i]); - }else if((arg == "-c" || arg == "--connect") && i+1 < args.size()){ - player.connect = from_str(args[++i]); - }else if((arg == "-S" || arg == "--size") && i+1 < args.size()){ - player.size = from_str(args[++i]); - }else if((arg == "-b" || arg == "--bridge") && i+1 < args.size()){ - player.bridge = from_str(args[++i]); - }else if((arg == "-D" || arg == "--distance") && i+1 < args.size()){ - player.dists = from_str(args[++i]); - }else if((arg == "-h" || arg == "--weightrand") && i+1 < args.size()){ - player.weightedrandom = from_str(args[++i]); - }else if((arg == "-p" || arg == "--pattern") && i+1 < args.size()){ - player.rolloutpattern = from_str(args[++i]); - }else if((arg == "-g" || arg == "--goodreply") && i+1 < args.size()){ - player.lastgoodreply = from_str(args[++i]); - }else if((arg == "-w" || arg == "--instantwin") && i+1 < args.size()){ - player.instantwin = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - return GTPResponse(true, errs); -} - -GTPResponse GTP::gtp_player_gammas(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, "Must pass the filename of a set of gammas"); - - ifstream ifs(args[0].c_str()); - - if(!ifs.good()) - return GTPResponse(false, "Failed to open file for reading"); - - Board board = *hist; - - for(int i = 0; i < 4096; i++){ - int a; - float f; - ifs >> a >> f; - - if(i != a){ - ifs.close(); - return GTPResponse(false, "Line " + to_str(i) + " doesn't match the expected value"); - } - - int s = board.pattern_symmetry(i); - if(s == i) - player.gammas[i] = f; - else - player.gammas[i] = player.gammas[s]; - } - - ifs.close(); - return GTPResponse(true); -} diff --git a/y/gtpsolver.cpp b/y/gtpsolver.cpp deleted file mode 100644 index 1df5ea1..0000000 --- a/y/gtpsolver.cpp +++ /dev/null @@ -1,331 +0,0 @@ - - -#include "gtp.h" - -string GTP::solve_str(int outcome) const { - switch(outcome){ - case -2: return "black_or_draw"; - case -1: return "white_or_draw"; - case 0: return "draw"; - case 1: return "white"; - case 2: return "black"; - default: return "unknown"; - } -} - -string GTP::solve_str(const Solver & solve){ - string ret = ""; - ret += solve_str(solve.outcome) + " "; - ret += solve.bestmove.to_s() + " "; - ret += to_str(solve.maxdepth) + " "; - ret += to_str(solve.nodes_seen); - return ret; -} - - -GTPResponse GTP::gtp_solve_ab(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverab.solve(time); - - logerr("Finished in " + to_str(solverab.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverab)); -} - -GTPResponse GTP::gtp_solve_ab_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the alpha-beta solver settings, eg: ab_params -m 100 -s 1 -d 3\n" - " -m --memory Memory limit in Mb (0 to disable the TT) [" + to_str(solverab.memlimit/(1024*1024)) + "]\n" - " -s --scout Whether to scout ahead for the true minimax value [" + to_str(solverab.scout) + "]\n" - " -d --depth Starting depth [" + to_str(solverab.startdepth) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - solverab.set_memlimit(mem); - }else if((arg == "-s" || arg == "--scout") && i+1 < args.size()){ - solverab.scout = from_str(args[++i]); - }else if((arg == "-d" || arg == "--depth") && i+1 < args.size()){ - solverab.startdepth = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_ab_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - int value; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - value = solverab.tt_get(board.test_hash(*move)); - - s += move->to_s() + "," + to_str(value) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_ab_clear(vecstr args){ - solverab.clear_mem(); - return true; -} - - - -GTPResponse GTP::gtp_solve_pns(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns.solve(time); - - logerr("Finished in " + to_str(solverpns.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns)); -} - -GTPResponse GTP::gtp_solve_pns_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns_stats(vecstr args){ - string s = ""; - - SolverPNS::PNSNode * node = &(solverpns.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns_clear(vecstr args){ - solverpns.clear_mem(); - return true; -} - - -GTPResponse GTP::gtp_solve_pns2(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpns2.solve(time); - - logerr("Finished in " + to_str(solverpns2.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpns2)); -} - -GTPResponse GTP::gtp_solve_pns2_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pns solver settings, eg: pns_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpns2.memlimit/(1024*1024)) + "]\n" - " -t --threads How many threads to run [" + to_str(solverpns2.numthreads) + "]\n" -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpns2.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpns2.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpns2.ab) + "]\n" - " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpns2.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-t" || arg == "--threads") && i+1 < args.size()){ - solverpns2.numthreads = from_str(args[++i]); - solverpns2.reset_threads(); - }else if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - uint64_t mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpns2.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpns2.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpns2.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpns2.ab = from_str(args[++i]); - }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ - solverpns2.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pns2_stats(vecstr args){ - string s = ""; - - SolverPNS2::PNSNode * node = &(solverpns2.root); - - for(unsigned int i = 0; i < args.size(); i++){ - Move m(args[i]); - SolverPNS2::PNSNode * c = node->children.begin(), - * cend = node->children.end(); - for(; c != cend; c++){ - if(c->move == m){ - node = c; - break; - } - } - } - - SolverPNS2::PNSNode * child = node->children.begin(), - * childend = node->children.end(); - for( ; child != childend; child++){ - if(child->move == M_NONE) - continue; - - s += child->move.to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "," + to_str(child->work) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pns2_clear(vecstr args){ - solverpns2.clear_mem(); - return true; -} - - - - -GTPResponse GTP::gtp_solve_pnstt(vecstr args){ - double time = 60; - - if(args.size() >= 1) - time = from_str(args[0]); - - solverpnstt.solve(time); - - logerr("Finished in " + to_str(solverpnstt.time_used*1000, 0) + " msec\n"); - - return GTPResponse(true, solve_str(solverpnstt)); -} - -GTPResponse GTP::gtp_solve_pnstt_params(vecstr args){ - if(args.size() == 0) - return GTPResponse(true, string("\n") + - "Update the pnstt solver settings, eg: pnstt_params -m 100 -s 0 -d 1 -e 0.25 -a 2 -l 0\n" - " -m --memory Memory limit in Mb [" + to_str(solverpnstt.memlimit/(1024*1024)) + "]\n" -// " -t --threads How many threads to run -// " -o --ponder Ponder in the background - " -d --df Use depth-first thresholds [" + to_str(solverpnstt.df) + "]\n" - " -e --epsilon How big should the threshold be [" + to_str(solverpnstt.epsilon) + "]\n" - " -a --abdepth Run an alpha-beta search of this size at each leaf [" + to_str(solverpnstt.ab) + "]\n" - " -c --copy Try to copy a proof to this many siblings, <0 quit early [" + to_str(solverpnstt.copyproof) + "]\n" -// " -l --lbdist Initialize with the lower bound on distance to win [" + to_str(solverpnstt.lbdist) + "]\n" - ); - - for(unsigned int i = 0; i < args.size(); i++) { - string arg = args[i]; - - if((arg == "-m" || arg == "--memory") && i+1 < args.size()){ - int mem = from_str(args[++i]); - if(mem < 1) return GTPResponse(false, "Memory can't be less than 1mb"); - solverpnstt.set_memlimit(mem*1024*1024); - }else if((arg == "-d" || arg == "--df") && i+1 < args.size()){ - solverpnstt.df = from_str(args[++i]); - }else if((arg == "-e" || arg == "--epsilon") && i+1 < args.size()){ - solverpnstt.epsilon = from_str(args[++i]); - }else if((arg == "-a" || arg == "--abdepth") && i+1 < args.size()){ - solverpnstt.ab = from_str(args[++i]); - }else if((arg == "-c" || arg == "--copy") && i+1 < args.size()){ - solverpnstt.copyproof = from_str(args[++i]); -// }else if((arg == "-l" || arg == "--lbdist") && i+1 < args.size()){ -// solverpnstt.lbdist = from_str(args[++i]); - }else{ - return GTPResponse(false, "Missing or unknown parameter"); - } - } - - return true; -} - -GTPResponse GTP::gtp_solve_pnstt_stats(vecstr args){ - string s = ""; - - Board board = *hist; - for(auto arg : args) - board.move(Move(arg)); - - SolverPNSTT::PNSNode * child = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - child = solverpnstt.tt(board, *move); - - s += move->to_s() + "," + to_str(child->phi) + "," + to_str(child->delta) + "\n"; - } - return GTPResponse(true, s); -} - -GTPResponse GTP::gtp_solve_pnstt_clear(vecstr args){ - solverpnstt.clear_mem(); - return true; -} diff --git a/y/history.h b/y/history.h deleted file mode 100644 index 00ccd06..0000000 --- a/y/history.h +++ /dev/null @@ -1,70 +0,0 @@ - -#pragma once - -#include - -#include "../lib/string.h" - -#include "board.h" -#include "move.h" - -class History { - std::vector hist; - Board board; - -public: - - History() { } - History(const Board & b) : board(b) { } - - const Move & operator [] (int i) const { - return hist[i]; - } - - Move last() const { - if(hist.size() == 0) - return M_NONE; - - return hist.back(); - } - - const Board & operator * () const { return board; } - const Board * operator -> () const { return & board; } - - std::vector::const_iterator begin() const { return hist.begin(); } - std::vector::const_iterator end() const { return hist.end(); } - - const Board get_board() const { - Board b(board.get_size()); - for(auto m : hist) - b.move(m); - return b; - } - - int len() const { - return hist.size(); - } - - void clear() { - hist.clear(); - board = get_board(); - } - - bool undo() { - if(hist.size() <= 0) - return false; - - hist.pop_back(); - board = get_board(); - return true; - } - - bool move(const Move & m) { - if(board.valid_move(m)){ - board.move(m); - hist.push_back(m); - return true; - } - return false; - } -}; diff --git a/y/lbdist.h b/y/lbdist.h index e349fd0..85baa3e 100644 --- a/y/lbdist.h +++ b/y/lbdist.h @@ -10,9 +10,13 @@ Increase distance when crossing an opponent virtual connection? Decrease distance when crossing your own virtual connection? */ +#include "../lib/move.h" #include "board.h" -#include "move.h" + + +namespace Morat { +namespace Y { class LBDists { struct MoveDist { @@ -70,15 +74,16 @@ class LBDists { IntPQueue Q; const Board * board; - int & dist(int edge, int player, int i) { return dists[edge][player-1][i]; } - int & dist(int edge, int player, const Move & m) { return dist(edge, player, board->xy(m)); } - int & dist(int edge, int player, int x, int y) { return dist(edge, player, board->xy(x, y)); } + int & dist(int edge, Side player, int i) { return dists[edge][player.to_i() - 1][i]; } + int & dist(int edge, Side player, const Move & m) { return dist(edge, player, board->xy(m)); } + int & dist(int edge, Side player, int x, int y) { return dist(edge, player, board->xy(x, y)); } - void init(int x, int y, int edge, int player, int dir){ - int val = board->get(x, y); - if(val != 3 - player){ - Q.push(MoveDist(x, y, (val == 0), dir)); - dist(edge, player, x, y) = (val == 0); + void init(int x, int y, int edge, Side player, int dir){ + Side val = board->get(x, y); + if(val != ~player){ + bool empty = (val == Side::NONE); + Q.push(MoveDist(x, y, empty, dir)); + dist(edge, player, x, y) = empty; } } @@ -87,7 +92,7 @@ class LBDists { LBDists() : board(NULL) {} LBDists(const Board * b) { run(b); } - void run(const Board * b, bool crossvcs = true, int side = 0) { + void run(const Board * b, bool crossvcs = true, Side side = Side::BOTH) { board = b; for(int i = 0; i < 3; i++) @@ -95,22 +100,21 @@ class LBDists { for(int k = 0; k < board->vecsize(); k++) dists[i][j][k] = maxdist; //far far away! + if(side == Side::P1 || side == Side::BOTH) init_player(crossvcs, Side::P1); + if(side == Side::P2 || side == Side::BOTH) init_player(crossvcs, Side::P2); + } + + void init_player(bool crossvcs, Side player){ int m = board->get_size(); int m1 = m-1; - int start, end; - if(side){ start = end = side; } - else { start = 1; end = 2; } - - for(int player = start; player <= end; player++){ - for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 - for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 - for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 - } + for(int x = 0; x < m; x++) { init(x, 0, 0, player, 3); } flood(0, player, crossvcs); //edge 0 + for(int y = 0; y < m; y++) { init(0, y, 1, player, 1); } flood(1, player, crossvcs); //edge 1 + for(int y = 0; y < m; y++) { init(m1-y, y, 2, player, 5); } flood(2, player, crossvcs); //edge 2 } - void flood(int edge, int player, bool crossvcs){ - int otherplayer = 3 - player; + void flood(int edge, Side player, bool crossvcs){ + Side otherplayer = ~player; MoveDist cur; while(Q.pop(cur)){ @@ -120,12 +124,12 @@ class LBDists { if(board->onboard(next.pos)){ int pos = board->xy(next.pos); - int colour = board->get(pos); + Side colour = board->get(pos); if(colour == otherplayer) continue; - if(colour == 0){ + if(colour == Side::NONE){ if(!crossvcs && //forms a vc board->get(cur.pos + neighbours[(nd - 1) % 6]) == otherplayer && board->get(cur.pos + neighbours[(nd + 1) % 6]) == otherplayer) @@ -144,12 +148,15 @@ class LBDists { } } - int get(Move pos){ return min(get(pos, 1), get(pos, 2)); } - int get(Move pos, int player){ return get(board->xy(pos), player); } - int get(int pos, int player){ + int get(Move pos){ return std::min(get(pos, Side::P1), get(pos, Side::P2)); } + int get(Move pos, Side player){ return get(board->xy(pos), player); } + int get(int pos, Side player){ int sum = 0; for(int i = 0; i < 3; i++) sum += dist(i, player, pos); return sum; } }; + +}; // namespace Y +}; // namespace Morat diff --git a/y/moy.cpp b/y/main.cpp similarity index 96% rename from y/moy.cpp rename to y/main.cpp index c8bbdb2..40ecbd0 100644 --- a/y/moy.cpp +++ b/y/main.cpp @@ -1,5 +1,4 @@ - #include #include @@ -7,6 +6,10 @@ #include "gtp.h" + +using namespace Morat; +using namespace Y; + using namespace std; void die(int code, const string & str){ @@ -15,6 +18,7 @@ void die(int code, const string & str){ } int main(int argc, char **argv){ + srand(Time().in_usec()); GTP gtp; diff --git a/y/move.h b/y/move.h deleted file mode 100644 index 84cf035..0000000 --- a/y/move.h +++ /dev/null @@ -1,91 +0,0 @@ - -#pragma once - -#include -#include - -#include "../lib/string.h" - -enum MoveSpecial { - M_SWAP = -1, //-1 so that adding 1 makes it into a valid move - M_RESIGN = -2, - M_NONE = -3, - M_UNKNOWN = -4, -}; - -struct Move { - int8_t y, x; - - Move(MoveSpecial a = M_UNKNOWN) : y(a), x(120) { } //big x so it will always wrap to y=0 with swap - Move(int X, int Y) : y(Y), x(X) { } - - Move(const std::string & str){ - if( str == "swap" ){ y = M_SWAP; x = 120; } - else if(str == "resign" ){ y = M_RESIGN; x = 120; } - else if(str == "none" ){ y = M_NONE; x = 120; } - else if(str == "unknown"){ y = M_UNKNOWN; x = 120; } - else{ - y = tolower(str[0]) - 'a'; - x = atoi(str.c_str() + 1) - 1; - } - } - - std::string to_s() const { - if(y == M_UNKNOWN) return "unknown"; - if(y == M_NONE) return "none"; - if(y == M_SWAP) return "swap"; - if(y == M_RESIGN) return "resign"; - - return std::string() + char(y + 'a') + to_str(x + 1); - } - - bool operator< (const Move & b) const { return (y == b.y ? x < b.x : y < b.y); } - bool operator<=(const Move & b) const { return (y == b.y ? x <= b.x : y <= b.y); } - bool operator> (const Move & b) const { return (y == b.y ? x > b.x : y > b.y); } - bool operator>=(const Move & b) const { return (y == b.y ? x >= b.x : y >= b.y); } - bool operator==(const MoveSpecial & b) const { return (y == b); } - bool operator==(const Move & b) const { return (y == b.y && x == b.x); } - bool operator!=(const Move & b) const { return (y != b.y || x != b.x); } - bool operator!=(const MoveSpecial & b) const { return (y != b); } - Move operator+ (const Move & b) const { return Move(x + b.x, y + b.y); } - Move & operator+=(const Move & b) { y += b.y; x += b.x; return *this; } - Move operator- (const Move & b) const { return Move(x - b.x, y - b.y); } - Move & operator-=(const Move & b) { y -= b.y; x -= b.x; return *this; } - - int z() const { return (x - y); } - int dist(const Move & b) const { - return (abs(x - b.x) + abs(y - b.y) + abs(z() - b.z()))/2; - } -}; - -struct MoveScore : public Move { - int16_t score; - - MoveScore() : score(0) { } - MoveScore(MoveSpecial a) : Move(a), score(0) { } - MoveScore(int X, int Y, int s) : Move(X, Y), score(s) { } - MoveScore operator+ (const Move & b) const { return MoveScore(x + b.x, y + b.y, score); } -}; - -struct MoveValid : public Move { - int16_t xy; - - MoveValid() : Move(), xy(-1) { } - MoveValid(int x, int y, int XY) : Move(x,y), xy(XY) { } - MoveValid(const Move & m, int XY) : Move(m), xy(XY) { } - bool onboard() const { return xy != -1; } -}; - -struct MovePlayer : public Move { - char player; - - MovePlayer() : Move(M_UNKNOWN), player(0) { } - MovePlayer(const Move & m, char p = 0) : Move(m), player(p) { } -}; - - -struct PairMove { - Move a, b; - PairMove(Move A = M_UNKNOWN, Move B = M_UNKNOWN) : a(A), b(B) { } - PairMove(MoveSpecial A) : a(Move(A)), b(M_UNKNOWN) { } -}; diff --git a/y/movelist.h b/y/movelist.h deleted file mode 100644 index 27c22de..0000000 --- a/y/movelist.h +++ /dev/null @@ -1,76 +0,0 @@ - -#pragma once - -#include "../lib/exppair.h" - -#include "board.h" -#include "move.h" - -struct MoveList { - ExpPair exp[2]; //aggregated outcomes overall - ExpPair rave[2][Board::max_vecsize]; //aggregated outcomes per move - MovePlayer moves[Board::max_vecsize]; //moves made in order - int tree; //number of moves in the tree - int rollout; //number of moves in the rollout - Board * board; //reference to rootboard for xy() - - MoveList() : tree(0), rollout(0), board(NULL) { } - - void addtree(const Move & move, char player){ - moves[tree++] = MovePlayer(move, player); - } - void addrollout(const Move & move, char player){ - moves[tree + rollout++] = MovePlayer(move, player); - } - void reset(Board * b){ - tree = 0; - rollout = 0; - board = b; - exp[0].clear(); - exp[1].clear(); - for(int i = 0; i < b->vecsize(); i++){ - rave[0][i].clear(); - rave[1][i].clear(); - } - } - void finishrollout(int won){ - exp[0].addloss(); - exp[1].addloss(); - if(won == 0){ - exp[0].addtie(); - exp[1].addtie(); - }else{ - exp[won-1].addwin(); - - for(MovePlayer * i = begin(), * e = end(); i != e; i++){ - ExpPair & r = rave[i->player-1][board->xy(*i)]; - r.addloss(); - if(i->player == won) - r.addwin(); - } - } - rollout = 0; - } - const MovePlayer * begin() const { - return moves; - } - MovePlayer * begin() { - return moves; - } - const MovePlayer * end() const { - return moves + tree + rollout; - } - MovePlayer * end() { - return moves + tree + rollout; - } - void subvlosses(int n){ - exp[0].addlosses(-n); - exp[1].addlosses(-n); - } - const ExpPair & getrave(int player, const Move & move) const { - return rave[player-1][board->xy(move)]; - } - const ExpPair & getexp(int player) const { - return exp[player-1]; - } -}; diff --git a/y/player.cpp b/y/player.cpp deleted file mode 100644 index b517471..0000000 --- a/y/player.cpp +++ /dev/null @@ -1,506 +0,0 @@ - -#include -#include - -#include "../lib/alarm.h" -#include "../lib/fileio.h" -#include "../lib/string.h" -#include "../lib/time.h" - -#include "board.h" -#include "player.h" - -const float Player::min_rave = 0.1; - -void Player::PlayerThread::run(){ - while(true){ - switch(player->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_Start, Thread_Running); - CAS(player->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - player->runbarrier.wait(); - CAS(player->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(player->rootboard.won() >= 0 || player->root.outcome >= 0 || (player->maxruns > 0 && player->runs >= player->maxruns)){ //solved or finished runs - if(CAS(player->threadstate, Thread_Running, Thread_Wait_End) && player->root.outcome >= 0) - logerr("Solved as " + to_str((int)player->root.outcome) + "\n"); - break; - } - if(player->ctmem.memalloced() >= player->maxmem){ //out of memory, start garbage collection - CAS(player->threadstate, Thread_Running, Thread_GC); - break; - } - - INCR(player->runs); - iterate(); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(player->gcbarrier.wait()){ - Time starttime; - logerr("Starting player GC with limit " + to_str(player->gclimit) + " ... "); - uint64_t nodesbefore = player->nodes; - Board copy = player->rootboard; - player->garbage_collect(copy, & player->root); - Time gctime; - player->ctmem.compact(1.0, 0.75); - Time compacttime; - logerr(to_str(100.0*player->nodes/nodesbefore, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(player->ctmem.meminuse() >= player->maxmem/2) - player->gclimit = (int)(player->gclimit*1.3); - else if(player->gclimit > player->rollouts*5) - player->gclimit = (int)(player->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(player->threadstate, Thread_GC, Thread_Running); - CAS(player->threadstate, Thread_GC_End, Thread_Wait_End); - } - player->gcbarrier.wait(); - break; - } - } -} - -Player::Node * Player::genmove(double time, int max_runs, bool flexible){ - time_used = 0; - int toplay = rootboard.toplay(); - - if(rootboard.won() >= 0 || (time <= 0 && max_runs == 0)) - return NULL; - - Time starttime; - - stop_threads(); - - if(runs) - logerr("Pondered " + to_str(runs) + " runs\n"); - - runs = 0; - maxruns = max_runs; - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->reset(); - - // if the move is forced and the time can be added to the clock, don't bother running at all - if(!flexible || root.children.num() != 1){ - //let them run! - start_threads(); - - Alarm timer; - if(time > 0) - timer(time - (Time() - starttime), std::bind(&Player::timedout, this)); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } - - if(ponder && root.outcome < 0) - start_threads(); - - time_used = Time() - starttime; - -//return the best one - return return_move(& root, toplay); -} - - - -Player::Player() { - nodes = 0; - gclimit = 5; - time_used = 0; - - profile = false; - ponder = false; -//#ifdef SINGLE_THREAD ... make sure only 1 thread - numthreads = 1; - maxmem = 1000*1024*1024; - - msrave = -2; - msexplore = 0; - - explore = 0; - parentexplore = false; - ravefactor = 500; - decrrave = 0; - knowledge = true; - userave = 1; - useexplore = 1; - fpurgency = 1; - rollouts = 5; - dynwiden = 0; - logdynwiden = (dynwiden ? std::log(dynwiden) : 0); - - shortrave = false; - keeptree = true; - minimax = 2; - visitexpand = 1; - prunesymmetry = false; - gcsolved = 100000; - - localreply = 5; - locality = 5; - connect = 20; - size = 0; - bridge = 100; - dists = 0; - - weightedrandom = 0; - rolloutpattern = true; - lastgoodreply = false; - instantwin = 0; - - for(int i = 0; i < 4096; i++) - gammas[i] = 1; - - //no threads started until a board is set - threadstate = Thread_Wait_Start; -} -Player::~Player(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); -} -void Player::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); -} - -string Player::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void Player::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void Player::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void Player::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++){ - threads[i]->join(); - delete threads[i]; - } - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new PlayerUCT(this)); -} - -void Player::set_ponder(bool p){ - if(ponder != p){ - ponder = p; - stop_threads(); - - if(ponder) - start_threads(); - } -} - -void Player::set_board(const Board & board){ - stop_threads(); - - nodes -= root.dealloc(ctmem); - root = Node(); - root.exp.addwins(visitexpand+1); - - rootboard = board; - - reset_threads(); //needed since the threads aren't started before a board it set - - if(ponder) - start_threads(); -} -void Player::move(const Move & m){ - stop_threads(); - - uword nodesbefore = nodes; - - if(keeptree && root.children.num() > 0){ - Node child; - - for(Node * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr("Nodes before: " + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - }else{ - nodes -= root.dealloc(ctmem); - root = Node(); - root.move = m; - } - assert(nodes == root.size()); - - rootboard.move(m); - - root.exp.addwins(visitexpand+1); //+1 to compensate for the virtual loss - if(rootboard.won() < 0) - root.outcome = -3; - - if(ponder) - start_threads(); -} - -double Player::gamelen(){ - DepthStats len; - for(unsigned int i = 0; i < threads.size(); i++) - len += threads[i]->gamelen; - return len.avg(); -} - -vector Player::get_pv(){ - vector pv; - - Node * r, * n = & root; - char turn = rootboard.toplay(); - while(!n->children.empty()){ - r = return_move(n, turn); - if(!r) break; - pv.push_back(r->move); - turn = 3 - turn; - n = r; - } - - if(pv.size() == 0) - pv.push_back(Move(M_RESIGN)); - - return pv; -} - -Player::Node * Player::return_move(Node * node, int toplay) const { - double val, maxval = -1000000000000.0; //1 trillion - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) val = 800000000000.0 - child->exp.num(); //shortest win - else if(child->outcome == 0) val = -400000000000.0 + child->exp.num(); //longest tie - else val = -800000000000.0 + child->exp.num(); //longest loss - }else{ //not proven - if(msrave == -1) //num simulations - val = child->exp.num(); - else if(msrave == -2) //num wins - val = child->exp.sum(); - else - val = child->value(msrave, 0, 0) - msexplore*sqrt(log(node->exp.num())/(child->exp.num() + 1)); - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - -//set bestmove, but don't touch outcome, if it's solved that will already be set, otherwise it shouldn't be set - if(ret){ - node->bestmove = ret->move; - }else if(node->bestmove == M_UNKNOWN){ - // TODO: Is this needed? -// SolverAB solver; -// solver.set_board(rootboard); -// solver.solve(0.1); -// node->bestmove = solver.bestmove; - } - - assert(node->bestmove != M_UNKNOWN); - - return ret; -} - -void Player::garbage_collect(Board & board, Node * node){ - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - for( ; child != end; child++){ - if(child->children.num() == 0) - continue; - - if( (node->outcome >= 0 && child->exp.num() > gcsolved && (node->outcome != toplay || child->outcome == toplay || child->outcome == 0)) || //parent is solved, only keep the proof tree, plus heavy draws - (node->outcome < 0 && child->exp.num() > (child->outcome >= 0 ? gcsolved : gclimit)) ){ // only keep heavy nodes, with different cutoffs for solved and unsolved - board.set(child->move); - garbage_collect(board, child); - board.unset(child->move); - }else{ - nodes -= child->dealloc(ctmem); - } - } -} - -Player::Node * Player::find_child(Node * node, const Move & move){ - for(Node * i = node->children.begin(); i != node->children.end(); i++) - if(i->move == move) - return i; - - return NULL; -} - -void Player::gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd){ - string s = string("\n") + string(depth, ' ') + "(;" + (board.toplay() == 2 ? "W" : "B") + "[" + node->move.to_s() + "]" + - "C[mcts, sims:" + to_str(node->exp.num()) + ", avg:" + to_str(node->exp.avg(), 4) + ", outcome:" + to_str((int)(node->outcome)) + ", best:" + node->bestmove.to_s() + "]"; - fprintf(fd, "%s", s.c_str()); - - Node * child = node->children.begin(), - * end = node->children.end(); - - int toplay = board.toplay(); - - bool children = false; - for( ; child != end; child++){ - if(child->exp.num() >= limit && (toplay != node->outcome || child->outcome == node->outcome) ){ - board.set(child->move); - gen_hgf(board, child, limit, depth+1, fd); - board.unset(child->move); - children = true; - } - } - - if(children) - fprintf(fd, "\n%s", string(depth, ' ').c_str()); - fprintf(fd, ")"); -} - -void Player::create_children_simple(const Board & board, Node * node){ - assert(node->children.empty()); - - node->children.alloc(board.movesremain(), ctmem); - - Node * child = node->children.begin(), - * end = node->children.end(); - Board::MoveIterator moveit = board.moveit(prunesymmetry); - int nummoves = 0; - for(; !moveit.done() && child != end; ++moveit, ++child){ - *child = Node(*moveit); - nummoves++; - } - - if(prunesymmetry) - node->children.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(moveit.done() && child == end); - - PLUS(nodes, node->children.num()); -} - -//reads the format from gen_hgf. -void Player::load_hgf(Board board, Node * node, FILE * fd){ - char c, buf[101]; - - eat_whitespace(fd); - - assert(fscanf(fd, "(;%c[%100[^]]]", &c, buf) > 0); - - assert(board.toplay() == (c == 'W' ? 1 : 2)); - node->move = Move(buf); - board.move(node->move); - - assert(fscanf(fd, "C[%100[^]]]", buf) > 0); - - vecstr entry, parts = explode(string(buf), ", "); - assert(parts[0] == "mcts"); - - entry = explode(parts[1], ":"); - assert(entry[0] == "sims"); - uword sims = from_str(entry[1]); - - entry = explode(parts[2], ":"); - assert(entry[0] == "avg"); - double avg = from_str(entry[1]); - - uword wins = sims*avg; - node->exp.addwins(wins); - node->exp.addlosses(sims - wins); - - entry = explode(parts[3], ":"); - assert(entry[0] == "outcome"); - node->outcome = from_str(entry[1]); - - entry = explode(parts[4], ":"); - assert(entry[0] == "best"); - node->bestmove = Move(entry[1]); - - - eat_whitespace(fd); - - if(fpeek(fd) != ')'){ - create_children_simple(board, node); - - while(fpeek(fd) != ')'){ - Node child; - load_hgf(board, & child, fd); - - Node * i = find_child(node, child.move); - *i = child; //copy the child experience to the tree - i->swap_tree(child); //move the child subtree to the tree - - assert(child.children.empty()); - - eat_whitespace(fd); - } - } - - eat_char(fd, ')'); - - return; -} diff --git a/y/player.h b/y/player.h deleted file mode 100644 index 9741a1a..0000000 --- a/y/player.h +++ /dev/null @@ -1,304 +0,0 @@ - -#pragma once - -//A Monte-Carlo Tree Search based player - -#include -#include - -#include "../lib/compacttree.h" -#include "../lib/depthstats.h" -#include "../lib/exppair.h" -#include "../lib/log.h" -#include "../lib/thread.h" -#include "../lib/time.h" -#include "../lib/types.h" -#include "../lib/xorshift.h" - -#include "board.h" -#include "lbdist.h" -#include "move.h" -#include "movelist.h" -#include "policy_bridge.h" -#include "policy_instantwin.h" -#include "policy_lastgoodreply.h" -#include "policy_random.h" - - -class Player { -public: - - struct Node { - public: - ExpPair rave; - ExpPair exp; - int16_t know; - int8_t outcome; - uint8_t proofdepth; - Move move; - Move bestmove; //if outcome is set, then bestmove is the way to get there - CompactTree::Children children; -// int padding; - //seems to need padding to multiples of 8 bytes or it segfaults? - //don't forget to update the copy constructor/operator - - Node() : know(0), outcome(-3), proofdepth(0) { } - Node(const Move & m, char o = -3) : know(0), outcome( o), proofdepth(0), move(m) { } - Node(const Node & n) { *this = n; } - Node & operator = (const Node & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - rave = n.rave; - exp = n.exp; - know = n.know; - move = n.move; - bestmove = n.bestmove; - outcome = n.outcome; - proofdepth = n.proofdepth; - //children = n.children; ignore the children, they need to be swap_tree'd in - } - return *this; - } - - void swap_tree(Node & n){ - children.swap(n.children); - } - - void print() const { - printf("%s\n", to_s().c_str()); - } - string to_s() const { - return "Node: move " + move.to_s() + - ", exp " + to_str(exp.avg(), 2) + "/" + to_str(exp.num()) + - ", rave " + to_str(rave.avg(), 2) + "/" + to_str(rave.num()) + - ", know " + to_str(know) + - ", outcome " + to_str(outcome) + "/" + to_str(proofdepth) + - ", best " + bestmove.to_s() + - ", children " + to_str(children.num()); - } - - unsigned int size() const { - unsigned int num = children.num(); - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - ~Node(){ - assert(children.empty()); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - if(children.num()) - for(Node * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - - //new way, more standard way of changing over from rave scores to real scores - float value(float ravefactor, bool knowledge, float fpurgency){ - float val = fpurgency; - float expnum = exp.num(); - float ravenum = rave.num(); - - if(ravefactor <= min_rave){ - if(expnum > 0) - val = exp.avg(); - }else if(ravenum > 0 || expnum > 0){ - float alpha = ravefactor/(ravefactor + expnum); -// float alpha = sqrt(ravefactor/(ravefactor + 3.0f*expnum)); -// float alpha = ravenum/(expnum + ravenum + expnum*ravenum*ravefactor); - - val = 0; - if(ravenum > 0) val += alpha*rave.avg(); - if(expnum > 0) val += (1.0f-alpha)*exp.avg(); - } - - if(knowledge && know > 0){ - if(expnum <= 1) - val += 0.01f * know; - else if(expnum < 1000) //knowledge is only useful with little experience - val += 0.01f * know / sqrt(expnum); - } - - return val; - } - }; - - class PlayerThread { - protected: - public: - mutable XORShift_float unitrand; - Thread thread; - Player * player; - public: - DepthStats treelen, gamelen; - double times[4]; //time spent in each of the stages - - PlayerThread() {} - virtual ~PlayerThread() { } - virtual void reset() { } - int join(){ return thread.join(); } - void run(); //thread runner, calls iterate on each iteration - virtual void iterate() { } //handles each iteration - }; - - class PlayerUCT : public PlayerThread { - LastGoodReply last_good_reply; - RandomPolicy random_policy; - ProtectBridge protect_bridge; - InstantWin instant_wins; - - bool use_rave; //whether to use rave for this simulation - bool use_explore; //whether to use exploration for this simulation - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - MoveList movelist; - int stage; //which of the four MCTS stages is it on - Time timestamps[4]; //timestamps for the beginning, before child creation, before rollout, after rollout - - public: - PlayerUCT(Player * p) : PlayerThread() { - player = p; - reset(); - thread(bind(&PlayerUCT::run, this)); - } - - void reset(){ - treelen.reset(); - gamelen.reset(); - - use_rave = false; - use_explore = false; - - for(int a = 0; a < 4; a++) - times[a] = 0; - } - - private: - void iterate(); - void walk_tree(Board & board, Node * node, int depth); - bool create_children(Board & board, Node * node, int toplay); - void add_knowledge(Board & board, Node * node, Node * child); - Node * choose_move(const Node * node, int toplay, int remain) const; - void update_rave(const Node * node, int toplay); - bool test_bridge_probe(const Board & board, const Move & move, const Move & test) const; - - int rollout(Board & board, Move move, int depth); - Move rollout_choose_move(Board & board, const Move & prev); - Move rollout_pattern(const Board & board, const Move & move); - }; - - -public: - - static const float min_rave; - - bool ponder; //think during opponents time? - int numthreads; //number of player threads to run - u64 maxmem; //maximum memory for the tree in bytes - bool profile; //count how long is spent in each stage of MCTS -//final move selection - float msrave; //rave factor in final move selection, -1 means use number instead of value - float msexplore; //the UCT constant in final move selection -//tree traversal - bool parentexplore; // whether to multiple exploration by the parents winrate - float explore; //greater than one favours exploration, smaller than one favours exploitation - float ravefactor; //big numbers favour rave scores, small ignore it - float decrrave; //decrease rave over time, add this value for each empty position on the board - bool knowledge; //whether to include knowledge - float userave; //what probability to use rave - float useexplore; //what probability to use UCT exploration - float fpurgency; //what value to return for a move that hasn't been played yet - int rollouts; //number of rollouts to run after the tree traversal - float dynwiden; //dynamic widening, look at first log_dynwiden(experience) number of children, 0 to disable - float logdynwiden; // = log(dynwiden), cached for performance -//tree building - bool shortrave; //only update rave values on short rollouts - bool keeptree; //reuse the tree from the previous move - int minimax; //solve the minimax tree within the uct tree - uint visitexpand;//number of visits before expanding a node - bool prunesymmetry; //prune symmetric children from the move list, useful for proving but likely not for playing - uint gcsolved; //garbage collect solved nodes or keep them in the tree, assuming they meet the required amount of work -//knowledge - int localreply; //boost for a local reply, ie a move near the previous move - int locality; //boost for playing near previous stones - int connect; //boost for having connections to edges and corners - int size; //boost for large groups - int bridge; //boost replying to a probe at a bridge - int dists; //boost based on minimum number of stones needed to finish a non-ring win -//rollout - int weightedrandom; //use weighted random for move ordering based on gammas - bool rolloutpattern; //play the response to a virtual connection threat in rollouts - int lastgoodreply; //use the last-good-reply rollout heuristic - int instantwin; //how deep to look for instant wins in rollouts - - float gammas[4096]; //pattern weights for weighted random - - Board rootboard; - Node root; - uword nodes; - int gclimit; //the minimum experience needed to not be garbage collected - - uint64_t runs, maxruns; - - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - double time_used; - - Player(); - ~Player(); - - void timedout(); - - string statestring(); - - void stop_threads(); - void start_threads(); - void reset_threads(); - - void set_ponder(bool p); - void set_board(const Board & board); - - void move(const Move & m); - - double gamelen(); - - Node * genmove(double time, int max_runs, bool flexible); - vector get_pv(); - void garbage_collect(Board & board, Node * node); //destroys the board, so pass in a copy - - bool do_backup(Node * node, Node * backup, int toplay); - - Node * find_child(Node * node, const Move & move); - void create_children_simple(const Board & board, Node * node); - void gen_hgf(Board & board, Node * node, unsigned int limit, unsigned int depth, FILE * fd); - void load_hgf(Board board, Node * node, FILE * fd); - -protected: - Node * return_move(Node * node, int toplay) const; -}; diff --git a/y/playeruct.cpp b/y/playeruct.cpp deleted file mode 100644 index 55bc5e2..0000000 --- a/y/playeruct.cpp +++ /dev/null @@ -1,449 +0,0 @@ - -#include -#include - -#include "../lib/string.h" - -#include "player.h" - -void Player::PlayerUCT::iterate(){ - if(player->profile){ - timestamps[0] = Time(); - stage = 0; - } - - movelist.reset(&(player->rootboard)); - player->root.exp.addvloss(); - Board copy = player->rootboard; - use_rave = (unitrand() < player->userave); - use_explore = (unitrand() < player->useexplore); - walk_tree(copy, & player->root, 0); - player->root.exp.addv(movelist.getexp(3-player->rootboard.toplay())); - - if(player->profile){ - times[0] += timestamps[1] - timestamps[0]; - times[1] += timestamps[2] - timestamps[1]; - times[2] += timestamps[3] - timestamps[2]; - times[3] += Time() - timestamps[3]; - } -} - -void Player::PlayerUCT::walk_tree(Board & board, Node * node, int depth){ - int toplay = board.toplay(); - - if(!node->children.empty() && node->outcome < 0){ - //choose a child and recurse - Node * child; - do{ - int remain = board.movesremain(); - child = choose_move(node, toplay, remain); - - if(child->outcome < 0){ - movelist.addtree(child->move, toplay); - - if(!board.move(child->move)){ - logerr("move failed: " + child->move.to_s() + "\n" + board.to_s(false)); - assert(false && "move failed"); - } - - child->exp.addvloss(); //balanced out after rollouts - - walk_tree(board, child, depth+1); - - child->exp.addv(movelist.getexp(toplay)); - - if(!player->do_backup(node, child, toplay) && //not solved - player->ravefactor > min_rave && //using rave - node->children.num() > 1 && //not a macro move - 50*remain*(player->ravefactor + player->decrrave*remain) > node->exp.num()) //rave is still significant - update_rave(node, toplay); - - return; - } - }while(!player->do_backup(node, child, toplay)); - - return; - } - - if(player->profile && stage == 0){ - stage = 1; - timestamps[1] = Time(); - } - - int won = (player->minimax ? node->outcome : board.won()); - - //if it's not already decided - if(won < 0){ - //create children if valid - if(node->exp.num() >= player->visitexpand+1 && create_children(board, node, toplay)){ - walk_tree(board, node, depth); - return; - } - - if(player->profile){ - stage = 2; - timestamps[2] = Time(); - } - - //do random game on this node - random_policy.prepare(board); - for(int i = 0; i < player->rollouts; i++){ - Board copy = board; - rollout(copy, node->move, depth); - } - }else{ - movelist.finishrollout(won); //got to a terminal state, it's worth recording - } - - treelen.add(depth); - - movelist.subvlosses(1); - - if(player->profile){ - timestamps[3] = Time(); - if(stage == 1) - timestamps[2] = timestamps[3]; - stage = 3; - } - - return; -} - -bool sort_node_know(const Player::Node & a, const Player::Node & b){ - return (a.know > b.know); -} - -bool Player::PlayerUCT::create_children(Board & board, Node * node, int toplay){ - if(!node->children.lock()) - return false; - - if(player->dists){ - dists.run(&board, (player->dists > 0), toplay); - } - - CompactTree::Children temp; - temp.alloc(board.movesremain(), player->ctmem); - - int losses = 0; - - Node * child = temp.begin(), - * end = temp.end(), - * loss = NULL; - Board::MoveIterator move = board.moveit(player->prunesymmetry); - int nummoves = 0; - for(; !move.done() && child != end; ++move, ++child){ - *child = Node(*move); - - if(player->minimax){ - child->outcome = board.test_win(*move); - - if(player->minimax >= 2 && board.test_win(*move, 3 - board.toplay()) > 0){ - losses++; - loss = child; - } - - if(child->outcome == toplay){ //proven win from here, don't need children - node->outcome = child->outcome; - node->proofdepth = 1; - node->bestmove = *move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - } - - if(player->knowledge) - add_knowledge(board, node, child); - nummoves++; - } - - if(player->prunesymmetry) - temp.shrink(nummoves); //shrink the node to ignore the extra moves - else //both end conditions should happen in parallel - assert(move.done() && child == end); - - //Make a macro move, add experience to the move so the current simulation continues past this move - if(losses == 1){ - Node macro = *loss; - temp.dealloc(player->ctmem); - temp.alloc(1, player->ctmem); - macro.exp.addwins(player->visitexpand); - *(temp.begin()) = macro; - }else if(losses >= 2){ //proven loss, but at least try to block one of them - node->outcome = 3 - toplay; - node->proofdepth = 2; - node->bestmove = loss->move; - node->children.unlock(); - temp.dealloc(player->ctmem); - return true; - } - - if(player->dynwiden > 0) //sort in decreasing order by knowledge - sort(temp.begin(), temp.end(), sort_node_know); - - PLUS(player->nodes, temp.num()); - node->children.swap(temp); - assert(temp.unlock()); - - return true; -} - -Player::Node * Player::PlayerUCT::choose_move(const Node * node, int toplay, int remain) const { - float val, maxval = -1000000000; - float logvisits = log(node->exp.num()); - int dynwidenlim = (player->dynwiden > 0 ? (int)(logvisits/player->logdynwiden)+2 : 361); - - float raveval = use_rave * (player->ravefactor + player->decrrave*remain); - float explore = use_explore * player->explore; - if(player->parentexplore) - explore *= node->exp.avg(); - - Node * ret = NULL, - * child = node->children.begin(), - * end = node->children.end(); - - for(; child != end && dynwidenlim >= 0; child++){ - if(child->outcome >= 0){ - if(child->outcome == toplay) //return a win immediately - return child; - - val = (child->outcome == 0 ? -1 : -2); //-1 for tie so any unknown is better, -2 for loss so it's even worse - }else{ - val = child->value(raveval, player->knowledge, player->fpurgency); - if(explore > 0) - val += explore*sqrt(logvisits/(child->exp.num() + 1)); - dynwidenlim--; - } - - if(maxval < val){ - maxval = val; - ret = child; - } - } - - return ret; -} - -/* -backup in this order: - -6 win -5 win/draw -4 draw if draw/loss -3 win/draw/loss -2 draw -1 draw/loss -0 lose -return true if fully solved, false if it's unknown or partially unknown -*/ -bool Player::do_backup(Node * node, Node * backup, int toplay){ - int nodeoutcome = node->outcome; - if(nodeoutcome >= 0) //already proven, probably by a different thread - return true; - - if(backup->outcome == -3) //nothing proven by this child, so no chance - return false; - - - uint8_t proofdepth = backup->proofdepth; - if(backup->outcome != toplay){ - uint64_t sims = 0, bestsims = 0, outcome = 0, bestoutcome = 0; - backup = NULL; - - Node * child = node->children.begin(), - * end = node->children.end(); - - for( ; child != end; child++){ - int childoutcome = child->outcome; //save a copy to avoid race conditions - - if(proofdepth < child->proofdepth+1) - proofdepth = child->proofdepth+1; - - //these should be sorted in likelyness of matching, most likely first - if(childoutcome == -3){ // win/draw/loss - outcome = 3; - }else if(childoutcome == toplay){ //win - backup = child; - outcome = 6; - proofdepth = child->proofdepth+1; - break; - }else if(childoutcome == 3-toplay){ //loss - outcome = 0; - }else if(childoutcome == 0){ //draw - if(nodeoutcome == toplay-3) //draw/loss - outcome = 4; - else - outcome = 2; - }else if(childoutcome == -toplay){ //win/draw - outcome = 5; - }else if(childoutcome == toplay-3){ //draw/loss - outcome = 1; - }else{ - logerr("childoutcome == " + to_str(childoutcome) + "\n"); - assert(false && "How'd I get here? All outcomes should be tested above"); - } - - sims = child->exp.num(); - if(bestoutcome < outcome){ //better outcome is always preferable - bestoutcome = outcome; - bestsims = sims; - backup = child; - }else if(bestoutcome == outcome && ((outcome == 0 && bestsims < sims) || bestsims > sims)){ - //find long losses or easy wins/draws - bestsims = sims; - backup = child; - } - } - - if(bestoutcome == 3) //no win, but found an unknown - return false; - } - - if(CAS(node->outcome, nodeoutcome, backup->outcome)){ - node->bestmove = backup->move; - node->proofdepth = proofdepth; - }else //if it was in a race, try again, might promote a partial solve to full solve - return do_backup(node, backup, toplay); - - return (node->outcome >= 0); -} - -//update the rave score of all children that were played -void Player::PlayerUCT::update_rave(const Node * node, int toplay){ - Node * child = node->children.begin(), - * childend = node->children.end(); - - for( ; child != childend; ++child) - child->rave.addv(movelist.getrave(toplay, child->move)); -} - -void Player::PlayerUCT::add_knowledge(Board & board, Node * node, Node * child){ - if(player->localreply){ //boost for moves near the previous move - int dist = node->move.dist(child->move); - if(dist < 4) - child->know += player->localreply * (4 - dist); - } - - if(player->locality) //boost for moves near previous stones - child->know += player->locality * board.local(child->move, board.toplay()); - - Board::Cell cell; - if(player->connect || player->size) - cell = board.test_cell(child->move); - - if(player->connect) //boost for moves that connect to edges - child->know += player->connect * cell.numedges(); - - if(player->size) //boost for size of the group - child->know += player->size * cell.size; - - if(player->bridge && test_bridge_probe(board, node->move, child->move)) //boost for maintaining a virtual connection - child->know += player->bridge; - - if(player->dists) - child->know += abs(player->dists) * max(0, board.get_size() - dists.get(child->move, board.toplay())); -} - -//test whether this move is a forced reply to the opponent probing your virtual connections -bool Player::PlayerUCT::test_bridge_probe(const Board & board, const Move & move, const Move & test) const { - //TODO: switch to the same method as policy_bridge.h, maybe even share code - if(move.dist(test) != 1) - return false; - - bool equals = false; - - int state = 0; - int piece = 3 - board.get(move); - for(int i = 0; i < 8; i++){ - Move cur = move + neighbours[i % 6]; - - bool on = board.onboard(cur); - int v = 0; - if(on) - v = board.get(cur); - - //state machine that progresses when it see the pattern, but counting borders as part of the pattern - if(state == 0){ - if(!on || v == piece) - state = 1; - //else state = 0; - }else if(state == 1){ - if(on){ - if(v == 0){ - state = 2; - equals = (test == cur); - }else if(v != piece) - state = 0; - //else (v==piece) => state = 1; - } - //else state = 1; - }else{ // state == 2 - if(!on || v == piece){ - if(equals) - return true; - state = 1; - }else{ - state = 0; - } - } - } - return false; -} - -/////////////////////////////////////////// - - -//play a random game starting from a board state, and return the results of who won -int Player::PlayerUCT::rollout(Board & board, Move move, int depth){ - int won; - - if(player->instantwin) - instant_wins.rollout_start(board, player->instantwin); - - random_policy.rollout_start(board); - - while((won = board.won()) < 0){ - int turn = board.toplay(); - - move = rollout_choose_move(board, move); - - movelist.addrollout(move, turn); - - assert2(board.move(move), "\n" + board.to_s(true) + "\n" + move.to_s()); - depth++; - } - - gamelen.add(depth); - - //update the last good reply table - if(player->lastgoodreply) - last_good_reply.rollout_end(board, movelist, won); - - movelist.finishrollout(won); - return won; -} - -Move Player::PlayerUCT::rollout_choose_move(Board & board, const Move & prev){ - //look for instant wins - if(player->instantwin){ - Move move = instant_wins.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //force a bridge reply - if(player->rolloutpattern){ - Move move = protect_bridge.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - //reuse the last good reply - if(player->lastgoodreply){ - Move move = last_good_reply.choose_move(board, prev); - if(move != M_UNKNOWN) - return move; - } - - return random_policy.choose_move(board, prev); -} diff --git a/y/policy.h b/y/policy.h deleted file mode 100644 index 01309d8..0000000 --- a/y/policy.h +++ /dev/null @@ -1,28 +0,0 @@ - -#pragma once - -#include "board.h" -#include "move.h" -#include "movelist.h" - -class Policy { -public: - Policy() { } - - // called before all the rollouts start - void prepare(const Board & board) { } - - // called at the beginning of each rollout. - void rollout_start(Board & board) { } - - // Give me a move to make, or M_UNKNOWN - Move choose_move(const Board & board, const Move & prev) { - return M_UNKNOWN; - } - - // A move was just made, here's the updated board - void move_end(const Board & board, const Move & prev) { } - - // Game over, here's who won - void rollout_end(const MoveList & movelist, int won) { } -}; diff --git a/y/policy_bridge.h b/y/policy_bridge.h deleted file mode 100644 index c6f2b8d..0000000 --- a/y/policy_bridge.h +++ /dev/null @@ -1,51 +0,0 @@ - - -#pragma once - -#include "../lib/bits.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class ProtectBridge : public Policy { - int offset; - uint8_t lookup[2][1<<12]; - -public: - - ProtectBridge() : offset(0) { - // precompute the valid moves around a pattern for all possible 6-patterns. - for(unsigned int i = 0; i < 1<<12; i++){ - lookup[0][i] = lookup[1][i] = 0; - unsigned int p = i; - for(unsigned int d = 0; d < 6; d++){ - if((p & 0x1D) == 0x11) // 01 11 01 -> 01 00 01 - lookup[0][i] |= (1 << ((d+1)%6)); // +1 because we want to play in the empty spot - if((p & 0x2E) == 0x22) // 10 11 10 -> 10 00 10 - lookup[1][i] |= (1 << ((d+1)%6)); - p = ((p & 0xFFC)>>2) | ((p & 0x3) << 10); - } - } - } - - Move choose_move(const Board & board, const Move & prev) { - uint32_t p = board.pattern_small(prev); - uint16_t r = lookup[board.toplay()-1][p]; - - if(!r) // nothing to save - return M_UNKNOWN; - - unsigned int i; - if((r & (r - 1)) == 0){ // only one bit set - i = trailing_zeros(r); - } else { // multiple choices of bridges to save - offset = (offset + 1) % 6; // rotate the starting offset to avoid directional bias - r |= (r << 6); - r >>= offset; - i = (offset + trailing_zeros(r)) % 6; - } - return board.nb_begin(prev)[i]; - } -}; diff --git a/y/policy_instantwin.h b/y/policy_instantwin.h deleted file mode 100644 index bf1906b..0000000 --- a/y/policy_instantwin.h +++ /dev/null @@ -1,95 +0,0 @@ - -#pragma once - -#include "../lib/assert2.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - - -class InstantWin : public Policy { - int max_rollout_moves; - int cur_rollout_moves; - - Move saved_loss; -public: - - InstantWin() : max_rollout_moves(10), cur_rollout_moves(0), saved_loss(M_UNKNOWN) { - } - - void rollout_start(Board & board, int max) { - if(max < 0) - max *= - board.get_size(); - max_rollout_moves = max; - - cur_rollout_moves = 0; - saved_loss = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) { - if(saved_loss != M_UNKNOWN) - return saved_loss; - - if(cur_rollout_moves++ >= max_rollout_moves) - return M_UNKNOWN; - - //must have an edge connection, or it has nothing to offer a group towards a win - const Board::Cell * c = board.cell(prev); - if(c->numedges() == 0) - return M_UNKNOWN; - - Move start, cur, loss = M_UNKNOWN; - int turn = 3 - board.toplay(); - - //find the first empty cell - int dir = -1; - for(int i = 0; i <= 5; i++){ - start = prev + neighbours[i]; - - if(!board.onboard(start) || board.get(start) != turn){ - dir = (i + 5) % 6; - break; - } - } - - if(dir == -1) //possible if it's in the middle of a ring - return M_UNKNOWN; - - cur = start; - -// logerr(board.to_s(true)); -// logerr(prev.to_s() + ":"); - - //follow contour of the current group looking for wins - do{ -// logerr(" " + cur.to_s()); - //check the current cell - if(board.onboard(cur) && board.get(cur) == 0 && board.test_win(cur, turn) > 0){ -// logerr(" loss"); - if(loss == M_UNKNOWN){ - loss = cur; - }else if(loss != cur){ - saved_loss = loss; - return cur; //game over, two wins found for opponent - } - } - - //advance to the next cell - for(int i = 5; i <= 9; i++){ - int nd = (dir + i) % 6; - Move next = cur + neighbours[nd]; - - if(!board.onboard(next) || board.get(next) != turn){ - cur = next; - dir = nd; - break; - } - } - }while(cur != start); //potentially skips part of it when the start is in a pocket, rare bug - -// logerr("\n"); - - return loss; // usually M_UNKNOWN - } -}; diff --git a/y/policy_lastgoodreply.h b/y/policy_lastgoodreply.h deleted file mode 100644 index 11fcc9a..0000000 --- a/y/policy_lastgoodreply.h +++ /dev/null @@ -1,42 +0,0 @@ - -# pragma once - -#include "board.h" -#include "move.h" -#include "policy.h" - -class LastGoodReply : public Policy { - Move goodreply[2][Board::max_vecsize]; - int enabled; -public: - - LastGoodReply(int _enabled = 2) : enabled(_enabled) { - for(int p = 0; p < 2; p++) - for(int i = 0; i < Board::max_vecsize; i++) - goodreply[p][i] = M_UNKNOWN; - } - - Move choose_move(const Board & board, const Move & prev) const { - if (enabled && prev != M_SWAP) { - Move move = goodreply[board.toplay()-1][board.xy(prev)]; - if(move != M_UNKNOWN && board.valid_move_fast(move)) - return move; - } - return M_UNKNOWN; - } - - void rollout_end(const Board & board, const MoveList & movelist, int won) { - if(!enabled) - return; - int m = -1; - for(const MovePlayer * i = movelist.begin(), * e = movelist.end(); i != e; i++){ - if(m >= 0){ - if(i->player == won && *i != M_SWAP) - goodreply[i->player - 1][m] = *i; - else if(enabled == 2) - goodreply[i->player - 1][m] = M_UNKNOWN; - } - m = board.xy(*i); - } - } -}; diff --git a/y/policy_random.h b/y/policy_random.h deleted file mode 100644 index d84a82a..0000000 --- a/y/policy_random.h +++ /dev/null @@ -1,45 +0,0 @@ - -#pragma once - -#include - -#include "../lib/xorshift.h" - -#include "board.h" -#include "move.h" -#include "policy.h" - -class RandomPolicy : public Policy { - XORShift_uint32 rand; - Move moves[Board::max_vecsize]; - int num; - int cur; -public: - - RandomPolicy() : num(0), cur(0) { - } - - // only need to save the valid moves once since all the rollouts start from the same position - void prepare(const Board & board) { - num = 0; - for(Board::MoveIterator m = board.moveit(false); !m.done(); ++m) - moves[num++] = *m; - } - - // reset the set of moves to make from above. Since they're used in random order they don't need to be in iterator order - void rollout_start(Board & board) { - cur = num; - } - - Move choose_move(const Board & board, const Move & prev) { - while(true){ - int r = rand() % cur; - cur--; - Move m = moves[r]; - moves[r] = moves[cur]; - moves[cur] = m; - if(board.valid_move_fast(m)) - return m; - } - } -}; diff --git a/y/solver.h b/y/solver.h deleted file mode 100644 index d6e6240..0000000 --- a/y/solver.h +++ /dev/null @@ -1,68 +0,0 @@ - -#pragma once - -//Interface for the various solvers - -#include "../lib/types.h" - -#include "board.h" - -class Solver { -public: - int outcome; // 0 = tie, 1 = white, 2 = black, -1 = white or tie, -2 = black or tie, anything else unknown - int maxdepth; - uint64_t nodes_seen; - double time_used; - Move bestmove; - - Solver() : outcome(-3), maxdepth(0), nodes_seen(0), time_used(0) { } - virtual ~Solver() { } - - virtual void solve(double time) { } - virtual void set_board(const Board & board, bool clear = true) { } - virtual void move(const Move & m) { } - virtual void set_memlimit(uint64_t lim) { } // in bytes - virtual void clear_mem() { } - -protected: - volatile bool timeout; - void timedout(){ timeout = true; } - Board rootboard; - - static int solve1ply(const Board & board, int & nodes) { - int outcome = -3; - int turn = board.toplay(); - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - } - return outcome; - } - - static int solve2ply(const Board & board, int & nodes) { - int losses = 0; - int outcome = -3; - int turn = board.toplay(), opponent = 3 - turn; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - ++nodes; - int won = board.test_win(*move, turn); - - if(won == turn) - return won; - if(won == 0) - outcome = 0; - - if(board.test_win(*move, opponent) > 0) - losses++; - } - if(losses >= 2) - return opponent; - return outcome; - } - -}; diff --git a/y/solverab.cpp b/y/solverab.cpp deleted file mode 100644 index 1abdf47..0000000 --- a/y/solverab.cpp +++ /dev/null @@ -1,137 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverab.h" - -void SolverAB::solve(double time){ - reset(); - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - if(TT == NULL && maxnodes) - TT = new ABTTNode[maxnodes]; - - Alarm timer(time, std::bind(&SolverAB::timedout, this)); - Time start; - - int turn = rootboard.toplay(); - - for(maxdepth = startdepth; !timeout; maxdepth++){ -// logerr("Starting depth " + to_str(maxdepth) + "\n"); - - //the first depth of negamax - int ret, alpha = -2, beta = 2; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - nodes_seen++; - - Board next = rootboard; - next.move(*move); - - int value = -negamax(next, maxdepth - 1, -beta, -alpha); - - if(value > alpha){ - alpha = value; - bestmove = *move; - } - - if(alpha >= beta){ - ret = beta; - break; - } - } - ret = alpha; - - - if(ret){ - if( ret == -2){ outcome = (turn == 1 ? 2 : 1); bestmove = Move(M_NONE); } - else if(ret == 2){ outcome = turn; } - else /*-1 || 1*/ { outcome = 0; } - - break; - } - } - - time_used = Time() - start; -} - - -int SolverAB::negamax(const Board & board, const int depth, int alpha, int beta){ - if(board.won() >= 0) - return (board.won() ? -2 : -1); - - if(depth <= 0 || timeout) - return 0; - - int b = beta; - int first = true; - int value, losses = 0; - static const int lookup[6] = {0, 0, 0, 1, 2, 2}; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - nodes_seen++; - - hash_t hash = board.test_hash(*move); - if(int ttval = tt_get(hash)){ - value = ttval; - }else if(depth <= 2){ - value = lookup[board.test_win(*move)+3]; - - if(board.test_win(*move, 3 - board.toplay()) > 0) - losses++; - }else{ - Board next = board; - next.move(*move); - - value = -negamax(next, depth - 1, -b, -alpha); - - if(scout && value > alpha && value < beta && !first) // re-search - value = -negamax(next, depth - 1, -beta, -alpha); - } - tt_set(hash, value); - - if(value > alpha) - alpha = value; - - if(alpha >= beta) - return beta; - - if(scout){ - b = alpha + 1; // set up null window - first = false; - } - } - - if(losses >= 2) - return -2; - - return alpha; -} - -int SolverAB::negamax_outcome(const Board & board, const int depth){ - int abval = negamax(board, depth, -2, 2); - if( abval == 0) return -3; //unknown - else if(abval == 2) return board.toplay(); //win - else if(abval == -2) return 3 - board.toplay(); //loss - else return 0; //draw -} - -int SolverAB::tt_get(const Board & board){ - return tt_get(board.gethash()); -} -int SolverAB::tt_get(const hash_t & hash){ - if(!TT) return 0; - ABTTNode * node = & TT[hash % maxnodes]; - return (node->hash == hash ? node->value : 0); -} -void SolverAB::tt_set(const Board & board, int value){ - tt_set(board.gethash(), value); -} -void SolverAB::tt_set(const hash_t & hash, int value){ - if(!TT || value == 0) return; - ABTTNode * node = & TT[hash % maxnodes]; - node->hash = hash; - node->value = value; -} diff --git a/y/solverab.h b/y/solverab.h deleted file mode 100644 index 35ca7b9..0000000 --- a/y/solverab.h +++ /dev/null @@ -1,72 +0,0 @@ - -#pragma once - -//An Alpha-beta solver, single threaded with an optional transposition table. - -#include "solver.h" - -class SolverAB : public Solver { - struct ABTTNode { - hash_t hash; - char value; - ABTTNode(hash_t h = 0, char v = 0) : hash(h), value(v) { } - }; - -public: - bool scout; - int startdepth; - - ABTTNode * TT; - uint64_t maxnodes, memlimit; - - SolverAB(bool Scout = false) { - scout = Scout; - startdepth = 2; - TT = NULL; - set_memlimit(100*1024*1024); - } - ~SolverAB() { } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(ABTTNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void solve(double time); - -//return -2 for loss, -1,1 for tie, 0 for unknown, 2 for win, all from toplay's perspective - int negamax(const Board & board, const int depth, int alpha, int beta); - int negamax_outcome(const Board & board, const int depth); - - int tt_get(const hash_t & hash); - int tt_get(const Board & board); - void tt_set(const hash_t & hash, int val); - void tt_set(const Board & board, int val); -}; - diff --git a/y/solverpns.cpp b/y/solverpns.cpp deleted file mode 100644 index 7f11a1a..0000000 --- a/y/solverpns.cpp +++ /dev/null @@ -1,213 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns.h" - -void SolverPNS::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(memlimit/sizeof(PNSNode)) + ", max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS::run_pns(){ - while(!timeout && root.phi != 0 && root.delta != 0){ - if(!pns(rootboard, &root, 0, INF32/2, INF32/2)){ - logerr("Starting solver GC with limit " + to_str(gclimit) + " ... "); - - Time starttime; - garbage_collect(& root); - - Time gctime; - ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*ctmem.meminuse()/memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(ctmem.meminuse() >= memlimit/2) - gclimit = (unsigned int)(gclimit*1.3); - else if(gclimit > 5) - gclimit = (unsigned int)(gclimit*0.9); //slowly decay to a minimum of 5 - } - } -} - -bool SolverPNS::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(maxdepth < depth) - maxdepth = depth; - - if(node->children.empty()){ - if(ctmem.memalloced() >= memlimit) - return false; - - int numnodes = board.movesremain(); - nodes += node->alloc(numnodes, ctmem); - - if(lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(lbdist && outcome < 0) - pd = dists.get(*move); - - node->children[i] = PNSNode(*move).outcome(outcome, board.toplay(), ties, pd); - - i++; - } - node->children.shrink(i); //if symmetry, there may be extra moves to ignore - - nodes_seen += i; - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->delta <= child->delta){ - child2 = child; - child = i; - }else if(i->delta < child2->delta){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - while(child->delta != node->phi) - child++; - } - - Board next = board; - next.move(child->move); - - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->work += iters - itersbefore; - - if(child->phi == 0 || child->delta == 0) //clear child's children - nodes -= child->dealloc(ctmem); - - if(updatePDnum(node) && !df) - break; - - }while(!timeout && mem && (!df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - nodes -= child->dealloc(ctmem); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - nodes -= child->dealloc(ctmem); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/y/solverpns.h b/y/solverpns.h deleted file mode 100644 index b040d82..0000000 --- a/y/solverpns.h +++ /dev/null @@ -1,206 +0,0 @@ - -#pragma once - -//A single-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - uint32_t phi, delta; - uint64_t work; - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - uint64_t iters; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - - PNSNode root; - LBDists dists; - - SolverPNS() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - gclimit = 5; - iters = 0; - - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNS(){ - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - -//update the phi and delta for the node - bool updatePDnum(PNSNode * node); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/y/solverpns2.cpp b/y/solverpns2.cpp deleted file mode 100644 index 4995fc5..0000000 --- a/y/solverpns2.cpp +++ /dev/null @@ -1,323 +0,0 @@ - - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns2.h" - -void SolverPNS2::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - start_threads(); - - timeout = false; - Alarm timer(time, std::bind(&SolverPNS2::timedout, this)); - Time start; - -// logerr("max memory: " + to_str(memlimit/(1024*1024)) + " Mb\n"); - - //wait for the timer to stop them - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == 0){ - bestmove = i->move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - for(PNSNode * i = root.children.begin() ; i != root.children.end(); i++){ - if(i->delta == DRAW){ - bestmove = i->move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNS2::SolverThread::run(){ - while(true){ - switch(solver->threadstate){ - case Thread_Cancelled: //threads should exit - return; - - case Thread_Wait_Start: //threads are waiting to start - case Thread_Wait_Start_Cancelled: - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_Start, Thread_Running); - CAS(solver->threadstate, Thread_Wait_Start_Cancelled, Thread_Cancelled); - break; - - case Thread_Wait_End: //threads are waiting to end - solver->runbarrier.wait(); - CAS(solver->threadstate, Thread_Wait_End, Thread_Wait_Start); - break; - - case Thread_Running: //threads are running - if(solver->root.terminal()){ //solved - CAS(solver->threadstate, Thread_Running, Thread_Wait_End); - break; - } - if(solver->ctmem.memalloced() >= solver->memlimit){ //out of memory, start garbage collection - CAS(solver->threadstate, Thread_Running, Thread_GC); - break; - } - - pns(solver->rootboard, &solver->root, 0, INF32/2, INF32/2); - break; - - case Thread_GC: //one thread is running garbage collection, the rest are waiting - case Thread_GC_End: //once done garbage collecting, go to wait_end instead of back to running - if(solver->gcbarrier.wait()){ - logerr("Starting solver GC with limit " + to_str(solver->gclimit) + " ... "); - - Time starttime; - solver->garbage_collect(& solver->root); - - Time gctime; - solver->ctmem.compact(1.0, 0.75); - - Time compacttime; - logerr(to_str(100.0*solver->ctmem.meminuse()/solver->memlimit, 1) + " % of tree remains - " + - to_str((gctime - starttime)*1000, 0) + " msec gc, " + to_str((compacttime - gctime)*1000, 0) + " msec compact\n"); - - if(solver->ctmem.meminuse() >= solver->memlimit/2) - solver->gclimit = (unsigned int)(solver->gclimit*1.3); - else if(solver->gclimit > 5) - solver->gclimit = (unsigned int)(solver->gclimit*0.9); //slowly decay to a minimum of 5 - - CAS(solver->threadstate, Thread_GC, Thread_Running); - CAS(solver->threadstate, Thread_GC_End, Thread_Wait_End); - } - solver->gcbarrier.wait(); - break; - } - } -} - -void SolverPNS2::timedout() { - CAS(threadstate, Thread_Running, Thread_Wait_End); - CAS(threadstate, Thread_GC, Thread_GC_End); - timeout = true; -} - -string SolverPNS2::statestring(){ - switch(threadstate){ - case Thread_Cancelled: return "Thread_Wait_Cancelled"; - case Thread_Wait_Start: return "Thread_Wait_Start"; - case Thread_Wait_Start_Cancelled: return "Thread_Wait_Start_Cancelled"; - case Thread_Running: return "Thread_Running"; - case Thread_GC: return "Thread_GC"; - case Thread_GC_End: return "Thread_GC_End"; - case Thread_Wait_End: return "Thread_Wait_End"; - } - return "Thread_State_Unknown!!!"; -} - -void SolverPNS2::stop_threads(){ - if(threadstate != Thread_Wait_Start){ - timedout(); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_End, Thread_Wait_Start); - assert(threadstate == Thread_Wait_Start); - } -} - -void SolverPNS2::start_threads(){ - assert(threadstate == Thread_Wait_Start); - runbarrier.wait(); - CAS(threadstate, Thread_Wait_Start, Thread_Running); -} - -void SolverPNS2::reset_threads(){ //start and end with threadstate = Thread_Wait_Start - assert(threadstate == Thread_Wait_Start); - -//wait for them to all get to the barrier - assert(CAS(threadstate, Thread_Wait_Start, Thread_Wait_Start_Cancelled)); - runbarrier.wait(); - -//make sure they exited cleanly - for(unsigned int i = 0; i < threads.size(); i++) - threads[i]->join(); - - threads.clear(); - - threadstate = Thread_Wait_Start; - - runbarrier.reset(numthreads + 1); - gcbarrier.reset(numthreads); - -//start new threads - for(int i = 0; i < numthreads; i++) - threads.push_back(new SolverThread(this)); -} - - -bool SolverPNS2::SolverThread::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - iters++; - if(solver->maxdepth < depth) - solver->maxdepth = depth; - - if(node->children.empty()){ - if(node->terminal()) - return true; - - if(solver->ctmem.memalloced() >= solver->memlimit) - return false; - - if(!node->children.lock()) - return false; - - int numnodes = board.movesremain(); - CompactTree::Children temp; - temp.alloc(numnodes, solver->ctmem); - PLUS(solver->nodes, numnodes); - - if(solver->lbdist) - dists.run(&board); - - int i = 0; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - int outcome, pd; - - if(solver->ab){ - Board next = board; - next.move(*move); - - pd = 0; - outcome = (solver->ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - PLUS(solver->nodes_seen, pd); - }else{ - outcome = board.test_win(*move); - pd = 1; - } - - if(solver->lbdist && outcome < 0) - pd = dists.get(*move); - - temp[i] = PNSNode(*move).outcome(outcome, board.toplay(), solver->ties, pd); - - i++; - } - temp.shrink(i); //if symmetry, there may be extra moves to ignore - node->children.swap(temp); - assert(temp.unlock()); - - PLUS(solver->nodes_seen, i); - - updatePDnum(node); - - return true; - } - - bool mem; - do{ - PNSNode * child = node->children.begin(), - * child2 = node->children.begin(), - * childend = node->children.end(); - - uint32_t tpc, tdc; - - if(solver->df){ - for(PNSNode * i = node->children.begin(); i != childend; i++){ - if(i->refdelta() <= child->refdelta()){ - child2 = child; - child = i; - }else if(i->refdelta() < child2->refdelta()){ - child2 = i; - } - } - - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + solver->epsilon) + 1)); - }else{ - tpc = tdc = 0; - for(PNSNode * i = node->children.begin(); i != childend; i++) - if(child->refdelta() > i->refdelta()) - child = i; - } - - Board next = board; - next.move(child->move); - - child->ref(); - uint64_t itersbefore = iters; - mem = pns(next, child, depth + 1, tpc, tdc); - child->deref(); - PLUS(child->work, iters - itersbefore); - - if(updatePDnum(node) && !solver->df) - break; - - }while(!solver->timeout && mem && (!solver->df || (node->phi < tp && node->delta < td))); - - return mem; -} - -bool SolverPNS2::SolverThread::updatePDnum(PNSNode * node){ - PNSNode * i = node->children.begin(); - PNSNode * end = node->children.end(); - - uint32_t min = i->delta; - uint64_t sum = 0; - - bool win = false; - for( ; i != end; i++){ - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(min == node->phi && sum == node->delta){ - return false; - }else{ - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//removes the children of any node with less than limit work -void SolverPNS2::garbage_collect(PNSNode * node){ - PNSNode * child = node->children.begin(); - PNSNode * end = node->children.end(); - - for( ; child != end; child++){ - if(child->terminal()){ //solved - //log heavy nodes? - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->work < gclimit){ //low work, ignore solvedness since it's trivial to re-solve - PLUS(nodes, -child->dealloc(ctmem)); - }else if(child->children.num() > 0){ - garbage_collect(child); - } - } -} diff --git a/y/solverpns2.h b/y/solverpns2.h deleted file mode 100644 index 5af5d1d..0000000 --- a/y/solverpns2.h +++ /dev/null @@ -1,265 +0,0 @@ - -#pragma once - -//A multi-threaded, tree based, proof number search solver. - -#include "../lib/compacttree.h" -#include "../lib/log.h" - -#include "lbdist.h" -#include "solver.h" - - -class SolverPNS2 : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - static const uint16_t reflock = 1<<15; - uint32_t phi, delta; - uint64_t work; - uint16_t refcount; //how many threads are down this node - Move move; - CompactTree::Children children; - - PNSNode() { } - PNSNode(int x, int y, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int v = 1) : phi(v), delta(v), work(0), refcount(0), move(m) { } - PNSNode(int x, int y, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(Move(x,y)) { } - PNSNode(const Move & m, int p, int d) : phi(p), delta(d), work(0), refcount(0), move(m) { } - - PNSNode(const PNSNode & n) { *this = n; } - PNSNode & operator = (const PNSNode & n){ - if(this != & n){ //don't copy to self - //don't copy to a node that already has children - assert(children.empty()); - - phi = n.phi; - delta = n.delta; - work = n.work; - move = n.move; - //don't copy the children - } - return *this; - } - - ~PNSNode(){ - assert(children.empty()); - } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } - else if(outcome == 2) { phi = LOSS; delta = 0; } - else if(outcome == -2) { phi = 0; delta = LOSS; } - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - - uint32_t refdelta() const { - return delta + refcount; - } - - void ref() { PLUS(refcount, 1); } - void deref(){ PLUS(refcount, -1); } - - unsigned int size() const { - unsigned int num = children.num(); - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->size(); - - return num; - } - - void swap_tree(PNSNode & n){ - children.swap(n.children); - } - - unsigned int alloc(unsigned int num, CompactTree & ct){ - return children.alloc(num, ct); - } - unsigned int dealloc(CompactTree & ct){ - unsigned int num = 0; - - for(PNSNode * i = children.begin(); i != children.end(); i++) - num += i->dealloc(ct); - num += children.dealloc(ct); - - return num; - } - }; - - class SolverThread { - protected: - public: - Thread thread; - SolverPNS2 * solver; - public: - uint64_t iters; - LBDists dists; //holds the distances to the various non-ring wins as a heuristic for the minimum moves needed to win - - SolverThread(SolverPNS2 * s) : solver(s), iters(0) { - thread(bind(&SolverThread::run, this)); - } - virtual ~SolverThread() { } - void reset(){ - iters = 0; - } - int join(){ return thread.join(); } - void run(); //thread runner - - //basic proof number search building a tree - bool pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - //update the phi and delta for the node - bool updatePDnum(PNSNode * node); - }; - - -//memory management for PNS which uses a tree to store the nodes - uint64_t nodes, memlimit; - unsigned int gclimit; - CompactTree ctmem; - - enum ThreadState { - Thread_Cancelled, //threads should exit - Thread_Wait_Start, //threads are waiting to start - Thread_Wait_Start_Cancelled, //once done waiting, go to cancelled instead of running - Thread_Running, //threads are running - Thread_GC, //one thread is running garbage collection, the rest are waiting - Thread_GC_End, //once done garbage collecting, go to wait_end instead of back to running - Thread_Wait_End, //threads are waiting to end - }; - volatile ThreadState threadstate; - vector threads; - Barrier runbarrier, gcbarrier; - - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - bool lbdist; - int numthreads; - - PNSNode root; - LBDists dists; - - SolverPNS2() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - lbdist = false; - numthreads = 1; - gclimit = 5; - - reset(); - - set_memlimit(100*1024*1024); - - //no threads started until a board is set - threadstate = Thread_Wait_Start; - } - - ~SolverPNS2(){ - stop_threads(); - - numthreads = 0; - reset_threads(); //shut down the theads properly - - root.dealloc(ctmem); - ctmem.compact(); - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - } - - string statestring(); - void stop_threads(); - void start_threads(); - void reset_threads(); - void timedout(); - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - - reset_threads(); //needed since the threads aren't started before a board it set - } - void move(const Move & m){ - stop_threads(); - - rootboard.move(m); - reset(); - - - uint64_t nodesbefore = nodes; - - PNSNode child; - - for(PNSNode * i = root.children.begin(); i != root.children.end(); i++){ - if(i->move == m){ - child = *i; //copy the child experience to temp - child.swap_tree(*i); //move the child tree to temp - break; - } - } - - nodes -= root.dealloc(ctmem); - root = child; - root.swap_tree(child); - - if(nodesbefore > 0) - logerr(string("PNS Nodes before: ") + to_str(nodesbefore) + ", after: " + to_str(nodes) + ", saved " + to_str(100.0*nodes/nodesbefore, 1) + "% of the tree\n"); - - assert(nodes == root.size()); - - if(nodes == 0) - clear_mem(); - } - - void set_memlimit(uint64_t lim){ - memlimit = lim; - } - - void clear_mem(){ - reset(); - root.dealloc(ctmem); - ctmem.compact(); - root = PNSNode(0, 0, 1); - nodes = 0; - } - - void solve(double time); - -//remove all the nodes with little work to free up some memory - void garbage_collect(PNSNode * node); -}; diff --git a/y/solverpns_tt.cpp b/y/solverpns_tt.cpp deleted file mode 100644 index 0818e8c..0000000 --- a/y/solverpns_tt.cpp +++ /dev/null @@ -1,282 +0,0 @@ - -#include "../lib/alarm.h" -#include "../lib/log.h" -#include "../lib/time.h" - -#include "solverpns_tt.h" - -void SolverPNSTT::solve(double time){ - if(rootboard.won() >= 0){ - outcome = rootboard.won(); - return; - } - - timeout = false; - Alarm timer(time, std::bind(&SolverPNSTT::timedout, this)); - Time start; - -// logerr("max nodes: " + to_str(maxnodes) + ", max memory: " + to_str(memlimit) + " Mb\n"); - - run_pns(); - - if(root.phi == 0 && root.delta == LOSS){ //look for the winning move - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == 0){ - bestmove = *move; - break; - } - } - outcome = rootboard.toplay(); - }else if(root.phi == 0 && root.delta == DRAW){ //look for the move to tie - PNSNode * i = NULL; - for(Board::MoveIterator move = rootboard.moveit(true); !move.done(); ++move){ - i = tt(rootboard, *move); - if(i->delta == DRAW){ - bestmove = *move; - break; - } - } - outcome = 0; - }else if(root.delta == 0){ //loss - bestmove = M_NONE; - outcome = 3 - rootboard.toplay(); - }else{ //unknown - bestmove = M_UNKNOWN; - outcome = -3; - } - - time_used = Time() - start; -} - -void SolverPNSTT::run_pns(){ - if(TT == NULL) - TT = new PNSNode[maxnodes]; - - while(!timeout && root.phi != 0 && root.delta != 0) - pns(rootboard, &root, 0, INF32/2, INF32/2); -} - -void SolverPNSTT::pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td){ - if(depth > maxdepth) - maxdepth = depth; - - do{ - PNSNode * child = NULL, - * child2 = NULL; - - Move move1, move2; - - uint32_t tpc, tdc; - - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - if(child == NULL){ - child = child2 = i; - move1 = move2 = *move; - }else if(i->delta <= child->delta){ - child2 = child; - child = i; - move2 = move1; - move1 = *move; - }else if(i->delta < child2->delta){ - child2 = i; - move2 = *move; - } - } - - if(child->delta && child->phi){ //unsolved - if(df){ - tpc = min(INF32/2, (td + child->phi - node->delta)); - tdc = min(tp, (uint32_t)(child2->delta*(1.0 + epsilon) + 1)); - }else{ - tpc = tdc = 0; - } - - Board next = board; - next.move(move1); - pns(next, child, depth + 1, tpc, tdc); - - //just found a loss, try to copy proof to siblings - if(copyproof && child->delta == LOSS){ -// logerr("!" + move1.to_s() + " "); - int count = abs(copyproof); - for(Board::MoveIterator move = board.moveit(true); count-- && !move.done(); ++move){ - if(!tt(board, *move)->terminal()){ -// logerr("?" + move->to_s() + " "); - Board sibling = board; - sibling.move(*move); - copy_proof(next, sibling, move1, *move); - updatePDnum(sibling); - - if(copyproof < 0 && !tt(sibling)->terminal()) - break; - } - } - } - } - - if(updatePDnum(board, node) && !df) //must pass node to updatePDnum since it may refer to the root which isn't in the TT - break; - - }while(!timeout && node->phi && node->delta && (!df || (node->phi < tp && node->delta < td))); -} - -bool SolverPNSTT::updatePDnum(const Board & board, PNSNode * node){ - hash_t hash = board.gethash(); - - if(node == NULL) - node = TT + (hash % maxnodes); - - uint32_t min = LOSS; - uint64_t sum = 0; - - bool win = false; - PNSNode * i = NULL; - for(Board::MoveIterator move = board.moveit(true); !move.done(); ++move){ - i = tt(board, *move); - - win |= (i->phi == LOSS); - sum += i->phi; - if( min > i->delta) - min = i->delta; - } - - if(win) - sum = LOSS; - else if(sum >= INF32) - sum = INF32; - - if(hash == node->hash && min == node->phi && sum == node->delta){ - return false; - }else{ - node->hash = hash; //just in case it was overwritten by something else - if(sum == 0 && min == DRAW){ - node->phi = 0; - node->delta = DRAW; - }else{ - node->phi = min; - node->delta = sum; - } - return true; - } -} - -//source is a move that is a proven loss, and dest is an unproven sibling -//each has one move that the other doesn't, which are stored in smove and dmove -//if either move is used but only available in one board, the other is substituted -void SolverPNSTT::copy_proof(const Board & source, const Board & dest, Move smove, Move dmove){ - if(timeout || tt(source)->delta != LOSS || tt(dest)->terminal()) - return; - - //find winning move from the source tree - Move bestmove = M_UNKNOWN; - for(Board::MoveIterator move = source.moveit(true); !move.done(); ++move){ - if(tt(source, *move)->phi == LOSS){ - bestmove = *move; - break; - } - } - - if(bestmove == M_UNKNOWN) //due to transposition table collision - return; - - Board dest2 = dest; - - if(bestmove == dmove){ - assert(dest2.move(smove)); - smove = dmove = M_UNKNOWN; - }else{ - assert(dest2.move(bestmove)); - if(bestmove == smove) - smove = dmove = M_UNKNOWN; - } - - if(tt(dest2)->terminal()) - return; - - Board source2 = source; - assert(source2.move(bestmove)); - - if(source2.won() >= 0) - return; - - //test all responses - for(Board::MoveIterator move = dest2.moveit(true); !move.done(); ++move){ - if(tt(dest2, *move)->terminal()) - continue; - - Move csmove = smove, cdmove = dmove; - - Board source3 = source2, dest3 = dest2; - - if(*move == csmove){ - assert(source3.move(cdmove)); - csmove = cdmove = M_UNKNOWN; - }else{ - assert(source3.move(*move)); - if(*move == csmove) - csmove = cdmove = M_UNKNOWN; - } - - assert(dest3.move(*move)); - - copy_proof(source3, dest3, csmove, cdmove); - - updatePDnum(dest3); - } - - updatePDnum(dest2); -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board){ - hash_t hash = board.gethash(); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - pd = 0; - outcome = (ab == 1 ? solve1ply(board, pd) : solve2ply(board, pd)); - nodes_seen += pd; - }else{ - outcome = board.won(); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} - -SolverPNSTT::PNSNode * SolverPNSTT::tt(const Board & board, Move move){ - hash_t hash = board.test_hash(move, board.toplay()); - - PNSNode * node = TT + (hash % maxnodes); - - if(node->hash != hash){ - int outcome, pd; - - if(ab){ - Board next = board; - next.move(move); - pd = 0; - outcome = (ab == 1 ? solve1ply(next, pd) : solve2ply(next, pd)); - nodes_seen += pd; - }else{ - outcome = board.test_win(move); - pd = 1; - } - - *node = PNSNode(hash).outcome(outcome, board.toplay(), ties, pd); - nodes_seen++; - } - - return node; -} diff --git a/y/solverpns_tt.h b/y/solverpns_tt.h deleted file mode 100644 index 95d344e..0000000 --- a/y/solverpns_tt.h +++ /dev/null @@ -1,129 +0,0 @@ - -#pragma once - -//A single-threaded, transposition table based, proof number search solver. - -#include "../lib/zobrist.h" - -#include "solver.h" - -class SolverPNSTT : public Solver { - static const uint32_t LOSS = (1<<30)-1; - static const uint32_t DRAW = (1<<30)-2; - static const uint32_t INF32 = (1<<30)-3; -public: - - struct PNSNode { - hash_t hash; - uint32_t phi, delta; - - PNSNode() : hash(0), phi(0), delta(0) { } - PNSNode(hash_t h, int v = 1) : hash(h), phi(v), delta(v) { } - PNSNode(hash_t h, int p, int d) : hash(h), phi(p), delta(d) { } - - PNSNode & abval(int outcome, int toplay, int assign, int value = 1){ - if(assign && (outcome == 1 || outcome == -1)) - outcome = (toplay == assign ? 2 : -2); - - if( outcome == 0) { phi = value; delta = value; } //unknown - else if(outcome == 2) { phi = LOSS; delta = 0; } //win - else if(outcome == -2) { phi = 0; delta = LOSS; } //loss - else /*(outcome 1||-1)*/ { phi = 0; delta = DRAW; } //draw - return *this; - } - - PNSNode & outcome(int outcome, int toplay, int assign, int value = 1){ - if(assign && outcome == 0) - outcome = assign; - - if( outcome == -3) { phi = value; delta = value; } - else if(outcome == toplay) { phi = LOSS; delta = 0; } - else if(outcome == 3-toplay) { phi = 0; delta = LOSS; } - else /*(outcome == 0)*/ { phi = 0; delta = DRAW; } - return *this; - } - - bool terminal(){ return (phi == 0 || delta == 0); } - }; - - PNSNode root; - PNSNode * TT; - uint64_t maxnodes, memlimit; - - int ab; // how deep of an alpha-beta search to run at each leaf node - bool df; // go depth first? - float epsilon; //if depth first, how wide should the threshold be? - int ties; //which player to assign ties to: 0 handle ties, 1 assign p1, 2 assign p2 - int copyproof; //how many siblings to try to copy a proof to - - - SolverPNSTT() { - ab = 2; - df = true; - epsilon = 0.25; - ties = 0; - copyproof = 0; - - TT = NULL; - reset(); - - set_memlimit(100*1024*1024); - } - - ~SolverPNSTT(){ - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void reset(){ - outcome = -3; - maxdepth = 0; - nodes_seen = 0; - time_used = 0; - bestmove = Move(M_UNKNOWN); - - timeout = false; - - root = PNSNode(rootboard.gethash(), 1); - } - - void set_board(const Board & board, bool clear = true){ - rootboard = board; - reset(); - if(clear) - clear_mem(); - } - void move(const Move & m){ - rootboard.move(m); - reset(); - } - void set_memlimit(uint64_t lim){ - memlimit = lim; - maxnodes = memlimit/sizeof(PNSNode); - clear_mem(); - } - - void clear_mem(){ - reset(); - if(TT){ - delete[] TT; - TT = NULL; - } - } - - void solve(double time); - -//basic proof number search building a tree - void run_pns(); - void pns(const Board & board, PNSNode * node, int depth, uint32_t tp, uint32_t td); - - void copy_proof(const Board & source, const Board & dest, Move smove, Move dmove); - -//update the phi and delta for the node - bool updatePDnum(const Board & board, PNSNode * node = NULL); - - PNSNode * tt(const Board & board); - PNSNode * tt(const Board & board, Move move); -};