using Ghost.Core; using Ghost.Graphics.D3D12; using System.Runtime.InteropServices; using System.Text; using Win32; using Win32.Graphics.Direct3D; using Win32.Graphics.Direct3D.Fxc; using Win32.Graphics.Direct3D12; namespace Ghost.Graphics.Data; public unsafe class Shader { private static readonly Shader s_empty = new("ErrorShader"); public static Shader Empty => s_empty; private ComPtr _rootSignature; public ConstPtr RootSignature => new(_rootSignature.Get()); public Shader(string shaderPath) { } /// /// Compiles HLSL source code from a string into shader bytecode. /// /// The string containing the HLSL code. /// The name of the shader entry point function (e.g., "VSMain"). /// The shader model to target (e.g., "vs_5_0", "ps_5_0"). /// A byte array containing the compiled shader bytecode. /// Thrown if shader compilation fails. public static unsafe byte[] CompileShader(string sourceCode, string entryPoint, string shaderProfile) { ComPtr bytecodeBlob = default; ComPtr errorBlob = default; // Convert strings to null-terminated ASCII for the native function var sourceCodeBytes = Encoding.UTF8.GetBytes(sourceCode); var entryPointBytes = Encoding.UTF8.GetBytes(entryPoint); var shaderProfileBytes = Encoding.UTF8.GetBytes(shaderProfile); // Call the D3DCompile function var hr = D3DCompile( sourceCodeBytes.AsSpan(), entryPointBytes.AsSpan(), shaderProfileBytes.AsSpan(), CompileFlags.EnableStrictness | CompileFlags.Debug, bytecodeBlob.GetAddressOf(), errorBlob.GetAddressOf() ); if (hr.Failure) { // If compilation fails, get the error message from the error blob var errorMessage = "Shader compilation failed."; if (errorBlob.Get() is not null) { errorMessage += "\n" + Encoding.ASCII.GetString( (byte*)errorBlob.Get()->GetBufferPointer(), (int)errorBlob.Get()->GetBufferSize() ); } errorBlob.Dispose(); throw new Exception(errorMessage); } // Copy the compiled bytecode from the blob into a managed byte array var bytecode = new byte[bytecodeBlob.Get()->GetBufferSize()]; Marshal.Copy((IntPtr)bytecodeBlob.Get()->GetBufferPointer(), bytecode, 0, bytecode.Length); // Clean up the COM blobs bytecodeBlob.Dispose(); errorBlob.Dispose(); return bytecode; } private void LoadShader(Span byteCode) { using ComPtr reflector = default; fixed (void* codePtr = byteCode) { D3DReflect(codePtr, (nuint)byteCode.Length, __uuidof(), reflector.GetVoidAddressOf()); } ShaderDescription shaderDesc; reflector.Get()->GetDesc(&shaderDesc); var rootParameters = new List(); var staticSamplers = new List(); for (uint i = 0; i < shaderDesc.BoundResources; i++) { ShaderInputBindDescription bindDesc; reflector.Get()->GetResourceBindingDesc(i, &bindDesc); switch (bindDesc.Type) { case ShaderInputType.ConstantBuffer: var cbufferParam = new RootParameter(); cbufferParam.ParameterType = RootParameterType.Cbv; cbufferParam.ShaderVisibility = ShaderVisibility.All; cbufferParam.Descriptor.RegisterSpace = bindDesc.Space; cbufferParam.Descriptor.ShaderRegister = bindDesc.BindPoint; rootParameters.Add(cbufferParam); var cbuffer = reflector.Get()->GetConstantBufferByName(bindDesc.Name); ShaderBufferDescription cbufferDesc; cbuffer->GetDesc(&cbufferDesc); for (var j = 0u; j < cbufferDesc.Variables; j++) { var variable = cbuffer->GetVariableByIndex(j); ShaderVariableDescription varDesc; variable->GetDesc(&varDesc); } break; case ShaderInputType.TextureBuffer: break; case ShaderInputType.Texture: break; case ShaderInputType.Sampler: var samplerDesc = new StaticSamplerDescription { Filter = Filter.MinMagMipLinear, AddressU = TextureAddressMode.Wrap, AddressV = TextureAddressMode.Wrap, AddressW = TextureAddressMode.Wrap, ShaderVisibility = ShaderVisibility.All, ShaderRegister = bindDesc.BindPoint, RegisterSpace = bindDesc.Space, }; staticSamplers.Add(samplerDesc); break; case ShaderInputType.UavRwTyped: break; case ShaderInputType.Structured: break; case ShaderInputType.UavRwStructured: break; case ShaderInputType.ByteAddress: break; case ShaderInputType.UavRwByteAddress: break; case ShaderInputType.UavAppendStructured: break; case ShaderInputType.UavConsumeStructured: break; case ShaderInputType.UavRwStructuredWithCounter: break; case ShaderInputType.RtAccelerationStructure: break; case ShaderInputType.UavFeedbackTexture: break; default: break; } } } private void CreateRootSignature() { var rootSignatureDesc = new RootSignatureDescription(); using ComPtr signature = default; using ComPtr error = default; var hr = D3D12SerializeRootSignature(&rootSignatureDesc, RootSignatureVersion.V1_2, signature.GetAddressOf(), error.GetAddressOf()); if (hr.Failure) { var errorMessage = System.Text.Encoding.ASCII.GetString((byte*)error.Get()->GetBufferPointer(), (int)error.Get()->GetBufferSize()); throw new Exception($"Failed to serialize root signature: {errorMessage}"); } GraphicsPipeline.GetGraphicsDevice().NativeDevice.Ptr->CreateRootSignature(0, signature.Get()->GetBufferPointer(), signature.Get()->GetBufferSize(), __uuidof(), _rootSignature.GetVoidAddressOf()); } }