Skip to content

Commit

Permalink
Updated logic of unit tests in test_basic_operations.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fmalatino committed Mar 29, 2024
1 parent 5e4daf4 commit 7f0d0f0
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tests/test_basic_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from ndsl.stencils import basic_operations as basic


nx = 5
ny = 5
nz = 1
nx = 20
ny = 20
nz = 79
nhalo = 0
backend = "numpy"

Expand Down Expand Up @@ -125,88 +125,88 @@ def test_copy():
copy = Copy(stencil_factory)

infield = Quantity(
data=np.zeros([5, 5, 1]),
data=np.zeros([20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

outfield = Quantity(
data=np.ones([5, 5, 1]),
data=np.ones([20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

copy(f_in=infield.data, f_out=outfield.data)

assert infield.data.any() == outfield.data.any()
assert (infield.data == outfield.data).any()


def test_adjustmentfactor():
adfact = AdjustmentFactor(stencil_factory)

factorfield = Quantity(
data=np.full(shape=[5, 5], fill_value=2.0),
data=np.full(shape=[20, 20], fill_value=2.0),
dims=[X_DIM, Y_DIM],
units="m",
)

outfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=2.0),
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=4.0),
data=np.full(shape=[20, 20, 79], fill_value=4.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

adfact(factor=factorfield.data, f_out=outfield.data)
assert outfield.data.any() == testfield.data.any()
assert (outfield.data == testfield.data).any()


def test_setvalue():
setvalue = SetValue(stencil_factory)

outfield = Quantity(
data=np.zeros(shape=[5, 5, 1]),
data=np.zeros(shape=[20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=2.0),
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

setvalue(f_out=outfield.data, value=2.0)

assert outfield.data.any() == testfield.data.any()
assert (outfield.data == testfield.data).any()


def test_adjustdivide():
addiv = AdjustDivide(stencil_factory)

factorfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=2.0),
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

outfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=2.0),
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[5, 5, 1], fill_value=1.0),
data=np.full(shape=[20, 20, 79], fill_value=1.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

addiv(factor=factorfield.data, f_out=outfield.data)

assert outfield.data.any() == testfield.data.any()
assert (outfield.data == testfield.data).any()

0 comments on commit 7f0d0f0

Please sign in to comment.