From 1cc65e8218375294a989722103f94feef2b54ff3 Mon Sep 17 00:00:00 2001 From: Misaki Date: Fri, 8 May 2026 19:36:24 +0900 Subject: [PATCH] refactor(shader): rewrite editor shader compilation bridge with keyword resolution - Add AssetCatalog.EnumerateByTypes for filtered SQL queries - Thread LocalKeywordSet through IShaderCompilationBridge API so the bridge can resolve keyword bitmask to string defines at compile time - Eager/lazy popluation of shader-id-to-asset-id map eliminating full catalog scan per compilation miss - Build keyword mapping from PassDescriptor groups to reconstruct localIndex -> keywordName in the bridge - Use CompileShaderPass extension for multi-stage AS/MS/PS compilation with correct ShaderModel from descriptor - Remove double-load of shader asset in compilation flow - Update test mock to match new interface signature --- .../Services/AssetCatalog.cs | 30 +- .../Services/EditorShaderCompilerBridge.cs | 377 ++++++++++++------ src/Editor/Ghost.Editor/App.xaml.cs | 2 + src/Runtime/Ghost.Graphics.RHI/Common.cs | 2 +- .../IShaderCompilationBridge.cs | 4 +- .../Ghost.Graphics/Core/RenderContext.cs | 2 +- src/Runtime/Ghost.Graphics/Core/Shader.cs | 2 +- .../RenderGraphModule/RenderGraphContext.cs | 6 +- .../Ghost.Graphics/Services/ShaderLibrary.cs | 4 +- .../Graphics/ShaderLibraryTest.cs | 6 +- 10 files changed, 306 insertions(+), 129 deletions(-) diff --git a/src/Editor/Ghost.Editor.Core/Services/AssetCatalog.cs b/src/Editor/Ghost.Editor.Core/Services/AssetCatalog.cs index d2612bc..976aa42 100644 --- a/src/Editor/Ghost.Editor.Core/Services/AssetCatalog.cs +++ b/src/Editor/Ghost.Editor.Core/Services/AssetCatalog.cs @@ -89,6 +89,7 @@ public sealed partial class AssetCatalog ); CREATE UNIQUE INDEX IF NOT EXISTS idx_assets_path ON assets(source_path); CREATE INDEX IF NOT EXISTS idx_assets_parent ON assets(parent_guid); + CREATE INDEX IF NOT EXISTS idx_assets_type_id ON assets(asset_type_id); CREATE TABLE IF NOT EXISTS dependencies ( from_guid BLOB(16) NOT NULL REFERENCES assets(guid) ON DELETE CASCADE, @@ -272,7 +273,7 @@ public sealed partial class AssetCatalog { using var connection = OpenConnection(); using var cmd = connection.CreateCommand(); - + cmd.CommandText = SqlEnumerate; using var reader = cmd.ExecuteReader(); while (reader.Read()) @@ -281,6 +282,33 @@ public sealed partial class AssetCatalog } } + public IEnumerable EnumerateByTypes(params Guid[] assetTypeIds) + { + if (assetTypeIds.Length == 0) + { + yield break; + } + + using var connection = OpenConnection(); + using var cmd = connection.CreateCommand(); + + var parameterNames = new List(assetTypeIds.Length); + for (int i = 0; i < assetTypeIds.Length; i++) + { + string paramName = $"@typeId{i}"; + parameterNames.Add(paramName); + cmd.Parameters.AddWithValue(paramName, assetTypeIds[i].ToByteArray()); + } + + cmd.CommandText = $"SELECT guid FROM assets WHERE asset_type_id IN ({string.Join(", ", parameterNames)})"; + + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + { + yield return new Guid((byte[])reader[0]); + } + } + public List GetSubAssets(Guid parentGuid) { using var connection = OpenConnection(); diff --git a/src/Editor/Ghost.Editor.Core/Services/EditorShaderCompilerBridge.cs b/src/Editor/Ghost.Editor.Core/Services/EditorShaderCompilerBridge.cs index ef26416..d25427c 100644 --- a/src/Editor/Ghost.Editor.Core/Services/EditorShaderCompilerBridge.cs +++ b/src/Editor/Ghost.Editor.Core/Services/EditorShaderCompilerBridge.cs @@ -3,24 +3,26 @@ using Ghost.Core.Graphics; using Ghost.Editor.Core.Assets; using Ghost.Editor.Core.Contracts; using Ghost.Editor.Core.Utilities; +using Ghost.Engine; using Ghost.Graphics.Core; using Ghost.Graphics.RHI; using Ghost.Graphics.Services; -using Ghost.Engine; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; +using Misaki.HighPerformance.LowLevel.Buffer; using System.Collections.Concurrent; -using Misaki.HighPerformance.LowLevel.Collections; +using System.Runtime.CompilerServices; namespace Ghost.Editor.Core.Services; -[EditorInjection(EditorInjectionAttribute.ServiceLifetime.Singleton, typeof(IShaderCompilationBridge))] -internal sealed class EditorShaderCompilerBridge : IShaderCompilationBridge, IDisposable +internal sealed class EditorShaderCompilerBridge : IShaderCompilationBridge { private readonly IAssetRegistry _assetRegistry; - private readonly IShaderCompiler _compiler; - private readonly ConcurrentDictionary _shaderIdToAssetId = new(); private readonly IServiceProvider _serviceProvider; + private readonly IShaderCompiler _compiler; + + private readonly ConcurrentDictionary _shaderIdToAssetId = new(); + private readonly ConcurrentDictionary[]> _assetKeywordMappings = new(); + private Task? _shaderDictionaryPopulated; public event Action, ulong>? OnShaderVariantCompiled; @@ -29,7 +31,7 @@ internal sealed class EditorShaderCompilerBridge : IShaderCompilationBridge, IDi _assetRegistry = assetRegistry; _serviceProvider = serviceProvider; _compiler = new DXCShaderCompiler(); - + _assetRegistry.OnAssetImported += OnAssetImported; } @@ -41,20 +43,12 @@ internal sealed class EditorShaderCompilerBridge : IShaderCompilationBridge, IDi var result = _assetRegistry.LoadAssetAsync(guid).AsTask().Result; if (result.IsSuccess) { - ulong nameHash = 0; - if (result.Value is GraphicsShaderAsset graphicsAsset) - { - nameHash = RHIUtility.GetShaderID(graphicsAsset.Descriptor.Name); - } - else if (result.Value is ComputeShaderAsset computeAsset) - { - nameHash = RHIUtility.GetShaderID(computeAsset.Descriptor.Name); - } - + var nameHash = ExtractNameHash(result.Value); if (nameHash != 0) { _shaderIdToAssetId[nameHash] = guid; - + BuildKeywordMappings(result.Value, guid); + var engineCore = _serviceProvider.GetService(); if (engineCore != null) { @@ -67,143 +61,296 @@ internal sealed class EditorShaderCompilerBridge : IShaderCompilationBridge, IDi } } - public void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ulong ExtractNameHash(IAsset asset) { + if (asset is GraphicsShaderAsset graphicsAsset) + { + return RHIUtility.GetShaderID(graphicsAsset.Descriptor.Name); + } + + if (asset is ComputeShaderAsset computeAsset) + { + return RHIUtility.GetShaderID(computeAsset.Descriptor.Name); + } + + return 0; + } + + private Task EnsureShaderDictionaryPopulatedAsync() + { + var existing = Volatile.Read(ref _shaderDictionaryPopulated); + if (existing != null) + { + return existing; + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var original = Interlocked.CompareExchange(ref _shaderDictionaryPopulated, tcs.Task, null); + if (original != null) + { + return original; + } + Task.Run(async () => { - if (!_shaderIdToAssetId.TryGetValue(shaderId, out var guid)) + try { var catalog = _assetRegistry.GetAssetCatalog(); - foreach (var (assetGuid, path) in catalog.EnumerateAll()) + var assetGuids = catalog.EnumerateByTypes(typeof(GraphicsShaderAsset).GUID, typeof(ComputeShaderAsset).GUID); + + foreach (var assetGuid in assetGuids) { - if (path.EndsWith(".gshdr") || path.EndsWith(".gcomp")) + var result = await _assetRegistry.LoadAssetAsync(assetGuid); + if (result.IsSuccess) { - var result = await _assetRegistry.LoadAssetAsync(assetGuid); - if (result.IsSuccess) + var nameHash = ExtractNameHash(result.Value); + if (nameHash != 0) { - ulong nameHash = 0; - if (result.Value is GraphicsShaderAsset graphicsAsset) - { - nameHash = RHIUtility.GetShaderID(graphicsAsset.Descriptor.Name); - } - else if (result.Value is ComputeShaderAsset computeAsset) - { - nameHash = RHIUtility.GetShaderID(computeAsset.Descriptor.Name); - } - if (nameHash != 0) - { - _shaderIdToAssetId[nameHash] = assetGuid; - } + _shaderIdToAssetId[nameHash] = assetGuid; + BuildKeywordMappings(result.Value, assetGuid); } } } + + tcs.SetResult(); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + }); + + return tcs.Task; + } + + private void BuildKeywordMappings(IAsset asset, Guid assetId) + { + if (asset is GraphicsShaderAsset graphicsAsset) + { + var passes = graphicsAsset.Descriptor.Passes; + var mappings = new Dictionary[passes.Length]; + for (var i = 0; i < passes.Length; i++) + { + mappings[i] = BuildKeywordMappingFromGroups(passes[i].keywords); } - if (_shaderIdToAssetId.TryGetValue(shaderId, out var assetId)) + _assetKeywordMappings[assetId] = mappings; + } + else if (asset is ComputeShaderAsset computeAsset) + { + var entryCount = computeAsset.Descriptor.ShaderCodes.Length; + var mappings = new Dictionary[entryCount]; + var sharedMapping = BuildKeywordMappingFromGroups(computeAsset.Descriptor.Keywords); + for (var i = 0; i < entryCount; i++) { - var assetResult = await _assetRegistry.LoadAssetAsync(assetId); - if (assetResult.IsSuccess) - { - if (assetResult.Value is GraphicsShaderAsset graphicsAsset) - { - var pass = graphicsAsset.Descriptor.Passes[passIndex]; - await CompileGraphicsPassAsync(shaderId, passIndex, variantKey, pass); - } - else if (assetResult.Value is ComputeShaderAsset computeAsset) - { - var code = computeAsset.Descriptor.ShaderCodes[passIndex]; - await CompileComputePassAsync(shaderId, passIndex, variantKey, code); - } - } + mappings[i] = sharedMapping; + } + + _assetKeywordMappings[assetId] = mappings; + } + } + + private static Dictionary BuildKeywordMappingFromGroups(KeywordsGroup[] groups) + { + var mapping = new Dictionary(); + var localIndex = 0; + + foreach (var group in groups) + { + if (group.keywords == null) + { + continue; + } + + if (group.space != KeywordSpace.Local) + { + continue; + } + + foreach (var kw in group.keywords) + { + mapping[localIndex++] = kw; + } + } + + return mapping; + } + + private static string[] BuildVariantDefines(LocalKeywordSet keywordMask, Dictionary? keywordMapping) + { + if (keywordMapping == null || keywordMapping.Count == 0) + { + return Array.Empty(); + } + + var defines = new List(keywordMapping.Count); + foreach (var (localIndex, keywordName) in keywordMapping) + { + if (keywordMask.IsKeywordEnabled(localIndex)) + { + defines.Add(keywordName); + } + } + + return defines.ToArray(); + } + + private static ReadOnlySpan CombineDefines(ReadOnlySpan staticDefines, ReadOnlySpan variantDefines) + { + if (variantDefines.Length == 0) + { + return staticDefines; + } + + if (staticDefines.Length == 0) + { + return variantDefines; + } + + var combined = new string[staticDefines.Length + variantDefines.Length]; + staticDefines.CopyTo(combined); + variantDefines.CopyTo(combined.AsSpan(staticDefines.Length)); + return combined; + } + + public void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask) + { + Task.Run(async () => + { + await EnsureShaderDictionaryPopulatedAsync(); + + if (!_shaderIdToAssetId.TryGetValue(shaderId, out var assetId)) + { + return; + } + + var assetResult = await _assetRegistry.LoadAssetAsync(assetId); + if (assetResult.IsFailure) + { + return; + } + + Dictionary? keywordMapping = null; + if (_assetKeywordMappings.TryGetValue(assetId, out var mappings) && passIndex < mappings.Length) + { + keywordMapping = mappings[passIndex]; + } + + if (assetResult.Value is GraphicsShaderAsset graphicsAsset) + { + var pass = graphicsAsset.Descriptor.Passes[passIndex]; + await CompileGraphicsPassAsync(shaderId, passIndex, variantKey, keywordMask, pass, graphicsAsset.Descriptor.ShaderModel, keywordMapping); + } + else if (assetResult.Value is ComputeShaderAsset computeAsset) + { + await CompileComputePassAsync(shaderId, passIndex, variantKey, keywordMask, computeAsset.Descriptor, passIndex, keywordMapping); } }); } - private unsafe Task CompileGraphicsPassAsync(ulong shaderId, int passIndex, Key64 variantKey, PassDescriptor pass) + private unsafe Task CompileGraphicsPassAsync(ulong shaderId, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask, PassDescriptor descriptor, ShaderModel shaderModel, Dictionary? keywordMapping) { - // For simplicity, just compile the pixel shader. A real implementation would compile - // all stages (Mesh/Amp/Vertex/Pixel) defined in the pass descriptor. - var config = new ShaderCompilationConfig + var variantDefines = BuildVariantDefines(keywordMask, keywordMapping); + + var additionalConfig = new ShaderCompilationConfig { - shaderCode = pass.pixelShaderCode.code, - entryPoint = pass.pixelShaderCode.entryPoint, - stage = ShaderStage.PixelShader, - defines = pass.defines, - model = ShaderModel.SM_6_6 + defines = variantDefines, + model = shaderModel, + optimizeLevel = CompilerOptimizeLevel.O3, + options = CompilerOption.None }; - var compileResult = _compiler.Compile(in config, Misaki.HighPerformance.LowLevel.Buffer.AllocationHandle.Persistent); - if (compileResult.IsSuccess) - { - var engineCore = _serviceProvider.GetService(); - if (engineCore != null) - { - using var bytecodeArray = compileResult.Value; - - var byteCode = new ShaderByteCode - { - pCode = (byte*)bytecodeArray.GetUnsafePtr(), - size = (ulong)bytecodeArray.Length - }; - - // Assume 1 stage for now. In reality, we'd pass an array of ShaderByteCode for all stages. - var byteCodes = new Span(ref byteCode); - - engineCore.RenderSystem.ShaderLibrary.CacheCompiledResult(shaderId, passIndex, variantKey, byteCodes); - - // Get the generated hash to fire the event - var dataSpan = new ReadOnlySpan(byteCode.pCode, (int)byteCode.size); - var hash = System.IO.Hashing.XxHash64.HashToUInt64(dataSpan); - OnShaderVariantCompiled?.Invoke(variantKey, hash); - } - } - else + var compileResult = _compiler.CompileShaderPass(ref descriptor, ref additionalConfig, AllocationHandle.Persistent); + if (compileResult.IsFailure) { Ghost.Core.Logger.Error($"Failed to compile graphics shader {shaderId}: {compileResult.Message}"); + return Task.CompletedTask; } + var engineCore = _serviceProvider.GetService(); + if (engineCore == null) + { + return Task.CompletedTask; + } + + using var compiled = compileResult.Value; + + var stageCount = 0; + if (compiled.asResult.IsCreated) stageCount++; + if (compiled.msResult.IsCreated) stageCount++; + if (compiled.psResult.IsCreated) stageCount++; + + var byteCodes = stackalloc ShaderByteCode[stageCount]; + var idx = 0; + if (compiled.asResult.IsCreated) + { + byteCodes[idx++] = new ShaderByteCode { pCode = (byte*)compiled.asResult.GetUnsafePtr(), size = (ulong)compiled.asResult.Length }; + } + + if (compiled.msResult.IsCreated) + { + byteCodes[idx++] = new ShaderByteCode { pCode = (byte*)compiled.msResult.GetUnsafePtr(), size = (ulong)compiled.msResult.Length }; + } + + if (compiled.psResult.IsCreated) + { + byteCodes[idx++] = new ShaderByteCode { pCode = (byte*)compiled.psResult.GetUnsafePtr(), size = (ulong)compiled.psResult.Length }; + } + + var shaderLibrary = engineCore.RenderSystem.ShaderLibrary; + shaderLibrary.CacheCompiledResult(shaderId, passIndex, variantKey, new ReadOnlySpan(byteCodes, stageCount)); + + var (compiledHash, _) = shaderLibrary.GetCompiledHash(shaderId, passIndex, variantKey); + OnShaderVariantCompiled?.Invoke(variantKey, compiledHash); + return Task.CompletedTask; } - private unsafe Task CompileComputePassAsync(ulong shaderId, int passIndex, Key64 variantKey, ShaderCode code) + private unsafe Task CompileComputePassAsync(ulong shaderId, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask, ComputeShaderDescriptor descriptor, int entryIndex, Dictionary? keywordMapping) { + var variantDefines = BuildVariantDefines(keywordMask, keywordMapping); + var fullDefines = CombineDefines(descriptor.Defines, variantDefines); + + var code = descriptor.ShaderCodes[entryIndex]; var config = new ShaderCompilationConfig { shaderCode = code.code, entryPoint = code.entryPoint, stage = ShaderStage.ComputeShader, - defines = Array.Empty(), - model = ShaderModel.SM_6_6 + defines = fullDefines, + model = descriptor.ShaderModel, + optimizeLevel = CompilerOptimizeLevel.O3, + options = CompilerOption.None }; - var compileResult = _compiler.Compile(in config, Misaki.HighPerformance.LowLevel.Buffer.AllocationHandle.Persistent); - if (compileResult.IsSuccess) - { - var engineCore = _serviceProvider.GetService(); - if (engineCore != null) - { - using var bytecodeArray = compileResult.Value; - - var byteCode = new ShaderByteCode - { - pCode = (byte*)bytecodeArray.GetUnsafePtr(), - size = (ulong)bytecodeArray.Length - }; - - var byteCodes = new Span(ref byteCode); - - engineCore.RenderSystem.ShaderLibrary.CacheCompiledResult(shaderId, passIndex, variantKey, byteCodes); - - var dataSpan = new ReadOnlySpan(byteCode.pCode, (int)byteCode.size); - var hash = System.IO.Hashing.XxHash64.HashToUInt64(dataSpan); - OnShaderVariantCompiled?.Invoke(variantKey, hash); - } - } - else + var compileResult = _compiler.Compile(ref config, AllocationHandle.Persistent); + if (compileResult.IsFailure) { Ghost.Core.Logger.Error($"Failed to compile compute shader {shaderId}: {compileResult.Message}"); + return Task.CompletedTask; } + var engineCore = _serviceProvider.GetService(); + if (engineCore == null) + { + return Task.CompletedTask; + } + + using var bytecodeArray = compileResult.Value; + + var byteCode = new ShaderByteCode + { + pCode = (byte*)bytecodeArray.GetUnsafePtr(), + size = (ulong)bytecodeArray.Length + }; + + var shaderLibrary = engineCore.RenderSystem.ShaderLibrary; + shaderLibrary.CacheCompiledResult(shaderId, passIndex, variantKey, new ReadOnlySpan(ref byteCode)); + + var (compiledHash, _) = shaderLibrary.GetCompiledHash(shaderId, passIndex, variantKey); + OnShaderVariantCompiled?.Invoke(variantKey, compiledHash); + return Task.CompletedTask; } diff --git a/src/Editor/Ghost.Editor/App.xaml.cs b/src/Editor/Ghost.Editor/App.xaml.cs index faaf93f..88e9cc4 100644 --- a/src/Editor/Ghost.Editor/App.xaml.cs +++ b/src/Editor/Ghost.Editor/App.xaml.cs @@ -7,6 +7,7 @@ using Ghost.Editor.ViewModels.Controls; using Ghost.Editor.ViewModels.Windows; using Ghost.Editor.Views.Windows; using Ghost.Engine; +using Ghost.Graphics.RHI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.UI.Dispatching; @@ -66,6 +67,7 @@ public partial class App : Application services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); diff --git a/src/Runtime/Ghost.Graphics.RHI/Common.cs b/src/Runtime/Ghost.Graphics.RHI/Common.cs index cf7dba0..2f05072 100644 --- a/src/Runtime/Ghost.Graphics.RHI/Common.cs +++ b/src/Runtime/Ghost.Graphics.RHI/Common.cs @@ -179,7 +179,7 @@ public readonly struct ShaderPass get; init; } - public LocalKeywordSet KeywordIDs + public LocalKeywordSet DefinedKeywords { get; init; } diff --git a/src/Runtime/Ghost.Graphics.RHI/IShaderCompilationBridge.cs b/src/Runtime/Ghost.Graphics.RHI/IShaderCompilationBridge.cs index d43d8d8..0c085ea 100644 --- a/src/Runtime/Ghost.Graphics.RHI/IShaderCompilationBridge.cs +++ b/src/Runtime/Ghost.Graphics.RHI/IShaderCompilationBridge.cs @@ -2,13 +2,13 @@ using Ghost.Core; namespace Ghost.Graphics.RHI; -public interface IShaderCompilationBridge +public interface IShaderCompilationBridge : IDisposable { /// /// Request the bridge to recompile a shader variant or handle cache misses. /// This is typically called by the ShaderLibrary when a variant hash is not found. /// - void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey); + void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask); /// /// Event triggered when a shader variant has been successfully compiled and updated. diff --git a/src/Runtime/Ghost.Graphics/Core/RenderContext.cs b/src/Runtime/Ghost.Graphics/Core/RenderContext.cs index ca520a0..5bec211 100644 --- a/src/Runtime/Ghost.Graphics/Core/RenderContext.cs +++ b/src/Runtime/Ghost.Graphics/Core/RenderContext.cs @@ -379,7 +379,7 @@ public readonly unsafe ref struct RenderContext var variantKey = RHIUtility.CreateShaderVariantKey(entryHash, in keywordSet); // TODO: Refactor this into a helper method. - var (compiledHash, error) = ShaderLibrary.GetCompiledHash(shader.UniqueID, entryIndex, variantKey); + var (compiledHash, error) = ShaderLibrary.GetCompiledHash(shader.UniqueID, entryIndex, variantKey, keywordSet); if (error.IsFailure) { // TODO: Fallback to an error material. diff --git a/src/Runtime/Ghost.Graphics/Core/Shader.cs b/src/Runtime/Ghost.Graphics/Core/Shader.cs index 6f25841..c81983c 100644 --- a/src/Runtime/Ghost.Graphics/Core/Shader.cs +++ b/src/Runtime/Ghost.Graphics/Core/Shader.cs @@ -119,7 +119,7 @@ public partial struct Shader : IResourceReleasable { Key = RHIUtility.GetPassID(_nameHash, i), DefaultState = pass.localPipeline, - KeywordIDs = keywords, + DefinedKeywords = keywords, }; _passIDToLocal[GetPassID(pass.name)] = (ushort)i; diff --git a/src/Runtime/Ghost.Graphics/RenderGraphModule/RenderGraphContext.cs b/src/Runtime/Ghost.Graphics/RenderGraphModule/RenderGraphContext.cs index dbaecb3..50681a3 100644 --- a/src/Runtime/Ghost.Graphics/RenderGraphModule/RenderGraphContext.cs +++ b/src/Runtime/Ghost.Graphics/RenderGraphModule/RenderGraphContext.cs @@ -169,10 +169,10 @@ internal sealed class RenderGraphContext : IUnsafeRenderContext var materialPipeline = material.GetPassPipelineOverride(material.ActivePassIndex); // Mask out the keywords that are not used in this pass. - var variantMask = material._keywordMask & pass.KeywordIDs; + var variantMask = material._keywordMask & pass.DefinedKeywords; var variantKey = RHIUtility.CreateShaderVariantKey(pass.Key, in variantMask); - var (compiledHash, error) = _shaderLibrary.GetCompiledHash(shader.UniqueID, material.ActivePassIndex, variantKey); + var (compiledHash, error) = _shaderLibrary.GetCompiledHash(shader.UniqueID, material.ActivePassIndex, variantKey, variantMask); if (error.IsFailure) { // TODO: Fallback to a default shader or show an error material. @@ -277,7 +277,7 @@ internal sealed class RenderGraphContext : IUnsafeRenderContext var keywordSet = new LocalKeywordSet(); // TODO: Support keywords in compute shader. var variantKey = RHIUtility.CreateShaderVariantKey(entryHash, in keywordSet); - var (compiledHash, error) = _shaderLibrary.GetCompiledHash(shader.UniqueID, entryIndex, variantKey); + var (compiledHash, error) = _shaderLibrary.GetCompiledHash(shader.UniqueID, entryIndex, variantKey, keywordSet); if (error.IsFailure) { // TODO: Fallback to a default shader or show an error material. diff --git a/src/Runtime/Ghost.Graphics/Services/ShaderLibrary.cs b/src/Runtime/Ghost.Graphics/Services/ShaderLibrary.cs index 16ccd3e..248214d 100644 --- a/src/Runtime/Ghost.Graphics/Services/ShaderLibrary.cs +++ b/src/Runtime/Ghost.Graphics/Services/ShaderLibrary.cs @@ -195,14 +195,14 @@ internal unsafe class ShaderLibrary : IDisposable } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Result GetCompiledHash(ulong id, int passIndex, Key64 variantKey) + public Result GetCompiledHash(ulong id, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask = default) { if (_variantToCompiledHash.TryGetValue(variantKey, out var compiledHash)) { return compiledHash; } - _shaderCompilationBridge?.RequestCompilation(id, passIndex, variantKey); + _shaderCompilationBridge?.RequestCompilation(id, passIndex, variantKey, keywordMask); return Error.NotFound; } diff --git a/src/Test/Ghost.UnitTest/Graphics/ShaderLibraryTest.cs b/src/Test/Ghost.UnitTest/Graphics/ShaderLibraryTest.cs index 662a926..9e5ecb0 100644 --- a/src/Test/Ghost.UnitTest/Graphics/ShaderLibraryTest.cs +++ b/src/Test/Ghost.UnitTest/Graphics/ShaderLibraryTest.cs @@ -30,12 +30,12 @@ public class ShaderLibraryTest private class MockShaderCompilationBridge : IShaderCompilationBridge { - public List<(ulong id, int passIndex, Key64 variantKey)> Requests { get; } = new(); + public List<(ulong id, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask)> Requests { get; } = new(); public event Action, ulong>? OnShaderVariantCompiled; - public void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey) + public void RequestCompilation(ulong shaderId, int passIndex, Key64 variantKey, LocalKeywordSet keywordMask) { - Requests.Add((shaderId, passIndex, variantKey)); + Requests.Add((shaderId, passIndex, variantKey, keywordMask)); } public void TriggerCompiled(Key64 variantKey, ulong newHash)