Skip to content

Commit

Permalink
Merge branch 'new_bench' into 'main'
Browse files Browse the repository at this point in the history
Release cuquantum-benchmarks v0.3.0 + Fix a sample

See merge request cuda-hpc-libraries/cuquantum-sdk/cuquantum-public!23
  • Loading branch information
leofang committed Jul 19, 2023
2 parents 6a7fa3b + 92a18e9 commit b75eb45
Show file tree
Hide file tree
Showing 20 changed files with 1,237 additions and 370 deletions.
6 changes: 4 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install .[all]
```
if running outside of the [cuQuantum Appliance container](https://docs.nvidia.com/cuda/cuquantum/latest/appliance/index.html).

**Note: You may have to build `qsimcirq` and `qiskit-aer` GPU support from source if needed.**
**Note: You may have to build `qsimcirq`, `qiskit-aer`, and `qulacs` GPU support from source if needed.**

Alternatively, you can choose to manage all (required & optional) dependencies yourself via
```
Expand Down Expand Up @@ -44,7 +44,7 @@ Starting v0.2.0, we offer subcommands for performing benchmarks at different lev

Alternatively, you can launch the benchmark program via `python -m cuquantum_benchmarks`. This is equivalent to the standalone command, and is useful when, say, `pip` installs this package to the user site-package (so that the `cuquantum-benchmarks` command may not be available without modifying `$PATH`).

For GPU backends, it is preferred that `--ngpus` is explicitly set.
For GPU backends, it is preferred that `--ngpus N` is explicitly set. On a multi-GPU system, the first `N` GPUs would be used. To limit which GPUs can be accessed by the CUDA runtime, use the environment variable `CUDA_VISIBLE_DEVICES` following the CUDA documentation.

For backends that support MPI parallelism, it is assumed that `MPI_COMM_WORLD` is the communicator, and that `mpi4py` is installed. You can run the benchmarks as you would normally do to launch MPI processes: `mpiexec -n N cuquantum-benchmarks ...`. It is preferred if you fully specify the problem (explicitly set `--benchmark` & `--nqubits`).

Expand All @@ -70,6 +70,8 @@ Currently all environment variables are reserved for internal use only, and are

* `CUTENSORNET_DUMP_TN=txt`
* `CUTENSORNET_BENCHMARK_TARGET={amplitude,state_vector,expectation}` (pick one)
* `CUTENSORNET_APPROX_TN_UTILS_PATH`
* `CUQUANTUM_BENCHMARKS_DUMP_GATES`

## Development Overview

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cuquantum_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: BSD-3-Clause

__version__ = '0.2.0'
__version__ = '0.3.0'
64 changes: 54 additions & 10 deletions benchmarks/cuquantum_benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import ctypes
from dataclasses import dataclass
import functools
import math
import json
import hashlib
Expand All @@ -19,6 +20,7 @@

import cupy as cp
import numpy as np
import nvtx
from cuquantum import cudaDataType, ComputeType
from cuquantum.cutensornet._internal.einsum_parser import create_size_dict
import psutil
Expand All @@ -29,6 +31,15 @@
logger = logging.getLogger(logger_name)


def wrap_with_nvtx(func, msg):
"""Add NVTX makers to a function with a message."""
@functools.wraps(func)
def inner(*args, **kwargs):
with nvtx.annotate(msg):
return func(*args, **kwargs)
return inner


def reseed(seed=1234):
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -162,11 +173,16 @@ def is_running_mpi():
return MPI


def get_num_processes():
def get_mpi_size():
MPI = is_running_mpi()
return MPI.COMM_WORLD.Get_size() if MPI else 1


def get_mpi_rank():
MPI = is_running_mpi()
return MPI.COMM_WORLD.Get_rank() if MPI else 0


def call_by_root(f, root=0):
""" Call the callable f only by the root process. """
MPI = is_running_mpi()
Expand Down Expand Up @@ -409,7 +425,7 @@ def dump():
return full_data


def load_benchmark_data(filepath, cache_dir, required_subdirs=()):
def load_benchmark_data(filepath):
try:
with open(filepath, 'r') as f:
full_data = json.load(f)
Expand All @@ -419,17 +435,16 @@ def load_benchmark_data(filepath, cache_dir, required_subdirs=()):
full_data = {}
logger.debug(f'{filepath} not found')

