Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
massimim committed Aug 28, 2024
1 parent 32e1779 commit 7f11a78
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
10 changes: 5 additions & 5 deletions libNeonPy/src/Neon/py/CudaDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ auto CudaDriver::run_kernel(
int const ndevs = backend.getDeviceCount();
// #pragma omp parallel for num_threads(ndevs)
for (int setIdx = 0; setIdx < ndevs; setIdx++) {
backend.devSet().setActiveDevContext(setIdx);
cudaStream_t const& cuda_stream = streamSet.cudaStream(setIdx);
CUstream driverStream = (CUstream)cuda_stream;
CUfunction function = static_cast<CUfunction>(kernelSet[setIdx]);
backend.devSet().setActiveDevContext(setIdx);
auto& launch_info = launch_params[setIdx];

// auto const cudaGrid = launch_info.cudaGrid();
Expand Down Expand Up @@ -110,10 +110,10 @@ auto CudaDriver::run_kernel(
// }
// int block_dim = 256;
// int grid_dim = (n + block_dim - 1) / block_dim;
// std::cout << "block_dim " << block_dim << std::endl;
// std::cout << "grid_dim " << grid_dim << std::endl;
// std::cout << "n " << n << std::endl;
// std::cout << "cuLaunchKernel" << std::endl;
// std::cout << "block_dim " << launch_info.toString()<< std::endl;
// std::cout << "grid_dim " << launch_info << std::endl;
// std::cout << "n " << n << std::endl;
// std::cout << "cuLaunchKernel" << std::endl;

res = cuLaunchKernel(
function,
Expand Down
1 change: 1 addition & 0 deletions py_neon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dataview import DataView
from .execution import Execution
from .index_3d import Index_3d
from .ngh_idx import Ngh_idx

from .dense.__init__ import *
from .block.__init__ import *
Expand Down
43 changes: 43 additions & 0 deletions py_neon/ngh_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import ctypes
import typing


class Ngh_idx(ctypes.Structure):
_fields_ = [("x", ctypes.c_int8),
("y", ctypes.c_int8),
("z", ctypes.c_int8)]

def __init__(self,
x: int,
y: int,
z: int):
self.x = x
self.y = y
self.z = z

def __len__(self):
return 3

def __getitem__(self, index):
if index == 0:
return self.x
if index == 1:
return self.y
if index == 2:
return self.z
raise IndexError("Index out of range")

def to_wp_kernel_dim(self) -> typing.Tuple[int, int, int]:
return (self.x, self.y, self.z)

def __str__(self):
str = "<Index_3d: addr=%ld>" % (ctypes.addressof(self))
str += f"\n\tx: {self.x}"
str += f"\n\ty: {self.y}"
str += f"\n\tz: {self.z}"
return str

def __eq__(self, other):
if not isinstance(other, Index_3d):
return NotImplemented
return (self.x == other.x and self.y == other.y and self.z == other.z)

0 comments on commit 7f11a78

Please sign in to comment.