Skip to content

Commit

Permalink
Forgot to commit molecule changes for mo grads
Browse files Browse the repository at this point in the history
  • Loading branch information
jackbaker1001 committed Nov 10, 2023
1 parent 357434a commit 005af27
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ def get_occ(self) -> Array:
nelecs = jnp.array([self.mo_occ[i].sum() for i in range(2)], dtype=jnp.int64)
naos = self.mo_occ.shape[1]
return get_occ(self.mo_energy, nelecs, naos)

def get_mo_grads(self, *args, **kwargs):
r"""Compute the gradient of the electronic energy with respect
to the molecular orbital coefficients.
Returns:
-------
Float[Array, "orbitals orbitals"]
"""
return orbital_grad(self.mo_coeff, self.mo_occ, self.fock, *args, **kwargs)

def to_dict(self) -> dict:
r""" Returns a dictionary with the attributes of the molecule."""
Expand All @@ -337,7 +347,8 @@ def orbital_grad(
F: Float[Array, "spin orbitals orbitals"],
precision: Precision = Precision.HIGHEST
) -> Float[Array, "orbitals orbitals"]:
r""" Computes the restricted Hartree Fock orbital gradients
r"""Compute the gradient of the electronic energy with respect
to the molecular orbital coefficients.
Parameters:
----------
Expand All @@ -356,7 +367,7 @@ def orbital_grad(
Notes:
-----
# Similar to pyscf/scf/hf.py:
# Performs same task as pyscf/scf/hf.py:
occidx = mo_occ > 0
viridx = ~occidx
g = reduce(jnp.dot, (mo_coeff[:,viridx].conj().T, fock_ao,
Expand Down

0 comments on commit 005af27

Please sign in to comment.