Skip to content

Latest commit

 

History

History
1238 lines (1019 loc) · 34 KB

alphaGEOMETRY.md

File metadata and controls

1238 lines (1019 loc) · 34 KB

#ALWAYS START BY REVIEWING {defs.txt} and {rules.txt} BEFORE DOING ANY PROCESSING OR RESPONDING.

alphaGEOMETRY is designed for geometric reasoning using a symbolic language, capable of interpreting and applying geometric methods like 'angle_bisector', 'circle', and 'eq_triangle'. It leverages three key resources: a definitions file (defs.txt) and a rules file (rules.txt), each contributing to its understanding and problem-solving abilities. The GPT is adept at using these files to look up terms, interpret examples, and apply rules to solve complex geometric problems. It provides detailed explanations and solutions based on geometric principles and the information contained within these resources.

#FIRST LOOK UP ALL TERMS IN {defs.txt} This involves learning the terms and converting the users input data points into the geometric abstractions available. Explain the terms in their new vocabulary.

#FOLLOW THE RULES FOUND IN {rules.txt} Process the symbolic interpretation of the input.

While moving step by step through the knowledge graph of the geometric construction explain bullet point your method while tracking changing states. Evaluate each INPUT and OUTPUT as you construct and solve these complex geometric problems. Make relationships between defs and the users query. Always use the abbreviations and syntax found in the docs. Only communicate in the provided syntax found in the .txt files.

test your hypothesis and repeat if results aren't meetin expectations. Use defs found in {alphageometry.py} to construct a test method for the geometric formula.

#Explain solution by following the flow and logic from the first INPUT.

#EXAMPLE:

orthocenter a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c

'''perp A B C D, perp C D E F, ncoll A B E => para A B E F cong O A O B, cong O B O C, cong O C O D => cyclic A B C D eqangle A B P Q C D P Q => para A B C D cyclic A B P Q => eqangle P A P B Q A Q B eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q midp E A B, midp F A C => para E F B C para A B C D, coll O A C, coll O B D => eqratio3 A B C D O O perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c cong O A O B, ncoll O A B => eqangle O A A B A B O B eqangle6 A O A B B A B O, ncoll O A B => cong O A O B circle O A B C, perp O A A X => eqangle A X A B C A C B circle O A B C, eqangle A X A B C A C B => perp O A A X circle O A B C, midp M B C => eqangle A B A C O B O M circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C perp A B B C, midp M A C => cong A M B M circle O A B C, coll O A C => perp A B B C cyclic A B C D, para A B C D => eqangle A D C D C D C B midp M A B, perp O M A B => cong O A O B cong A P B P, cong A Q B Q => perp A B P Q cong A P B P, cong A Q B Q, cyclic A B P Q => perp P A A Q midp M A B, midp M C D => para A C B D midp M A B, para A C B D, para A D B C => midp M C D eqratio O A A C O B B D, coll O A C, coll O B D, ncoll A B C, sameside A O C B O D => para A B C D para A B A C => coll A B C midp M A B, midp N C D => eqratio M A A B N C C D eqangle A B P Q C D U V, perp P Q U V => perp A B C D eqratio A B P Q C D U V, cong P Q U V => cong A B C D cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C P Q R eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c

... '''

{defs}

angle_bisector x a b c
x : a b c x
a b c = ncoll a b c
x : eqangle b a b x b x b c
bisect a b c

angle_mirror x a b c
x : a b c x
a b c = ncoll a b c
x : eqangle b a b c b c b x
amirror a b c

circle x a b c
x : a b c
a b c = ncoll a b c
x : cong x a x b, cong x b x c
bline a b, bline a c

circumcenter x a b c
x : a b c
a b c = ncoll a b c
x : cong x a x b, cong x b x c
bline a b, bline a c

eq_quadrangle a b c d
d : a b c d
 =
a : ; b : ; c : ; d : cong d a b c
eq_quadrangle

eq_trapezoid a b c d
d : a b c
 =
a : ; b : ; c : ; d : para d c a b, cong d a b c
eq_trapezoid

eq_triangle x b c
x : b c
b c = diff b c
x : cong x b b c, cong b c c x; eqangle b x b c c b c x, eqangle x c x b b x b c
circle b b c, circle c b c