# it could be that the cache dirs are not created yet
def create_cache():
for subdir in required_subdirs:
path = os.path.join(cache_dir, subdir)
if not os.path.isdir(path):
os.makedirs(path, exist_ok=True)
call_by_root(create_cache)

return full_data


def create_cache(cache_dir, required_subdirs):
for subdir in required_subdirs:
path = os.path.join(cache_dir, subdir)
if not os.path.isdir(path):
os.makedirs(path, exist_ok=True)


# TODO: upstream this to cupyx.profiler.benchmark
class L2flush:
""" Handly utility for flushing the current device's L2 cache.
Expand Down Expand Up @@ -496,3 +511,32 @@ class _Result: pass
result.gpu_times = gpu_times

return result


class EarlyReturnError(RuntimeError): pass


is_unique = lambda a: len(set(a)) == len(a)
is_disjoint = lambda a, b: not bool(set(a) & set(b))


def check_targets_controls(targets, controls, n_qubits):
# simple checks for targets and controls
assert len(targets) >= 1, "must have at least 1 target qubit"
assert is_unique(targets), "qubit indices in targets must be unique"
assert is_unique(controls), "qubit indices in controls must be unique"
assert is_disjoint(targets, controls), "qubit indices in targets and controls must be disjoint"
assert all(0 <= q and q < n_qubits for q in targets + controls), f"target and control qubit indices must be in range [0, {n_qubits})"


def check_sequence(seq, expected_size=None, max_size=None, name=''):
if expected_size is not None:
assert len(seq) == expected_size, f"the provided {name} must be of length {expected_size}"
size = expected_size
elif max_size is not None:
assert len(seq) <= max_size, f"the provided {name} must have length <= {max_size}"
size = max_size
else:
assert False
assert is_unique(seq), f"the provided {name} must have non-repetitve entries"
assert all(0 <= i and i < size for i in seq), f"entries in the {name} must be in [0, {size})"
4 changes: 3 additions & 1 deletion benchmarks/cuquantum_benchmarks/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from .backend_cirq import Cirq
from .backend_cutn import cuTensorNet
from .backend_pny import Pny, PnyLightningGpu, PnyLightningCpu, PnyLightningKokkos
from .backend_pny import (Pny, PnyLightningGpu, PnyLightningCpu,
PnyLightningKokkos, PnyDumper)
from .backend_qsim import Qsim, QsimCuda, QsimCusv, QsimMgpu
from .backend_qiskit import Aer, AerCuda, AerCusv, CusvAer
from .backend_qulacs import QulacsGpu, QulacsCpu
Expand All @@ -29,6 +30,7 @@
'pennylane-lightning-gpu': PnyLightningGpu,
'pennylane-lightning-qubit': PnyLightningCpu,
'pennylane-lightning-kokkos': PnyLightningKokkos,
'pennylane-dumper': PnyDumper,
'qulacs-cpu': QulacsCpu,
'qulacs-gpu': QulacsGpu,
}
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/cuquantum_benchmarks/backends/backend_cutn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, ngpus, ncpu_threads, precision, **kwargs):
# cuQuantum Python 22.07 or below
opts = cutn.NetworkOptions(handle=self.handle)
self.network_opts = opts
self.n_samples = kwargs.pop('nhypersamples')

def __del__(self):
cutn.destroy(self.handle)
Expand Down Expand Up @@ -104,10 +105,12 @@ def preprocess_circuit(self, circuit, *args, **kwargs):
t1 = time.perf_counter()
path, opt_info = self.network.contract_path(
# TODO: samples may be too large for small circuits
optimize={'samples': 512, 'threads': self.ncpu_threads})
optimize={'samples': self.n_samples, 'threads': self.ncpu_threads})
t2 = time.perf_counter()
time_path = t2 - t1
logger.info(f'contract_path() took {time_path} s')
logger.debug(f'# samples: {self.n_samples}')
logger.debug(opt_info)

self.path = path
self.opt_info = opt_info
Expand Down
25 changes: 22 additions & 3 deletions benchmarks/cuquantum_benchmarks/backends/backend_pny.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import time
import warnings
import sys

import numpy as np
try:
Expand All @@ -15,7 +16,7 @@
pennylane = None

from .backend import Backend
from .._utils import is_running_mpi
from .._utils import call_by_root, EarlyReturnError, is_running_mpi


# set up a logger
Expand Down Expand Up @@ -80,6 +81,23 @@ def _make_qnode(self, circuit, nshots=1024, **kwargs):
if self.ngpus != 0:
raise ValueError(f"cannot specify --ngpus for the backend {self.identifier}")
dev = pennylane.device("default.qubit", wires=self.nqubits, shots=nshots, c_dtype=self.dtype)
elif self.identifier == "pennylane-dumper":
import cloudpickle
import cuquantum_benchmarks
cloudpickle.register_pickle_by_value(cuquantum_benchmarks)

# note: before loading the pickle, one should check if the Python version agrees
# (probably pennylane's version too)
py_major_minor = f'{sys.version_info.major}.{sys.version_info.minor}'
circuit_filename = kwargs.pop('circuit_filename')
circuit_filename += f"_pny_raw_py{py_major_minor}.pickle"
def dump():
logger.info(f"dumping pennylane (raw) circuit as {circuit_filename} ...")
with open(circuit_filename, 'wb') as f:
cloudpickle.dump(circuit, f) # use highest protocol
logger.info("early exiting as the dumper task is completed")
call_by_root(dump)
raise EarlyReturnError
else:
raise ValueError(f"the backend {self.identifier} is not recognized")

Expand All @@ -89,9 +107,9 @@ def _make_qnode(self, circuit, nshots=1024, **kwargs):
def preprocess_circuit(self, circuit, *args, **kwargs):
nshots = kwargs.get('nshots', 1024)
t1 = time.perf_counter()
self.circuit = self._make_qnode(circuit, nshots)
self.circuit = self._make_qnode(circuit, nshots, **kwargs)
t2 = time.perf_counter()
time_make_qnode = t2-t1
time_make_qnode = t2 - t1
logger.info(f'make qnode took {time_make_qnode} s')
return {'make_qnode': time_make_qnode}

Expand All @@ -107,3 +125,4 @@ def run(self, circuit, nshots=1024):
PnyLightningCpu = functools.partial(Pennylane, identifier='pennylane-lightning-qubit')
PnyLightningKokkos = functools.partial(Pennylane, identifier='pennylane-lightning-kokkos')
Pny = functools.partial(Pennylane, identifier='pennylane')
PnyDumper = functools.partial(Pennylane, identifier='pennylane-dumper')
18 changes: 12 additions & 6 deletions benchmarks/cuquantum_benchmarks/backends/backend_qiskit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import logging
import time
from importlib.metadata import version

import numpy as np
import cupy as cp
Expand All @@ -15,7 +16,7 @@
qiskit = None

from .backend import Backend
from .._utils import get_num_processes
from .._utils import get_mpi_size, get_mpi_rank


# set up a logger
Expand Down Expand Up @@ -48,8 +49,8 @@ def run(self, circuit, nshots=1024):
results = self.backend.run(transpiled_qc, shots=nshots, memory=True)
else:
results = self.backend.run(transpiled_qc, shots=0, memory=True)
# workaround for memory allocation failure for cusvaer 22.11
if self.identifier == 'cusvaer':
# workaround for memory allocation failure for cusvaer 22.11/23.03
if self.identifier == 'cusvaer' and self._need_sync():
self._synchronize()

post_res_list = results.result().get_memory()
Expand Down Expand Up @@ -169,7 +170,7 @@ def create_aer_backend(self, identifier, ngpus, ncpu_threads, *args, **kwargs):
return backend

def get_aer_blocking_setup(self, ngpus=None):
size = get_num_processes() # check if running MPI
size = get_mpi_size() # check if running MPI
if size > 1:
blocking_enable = True
if self.identifier == 'aer':
Expand All @@ -182,11 +183,16 @@ def get_aer_blocking_setup(self, ngpus=None):
blocking_qubits = None
return blocking_enable, blocking_qubits

def _need_sync(self):
ver_str = version('cusvaer')
ver = [int(num) for num in ver_str.split('.')]
return ver[0] == 0 and ver[1] <= 2

def _synchronize(self):
nprocs = get_num_processes()
my_rank = get_mpi_rank()
ndevices_in_node = cp.cuda.runtime.getDeviceCount()
# GPU selected in this process
device_id = nprocs % ndevices_in_node
device_id = my_rank % ndevices_in_node
cp.cuda.Device(device_id).synchronize()


Expand Down
Loading

0 comments on commit b75eb45

Please sign in to comment.