Skip to content

Commit

Permalink
Add caching for raytracing shaders and effects
Browse files Browse the repository at this point in the history
  • Loading branch information
hyazinthh committed Jul 13, 2023
1 parent c2d3aa8 commit 46d8063
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 34 deletions.
33 changes: 22 additions & 11 deletions src/Libs/FShade.Core/RaytracingEffect.fs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ module private RaytracingUtilities =
RebuildShapeCombination(o, args)

let prepareShader (sbt : ShaderBindingTableLayout) (slot : ShaderSlot) (shader : RaytracingShader) =
let prepared =
let preparedBody =
shader.Body
|> substituteStubs sbt
|> setShaderSlotForReads slot

shader.Body <- prepared
let prepared = { shader.Shader with shaderBody = preparedBody }
RaytracingShader(shader.Id, prepared, ?definition = shader.SourceDefinition)

let computeHash (shaders : Map<ShaderSlot, RaytracingShader>) =
use hash = SHA1.Create()
Expand Down Expand Up @@ -135,12 +136,10 @@ module private RaytracingUtilities =
hash.Hash |> Convert.ToBase64String


type RaytracingEffect(shaders : Map<ShaderSlot, RaytracingShader>) =
type RaytracingEffect internal (id : string, shaders : Map<ShaderSlot, RaytracingShader>) =
do for KeyValue(slot, shader) in shaders do
if shader.Stage <> slot.Stage then raise <| ArgumentException($"Invalid {slot.Stage} shader in slot {slot}.")

let id = computeHash shaders

let shaderBindingTableLayout =
lazy (
let shaders = shaders |> Map.values |> Array.ofSeq
Expand All @@ -154,14 +153,13 @@ type RaytracingEffect(shaders : Map<ShaderSlot, RaytracingShader>) =
|> Array.fold Map.union Map.empty
)

let shaders =
let preparedShaders =
lazy (
let sbt = shaderBindingTableLayout.Value

for KeyValue(slot, shader) in shaders do
shaders |> Map.map (fun slot shader ->
shader |> prepareShader sbt slot

shaders
)
)

member x.Id = id
Expand All @@ -170,13 +168,26 @@ type RaytracingEffect(shaders : Map<ShaderSlot, RaytracingShader>) =
shaderBindingTableLayout.Value

member x.Shaders =
shaders.Value
preparedShaders.Value

/// Returns the individual unprepared shaders, i.e. the shader binding layout of the effect has not been applied.
member x.ShadersWithStubs =
shaders

member x.Uniforms =
uniforms.Value

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
module RaytracingEffect =
open System.Collections.Concurrent

let private cache = ConcurrentDictionary<string, RaytracingEffect>()

let ofShaders (shaders : Map<ShaderSlot, RaytracingShader>) =
let hash = computeHash shaders
cache.GetOrAdd(hash, fun hash ->
RaytracingEffect(hash, shaders)
)

let toModule (effect : RaytracingEffect) =
Serializer.Init()
Expand Down Expand Up @@ -409,7 +420,7 @@ module RaytracingBuilders =
member inline x.Delay f = f()

member x.Run(shaders : Map<ShaderSlot, RaytracingShader>) =
RaytracingEffect(shaders)
RaytracingEffect.ofShaders shaders


member private x.SetRaygen(shader : RaytracingShader) =
Expand Down
43 changes: 28 additions & 15 deletions src/Libs/FShade.Core/RaytracingShader.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@ open System
open FSharp.Quotations
open FShade.Imperative

type RaytracingShader internal (shader : Shader, ?definition : SourceDefinition) =
static do Serializer.Init()
type RaytracingShader internal (id : string, shader : Shader, ?definition : SourceDefinition) =
do assert (ShaderStage.isRaytracing shader.shaderStage)

let id = Expr.ComputeHash shader.shaderBody
let mutable shader = shader

member x.Id = id
member x.Shader = shader
member x.SourceDefinition = definition

member x.Body
with get() = shader.shaderBody
and internal set(e) = shader <- { shader with shaderBody = e }
member inline x.Body =
x.Shader.shaderBody

member inline x.Stage =
x.Shader.shaderStage
Expand All @@ -36,21 +31,39 @@ type RaytracingShader internal (shader : Shader, ?definition : SourceDefinition)

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
module RaytracingShader =
open System.Collections.Concurrent

let private cache = ConcurrentDictionary<string, RaytracingShader>()

let ofShader (shader : Shader) =
RaytracingShader shader
Serializer.Init()
let hash = Expr.ComputeHash shader.shaderBody

