From d44551e5f7c6933cc9eb5c84fd591b7721aac13e Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Thu, 5 Sep 2024 12:48:18 -0400 Subject: [PATCH] typing works --- ndsl/stencils/testing/translate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index a25316a..0914308 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -5,7 +5,7 @@ import ndsl.dsl.gt4py_utils as utils from ndsl.dsl.stencil import StencilFactory -from ndsl.dsl.typing import Field, Float # noqa: F401 +from ndsl.dsl.typing import Field, Float, Int # noqa: F401 from ndsl.quantity import Quantity from ndsl.stencils.testing.grid import Grid # type: ignore @@ -113,6 +113,12 @@ def make_storage_data( elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1: use_shape[1] = 1 start = (int(istart), int(jstart), int(kstart)) + if 'float' in str(array.dtype): + dtype = Float + elif 'int' in str(array.dtype): + dtype = Int + else: + dtype = array.dtype if names_4d: return utils.make_storage_dict( array, @@ -123,7 +129,7 @@ def make_storage_data( axis=axis, names=names_4d, backend=self.stencil_factory.backend, - dtype=array.dtype, + dtype=dtype, ) else: if len(array.shape) == 4: @@ -138,7 +144,7 @@ def make_storage_data( axis=axis, read_only=read_only, backend=self.stencil_factory.backend, - dtype=array.dtype, + dtype=dtype, ) def storage_vars(self):