From 46d8063e7e6daf9a980205c5fee20c689f7b2a06 Mon Sep 17 00:00:00 2001 From: Martin Date: Wed, 12 Jul 2023 11:30:33 +0200 Subject: [PATCH] Add caching for raytracing shaders and effects --- src/Libs/FShade.Core/RaytracingEffect.fs | 33 ++++++++++++------ src/Libs/FShade.Core/RaytracingShader.fs | 43 +++++++++++++++--------- src/Libs/FShade.Debug/Debugger.fs | 4 +-- src/Libs/FShade.Imperative/Ast.fs | 12 +++---- 4 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/Libs/FShade.Core/RaytracingEffect.fs b/src/Libs/FShade.Core/RaytracingEffect.fs index 68b8d71..7f02947 100644 --- a/src/Libs/FShade.Core/RaytracingEffect.fs +++ b/src/Libs/FShade.Core/RaytracingEffect.fs @@ -97,12 +97,13 @@ module private RaytracingUtilities = RebuildShapeCombination(o, args) let prepareShader (sbt : ShaderBindingTableLayout) (slot : ShaderSlot) (shader : RaytracingShader) = - let prepared = + let preparedBody = shader.Body |> substituteStubs sbt |> setShaderSlotForReads slot - shader.Body <- prepared + let prepared = { shader.Shader with shaderBody = preparedBody } + RaytracingShader(shader.Id, prepared, ?definition = shader.SourceDefinition) let computeHash (shaders : Map) = use hash = SHA1.Create() @@ -135,12 +136,10 @@ module private RaytracingUtilities = hash.Hash |> Convert.ToBase64String -type RaytracingEffect(shaders : Map) = +type RaytracingEffect internal (id : string, shaders : Map) = do for KeyValue(slot, shader) in shaders do if shader.Stage <> slot.Stage then raise <| ArgumentException($"Invalid {slot.Stage} shader in slot {slot}.") - let id = computeHash shaders - let shaderBindingTableLayout = lazy ( let shaders = shaders |> Map.values |> Array.ofSeq @@ -154,14 +153,13 @@ type RaytracingEffect(shaders : Map) = |> Array.fold Map.union Map.empty ) - let shaders = + let preparedShaders = lazy ( let sbt = shaderBindingTableLayout.Value - for KeyValue(slot, shader) in shaders do + shaders |> Map.map (fun slot shader -> shader |> prepareShader sbt slot - - shaders + ) ) member x.Id = id @@ -170,13 +168,26 @@ type RaytracingEffect(shaders : Map) = shaderBindingTableLayout.Value member x.Shaders = - shaders.Value + preparedShaders.Value + + /// Returns the individual unprepared shaders, i.e. the shader binding layout of the effect has not been applied. + member x.ShadersWithStubs = + shaders member x.Uniforms = uniforms.Value [] module RaytracingEffect = + open System.Collections.Concurrent + + let private cache = ConcurrentDictionary() + + let ofShaders (shaders : Map) = + let hash = computeHash shaders + cache.GetOrAdd(hash, fun hash -> + RaytracingEffect(hash, shaders) + ) let toModule (effect : RaytracingEffect) = Serializer.Init() @@ -409,7 +420,7 @@ module RaytracingBuilders = member inline x.Delay f = f() member x.Run(shaders : Map) = - RaytracingEffect(shaders) + RaytracingEffect.ofShaders shaders member private x.SetRaygen(shader : RaytracingShader) = diff --git a/src/Libs/FShade.Core/RaytracingShader.fs b/src/Libs/FShade.Core/RaytracingShader.fs index bfc4496..7a49220 100644 --- a/src/Libs/FShade.Core/RaytracingShader.fs +++ b/src/Libs/FShade.Core/RaytracingShader.fs @@ -4,20 +4,15 @@ open System open FSharp.Quotations open FShade.Imperative -type RaytracingShader internal (shader : Shader, ?definition : SourceDefinition) = - static do Serializer.Init() +type RaytracingShader internal (id : string, shader : Shader, ?definition : SourceDefinition) = do assert (ShaderStage.isRaytracing shader.shaderStage) - let id = Expr.ComputeHash shader.shaderBody - let mutable shader = shader - member x.Id = id member x.Shader = shader member x.SourceDefinition = definition - member x.Body - with get() = shader.shaderBody - and internal set(e) = shader <- { shader with shaderBody = e } + member inline x.Body = + x.Shader.shaderBody member inline x.Stage = x.Shader.shaderStage @@ -36,21 +31,39 @@ type RaytracingShader internal (shader : Shader, ?definition : SourceDefinition) [] module RaytracingShader = + open System.Collections.Concurrent + + let private cache = ConcurrentDictionary() let ofShader (shader : Shader) = - RaytracingShader shader + Serializer.Init() + let hash = Expr.ComputeHash shader.shaderBody + + cache.GetOrAdd(hash, fun hash -> + RaytracingShader(hash, shader) + ) let ofExpr (inputTypes : Type list) (expr : Expr) = - let shaders = expr |> Shader.ofExpr inputTypes - let definition = expr |> SourceDefinition.ofExpr inputTypes - RaytracingShader(shaders.Head, definition) + Serializer.Init() + let hash = Expr.ComputeHash expr + + cache.GetOrAdd(hash, fun hash -> + let shaders = expr |> Shader.ofExpr inputTypes + let definition = expr |> SourceDefinition.ofExpr inputTypes + RaytracingShader(hash, shaders.Head, definition) + ) let ofFunction (shaderFunction : 'a -> Expr<'b>) = match Shader.Utils.tryExtractExpr shaderFunction with | Some (expr, types) -> - let shader = expr |> Shader.ofExpr types |> List.head - let definition = expr |> SourceDefinition.create types shaderFunction - RaytracingShader(shader, definition) + Serializer.Init() + let hash = Expr.ComputeHash expr + + cache.GetOrAdd(hash, fun hash -> + let shader = expr |> Shader.ofExpr types |> List.head + let definition = expr |> SourceDefinition.create types shaderFunction + RaytracingShader(hash, shader, definition) + ) | _ -> failwithf "[FShade] cannot create raytracing shader using function: %A" shaderFunction diff --git a/src/Libs/FShade.Debug/Debugger.fs b/src/Libs/FShade.Debug/Debugger.fs index 4f22e97..802fbe1 100644 --- a/src/Libs/FShade.Debug/Debugger.fs +++ b/src/Libs/FShade.Debug/Debugger.fs @@ -310,10 +310,10 @@ module ShaderDebugger = "raytracing effect" (fun effect -> effect.Id) (fun shader -> shader.Id) - (fun effect -> effect.Shaders) + (fun effect -> effect.ShadersWithStubs) ShaderDefinition.ofRaytracingShader ShaderDefinition.tryResolveRaytracingShader - (fun shaders -> RaytracingEffect shaders) + RaytracingEffect.ofShaders let private tryRegisterComputeShader : ComputeShader -> aval option = tryRegister diff --git a/src/Libs/FShade.Imperative/Ast.fs b/src/Libs/FShade.Imperative/Ast.fs index e53266e..fe73bb6 100644 --- a/src/Libs/FShade.Imperative/Ast.fs +++ b/src/Libs/FShade.Imperative/Ast.fs @@ -55,7 +55,7 @@ type CType = | CArray of elementType : CType * length : int | CPointer of modifier : CPointerModifier * elementType : CType - | CStruct of name : string * fields : list * original : Option + | CStruct of name : string * fields : list * original : Option // TODO: Remove original, breaks equality when using the shader debugger | CIntrinsic of CIntrinsicType [] @@ -125,11 +125,11 @@ module CType = let name = typeName t if FSharpType.IsRecord(t, true) then let fields = FSharpType.GetRecordFields(t, true) |> Array.toList |> List.map (fun pi -> ofTypeInternal seen b pi.PropertyType, pi.Name) - CStruct(name, fields, Some t) + CStruct(name, fields, None) elif FSharpType.IsTuple t then let fields = FSharpType.GetTupleElements(t) |> Array.toList |> List.mapi (fun i t -> ofTypeInternal seen b t, sprintf "Item%d" i) - CStruct(name, fields, Some t) + CStruct(name, fields, None) elif FSharpType.IsUnion(t, true) then let caseFields = @@ -141,17 +141,17 @@ module CType = ) let tagField = (CType.CInt(true, 32), "tag") - CStruct(name, tagField :: caseFields, Some t) + CStruct(name, tagField :: caseFields, None) elif t.IsValueType then let fields = t.GetFields(BindingFlags.NonPublic ||| BindingFlags.Public ||| BindingFlags.Instance) let fields = fields |> Array.sortBy (fun f -> System.Runtime.InteropServices.Marshal.OffsetOf(f.DeclaringType, f.Name)) |> Array.toList |> List.map (fun fi -> ofTypeInternal seen b fi.FieldType, fi.Name) - CStruct(name, fields, Some t) + CStruct(name, fields, None) else let fields = t.GetFields(BindingFlags.NonPublic ||| BindingFlags.Public ||| BindingFlags.Instance) let fields = fields |> Array.toList |> List.map (fun fi -> ofTypeInternal seen b fi.FieldType, fi.Name) - CStruct(name, fields, Some t) + CStruct(name, fields, None) /// creates a c representation for a given system type let ofType (b : IBackend) (t : Type) : CType =