Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Hineven committed May 15, 2024
1 parent e7b9da1 commit cb2393a
Show file tree
Hide file tree
Showing 5 changed files with 515 additions and 194 deletions.
331 changes: 316 additions & 15 deletions src/core/src/render_techniques/migi/migi.comp
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,14 @@ void SSRC_ClearCounters() {
g_RWAllocatedProbeSGCountBuffer[0] = 0;
}


// Allocate fixed probes for current frame
// Allocate fixed uniform probes for current frame
[numthreads(WAVE_SIZE, 1, 1)]
void SSRC_AllocateFixedProbes (int DispatchID : SV_DispatchThreadID) {
ProbeHeader Header = GetScreenProbeHeader(DispatchID);
float Depth = g_DepthTexture.Load(int3(Header.ScreenCoords, 0)).x;
void SSRC_AllocateUniformProbes (int DispatchID : SV_DispatchThreadID) {
ProbeHeader Header;
Header.ScreenPosition = GetUniformScreenProbeScreenPosition(DispatchID);
float Depth = g_DepthTexture.Load(int3(Header.ScreenPosition, 0)).x;
bool bValid = Depth < 1.f;
Header.bValid = bValid;
if(bValid) Header.Rank = ComputeProbeRankFromSplattedError(Header.ScreenCoords);
if(bValid) Header.Rank = ComputeProbeRankFromSplattedError(Header.ScreenPosition);
int BasisCount = bValid ? GetProbeBasisCountFromRank(Header.Rank) : 0;
int BasisOffset = WavePrefixSum(BasisCount);
int BasisCountSum = WaveActiveSum(BasisCount);
Expand All @@ -285,32 +284,334 @@ void SSRC_AllocateFixedProbes (int DispatchID : SV_DispatchThreadID) {
InterlockedAdd(g_RWAllocatedProbeSGCountBuffer[0], BasisCountSum, BasisGroupOffset);
}
BasisGroupOffset = WaveReadLaneFirst(BasisGroupOffset);
Header.BasisOffset = BasisGroupOffset + BasisOffset;
Header.BasisOffset = BasisGroupOffset + BasisOffset;
// Negative depths stands for invalid probes
Header.LinearDepth = (bValid ? 1 : -1) * GetLinearDepth(Depth);
int2 TileCoords = int2(DispatchID % MI.TileDimensions.x, DispatchID / MI.TileDimensions.x);
Header.Position = RecoverWorldPositionHiRes(Header.ScreenPosition);
Header.Normal = normalize(g_GeometryNormalTexture.Load(int3(Header.ScreenPosition, 0)).xyz * 2.f - 1.f);
WriteScreenProbeHeader(DispatchID, Header);
}

int g_ProbeDownSampleFactor;
float3 GetUniformScreenProbePositionFromGBuffer (int2 TileCoords, float LinearDepth, bool bPrevious) {
float2 UV = GetUniformScreenProbeScreenUV(TileCoords, bPrevious);
// Multiply by linear depth to get rid of the perspective divide
float4 Position = float4(UV2NDC2(UV) * LinearDepth, LinearDepth, 1);
// Transform to world space without the perspective divide (as we already has the linear depth)
return mul(Position, MI.CameraProjViewInv).xyz;
}

void CalculateSSRCSampleWeightsForUniformScreenProbes (
float2 ScreenCoords,
float3 Position,
float LinearDepth,
float3 Normal,
out int2 ScreenTileBaseCoords,
out float4 Weights,
bool bPrevious = false
) {
int2 ScreenCoordsProbeGrid = clamp(ScreenCoords - GetTileJitter(SSRC_TILE_SIZE, bPrevious), 0, MI.ScreenDimensions - 1.xx);
int2 TileCoordsX00 = min(ScreenCoordsProbeGrid / SSRC_TILE_SIZE, MI.TileDimensions - 2);
// Pad the bilinear filtering weights
int BilinearExpand = 1;
float2 Bilinear =
(ScreenCoordsProbeGrid - TileCoordsX00 * SSRC_TILE_SIZE + BilinearExpand)
/ (float)(SSRC_TILE_SIZE + 2 * BilinearExpand);
float4 CornerLinearDepths;
CornerLinearDepths.x = GetScreenProbeLinearDepth(TileCoordsX00 + int2(0, 0), bPrevious);
CornerLinearDepths.y = GetScreenProbeLinearDepth(TileCoordsX00 + int2(1, 0), bPrevious);
CornerLinearDepths.z = GetScreenProbeLinearDepth(TileCoordsX00 + int2(0, 1), bPrevious);
CornerLinearDepths.w = GetScreenProbeLinearDepth(TileCoordsX00 + int2(1, 1), bPrevious);
Weights = float4(
(1.f - Bilinear.x) * (1.f - Bilinear.y),
Bilinear.x * (1.f - Bilinear.y),
(1.f - Bilinear.x) * Bilinear.y,
Bilinear.x * Bilinear.y
);
float4 DepthWeights;
float4 PixelPlane = float4(Normal, dot(Position, Normal));
float3 PositionX00 = GetUniformScreenProbePositionFromGBuffer(TileCoordsX00 + int2(0, 0), CornerLinearDepths.x, bPrevious);
float3 PositionX10 = GetUniformScreenProbePositionFromGBuffer(TileCoordsX00 + int2(1, 0), CornerLinearDepths.y, bPrevious);
float3 PositionX01 = GetUniformScreenProbePositionFromGBuffer(TileCoordsX00 + int2(0, 1), CornerLinearDepths.z, bPrevious);
float3 PositionX11 = GetUniformScreenProbePositionFromGBuffer(TileCoordsX00 + int2(1, 1), CornerLinearDepths.w, bPrevious);
float4 PlaneDistances;
PlaneDistances.x = abs(dot(PixelPlane, float4(PositionX00, -1.f)));
PlaneDistances.y = abs(dot(PixelPlane, float4(PositionX10, -1.f)));
PlaneDistances.z = abs(dot(PixelPlane, float4(PositionX01, -1.f)));
PlaneDistances.w = abs(dot(PixelPlane, float4(PositionX11, -1.f)));
float4 RelativeDepthDifference = PlaneDistances / LinearDepth;
// Negative depths stands for invalid probes
DepthWeights = select(
CornerLinearDepths > 0,
exp2(-10000.0f * (RelativeDepthDifference * RelativeDepthDifference)),
0.0
);
Weights *= DepthWeights;
}

void CalculateSSRCSampleWeights (
float2 ScreenCoords,
float3 Position,
float LinearDepth,
float3 Normal,
out SSRC_SampleData Sample,
bool bPrevious = false
) {
uint2 TileCoordsX00;
CalculateSSRCSampleWeightsForUniformScreenProbes(
ScreenCoords,
Position,
LinearDepth,
Normal,
TileCoordsX00,
Sample.Weights,
bPrevious
);
Sample.Index[0] = TileCoordsX00 + int2(0, 0);
Sample.Index[1] = TileCoordsX00 + int2(1, 0);
Sample.Index[2] = TileCoordsX00 + int2(0, 1);
Sample.Index[3] = TileCoordsX00 + int2(1, 1);

// Weight the adaptive probes and search within the nearest tiles
{
float Epsilon = .01f;
float4 PixelPlane = float4(Normal, dot(Position, Normal));
for (uint CornerIndex = 0; CornerIndex < 4; CornerIndex++)
{
if (Sample.Weights[CornerIndex] <= Epsilon)
{
int2 TileCoords = TileCoordsX00 + uint2(CornerIndex % 2, CornerIndex / 2);
int NumAdaptiveProbes =
bPrevious ? g_RWPreviousTileAdaptiveProbeCountTexture[TileCoords]
: g_RWTileAdaptiveProbeCountTexture[TileCoords];

for (uint AdaptiveProbeListIndex = 0; AdaptiveProbeListIndex < NumAdaptiveProbes; AdaptiveProbeListIndex++)
{
// TODO reconstruct probe data from G-Buffer to reduce VRAM bandwidth when shading

int AdaptiveProbeIndex1 = GetAdaptiveProbeIndex(TileCoords, AdaptiveProbeListIndex, bPrevious);
int ScreenProbeIndex1 = AdaptiveProbeIndex1 + MI.UniformScreenProbeCount;
int2 ScreenProbeIndex = int2(ScreenProbeIndex1 % MI.TileDimensions.x, ScreenProbeIndex1 / MI.TileDimensions.x);

ProbeHeader Header = GetScreenProbeHeader(ScreenProbeIndex1, bPrevious);
int2 ProbeScreenPosition = Header.ScreenPosition;
float ProbeLinearDepth = Header.LinearDepth;

float NewDepthWeight = 0;

float3 ProbePosition = Header.Position;
float PlaneDistance = abs(dot(float4(ProbePosition, -1), PixelPlane));
float RelativeDepthDifference = PlaneDistance / LinearDepth;
NewDepthWeight = exp2(-10000.0f * (RelativeDepthDifference * RelativeDepthDifference));

float2 DistanceToScreenProbe = abs(ProbeScreenPosition - ScreenCoords);
float NewCornerWeight = 1.0f - saturate(min(DistanceToScreenProbe.x, DistanceToScreenProbe.y) / (float)(SSRC_TILE_SIZE));
float NewInterpolationWeight = NewDepthWeight * NewCornerWeight;

if (NewInterpolationWeight > Sample.Weights[CornerIndex])
{
Sample.Weights[CornerIndex] = NewInterpolationWeight;
Sample.Index[CornerIndex] = ScreenProbeIndex;
}
}
}
}
}
}


groupshared int LocalNumProbesToAllocate;
groupshared int LocalAdaptiveProbeOffset;
groupshared int2 LocalProbeScreenPositionsToAllocate[WAVE_SIZE];
[numthreads(WAVE_SIZE, 1, 1)]
void SSRC_AllocateAdaptiveProbes (int DispatchID : SV_DispatchThreadID) {
void SSRC_AllocateAdaptiveProbes (int DispatchID : SV_DispatchThreadID, int LocalID : SV_GroupThreadID) {
if(WaveIsFirstLane()) {
LocalNumProbesToAllocate = 0;
}
GroupMemoryBarrierWithGroupSync();
{
int2 DownsampledTileDimensions = MI.TileDimensions / g_ProbeDownSampleFactor;
int2 DownsampledTileDimensions = MI.TileDimensions / g_AdaptiveProbeDownsampleFactor;
int2 TileCoords = int2(
DispatchID % DownsampledTileDimensions.x,
DispatchID / DownsampledTileDimensions.x
);
int2 ScreenProbeCoords = TileCoords * g_ProbeDownSampleFactor + GetTileJitter(g_ProbeDownSampleFactor);
float Depth = g_DepthTexture.Load(int3(ScreenProbeCoords, 0)).x;
// Compute the screen coords for current adaptive probe
int2 AdaptiveProbeScreenPosition = TileCoords * g_AdaptiveProbeDownsampleFactor + GetTileJitter(g_AdaptiveProbeDownsampleFactor);
float Depth = g_DepthTexture.Load(int3(AdaptiveProbeScreenPosition, 0)).x;
bool bValid = Depth < 1.f;
if(bValid) {
SSRC_Sample SampleData;
CalculateSSRCSampleWeights();
float3 WorldPosition = RecoverWorldPositionHiRes(AdaptiveProbeScreenPosition);
float LinearDepth = GetLinearDepth(Depth);
float3 GeometryNormal = normalize(g_GeometryNormalTexture.Load(int3(AdaptiveProbeScreenPosition, 0)).xyz * 2.f - 1.f);
SSRC_SampleData Sample;

CalculateSSRCSampleWeights(
AdaptiveProbeScreenPosition,
WorldPosition,
LinearDepth,
GeometryNormal,
Sample
);

float Epsilon = .01f;
Sample.Weights /= max(dot(Sample.Weights, 1), Epsilon);

float LightingIsValid = (dot(Sample.Weights, 1) < 1.0f - Epsilon) ? 0.0f : 1.0f;

if (!LightingIsValid)
{
int ListIndex;
InterlockedAdd(LocalNumProbesToAllocate, 1, ListIndex);
LocalProbeScreenPositionsToAllocate[ListIndex] = AdaptiveProbeScreenPosition;
}
}
}

GroupMemoryBarrierWithGroupSync();

if(WaveIsFirstLane()) {
InterlockedAdd(g_RWAdaptiveProbeCountBuffer[0], LocalNumProbesToAllocate, LocalAdaptiveProbeOffset);
}

GroupMemoryBarrierWithGroupSync();

int AdaptiveProbeIndex = LocalID + LocalAdaptiveProbeOffset;

if(LocalID < LocalNumProbesToAllocate && AdaptiveProbeIndex < MI.MaxAdaptiveProbeCount) {
// The probe must be valid upon allocation.
int ScreenProbeIndex1 = AdaptiveProbeIndex + MI.UniformScreenProbeCount;
ProbeHeader Header;
Header.ScreenPosition = LocalProbeScreenPositionsToAllocate[LocalID];
float Depth = g_DepthTexture.Load(int3(Header.ScreenPosition, 0)).x;
Header.Rank = ComputeProbeRankFromSplattedError(Header.ScreenPosition);
int BasisCount = GetProbeBasisCountFromRank(Header.Rank);
int BasisOffset = WavePrefixSum(BasisCount);
int BasisCountSum = WaveActiveSum(BasisCount);
int BasisGroupOffset;
if(WaveIsFirstLane()) {
InterlockedAdd(g_RWAllocatedProbeSGCountBuffer[0], BasisCountSum, BasisGroupOffset);
}
BasisGroupOffset = WaveReadLaneFirst(BasisGroupOffset);
Header.BasisOffset = BasisGroupOffset + BasisOffset;
Header.LinearDepth = GetLinearDepth(Depth);
Header.Position = RecoverWorldPositionHiRes(Header.ScreenPosition);
Header.Normal = normalize(g_GeometryNormalTexture.Load(int3(Header.ScreenPosition, 0)).xyz * 2.f - 1.f);
int2 ScreenProbeIndex = int2(ScreenProbeIndex1 % MI.TileDimensions.x, ScreenProbeIndex1 / MI.TileDimensions.x);
WriteScreenProbeHeader(ScreenProbeIndex, Header);
}
}

// Write dispatch parameters for SSRC computations
[numthreads(1, 1, 1)]
void SSRC_WriteProbeDispatchParameters () {
DispatchCommand Command;
Command.num_groups_x = MI.UniformScreenProbeCount + g_RWActiveProbeCountBuffer[0];
Command.num_groups_y = 1;
Command.num_groups_z = 1;
g_RWDispatchCommandBuffer[0] = Command;
}

float LocalSGSize[SSRC_MAX_NUM_BASIS_PER_PROBE * 4];
SGData LocalSGData[SSRC_MAX_NUM_BASIS_PER_PROBE * 4];
// Initialize probe cache from the previous frame, one group per probe
[numthreads(WAVE_SIZE, 1, 1)]
void SSRC_ReprojectHistory (int LocalID : SV_GroupThreadID, int GroupID : SV_GroupID) {
int2 ProbeIndex = int2(GroupID % MI.TileDimensions.x, GroupID / MI.TileDimensions.x);
ProbeHeader Header = GetScreenProbeHeader(ProbeIndex);
SSRC_SampleData Sample;
CalculateSSRCSampleWeights(
Header.ScreenPosition,
Header.Position,
Header.LinearDepth,
Header.Normal,
Sample,
true
);
int BasisOffsets[4];
BasisOffsets[0] = GetScreenProbeBasisOffset(Sample.Index[0], true);
BasisOffsets[1] = GetScreenProbeBasisOffset(Sample.Index[1], true);
BasisOffsets[2] = GetScreenProbeBasisOffset(Sample.Index[2], true);
BasisOffsets[3] = GetScreenProbeBasisOffset(Sample.Index[3], true);

int BasisCount[4];
BasisCount[0] = GetProbeBasisCountFromRank(GetScreenProbeHeader(Sample.Index[0], true).Rank);
BasisCount[1] = GetProbeBasisCountFromRank(GetScreenProbeHeader(Sample.Index[1], true).Rank);
BasisCount[2] = GetProbeBasisCountFromRank(GetScreenProbeHeader(Sample.Index[2], true).Rank);
BasisCount[3] = GetProbeBasisCountFromRank(GetScreenProbeHeader(Sample.Index[3], true).Rank);
int NumBasis0 = BasisCount[0] + BasisCount[1];
int NumBasis1 = BasisCount[2] + BasisCount[3];
int NumBasis = NumBasis0 + NumBasis1;

// Fetch the basis data from the previous frame
{
#if SSRC_MAX_BASIS_PER_TILE <= 8
int BasisRankBase = 0;
#else
for(int BasisRankBase = 0; BasisRankBase < NumBasis; BasisRankBase += WAVE_SIZE) {
#endif
int BasisRank = BasisRankBase + LocalID;
if(BasisRank < NumBasis) {
int BasisOffset;
float Weight;
if(BasisRank < NumBasis0) {
BasisOffset = BasisOffsets[BasisRank < BasisCount[0] ? 0 : 1] + BasisRank;
Weight = Sample.Weights[BasisRank < BasisCount[0] ? 0 : 1];
} else {
BasisOffset = BasisOffsets[BasisRank < NumBasis0 + BasisCount[2] ? 2 : 3] + BasisRank - NumBasis0;
Weight = Sample.Weights[BasisRank < NumBasis0 + BasisCount[2] ? 2 : 3];
}
SGData SG = FetchBasisData(BasisOffset + BasisRank);
LocalSGData[BasisRank] = SG;
LocalSGSize[BasisRank] = Weight * SGIntegrate(SG.Lambda) * dot(SG.Color, 1.f.xxx);
}
#if SSRC_MAX_BASIS_PER_TILE > 8
}
#endif
}
GroupMemoryBarrierWithGroupSync();
// Simple N^2 Sort
{
SGData ThreadSG[(SSRC_MAX_NUM_BASIS_PER_PROBE*4 + WAVE_SIZE - 1) / WAVE_SIZE];
int ThreadSGRank[(SSRC_MAX_NUM_BASIS_PER_PROBE*4 + WAVE_SIZE - 1) / WAVE_SIZE];
#if SSRC_MAX_BASIS_PER_TILE <= 8
int BasisRankBase = 0;
#else
for(int BasisRankBase = 0; BasisRankBase < NumBasis; BasisRankBase += WAVE_SIZE) {
#endif
int BasisRank = BasisRankBase + LocalID;
if(BasisRank < NumBasis) {
float Weight = LocalSGSize[BasisRank];
int SortedRank = 0;
for(int i = 0; i < NumBasis; i++) {
if(LocalSGSize[i] > Weight || (LocalSGSize[i] == Weight && i < BasisRank)) {
SortedRank ++;
}
}
ThreadSGRank[BasisRankBase / WAVE_SIZE] = SortedRank;
}
#if SSRC_MAX_BASIS_PER_TILE > 8
}
#endif
for(int i = 0; i * WAVE_SIZE < NumBasis; i++) {
int BasisRank = i * WAVE_SIZE + LocalID;
if(BasisRank < NumBasis) {
ThreadSG[i] = LocalSGData[BasisRank];
}
}
GroupMemoryBarrierWithGroupSync();
for(int i = 0; i * WAVE_SIZE < NumBasis; i++) {
int BasisRank = i * WAVE_SIZE + LocalID;
if(BasisRank < NumBasis) {
LocalSGData[ThreadSGRank[i]] = ThreadSG[i];
}
}
}
GroupMemoryBarrierWithGroupSync();

int CurrentProbeBasisCount = GetProbeBasisCountFromRank(Header.Rank);
// Progressively merges the candidate SGs.
{
// TODO use python to test the performance of different merging strategies
}
}

Expand Down
Loading

0 comments on commit cb2393a

Please sign in to comment.