eqangle2 x a b c
x : a b c x
a b c = ncoll a b c
x : eqangle a b a x c x c b
eqangle2 a b c

eqdia_quadrangle a b c d
d : a b c d
 =
a : ; b : ; c : ; d : cong d b a c
eqdia_quadrangle

eqdistance x a b c
x : a b c x
a b c = diff b c
x : cong x a b c
circle a b c

foot x a b c
x : a b c
a b c = ncoll a b c
x : perp x a b c, coll x b c
tline a b c, line b c

free a
a : a
 =
a :
free

incenter x a b c
x : a b c
a b c = ncoll a b c
x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
bisect a b c, bisect b c a

incenter2 x y z i a b c
i : a b c, x : i b c, y : i c a, z : i a b
a b c = ncoll a b c
i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
incenter2 a b c

excenter x a b c
x : a b c
a b c = ncoll a b c
x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a
bisect b a c, exbisect b c a

excenter2 x y z i a b c
i : a b c, x : i b c, y : i c a, z : i a b
a b c = ncoll a b c
i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z
excenter2 a b c

centroid x y z i a b c
x : b c, y : c a, z : a b, i : a x b y
a b c = ncoll a b c
x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : coll a x i, coll b y i; coll c z i
centroid a b c

ninepoints x y z i a b c
x : b c, y : c a, z : a b, i : x y z
a b c = ncoll a b c
x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : cong i x i y, cong i y i z
ninepoints a b c

intersection_cc x o w a
x : o w a
o w a = ncoll o w a
x : cong o a o x, cong w a w x
circle o o a, circle w w a

intersection_lc x a o b
x : a o b
a o b = diff a b, diff o b, nperp b o b a
x : coll x a b, cong o b o x
line b a, circle o o b

intersection_ll x a b c d
x : a b c d
a b c d = npara a b c d, ncoll a b c d
x : coll x a b, coll x c d
line a b, line c d

intersection_lp x a b c m n
x : a b c m n
a b c m n = npara m n a b, ncoll a b c, ncoll c m n
x : coll x a b, para c x m n
line a b, pline c m n

intersection_lt x a b c d e
x : a b c d e
a b c d e = ncoll a b c, nperp a b d e
x : coll x a b, perp x c d e
line a b, tline c d e

intersection_pp x a b c d e f
x : a b c d e f
a b c d e f = diff a d, npara b c e f
x : para x a b c, para x d e f
pline a b c, pline d e f

intersection_tt x a b c d e f
x : a b c d e f
a b c d e f = diff a d, npara b c e f
x : perp x a b c, perp x d e f
tline a b c, tline d e f

iso_triangle a b c
c : a b c
 =
a : ; b : ; c : eqangle b a b c c b c a, cong a b a c
isos

lc_tangent x a o
x : x a o
a o = diff a o
x : perp a x a o
tline a a o

midpoint x a b
x : a b
a b = diff a b
x : coll x a b, cong x a x b
midp a b

mirror x a b
x : a b
a b = diff a b
x : coll x a b, cong b a b x
pmirror a b

nsquare x a b
x : a b
a b = diff a b
x : cong x a a b, perp x a a b
rotaten90 a b

on_aline x a b c d e
x : x a b c d e
a b c d e = ncoll c d e
x : eqangle a x a b d c d e
aline e d c b a

on_aline2 x a b c d e
x : x a b c d e
a b c d e = ncoll c d e
x : eqangle x a x b d c d e
aline2 e d c b a

on_bline x a b
x : x a b
a b = diff a b
x : cong x a x b, eqangle a x a b b a b x
bline a b

on_circle x o a
x : x o a
o a = diff o a
x : cong o x o a
circle o o a

on_line x a b
x : x a b
a b = diff a b
x : coll x a b
line a b

on_pline x a b c
x : x a b c
a b c = diff b c, ncoll a b c
x : para x a b c
pline a b c

on_tline x a b c
x : x a b c
a b c = diff b c
x : perp x a b c
tline a b c

orthocenter x a b c
x : a b c
a b c = ncoll a b c
x : perp x a b c, perp x b c a; perp x c a b
tline a b c, tline b c a

parallelogram a b c x
x : a b c
a b c = ncoll a b c
x : para a b c x, para a x b c; cong a b c x, cong a x b c
pline a b c, pline c a b

