Skip to content

Commit

Permalink
Added a non-self consistent predictor matching the return signature o…
Browse files Browse the repository at this point in the history
…f the SCF predictors
  • Loading branch information
jackbaker1001 committed Sep 13, 2023
1 parent e69fd29 commit 74e4aac
Showing 1 changed file with 61 additions and 9 deletions.
70 changes: 61 additions & 9 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,50 @@ def kernel(

return kernel

######## Non self-consistent iterator ################

def make_non_scf_predictor(
functional: Functional,
**kwargs,
) -> Callable:
r"""
Creates an non_scf_predictor function which when called non-self consistently
calculates the total energy at a fixed density.
Main parameters
---------------
functional: Functional
Returns
---------
Callable
"""

def non_scf_predictor(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
r"""Calculates the total energy at a fixed density non-self consistently.
Main parameters
---------------
params: Pytree
Parameters of the neural functional
molecule: Molecule
A Grad-DFT molecule object
Returns
---------
tuple[Scalar, Array, Array]
The predicted energy only. Elements 1 and 2 of the tuple are
set to None as the density and rdm1 do not change when a non-self consistent
calculation is performed.
"""
predicted_e = functional.energy(params, molecule, *args, **kwargs)
predicted_rho = None
predicted_rdm1 = None
return (predicted_e, predicted_rho, predicted_rdm1)

return non_scf_predictor

# Add Harris-Foulkes predictor here too!

######## Test scf loop and orbital optimizers ########

Expand Down Expand Up @@ -88,7 +132,7 @@ def make_simple_scf_loop(

predict_molecule = molecule_predictor(functional, chunk_size=chunk_size, **kwargs)

def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Scalar]:
def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
r"""
Implements a scf loop for a Molecule and a functional implicitly defined predict_molecule with
parameters params
Expand Down Expand Up @@ -169,13 +213,15 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
f" relative energy difference: {abs((predicted_e - old_e)/predicted_e):.5e}"
)
old_e = predicted_e
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

if verbose > 1:
print(
f"cycle: {cycle}, predicted energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}"
)

return predicted_e
return (predicted_e, predicted_rho, predicted_rdm1)

return scf_iterator

Expand All @@ -198,7 +244,7 @@ def make_jitted_simple_scf_loop(functional: Functional, cycles: int = 25, mixing
predict_molecule = molecule_predictor(functional, chunk_size=None, **kwargs)

@jit
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Scalar]:
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
r"""
Implements a scf loop intented for use in a jax.jit compiled function (training loop).
If you are looking for a more flexible but not differentiable scf loop, see evaluate.py make_scf_loop.
Expand Down Expand Up @@ -263,8 +309,10 @@ def loop_body(cycle, state):
# Compute the scf loop
final_state = fori_loop(0, cycles, body_fun=loop_body, init_val=state)
molecule, fock, predicted_e, old_e, norm_gorb = final_state
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

return predicted_e, fock, molecule.rdm1
return (predicted_e, predicted_rho, predicted_rdm1)

return scf_jitted_iterator

Expand Down Expand Up @@ -301,7 +349,7 @@ def make_scf_loop(

predict_molecule = molecule_predictor(functional, chunk_size=chunk_size, **kwargs)

def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Scalar]:
def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
r"""
Implements a scf loop for a Molecule and a functional implicitly defined predict_molecule with
parameters params
Expand Down Expand Up @@ -496,8 +544,10 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron):
print(
f"cycle: {cycle}, predicted energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}"
)
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

return predicted_e
return (predicted_e, predicted_rho, predicted_rdm1)

return scf_iterator

Expand Down Expand Up @@ -705,7 +755,7 @@ def make_jitted_scf_loop(functional: Functional, cycles: int = 25, **kwargs) ->
predict_molecule = molecule_predictor(functional, chunk_size=None, **kwargs)

@jit
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Scalar]:
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
r"""
Implements a scf loop intented for use in a jax.jit compiled function (training loop).
If you are looking for a more flexible but not differentiable scf loop, see evaluate.py make_scf_loop.
Expand Down Expand Up @@ -788,9 +838,11 @@ def loop_body(cycle, state):
state = (molecule, fock, predicted_e, old_e, norm_gorb, diis_data)
state = loop_body(0, state)
molecule, fock, predicted_e, _, _, _ = final_state

predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1


return predicted_e, fock, molecule.rdm1
return (predicted_e, predicted_rho, predicted_rdm1)

return scf_jitted_iterator

Expand Down

0 comments on commit 74e4aac

Please sign in to comment.