Skip to content

Commit

Permalink
GEOS integration (#9)
Browse files Browse the repository at this point in the history
* Initialize GeosDycoreWrapper with bdt (timestep)

* Use GEOS version of constants

* 1. Add qcld to the list of tracers beings advected
2. Made GEOS specific changes to thresholds in saturation adjustment

* Accumulate diss_est

* Allow GEOS_WRAPPER to process device data

* Add clear to collector for 3rd party use. GEOS pass down timings to caller

* Make kernel analysis run a copy stencil to compute local bandwith
Parametrize tool with backend, output format

* Move constant on a env var
Add saturation adjustement threshold to const

* Remove unused if leading to empty code block

* Restrict dace to 0.14.1 due to a parsing bug

* Add guard for bdt==0
Fix bad merge for bdt with GEOS_Wrapper

* Remove unused code

* Fix theroritical timings

* Fixed a bug where pkz was being calculated twice, and the second calc was wrong

* Downgrade DaCe to 0.14.0 pending array aliasing fix

* Set default cache path for orchestrated DaCe to respect GT_CACHE_* env

* Remove previous per stencil override of default_build_folder

* Revert "Set default cache path for orchestrated DaCe to respect GT_CACHE_* env"

* Revert "Remove previous per stencil override of default_build_folder"

* Read cache_root in default dace backend

* Document faulty behavior with GT_CACHE_DIR_NAME

* Fix bad requirements syntax

* Check for the string value of CONST_VERSION directly instead of enum

* Protect constant selection more rigorusly.
Clean abort on unknown constant given

* Log constants selection

* Refactor NQ to constants.py

* Fix or explain inlined import

* Verbose runtime error when bad dt_atmos

* Verbose warm up

* re-initialize heat_source and diss_est each call, add do_skeb check to accumulation

---------

Co-authored-by: Purnendu Chakraborty <[email protected]>
Co-authored-by: Oliver Elbert <[email protected]>
  • Loading branch information
3 people authored Aug 1, 2023
1 parent 279feca commit 5334015
Show file tree
Hide file tree
Showing 15 changed files with 540 additions and 243 deletions.
2 changes: 1 addition & 1 deletion constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ cytoolz==0.11.2
# via
# gt4py
# gt4py (external/gt4py/setup.cfg)
dace==0.14.1
dace==0.14.0
# via
# -r requirements_dev.txt
# pace-dsl
Expand Down
4 changes: 4 additions & 0 deletions doc_primer_orchestration.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ _Parsing errors_

DaCe cannot parse _any_ dynamic Python and any code that allocates memory on the fly (think list creation). It will also complain about any arguments it can't memory describe (remember `dace_compiletime_args` ).

_GT_CACHE_DIR_NAME_

We do not honor the `GT_CACHE_DIR_NAME` with orchestration. `GT_CACHE_ROOT` is respected.

Conclusion
----------

Expand Down
2 changes: 1 addition & 1 deletion driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,5 +215,5 @@ def _transform_horizontal_grid(
grid.data[:, :, 0] = lon_transform[:]
grid.data[:, :, 1] = lat_transform[:]

metric_terms._grid.data[:] = grid.data[:]
metric_terms._grid.data[:] = grid.data[:] # type: ignore[attr-defined]
metric_terms._init_agrid()
4 changes: 4 additions & 0 deletions driver/pace/driver/performance/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def __init__(self, experiment_name: str, comm: pace.util.Comm):
self.experiment_name = experiment_name
self.comm = comm

def clear(self):
self.times_per_step = []
self.hits_per_step = []

def collect_performance(self):
"""
Take the accumulated timings and flush them into a new entry
Expand Down
38 changes: 36 additions & 2 deletions driver/pace/driver/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,48 @@
type=click.STRING,
)
@click.option("--report_detail", is_flag=True, type=click.BOOL, default=False)
def command_line(action: str, sdfg_path: Optional[str], report_detail: Optional[bool]):
@click.option(
"--hardware_bw_in_gb_s",
required=False,
type=click.FLOAT,
default=0.0,
)
@click.option(
"--output_format",
required=False,
type=click.STRING,
default=None,
)
@click.option(
"--backend",
required=False,
type=click.STRING,
default="dace:gpu",
)
def command_line(
action: str,
sdfg_path: Optional[str],
report_detail: Optional[bool],
hardware_bw_in_gb_s: Optional[float],
output_format: Optional[str],
backend: Optional[str],
):
"""
Run tooling.
"""
if action == ACTION_SDFG_MEMORY_STATIC_ANALYSIS:
print(memory_static_analysis_from_path(sdfg_path, detail_report=report_detail))
elif action == ACTION_SDFG_KERNEL_THEORETICAL_TIMING:
print(kernel_theoretical_timing_from_path(sdfg_path))
print(
kernel_theoretical_timing_from_path(
sdfg_path,
hardware_bw_in_GB_s=(
None if hardware_bw_in_gb_s == 0 else hardware_bw_in_gb_s
),
backend=backend,
output_format=output_format,
)
)


if __name__ == "__main__":
Expand Down
148 changes: 113 additions & 35 deletions dsl/pace/dsl/dace/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import json
import time
from dataclasses import dataclass, field
from typing import Dict, List
from typing import Dict, List, Optional

import dace
import numpy as np
from dace.transformation.helpers import get_parent_map
from gt4py.cartesian.gtscript import PARALLEL, computation, interval

from pace.dsl.dace.dace_config import DaceConfig
from pace.dsl.stencil import CompilationConfig, FrozenStencil, StencilConfig
from pace.dsl.typing import Float, FloatField
from pace.util._optional_imports import cupy as cp
from pace.util.logging import pace_log


# ----------------------------------------------------------
# Rough timer & log for major operations of DaCe build stack
# ----------------------------------------------------------
class DaCeProgress:
"""Timer and log to track build progress"""

Expand Down Expand Up @@ -48,6 +56,9 @@ def _is_ref(sd: dace.sdfg.SDFG, aname: str):
return found


# ----------------------------------------------------------
# Memory analyser from SDFG
# ----------------------------------------------------------
@dataclass
class ArrayReport:
name: str = ""
Expand Down Expand Up @@ -175,19 +186,38 @@ def memory_static_analysis_from_path(sdfg_path: str, detail_report=False) -> str
)


# TODO (floriand): in order for the timing analysis to be realistic the reference
# bandwidth of the hardware should be measured with GT4Py & simple in/out copy
# stencils. This allows to both measure the _actual_ deployed hardware and
# size it against the current GT4Py version.
# Below we bypass this needed automation by writing the P100 bw on Piz Daint
# measured with the above strategy.
# A better tool would allow this measure with a simple command and allow
# a one command that measure bw & report kernel analysis in one command
_HARDWARE_BW_GB_S = {"P100": 492.0}
# ----------------------------------------------------------
# Theoritical bandwith from SDFG
# ----------------------------------------------------------
def copy_defn(q_in: FloatField, q_out: FloatField):
with computation(PARALLEL), interval(...):
q_in = q_out


class MaxBandwithBenchmarkProgram:
def __init__(self, size, backend) -> None:
from pace.dsl.dace.orchestration import DaCeOrchestration, orchestrate

dconfig = DaceConfig(None, backend, orchestration=DaCeOrchestration.BuildAndRun)
c = CompilationConfig(backend=backend)
s = StencilConfig(dace_config=dconfig, compilation_config=c)
self.copy_stencil = FrozenStencil(
func=copy_defn,
origin=(0, 0, 0),
domain=size,
stencil_config=s,
)
orchestrate(obj=self, config=dconfig)

def __call__(self, A, B, n: int):
for i in dace.nounroll(range(n)):
self.copy_stencil(A, B)


def kernel_theoretical_timing(
sdfg: dace.sdfg.SDFG, hardware="P100", hardware_bw_in_Gb_s=None
sdfg: dace.sdfg.SDFG,
hardware_bw_in_GB_s=None,
backend=None,
) -> Dict[str, float]:
"""Compute a lower timing bound for kernels with the following hypothesis:
Expand All @@ -197,6 +227,39 @@ def kernel_theoretical_timing(
- Memory pressure is mostly in read/write from global memory, inner scalar & shared
memory is not counted towards memory movement.
"""
if not hardware_bw_in_GB_s:
size = np.array(sdfg.arrays["__g_self__w"].shape)
print(
f"Calculating experimental hardware bandwith on {size}"
f" arrays at {Float} precision..."
)
bench = MaxBandwithBenchmarkProgram(size, backend)
if backend == "dace:gpu":
A = cp.ones(size, dtype=Float)
B = cp.ones(size, dtype=Float)
else:
A = np.ones(size, dtype=Float)
B = np.ones(size, dtype=Float)
n = 1000
m = 4
dt = []
# Warm up run (build, allocation)
# to remove from timing the common runtime
bench(A, B, n)
# Time
for _ in range(m):
s = time.time()
bench(A, B, n)
dt.append((time.time() - s) / n)
memory_size_in_b = np.prod(size) * np.dtype(Float).itemsize * 8
bandwidth_in_bytes_s = memory_size_in_b / np.median(dt)
print(
f"Hardware bandwith computed: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s"
)
else:
bandwidth_in_bytes_s = hardware_bw_in_GB_s * 1024 * 1024 * 1024
print(f"Given hardware bandwith: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s")

allmaps = [
(me, state)
for me, state in sdfg.all_nodes_recursive()
Expand Down Expand Up @@ -228,19 +291,6 @@ def kernel_theoretical_timing(
]
)

# Compute hardware memory bandwidth in bytes/us
if hardware_bw_in_Gb_s and hardware in _HARDWARE_BW_GB_S.keys():
raise NotImplementedError("can't specify hardware bandwidth and hardware")
if hardware_bw_in_Gb_s:
bandwidth_in_bytes_s = hardware_bw_in_Gb_s * 1024 * 1024 * 1024
elif hardware in _HARDWARE_BW_GB_S.keys():
# Time it has to take (at least): bytes / bandwidth_in_bytes_s
bandwidth_in_bytes_s = _HARDWARE_BW_GB_S[hardware] * 1024 * 1024 * 1024
else:
print(
f"Timing analysis: hardware {hardware} unknown and no bandwidth given"
)

in_us = 1000 * 1000

# Theoretical fastest timing
Expand All @@ -249,39 +299,67 @@ def kernel_theoretical_timing(
except TypeError:
newresult_in_us = (alldata_in_bytes / bandwidth_in_bytes_s) * in_us

if node.label in result:
import sympy
# We keep sympy import here because sympy is known to be a problematic
# import and an heavy module which should be avoided if possible.
# TODO: refactor it out by shadow-coding the sympy.Max/Eval functions
import sympy

if node.label in result:
newresult_in_us = sympy.Max(result[node.label], newresult_in_us).expand()
try:
newresult_in_us = float(newresult_in_us)
except TypeError:
pass

# Bad expansion
if not isinstance(newresult_in_us, float):
if not isinstance(newresult_in_us, sympy.core.numbers.Float) and not isinstance(
newresult_in_us, float
):
continue

result[node.label] = newresult_in_us
result[node.label] = float(newresult_in_us)

return result


def report_kernel_theoretical_timing(
timings: Dict[str, float], human_readable: bool = True, csv: bool = False
timings: Dict[str, float],
human_readable: bool = True,
out_format: Optional[str] = None,
) -> str:
"""Produce a human readable or CSV of the kernel timings"""
result_string = f"Maps processed: {len(timings)}.\n"
if human_readable:
result_string += "Timing in microseconds Map name:\n"
result_string += "\n".join(f"{v:.2f}\t{k}," for k, v in sorted(timings.items()))
if csv:
result_string += "#Map name,timing in microseconds\n"
result_string += "\n".join(f"{k},{v}," for k, v in sorted(timings.items()))
if out_format == "csv":
csv_string = ""
csv_string += "#Map name,timing in microseconds\n"
csv_string += "\n".join(f"{k},{v}," for k, v in sorted(timings.items()))
with open("kernel_theoretical_timing.csv", "w") as f:
f.write(csv_string)
elif out_format == "json":
with open("kernel_theoretical_timing.json", "w") as f:
json.dump(timings, f, indent=2)

return result_string


def kernel_theoretical_timing_from_path(sdfg_path: str) -> str:
def kernel_theoretical_timing_from_path(
sdfg_path: str,
hardware_bw_in_GB_s: Optional[float] = None,
backend: Optional[str] = None,
output_format: Optional[str] = None,
) -> str:
"""Load an SDFG and report the theoretical kernel timings"""
timings = kernel_theoretical_timing(dace.SDFG.from_file(sdfg_path))
return report_kernel_theoretical_timing(timings, human_readable=True, csv=False)
print(f"Running kernel_theoretical_timing for {sdfg_path}")
timings = kernel_theoretical_timing(
dace.SDFG.from_file(sdfg_path),
hardware_bw_in_GB_s=hardware_bw_in_GB_s,
backend=backend,
)
return report_kernel_theoretical_timing(
timings,
human_readable=True,
out_format=output_format,
)
5 changes: 3 additions & 2 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,9 @@ def __init__(
if "dace" in self.stencil_config.compilation_config.backend:
dace.Config.set(
"default_build_folder",
value="{gt_cache}/dacecache".format(
gt_cache=gt4py.cartesian.config.cache_settings["dir_name"]
value="{gt_root}/{gt_cache}/dacecache".format(
gt_root=gt4py.cartesian.config.cache_settings["root_path"],
gt_cache=gt4py.cartesian.config.cache_settings["dir_name"],
),
)

Expand Down
Loading

0 comments on commit 5334015

Please sign in to comment.