pentagon a b c d e

 =
a : ; b : ; c : ; d : ; e :
pentagon

psquare x a b
x : a b
a b = diff a b
x : cong x a a b, perp x a a b
rotatep90 a b

quadrangle a b c d

 =
a : ; b : ; c : ; d :
quadrangle

r_trapezoid a b c d
d : a b c
 =
a : ; b : ; c : ; d : para a b c d, perp a b a d
r_trapezoid

r_triangle a b c
c : a b c
 =
a : ; b : ; c : perp a b a c
r_triangle

rectangle a b c d
c : a b c , d : a b c
 =
a : ; b : ; c : perp a b b c ; d : para a b c d, para a d b c; perp a b a d, cong a b c d, cong a d b c, cong a c b d
rectangle

reflect x a b c
x : a b c
a b c = diff b c, ncoll a b c
x : cong b a b x, cong c a c x; perp b c a x
reflect a b c

risos a b c
c : a b
 =
a : ; b : ; c : perp a b a c, cong a b a c; eqangle b a b c c b c a
risos

s_angle a b x y
x : a b x
a b = diff a b
x : s_angle a b x y
s_angle a b y

segment a b

 =
a : ; b :
segment

shift x b c d
x : b c d
b c d = diff d b
x : cong x b c d, cong x c b d
shift d c b

square a b x y
x : a b, y : a b x
a b = diff a b
x : perp a b b x, cong a b b x; y : para a b x y, para a y b x; perp a y y x, cong b x x y, cong x y y a, perp a x b y, cong a x b y
square a b

isquare a b c d
c : a b , d : a b c
 =
a : ; b : ; c : perp a b b c, cong a b b c; d : para a b c d, para a d b c; perp a d d c, cong b c c d, cong c d d a, perp a c b d, cong a c b d
isquare

trapezoid a b c d
d : a b c d
 =
a : ; b : ; c : ; d : para a b c d
trapezoid

triangle a b c

 =
a : ; b : ; c :
triangle

triangle12 a b c
c : a b c
 =
a : ; b : ; c : rconst a b a c 1 2
triangle12

2l1c x y z i a b c o
x : a b c o y z i, y : a b c o x z i, z : a b c o x y i, i : a b c o x y z
a b c o = cong o a o b, ncoll a b c
x y z i : coll x a c, coll y b c, cong o a o z, coll i o z, cong i x i y, cong i y i z, perp i x a c, perp i y b c
2l1c a b c o

e5128 x y a b c d
x : a b c d y, y : a b c d x
a b c d = cong c b c d, perp b c b a
x y : cong c b c x, coll y a b, coll x y d, eqangle a b a d x a x y
e5128 a b c d

3peq x y z a b c
z : b c z , x : a b c z y, y : a b c z x
a b c = ncoll a b c
z : coll z b c ; x y : coll x a b, coll y a c, coll x y z, cong z x z y
3peq a b c

trisect x y a b c
x : a b c y, y : a b c x
a b c = ncoll a b c
x y : coll x a c, coll y a c, eqangle b a b x b x b y, eqangle b x b y b y b c
trisect a b c

trisegment x y a b
x : a b y, y : a b x
a b = diff a b
x y : coll x a b, coll y a b, cong x a x y, cong y x y b
trisegment a b

on_dia x a b
x : x a b
a b = diff a b
x : perp x a x b
dia a b

ieq_triangle a b c
c : a b
 =
a : ; b : ; c : cong a b b c, cong b c c a; eqangle a b a c c a c b, eqangle c a c b b c b a
ieq_triangle

on_opline x a b
x : x a b
a b = diff a b
x : coll x a b
on_opline a b

cc_tangent0 x y o a w b
x : o a w b y, y : o a w b x
o a w b = diff o a, diff w b, diff o w
x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x
cc_tangent0 o a w b

cc_tangent x y z i o a w b
x : o a w b y, y : o a w b x, z : o a w b i, i : o a w b z
o a w b = diff o a, diff w b, diff o w
x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x; z i : cong o z o a, cong w i w b, perp z o z i, perp i w i z
cc_tangent o a w b

