Skip to content

Commit

Permalink
Tests for dynamic epsilon
Browse files Browse the repository at this point in the history
  • Loading branch information
loganbvh committed Oct 27, 2023
1 parent 0f3dd1f commit bbe5616
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tdgl/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def update_epsilon(self, time: float) -> np.ndarray:
epsilon = np.array(

Check warning on line 369 in tdgl/solver/solver.py

View check run for this annotation

Codecov / codecov/patch

tdgl/solver/solver.py#L369

Added line #L369 was not covered by tests
[float(self.disorder_epsilon(r, t=time)) for r in self.sites]
)
if self.use_cupy:
epsilon = cupy.asarray(epsilon)
return epsilon

Check warning on line 374 in tdgl/solver/solver.py

View check run for this annotation

Codecov / codecov/patch

tdgl/solver/solver.py#L372-L374

Added lines #L372 - L374 were not covered by tests

@staticmethod
Expand Down
24 changes: 21 additions & 3 deletions tdgl/test/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
@pytest.mark.parametrize("current", [5.0, lambda t: 10])
@pytest.mark.parametrize("field", [0, 1])
@pytest.mark.parametrize(
"terminal_psi, time_dependent, gpu",
[(0, False, True), (1, False, False), (1, True, True)],
"terminal_psi, time_dependent, gpu, vectorized",
[(0, False, True, True), (1, False, False, False), (1, True, True, True)],
)
def test_source_drain_current(
transport_device,
Expand All @@ -25,6 +25,7 @@ def test_source_drain_current(
terminal_psi,
time_dependent,
gpu,
vectorized,
):
device = transport_device
total_time = 10
Expand Down Expand Up @@ -76,6 +77,17 @@ def terminal_currents(t):
applied_vector_potential=field,
terminal_currents=terminal_currents,
)

if vectorized:

def disorder_epsilon(r):
return 1.0 * np.ones(len(r))

else:

def disorder_epsilon(r):
return 1.0

if time_dependent:
ramp = tdgl.sources.LinearRamp(tmin=1, tmax=8)
constant_field = tdgl.sources.ConstantField(
Expand All @@ -85,10 +97,16 @@ def terminal_currents(t):
)
field = ramp * constant_field
field = constant_field * ramp

_disorder_epsilon = disorder_epsilon

def disorder_epsilon(r, *, t, vectorized=vectorized):
return _disorder_epsilon(r)

solution = tdgl.solve(
device,
options,
disorder_epsilon=lambda r: 1,
disorder_epsilon=disorder_epsilon,
applied_vector_potential=field,
terminal_currents=terminal_currents,
)
Expand Down

0 comments on commit bbe5616

Please sign in to comment.