cache.GetOrAdd(hash, fun hash ->
RaytracingShader(hash, shader)
)

let ofExpr (inputTypes : Type list) (expr : Expr) =
let shaders = expr |> Shader.ofExpr inputTypes
let definition = expr |> SourceDefinition.ofExpr inputTypes
RaytracingShader(shaders.Head, definition)
Serializer.Init()
let hash = Expr.ComputeHash expr

cache.GetOrAdd(hash, fun hash ->
let shaders = expr |> Shader.ofExpr inputTypes
let definition = expr |> SourceDefinition.ofExpr inputTypes
RaytracingShader(hash, shaders.Head, definition)
)

let ofFunction (shaderFunction : 'a -> Expr<'b>) =
match Shader.Utils.tryExtractExpr shaderFunction with
| Some (expr, types) ->
let shader = expr |> Shader.ofExpr types |> List.head
let definition = expr |> SourceDefinition.create types shaderFunction
RaytracingShader(shader, definition)
Serializer.Init()
let hash = Expr.ComputeHash expr

cache.GetOrAdd(hash, fun hash ->
let shader = expr |> Shader.ofExpr types |> List.head
let definition = expr |> SourceDefinition.create types shaderFunction
RaytracingShader(hash, shader, definition)
)
| _ ->
failwithf "[FShade] cannot create raytracing shader using function: %A" shaderFunction

Expand Down
4 changes: 2 additions & 2 deletions src/Libs/FShade.Debug/Debugger.fs
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ module ShaderDebugger =
"raytracing effect"
(fun effect -> effect.Id)
(fun shader -> shader.Id)
(fun effect -> effect.Shaders)
(fun effect -> effect.ShadersWithStubs)
ShaderDefinition.ofRaytracingShader
ShaderDefinition.tryResolveRaytracingShader
(fun shaders -> RaytracingEffect shaders)
RaytracingEffect.ofShaders

let private tryRegisterComputeShader : ComputeShader -> aval<ComputeShader> option =
tryRegister<ComputeShader, ComputeShader, int>
Expand Down
12 changes: 6 additions & 6 deletions src/Libs/FShade.Imperative/Ast.fs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type CType =

| CArray of elementType : CType * length : int
| CPointer of modifier : CPointerModifier * elementType : CType
| CStruct of name : string * fields : list<CType * string> * original : Option<Type>
| CStruct of name : string * fields : list<CType * string> * original : Option<Type> // TODO: Remove original, breaks equality when using the shader debugger
| CIntrinsic of CIntrinsicType

[<AllowNullLiteral>]
Expand Down Expand Up @@ -125,11 +125,11 @@ module CType =
let name = typeName t
if FSharpType.IsRecord(t, true) then
let fields = FSharpType.GetRecordFields(t, true) |> Array.toList |> List.map (fun pi -> ofTypeInternal seen b pi.PropertyType, pi.Name)
CStruct(name, fields, Some t)
CStruct(name, fields, None)

elif FSharpType.IsTuple t then
let fields = FSharpType.GetTupleElements(t) |> Array.toList |> List.mapi (fun i t -> ofTypeInternal seen b t, sprintf "Item%d" i)
CStruct(name, fields, Some t)
CStruct(name, fields, None)

elif FSharpType.IsUnion(t, true) then
let caseFields =
Expand All @@ -141,17 +141,17 @@ module CType =
)

let tagField = (CType.CInt(true, 32), "tag")
CStruct(name, tagField :: caseFields, Some t)
CStruct(name, tagField :: caseFields, None)

elif t.IsValueType then
let fields = t.GetFields(BindingFlags.NonPublic ||| BindingFlags.Public ||| BindingFlags.Instance)
let fields = fields |> Array.sortBy (fun f -> System.Runtime.InteropServices.Marshal.OffsetOf(f.DeclaringType, f.Name)) |> Array.toList |> List.map (fun fi -> ofTypeInternal seen b fi.FieldType, fi.Name)
CStruct(name, fields, Some t)
CStruct(name, fields, None)

else
let fields = t.GetFields(BindingFlags.NonPublic ||| BindingFlags.Public ||| BindingFlags.Instance)
let fields = fields |> Array.toList |> List.map (fun fi -> ofTypeInternal seen b fi.FieldType, fi.Name)
CStruct(name, fields, Some t)
CStruct(name, fields, None)

/// creates a c representation for a given system type
let ofType (b : IBackend) (t : Type) : CType =
Expand Down

0 comments on commit 46d8063

Please sign in to comment.