Skip to content

Commit

Permalink
predictors now return updated molecule objects. Tests updated to refl…
Browse files Browse the repository at this point in the history
…ect this
  • Loading branch information
jackbaker1001 committed Sep 14, 2023
1 parent 74e4aac commit 8d574e8
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 38 deletions.
56 changes: 28 additions & 28 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def kernel(

def make_non_scf_predictor(
functional: Functional,
chunk_size: int = 1024,
**kwargs,
) -> Callable:
r"""
Expand All @@ -75,8 +76,8 @@ def make_non_scf_predictor(
---------
Callable
"""
def non_scf_predictor(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
predict_molecule = molecule_predictor(functional, chunk_size=chunk_size, **kwargs)
def non_scf_predictor(params: PyTree, molecule: Molecule, *args) -> Molecule:
r"""Calculates the total energy at a fixed density non-self consistently.
Main parameters
Expand All @@ -88,15 +89,13 @@ def non_scf_predictor(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar
Returns
---------
tuple[Scalar, Array, Array]
The predicted energy only. Elements 1 and 2 of the tuple are
set to None as the density and rdm1 do not change when a non-self consistent
calculation is performed.
Molecule
A Grad-DFT Molecule object with updated attributes
"""
predicted_e = functional.energy(params, molecule, *args, **kwargs)
predicted_rho = None
predicted_rdm1 = None
return (predicted_e, predicted_rho, predicted_rdm1)
predicted_e, fock = predict_molecule(params, molecule, *args)
molecule = molecule.replace(fock=fock)
molecule = molecule.replace(energy=predicted_e)
return molecule

return non_scf_predictor

Expand Down Expand Up @@ -132,7 +131,7 @@ def make_simple_scf_loop(

predict_molecule = molecule_predictor(functional, chunk_size=chunk_size, **kwargs)

def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Molecule:
r"""
Implements a scf loop for a Molecule and a functional implicitly defined predict_molecule with
parameters params
Expand Down Expand Up @@ -167,6 +166,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Arr

# Update the molecular occupation
mo_occ = molecule.get_occ()
molecule = molecule.replace(mo_occ=mo_occ)
if verbose > 2:
print(
f"Cycle {cycle} took {time.time() - start_time:.1e} seconds to compute and diagonalize Fock matrix"
Expand Down Expand Up @@ -213,15 +213,15 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Arr
f" relative energy difference: {abs((predicted_e - old_e)/predicted_e):.5e}"
)
old_e = predicted_e
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

if verbose > 1:
print(
f"cycle: {cycle}, predicted energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}"
)

return (predicted_e, predicted_rho, predicted_rdm1)
# Ensure molecule is fully updated
molecule = molecule.replace(fock=fock)
molecule = molecule.replace(energy=predicted_e)
return molecule

return scf_iterator

Expand All @@ -244,7 +244,7 @@ def make_jitted_simple_scf_loop(functional: Functional, cycles: int = 25, mixing
predict_molecule = molecule_predictor(functional, chunk_size=None, **kwargs)

@jit
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
def scf_jitted_iterator(params: PyTree, molecule: Molecule, *args) -> Molecule:
r"""
Implements a scf loop intented for use in a jax.jit compiled function (training loop).
If you are looking for a more flexible but not differentiable scf loop, see evaluate.py make_scf_loop.
Expand Down Expand Up @@ -309,10 +309,9 @@ def loop_body(cycle, state):
# Compute the scf loop
final_state = fori_loop(0, cycles, body_fun=loop_body, init_val=state)
molecule, fock, predicted_e, old_e, norm_gorb = final_state
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

return (predicted_e, predicted_rho, predicted_rdm1)
molecule = molecule.replace(fock=fock)
molecule = molecule.replace(energy=predicted_e)
return Molecule

return scf_jitted_iterator

Expand Down Expand Up @@ -349,7 +348,7 @@ def make_scf_loop(

predict_molecule = molecule_predictor(functional, chunk_size=chunk_size, **kwargs)

def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Array, Array]:
def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Molecule:
r"""
Implements a scf loop for a Molecule and a functional implicitly defined predict_molecule with
parameters params
Expand Down Expand Up @@ -544,10 +543,10 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron):
print(
f"cycle: {cycle}, predicted energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}"
)
predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1

return (predicted_e, predicted_rho, predicted_rdm1)
# Ensure molecule is fully updated
molecule = molecule.replace(fock=fock)
molecule = molecule.replace(energy=predicted_e)
return molecule

return scf_iterator

Expand Down Expand Up @@ -839,10 +838,11 @@ def loop_body(cycle, state):
state = loop_body(0, state)
molecule, fock, predicted_e, _, _, _ = final_state

predicted_rho = molecule.density()
predicted_rdm1 = molecule.rdm1
# Ensure molecule is fully updated
molecule = molecule.replace(fock=fock)
molecule = molecule.replace(energy=predicted_e)

return (predicted_e, predicted_rho, predicted_rdm1)
return molecule

return scf_jitted_iterator

Expand Down
10 changes: 7 additions & 3 deletions grad_dft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def mse_energy_loss(
"""
sum = 0
for i, molecule in enumerate(molecules):
E_predict, _, _ = molecule_predictor(params, molecule)
molecule_out = molecule_predictor(params, molecule)
E_predict = molecule_out.energy
diff = E_predict - truth_energies[i]
# Not jittable because of if.
if elec_num_norm:
Expand Down Expand Up @@ -433,7 +434,8 @@ def mse_density_loss(
"""
sum = 0
for i, molecule in enumerate(molecules):
_, rho_predict, _ = molecule_predictor(params, molecule)
molecule_out = molecule_predictor(params, molecule)
rho_predict = molecule_out.density()
diff = sq_electron_err_int(rho_predict, truth_rhos[i], molecule)
# Not jittable because of if.
if elec_num_norm:
Expand Down Expand Up @@ -475,7 +477,9 @@ def mse_energy_and_density_loss(
sum_energy = 0
sum_rho = 0
for i, molecule in enumerate(molecules):
energy_predict, rho_predict, _ = molecule_predictor(params, molecule)
molecule_out = molecule_predictor(params, molecule)
rho_predict = molecule_out.density()
energy_predict = molecule_out.energy
diff_rho = sq_electron_err_int(rho_predict, truth_rhos[i], molecule)
diff_energy = energy_predict - truth_energies[i]
# Not jittable because of if.
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_predict_B3LYP.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_predict(mol_and_name: tuple[gto.Mole, str]) -> None:
molecule = molecule_from_pyscf(mf, energy=energy, omegas=[0.0], scf_iteration=0)

iterator = make_scf_loop(FUNCTIONAL, verbose=2, max_cycles=25)
e_XND = iterator(PARAMS, molecule)
molecule_out = iterator(PARAMS, molecule)
e_XND = molecule_out.energy

if name == "water":
mf = dft.RKS(mol)
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_predict_B88.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_predict(mol_and_name: tuple[gto.Mole, str]) -> None:
molecule = molecule_from_pyscf(mf, energy=energy, omegas=[], scf_iteration=0)

iterator = make_scf_loop(FUNCTIONAL, verbose=2, max_cycles=25)
e_XND = iterator(PARAMS, molecule)

molecule_out = iterator(PARAMS, molecule)
e_XND = molecule_out.energy
mf = dft.UKS(mol)
mf.xc = "B88"
mf.max_cycle = 25
Expand All @@ -78,6 +78,7 @@ def test_predict(mol_and_name: tuple[gto.Mole, str]) -> None:

# Testing the training scf loop too.
iterator = make_jitted_scf_loop(FUNCTIONAL, cycles=25)
e_XND_jit, _, _ = iterator(PARAMS, molecule)
molecule_out = iterator(PARAMS, molecule)
e_XND_jit = molecule_out.energy
kcalmoldiff = (e_XND - e_XND_jit) * Hartree2kcalmol
assert np.allclose(kcalmoldiff, 0, atol=1e-6)
25 changes: 22 additions & 3 deletions tests/unit/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from flax import struct

from typing import Optional


from grad_dft.train import mse_energy_loss, mse_density_loss, mse_energy_and_density_loss

Expand All @@ -36,6 +38,11 @@
TRUTH_ENERGIES = [12.8, 10.1]
TRUTH_DENSITIES = [normal(rand_key, (1000, 2)) for rand_key in RANDOM_KEYS]

INIT_ENERGIES = [12.8, 10.1]
TRUTH_DENSITIES = [normal(rand_key, (1000, 2)) for rand_key in RANDOM_KEYS]



PARAMS = jnp.array([0.11, 0.80, 0.24])
GRID_WEIGHTS = jnp.ones(shape=(1000,))

Expand All @@ -60,12 +67,18 @@ class dummy_molecule:
"""
num_elec: Scalar
grid: dummy_grid
energy: Optional[Scalar] = 0
rdm1: Optional[Float[Array, "spin orbitals orbitals"]] = 0
rho: Optional[Float[Array, "spin spin"]] = 0

def density(self):
return self.rho


MOLECULES = [dummy_molecule(1.0, GRID), dummy_molecule(2.0, GRID)]


def dummy_predictor(params: PyTree, molecule: dummy_molecule):
def dummy_predictor(params: PyTree, molecule: dummy_molecule) -> dummy_molecule:
r"""A dummy function matching the signature of the predictor functions in Grad-DFT
Args:
Expand All @@ -76,10 +89,16 @@ def dummy_predictor(params: PyTree, molecule: dummy_molecule):
tuple[Scalar, Array, Array]: The total energy, density and 1RDM
"""
total_energy = 10.0 + params[0]
density = jnp.ones(shape=(1000, 2)) + params[1]
rho = jnp.ones(shape=(1000, 2)) + params[1]
# we don't presently implement loss functions using an RDM1, but we could in the future
rdm1 = jnp.ones(shape=(10, 10)) + params[2]
return (total_energy, density, rdm1)
molecule = molecule.replace(energy=total_energy)
molecule = molecule.replace(rdm1=rdm1)
# Real molecule object doesn't store density, but we do it here to not write any dummy logic
# converting rdm1 to a density
molecule = molecule.replace(rho=rho)

return molecule


LOSS_ARGS = [
Expand Down

0 comments on commit 8d574e8

Please sign in to comment.