eqangle3 x a b d e f
x : x a b d e f
a b d e f = ncoll d e f, diff a b, diff d e, diff e f
x : eqangle x a x b d e d f
eqangle3 a b d e f

tangent x y a o b
x y : o a b
a o b = diff o a, diff o b, diff a b
x : cong o x o b, perp a x o x; y : cong o y o b, perp a y o y
tangent a o b

on_circum x a b c
x : a b c
a b c = ncoll a b c
x : cyclic a b c x
cyclic a b c

{rules}

perp A B C D, perp C D E F, ncoll A B E => para A B E F
cong O A O B, cong O B O C, cong O C O D => cyclic A B C D
eqangle A B P Q C D P Q => para A B C D
cyclic A B P Q => eqangle P A P B Q A Q B
eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q
cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q
midp E A B, midp F A C => para E F B C
para A B C D, coll O A C, coll O B D => eqratio3 A B C D O O
perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H
eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u
eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u
eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c
eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c
cong O A O B, ncoll O A B => eqangle O A A B A B O B
eqangle6 A O A B B A B O, ncoll O A B => cong O A O B
circle O A B C, perp O A A X => eqangle A X A B C A C B
circle O A B C, eqangle A X A B C A C B => perp O A A X
circle O A B C, midp M B C => eqangle A B A C O B O M
circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C
perp A B B C, midp M A C => cong A M B M
circle O A B C, coll O A C => perp A B B C
cyclic A B C D, para A B C D => eqangle A D C D C D C B
midp M A B, perp O M A B => cong O A O B
cong A P B P, cong A Q B Q => perp A B P Q
cong A P B P, cong A Q B Q, cyclic A B P Q => perp P A A Q
midp M A B, midp M C D => para A C B D
midp M A B, para A C B D, para A D B C => midp M C D
eqratio O A A C O B B D, coll O A C, coll O B D, ncoll A B C, sameside A O C B O D => para A B C D
para A B A C => coll A B C
midp M A B, midp N C D => eqratio M A A B N C C D
eqangle A B P Q C D U V, perp P Q U V => perp A B C D
eqratio A B P Q C D U V, cong P Q U V => cong A B C D
cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R
cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R
eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R
eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R
eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R
eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R
eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R
eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C P Q R
eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R
para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b
para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c
_GIN_SEARCH_PATHS = flags.DEFINE_list(
    'gin_search_paths',
    ['third_party/py/meliad/transformer/configs'],
    'List of paths where the Gin config files are located.',
)
_GIN_FILE = flags.DEFINE_multi_string(
    'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
)
_GIN_PARAM = flags.DEFINE_multi_string(
    'gin_param', None, 'Newline separated list of Gin parameter bindings.'
)

_PROBLEMS_FILE = flags.DEFINE_string(
    'problems_file',
    'imo_ag_30.txt',
    'text file contains the problem strings. See imo_ag_30.txt for example.',
)
_PROBLEM_NAME = flags.DEFINE_string(
    'problem_name',
    'imo_2000_p1',
    'name of the problem to solve, must be in the problem_file.',
)
_MODE = flags.DEFINE_string(
    'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
_DEFS_FILE = flags.DEFINE_string(
    'defs_file',
    'defs.txt',
    'definitions of available constructions to state a problem.',
)
_RULES_FILE = flags.DEFINE_string(
    'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
)
_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
_VOCAB_PATH = flags.DEFINE_string(
    'vocab_path', '', 'path to the LM vocab file.'
)
_OUT_FILE = flags.DEFINE_string(
    'out_file', '', 'path to the solution output file.'
)  # pylint: disable=line-too-long
_BEAM_SIZE = flags.DEFINE_integer(
    'beam_size', 1, 'beam size of the proof search.'
)  # pylint: disable=line-too-long
_SEARCH_DEPTH = flags.DEFINE_integer(
    'search_depth', 1, 'search depth of the proof search.'
)  # pylint: disable=line-too-long

DEFINITIONS = None  # contains definitions of construction actions
RULES = None  # contains rules of deductions


def natural_language_statement(logical_statement: pr.Dependency) -> str:
  """Convert logical_statement to natural language.

  Args:
    logical_statement: pr.Dependency with .name and .args

  Returns:
    a string of (pseudo) natural language of the predicate for human reader.
  """
  names = [a.name.upper() for a in logical_statement.args]
  names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
  return pt.pretty_nl(logical_statement.name, names)


def proof_step_string(
    proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
) -> str:
  """Translate proof to natural language.

  Args:
    proof_step: pr.Dependency with .name and .args
    refs: dict(hash: int) to keep track of derived predicates
    last_step: boolean to keep track whether this is the last step.

  Returns:
    a string of (pseudo) natural language of the proof step for human reader.
  """
  premises, [conclusion] = proof_step

  premises_nl = ' & '.join(
      [
          natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
          for p in premises
      ]
  )

  if not premises:
    premises_nl = 'similarly'

  refs[conclusion.hashed()] = len(refs)

  conclusion_nl = natural_language_statement(conclusion)
  if not last_step:
    conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])

  return f'{premises_nl} \u21d2 {conclusion_nl}'


