Skip to content

Commit

Permalink
num electrons now stored in Molecule class for normalization in loss.…
Browse files Browse the repository at this point in the history
… Simple energy loss re-implemented
  • Loading branch information
jackbaker1001 committed Sep 12, 2023
1 parent 7cb9ac9 commit 52fa2b2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 33 deletions.
31 changes: 0 additions & 31 deletions grad_dft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,37 +855,6 @@ def canonicalize_inputs(x):
return x


@partial(value_and_grad, has_aux=True)
def default_loss(params: PyTree, molecule_predict: Callable, molecule: Molecule, trueenergy: float):
r"""
Computes the default loss function, here MSE, between predicted and true energy
Parameters
----------
params: PyTree
functional parameters (weights)
molecule_predict: Callable.
Use molecule_predict = molecule_predictor(functional) to generate it.
molecule: Molecule
trueenergy: float
Returns
----------
Tuple[float, float]
The loss and predicted energy.
Note
----------
Since it has the decorator @partial(value_and_grad, has_aux = True)
it will compute the gradients with respect to params.
"""

predictedenergy, _ = molecule_predict(params, molecule)
cost_value = (predictedenergy - trueenergy) ** 2

return cost_value, predictedenergy


def _canonicalize_fxc(fxc: Functional) -> Callable:
if hasattr(fxc, "energy"):
return fxc.energy
Expand Down
3 changes: 2 additions & 1 deletion grad_dft/interface/pyscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def molecule_from_pyscf(
chi = None
spin = mf.mol.spin
charge = mf.mol.charge

num_elec = jnp.sum(mo_occ)
grid_level = mf.grids.level

return Molecule(
Expand All @@ -124,6 +124,7 @@ def molecule_from_pyscf(
vj,
mo_coeff,
mo_occ,
num_elec,
mo_energy,
mf_e_tot,
s1e,
Expand Down
1 change: 1 addition & 0 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Molecule:
vj: Float[Array, "spin orbitals orbitals"]
mo_coeff: Float[Array, "spin orbitals orbitals"]
mo_occ: Int[Array, "spin orbitals"]
num_elec: Scalar
mo_energy: Float[Array, "spin orbitals"]
mf_energy: Optional[Scalar] = None
s1e: Optional[Float[Array, "spin orbitals orbitals"]] = None # Not used during training
Expand Down
45 changes: 44 additions & 1 deletion grad_dft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jax import grad, numpy as jnp, vmap
from jax import value_and_grad
from jax.profiler import annotate_function
from jax.lax import stop_gradient
from jax.lax import stop_gradient, fori_loop, cond
from optax import OptState, GradientTransformation, apply_updates

from grad_dft.functional import DispersionFunctional, Functional
Expand Down Expand Up @@ -320,3 +320,46 @@ def get_grad(mo_coeff: Float[Array, "spin ao ao"],
C_vir = vmap(jnp.where, in_axes=(None, 1, None), out_axes=1)(mo_occ == 0, mo_coeff, 0)

return jnp.einsum("sab,sac,scd->bd", C_vir.conj(), F, C_occ)

##################### Loss Functions #####################

@partial(value_and_grad, has_aux=True)
def mse_energy_loss(params: PyTree, molecule_predictor: Callable,
molecules: list[Molecule], truth_energies: Float[Array, "energy"], elec_num_norm: Scalar=True
) -> Scalar:
r"""
Computes the default loss function, here MSE, between predicted and true energy.
This loss function does not yet support parallel execution for the loss contributions
and instead implemented a simple for loop.
Parameters
----------
params: PyTree
functional parameters (weights)
molecule_predict: Callable.
Use molecule_predict = molecule_predictor(functional) to generate it.
molecule: Molecule
trueenergy: float
Returns
----------
"""

def unnorm_sum():
def increment_loss(i, energy_sum):
E_predict, _ = molecule_predictor(params, molecules[i])
energy_sum += (E_predict - truth_energies[i])**2
return fori_loop(0, len(molecules), increment_loss, 0)

def norm_sum():
def increment_loss_norm(i, energy_sum):
E_predict, _ = molecule_predictor(params, molecules[i])
energy_sum += ((E_predict - truth_energies[i])/molecules[i].num_elec)**2
return fori_loop(0, len(molecules), increment_loss_norm, 0)

energy_sum = cond(elec_num_norm, norm_sum, unnorm_sum)

cost_value = energy_sum/len(molecules)

return cost_value

0 comments on commit 52fa2b2

Please sign in to comment.