Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
loganbvh committed Sep 22, 2023
1 parent 34dae37 commit c3e1de4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 24 deletions.
2 changes: 1 addition & 1 deletion tdgl/solver/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def solve_for_psi_squared(
except Exception:
logger.warning("Unable to solve for |psi|^2.", exc_info=True)
return None
if bool(xp.any(discriminant < 0)):
if xp.any(discriminant < 0):
return None
new_sq_psi = (2 * w2) / (two_c_1 + xp.sqrt(discriminant))
psi = w - z * new_sq_psi
Expand Down
4 changes: 2 additions & 2 deletions tdgl/solver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def save_step(step):
) as pbar:
for i in it:
try:
dt = float(self.dt)
dt = self.dt
self.state["step"] = i
self.state["time"] = self.time
self.state["dt"] = dt
Expand Down Expand Up @@ -375,7 +375,7 @@ def save_step(step):
pbar.update(end_time - self.time)
if self.time >= end_time:
break
self.dt = float(new_dt)
self.dt = new_dt
self.running_state.step += 1
self.time += self.dt
except KeyboardInterrupt:
Expand Down
25 changes: 4 additions & 21 deletions tdgl/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,32 +447,17 @@ def update(
psi_laplacian,
options,
)

if use_cupy:
cupy.cuda.get_current_stream().synchronize()

# Compute the supercurrent, scalar potential, and normal current
supercurrent = operators.get_supercurrent(psi)

if use_cupy:
cupy.cuda.get_current_stream().synchronize()

rhs = (divergence @ (supercurrent - dA_dt)) - (
mu_boundary_laplacian @ mu_boundary
)
if options.sparse_solver is SparseSolver.PARDISO:
mu = pypardiso.spsolve(mu_laplacian, rhs)
else:
mu = mu_laplacian_lu(rhs)

if use_cupy:
cupy.cuda.get_current_stream().synchronize()

normal_current = -(mu_gradient @ mu) - dA_dt

if use_cupy:
cupy.cuda.get_current_stream().synchronize()

if not options.include_screening:
break

Expand Down Expand Up @@ -517,15 +502,13 @@ def update(
if options.adaptive:
# Compute the max abs change in |psi|^2, averaged over the adaptive window,
# and use it to select a new time step.
self.d_psi_sq_vals.append(xp.absolute(abs_sq_psi - old_sq_psi).max())
self.d_psi_sq_vals.append(float(xp.absolute(abs_sq_psi - old_sq_psi).max()))
window = options.adaptive_window
if step > window:
new_dt = options.dt_init / xp.clip(
xp.array(self.d_psi_sq_vals[-window:]).mean(),
1e-10,
xp.inf,
new_dt = options.dt_init / max(
1e-10, np.mean(self.d_psi_sq_vals[-window:])
)
self.tentative_dt = xp.clip(0.5 * (new_dt + dt), 0, self.dt_max)
self.tentative_dt = np.clip(0.5 * (new_dt + dt), 0, self.dt_max)

results = (
dt,
Expand Down

0 comments on commit c3e1de4

Please sign in to comment.