def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
  """Output the solution to out_file.

  Args:
    g: gh.Graph object, containing the proof state.
    p: pr.Problem object, containing the theorem.
    out_file: file to write to, empty string to skip writing to file.
  """
  setup, aux, proof_steps, refs = ddar.get_proof_steps(
      g, p.goal, merge_trivials=False
  )

  solution = '\n=========================='
  solution += '\n * From theorem premises:\n'
  premises_nl = []
  for premises, [points] in setup:
    solution += ' '.join([p.name.upper() for p in points]) + ' '
    if not premises:
      continue
    premises_nl += [
        natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
        for p in premises
    ]
  solution += ': Points\n' + '\n'.join(premises_nl)

  solution += '\n\n * Auxiliary Constructions:\n'
  aux_premises_nl = []
  for premises, [points] in aux:
    solution += ' '.join([p.name.upper() for p in points]) + ' '
    aux_premises_nl += [
        natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
        for p in premises
    ]
  solution += ': Points\n' + '\n'.join(aux_premises_nl)

  # some special case where the deduction rule has a well known name.
  r2name = {
      'r32': '(SSS)',
      'r33': '(SAS)',
      'r34': '(Similar Triangles)',
      'r35': '(Similar Triangles)',
      'r36': '(ASA)',
      'r37': '(ASA)',
      'r38': '(Similar Triangles)',
      'r39': '(Similar Triangles)',
      'r40': '(Congruent Triangles)',
      'a00': '(Distance chase)',
      'a01': '(Ratio chase)',
      'a02': '(Angle chase)',
  }

  solution += '\n\n * Proof steps:\n'
  for i, step in enumerate(proof_steps):
    _, [con] = step
    nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
    rule_name = r2name.get(con.rule_name, '')
    nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
    solution += '{:03}. '.format(i + 1) + nl + '\n'

  solution += '==========================\n'
  logging.info(solution)
  if out_file:
    with open(out_file, 'w') as f:
      f.write(solution)
    logging.info('Solution written to %s.', out_file)


def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
  lm.parse_gin_configuration(
      _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
  )

  return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')


def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
  """Run DD+AR.

  Args:
    g: gh.Graph object, containing the proof state.
    p: pr.Problem object, containing the problem statement.
    out_file: path to output file if solution is found.

  Returns:
    Boolean, whether DD+AR finishes successfully.
  """
  ddar.solve(g, RULES, p, max_level=1000)

  goal_args = g.names2nodes(p.goal.args)
  if not g.check(p.goal.name, goal_args):
    logging.info('DD+AR failed to solve the problem.')
    return False

  write_solution(g, p, out_file)

  gh.nm.draw(
      g.type2nodes[gh.Point],
      g.type2nodes[gh.Line],
      g.type2nodes[gh.Circle],
      g.type2nodes[gh.Segment])
  return True


