-
Notifications
You must be signed in to change notification settings - Fork 0
/
bdpt.py
44 lines (34 loc) · 1.26 KB
/
bdpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import mitsuba as mi
import drjit as dr
import matplotlib.pyplot as plt
from typing import Type, TypeVar, overload
mi.set_variant("llvm_ad_rgb")
# dr.set_log_level(dr.LogLevel.Debug)
T = TypeVar("T")
class Path:
idx: mi.UInt32
def __init__(self, dtype: Type[T], n_rays: int, max_depth: int):
self.n_rays = n_rays
self.max_depth = max_depth
self.idx = dr.arange(mi.UInt32, n_rays)
self.dtype = dtype
self.vertices = dr.zeros(dtype, shape=(self.max_depth * self.n_rays))
def __setitem__(self, depth: mi.UInt32, value: T):
dr.scatter(self.vertices, value, depth * self.n_rays + self.idx)
# Return vertex at depth
@overload
def __getitem__(self, depth: mi.UInt32) -> T:
...
# Return a vertex at (depth, ray_index)
@overload
def __getitem__(self, idx: tuple[mi.UInt32, mi.UInt32]) -> T:
...
def __getitem__(self, idx):
if isinstance(idx, mi.UInt32):
return dr.gather(self.dtype, self.vertices, idx * self.n_rays + self.idx)
if (
isinstance(idx, tuple)
and isinstance(idx[0], mi.UInt32)
and isinstance(idx[1], mi.UInt32)
):
return dr.gather(self.dtype, self.vertices, idx[0] * self.n_rays + idx[1])