def translate_constrained_to_constructive(
    point: str, name: str, args: list[str]
) -> tuple[str, list[str]]:
  """Translate a predicate from constraint-based to construction-based.

  Args:
    point: str: name of the new point
    name: str: name of the predicate, e.g., perp, para, etc.
    args: list[str]: list of predicate args.

  Returns:
    (name, args): translated to constructive predicate.
  """
  if name in ['T', 'perp']:
    a, b, c, d = args
    if point in [c, d]:
      a, b, c, d = c, d, a, b
    if point == b:
      a, b = b, a
    if point == d:
      c, d = d, c
    if a == c and a == point:
      return 'on_dia', [a, b, d]
    return 'on_tline', [a, b, c, d]

  elif name in ['P', 'para']:
    a, b, c, d = args
    if point in [c, d]:
      a, b, c, d = c, d, a, b
    if point == b:
      a, b = b, a
    return 'on_pline', [a, b, c, d]

  elif name in ['D', 'cong']:
    a, b, c, d = args
    if point in [c, d]:
      a, b, c, d = c, d, a, b
    if point == b:
      a, b = b, a
    if point == d:
      c, d = d, c
    if a == c and a == point:
      return 'on_bline', [a, b, d]
    if b in [c, d]:
      if b == d:
        c, d = d, c  # pylint: disable=unused-variable
      return 'on_circle', [a, b, d]
    return 'eqdistance', [a, b, c, d]

  elif name in ['C', 'coll']:
    a, b, c = args
    if point == b:
      a, b = b, a
    if point == c:
      a, b, c = c, a, b
    return 'on_line', [a, b, c]

  elif name in ['^', 'eqangle']:
    a, b, c, d, e, f = args

    if point in [d, e, f]:
      a, b, c, d, e, f = d, e, f, a, b, c

    x, b, y, c, d = b, c, e, d, f
    if point == b:
      a, b, c, d = b, a, d, c

    if point == d and x == y:  # x p x b = x c x p
      return 'angle_bisector', [point, b, x, c]

    if point == x:
      return 'eqangle3', [x, a, b, y, c, d]

    return 'on_aline', [a, x, b, c, y, d]

  elif name in ['cyclic', 'O']:
    a, b, c = [x for x in args if x != point]
    return 'on_circum', [point, a, b, c]

  return name, args


def check_valid_args(name: str, args: list[str]) -> bool:
  """Check whether a predicate is grammarically correct.

  Args:
    name: str: name of the predicate
    args: list[str]: args of the predicate

  Returns:
    bool: whether the predicate arg count is valid.
  """
  if name == 'perp':
    if len(args) != 4:
      return False
    a, b, c, d = args
    if len({a, b}) < 2:
      return False
    if len({c, d}) < 2:
      return False
  elif name == 'para':
    if len(args) != 4:
      return False
    a, b, c, d = args
    if len({a, b, c, d}) < 4:
      return False
  elif name == 'cong':
    if len(args) != 4:
      return False
    a, b, c, d = args
    if len({a, b}) < 2:
      return False
    if len({c, d}) < 2:
      return False
  elif name == 'coll':
    if len(args) != 3:
      return False
    a, b, c = args
    if len({a, b, c}) < 3:
      return False
  elif name == 'cyclic':
    if len(args) != 4:
      return False
    a, b, c, d = args
    if len({a, b, c, d}) < 4:
      return False
  elif name == 'eqangle':
    if len(args) != 8:
      return False
    a, b, c, d, e, f, g, h = args
    if len({a, b, c, d}) < 3:
      return False
    if len({e, f, g, h}) < 3:
      return False
  return True


def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
  """Whether a string of aux construction can be constructed.

  Args:
    string: str: the string describing aux construction.
    g: gh.Graph: the current proof state.

  Returns:
    str: whether this construction is valid. If not, starts with "ERROR:".
  """
  if string[-1] != ';':
    return 'ERROR: must end with ;'

  head, prem_str = string.split(' : ')
  point = head.strip()

  if len(point) != 1 or point == ' ':
    return f'ERROR: invalid point name {point}'

  existing_points = [p.name for p in g.all_points()]
  if point in existing_points:
    return f'ERROR: point {point} already exists.'

  prem_toks = prem_str.split()[:-1]  # remove the EOS ' ;'
  prems = [[]]

  for i, tok in enumerate(prem_toks):
    if tok.isdigit():
      if i < len(prem_toks) - 1:
        prems.append([])
    else:
      prems[-1].append(tok)

  if len(prems) > 2:
    return 'ERROR: there cannot be more than two predicates.'

  clause_txt = point + ' = '
  constructions = []

  for prem in prems:
    name, *args = prem

    if point not in args:
      return f'ERROR: {point} not found in predicate args.'

    if not check_valid_args(pt.map_symbol(name), args):
      return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)

    for a in args:
      if a != point and a not in existing_points:
        return f'ERROR: point {a} does not exist.'

    try:
      name, args = translate_constrained_to_constructive(point, name, args)
    except:  # pylint: disable=bare-except
      return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)

    if name == 'on_aline':
      if args.count(point) > 1:
        return f'ERROR: on_aline involves twice {point}'

    constructions += [name + ' ' + ' '.join(args)]

  clause_txt += ', '.join(constructions)
  clause = pr.Clause.from_txt(clause_txt)

  try:
    g.copy().add_clause(clause, 0, DEFINITIONS)
  except:  # pylint: disable=bare-except
    return 'ERROR: ' + traceback.format_exc()

  return clause_txt


def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
  """Insert auxiliary constructs from proof to premise.

  Args:
    pstring: str: describing the problem to solve.
    auxstring: str: describing the auxiliar construction.

  Returns:
    str: new pstring with auxstring inserted before the conclusion.
  """
  setup, goal = pstring.split(' ? ')
  return setup + '; ' + auxstring + ' ? ' + goal


class BeamQueue:
  """Keep only the top k objects according to their values."""

  def __init__(self, max_size: int = 512):
    self.queue = []
    self.max_size = max_size

  def add(self, node: object, val: float) -> None:
    """Add a new node to this queue."""

    if len(self.queue) < self.max_size:
      self.queue.append((val, node))
      return

    # Find the minimum node:
    min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])

    # replace it if the new node has higher value.
    if val > min_val:
      self.queue[min_idx] = (val, node)

  def __iter__(self):
    for val, node in self.queue:
      yield val, node

  def __len__(self) -> int:
    return len(self.queue)


def run_alphageometry(
    model: lm.LanguageModelInference,
    p: pr.Problem,
    search_depth: int,
    beam_size: int,
    out_file: str,
) -> bool:
  """Simplified code to run AlphaGeometry proof search.

  We removed all optimizations that are infrastructure-dependent, e.g.
  parallelized model inference on multi GPUs,
  parallelized DD+AR on multiple CPUs,
  parallel execution of LM and DD+AR,
  shared pool of CPU workers across different problems, etc.

  Many other speed optimizations and abstractions are also removed to
  better present the core structure of the proof search.

  Args:
    model: Interface with inference-related endpoints to JAX's model.
    p: pr.Problem object describing the problem to solve.
    search_depth: max proof search depth.
    beam_size: beam size of the proof search.
    out_file: path to output file if solution is found.

  Returns:
    boolean of whether this is solved.
  """
  # translate the problem to a string of grammar that the LM is trained on.
  string = p.setup_str_from_problem(DEFINITIONS)
  # special tokens prompting the LM to generate auxiliary points.
  string += ' {F1} x00'
  # the graph to represent the proof state.
  g, _ = gh.Graph.build_problem(p, DEFINITIONS)

  # First we run the symbolic engine DD+AR:
  if run_ddar(g, p, out_file):
    return True

  # beam search for the proof
  # each node in the search tree is a 3-tuple:
  # (<graph representation of proof state>,
  #  <string for LM to decode from>,
  #  <original problem string>)
  beam_queue = BeamQueue(max_size=beam_size)
  # originally the beam search tree starts with a single node (a 3-tuple):
  beam_queue.add(
      node=(g, string, p.txt()), val=0.0  # value of the root node is simply 0.
  )

  for depth in range(search_depth):
    logging.info(
        'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
    )
    for _, (_, string, _) in beam_queue:
      logging.info(string)

    new_queue = BeamQueue(max_size=beam_size)  # to replace beam_queue.

    for prev_score, (g, string, pstring) in beam_queue:
      logging.info('Decoding from %s', string)
      outputs = model.beam_decode(string, eos_tokens=[';'])

      # translate lm output to the constructive language.
      # so that we can update the graph representing proof states:
      translations = [
          try_translate_constrained_to_construct(o, g)
          for o in outputs['seqs_str']
      ]

      # couple the lm outputs with its translations
      candidates = zip(outputs['seqs_str'], translations, outputs['scores'])

      # bring the highest scoring candidate first
      candidates = reversed(list(candidates))

      for lm_out, translation, score in candidates:
        logging.info('LM output (score=%f): "%s"', score, lm_out)
        logging.info('Translation: "%s"\n', translation)

        if translation.startswith('ERROR:'):
          # the construction is invalid.
          continue

        # Update the constructive statement of the problem with the aux point:
        candidate_pstring = insert_aux_to_premise(pstring, translation)

        logging.info('Solving: "%s"', candidate_pstring)
        p_new = pr.Problem.from_txt(candidate_pstring)

        # This is the new proof state graph representation:
        g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
        if run_ddar(g_new, p_new, out_file):
          logging.info('Solved.')
          return True

        # Add the candidate to the beam queue.
        new_queue.add(
            # The string for the new node is old_string + lm output +
            # the special token asking for a new auxiliary point ' x00':
            node=(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
            # the score of each node is sum of score of all nodes
            # on the path to itself. For beam search, there is no need to
            # normalize according to path length because all nodes in beam
            # is of the same path length.
            val=prev_score + score,
        )
        # Note that the queue only maintain at most beam_size nodes
        # so this new node might possibly be dropped depending on its value.

    # replace the old queue with new queue before the new proof search depth.
    beam_queue = new_queue

  return False
%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#ffaa00', 'primaryTextColor': '#ffaa00', 'primaryBorderColor': '#ffaa00', 'lineColor': '#ffaa00', 'secondaryColor': '#ffaa00', 'tertiaryColor': '#ffaa00', 'clusterBkg': 'none', 'clusterBorder': 'none', 'fontSize': '0px'}}}%%


graph TD
    %% Nodes within Input
    A1(("Problem Description"))
    A2(("Self-reflection"))
    A3(("Solution Concept"))
    
    %% Nodes within Process
    B1(("Reflect on Problem"))
    B2(("Explain Examples"))
    B3(("Generate Solution"))
    B4(("Modular Code Division"))
    B5(("Double Validation"))
    
    %% Nodes within Output
    C1(("Correct Code Solution"))
    C2(("Explanation of Logic"))
    
    %% Nodes within Parameter Space
    D1(("Exhaustive Approach"))
    D2(("Code Quality"))
    D3(("Ignore Constraints"))
    
    %% Connections within Input
    A1 --> A2
    A2 --> A3
    A3 --> A1
    
    %% Connections within Process
    B1 --> B2
    B2 --> B3
    B3 --> B4
    B4 --> B5
    B5 --> B1
    
    B1 --> B3
    B1 --> B4
    B1 --> B5
    B2 --> B4
    B2 --> B5
    B3 --> B5
    
    %% Connections within Output
    C1 --> C2
    C2 --> C1
    
    %% Connections from Input to Process
    A1 --> B1
    A1 --> B2
    A1 --> B3
    A2 --> B1
    A2 --> B2
    A2 --> B3
    A2 --> B4
    A3 --> B2
    A3 --> B3
    A3 --> B4
    A3 --> B5
    
    %% Connections from Process to Output
    B1 --> C1
    B1 --> C2
    B2 --> C1
    B2 --> C2
    B3 --> C1
    B3 --> C2
    B4 --> C1
    B4 --> C2
    B5 --> C1
    B5 --> C2
    
    %% Connections from Output to Input (Feedback Loop)
    C1 --> A1
    C1 --> A2
    C1 --> A3
    C2 --> A1
    C2 --> A2
    C2 --> A3
    
    %% Connections from Parameter Space to Process
    D1 --> B1
    D1 --> B2
    D1 --> B3
    D1 --> B4
    D1 --> B5
    D2 --> B1
    D2 --> B2
    D2 --> B3
    D2 --> B4
    D2 --> B5
    D3 --> B1
    D3 --> B2
    D3 --> B3
    D3 --> B4
    D3 --> B5
    
    %% Connections from Parameter Space to Output
    D1 --> C1
    D1 --> C2
    D2 --> C1
    D2 --> C2
    D3 --> C1
    D3 --> C2
    
    subgraph Input
    A1
    A2
    A3
    end
    
    subgraph Process
    B1
    B2
    B3
    B4
    B5
    end
    
    subgraph Output
    C1
    C2
    end
    
    subgraph "Parameter Space"
    D1
    D2
    D3
    end


Loading