feat(shader): add compute shader support and refactor pipeline

Refactored shader system to support both graphics and compute shaders.
- Updated ANTLR grammars and parser logic for explicit shader model and compute shader entry points.
- Split shader models and descriptors for graphics and compute.
- Refactored pipeline key generation and D3D12 pipeline library for compute support.
- Updated push constant layouts and HLSL includes for both shader types.
- Improved error handling and test coverage with new example files.

BREAKING CHANGE: Shader model, descriptor, and pipeline APIs have changed. Existing shader and pipeline code must be updated to use the new types and conventions.
This commit is contained in:
2026-04-10 02:53:40 +09:00
parent 68fda03aa9
commit 4ed5572ce7
29 changed files with 742 additions and 290 deletions

View File

@@ -7,7 +7,7 @@ namespace Ghost.DSL.ShaderParser;
public class AntlrShaderCompiler
{
public static List<ShaderModel> ParseShaders(string source, out List<DSLShaderError> errors)
public static List<GraphicsShaderModel> ParseShaders(string source, out List<DSLShaderError> errors)
{
errors = new List<DSLShaderError>();
@@ -33,7 +33,7 @@ public class AntlrShaderCompiler
if (errors.Count > 0)
{
return new List<ShaderModel>();
return new List<GraphicsShaderModel>();
}
var visitor = new ShaderVisitor();
@@ -49,11 +49,165 @@ public class AntlrShaderCompiler
line = -1,
column = -1
});
return new List<ShaderModel>();
return new List<GraphicsShaderModel>();
}
}
public static DSLShaderSemantics? ConvertToSemantics(ShaderModel model, out List<DSLShaderError> errors)
public static List<ComputeShaderModel> ParseComputeShaders(string source, out List<DSLShaderError> errors)
{
errors = new List<DSLShaderError>();
try
{
var inputStream = new AntlrInputStream(source);
var lexer = new GhostShaderLexer(inputStream);
// Capture lexer errors
lexer.RemoveErrorListeners();
var lexerErrorListener = new ErrorListener(errors);
lexer.AddErrorListener(lexerErrorListener);
var tokenStream = new CommonTokenStream(lexer);
var parser = new GhostComputeShaderParser(tokenStream);
// Capture parser errors
parser.RemoveErrorListeners();
var parserErrorListener = new ErrorListener(errors);
parser.AddErrorListener(parserErrorListener);
var tree = parser.computeFile();
if (errors.Count > 0)
{
return new List<ComputeShaderModel>();
}
var visitor = new ComputeShaderVisitor();
visitor.Visit(tree);
return visitor.ComputeShaders;
}
catch (Exception ex)
{
errors.Add(new DSLShaderError
{
message = $"Unexpected error during parsing: {ex.Message}",
line = -1,
column = -1
});
return new List<ComputeShaderModel>();
}
}
public static DSLComputeShaderSemantics? ConvertToComputeSemantics(ComputeShaderModel model, out List<DSLShaderError> errors)
{
errors = new List<DSLShaderError>();
if (string.IsNullOrWhiteSpace(model.Name))
{
errors.Add(new DSLShaderError
{
message = "Compute shader name cannot be empty.",
line = 0,
column = 0
});
return null;
}
var semantics = new DSLComputeShaderSemantics
{
name = model.Name,
defines = model.Defines?.Defines,
includes = model.Includes?.Includes,
hlsl = model.Hlsl?.Code
};
if (string.IsNullOrEmpty(model.SM))
{
semantics.shaderModel = ShaderModel.SM_6_8; // Default to highest supported shader model
}
else
{
semantics.shaderModel = model.SM.ToLower() switch
{
"6_6" => ShaderModel.SM_6_6,
"6_7" => ShaderModel.SM_6_7,
"6_8" => ShaderModel.SM_6_8,
_ => ShaderModel.Invalid
};
if (semantics.shaderModel == ShaderModel.Invalid)
{
errors.Add(new DSLShaderError
{
message = $"Unknown shader model '{model.SM}'.",
line = 0,
column = 0
});
}
}
if (model.Keywords != null)
{
semantics.keywords = new List<KeywordsGroup>();
foreach (var group in model.Keywords.Groups)
{
var keywordGroup = new KeywordsGroup
{
space = group.Scope?.ToLower() == "global" ? KeywordSpace.Global : KeywordSpace.Local,
keywords = group.Keywords
};
semantics.keywords.Add(keywordGroup);
}
}
foreach (var entry in model.ShaderEntries)
{
var entryType = entry.EntryType.ToLower();
if (entryType == "cs")
{
semantics.entryPoints ??= new List<ShaderEntryPoint>();
semantics.entryPoints.Add(new ShaderEntryPoint
{
shader = entry.ShaderPath,
entry = entry.EntryPoint
});
}
else
{
errors.Add(new DSLShaderError
{
message = $"Unknown compute shader entry type '{entry.EntryType}'. Expected 'compute' or 'cs'.",
line = 0,
column = 0
});
}
}
if (semantics.entryPoints == null)
{
errors.Add(new DSLShaderError
{
message = $"Compute shader '{model.Name}' must contain a compute/cs entry declaration.",
line = 0,
column = 0
});
}
if (semantics.entryPoints != null && semantics.entryPoints.Count > 8)
{
errors.Add(new DSLShaderError
{
message = $"Compute shader '{model.Name}' cannot have more than 8 entry points.",
line = 0,
column = 0
});
}
return semantics;
}
public static DSLShaderSemantics? ConvertToSemantics(GraphicsShaderModel model, out List<DSLShaderError> errors)
{
errors = new List<DSLShaderError>();
@@ -74,6 +228,31 @@ public class AntlrShaderCompiler
pipeline = ConvertPipeline(model.Pipeline, errors)
};
if (string.IsNullOrEmpty(model.SM))
{
semantics.shaderModel = ShaderModel.SM_6_8; // Default to highest supported shader model
}
else
{
semantics.shaderModel = model.SM.ToLower() switch
{
"6_6" => ShaderModel.SM_6_6,
"6_7" => ShaderModel.SM_6_7,
"6_8" => ShaderModel.SM_6_8,
_ => ShaderModel.Invalid
};
if (semantics.shaderModel == ShaderModel.Invalid)
{
errors.Add(new DSLShaderError
{
message = $"Unknown shader model '{model.SM}'.",
line = 0,
column = 0
});
}
}
foreach (var pass in model.Passes)
{
var passSemantic = ConvertPass(pass, errors);
@@ -87,99 +266,6 @@ public class AntlrShaderCompiler
return semantics;
}
private static ShaderPropertyType ParsePropertyType(string type, List<DSLShaderError> errors)
{
return type.ToLower() switch
{
"float" => ShaderPropertyType.Float,
"float2" => ShaderPropertyType.Float2,
"float3" => ShaderPropertyType.Float3,
"float4" => ShaderPropertyType.Float4,
"float4x4" => ShaderPropertyType.Float4x4,
"int" => ShaderPropertyType.Int,
"int2" => ShaderPropertyType.Int2,
"int3" => ShaderPropertyType.Int3,
"int4" => ShaderPropertyType.Int4,
"uint" => ShaderPropertyType.UInt,
"uint2" => ShaderPropertyType.UInt2,
"uint3" => ShaderPropertyType.UInt3,
"uint4" => ShaderPropertyType.UInt4,
"bool" => ShaderPropertyType.Bool,
"bool2" => ShaderPropertyType.Bool2,
"bool3" => ShaderPropertyType.Bool3,
"bool4" => ShaderPropertyType.Bool4,
"tex2d" => ShaderPropertyType.Texture2D,
"tex3d" => ShaderPropertyType.Texture3D,
"texcube" => ShaderPropertyType.TextureCube,
"texcube_arr" => ShaderPropertyType.TextureCubeArray,
"tex2d_arr" => ShaderPropertyType.Texture2DArray,
"sampler" => ShaderPropertyType.Sampler,
_ => ShaderPropertyType.None
};
}
private static object? ParsePropertyValue(ShaderPropertyType type, List<string> values, List<DSLShaderError> errors)
{
// For textures, the value is an identifier (e.g., "white", "black")
if (type is ShaderPropertyType.Texture2D or ShaderPropertyType.Texture3D or ShaderPropertyType.TextureCube)
{
return values.Count > 0 ? values[0] : null;
}
// For samplers, no default value
if (type == ShaderPropertyType.Sampler)
{
return null;
}
// For numeric types, parse the values
try
{
return type switch
{
ShaderPropertyType.Float => values.Count > 0 ? float.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture) : 0f,
ShaderPropertyType.Float2 => values.Count >= 2 ? new Misaki.HighPerformance.Mathematics.float2(
float.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.Float3 => values.Count >= 3 ? new Misaki.HighPerformance.Mathematics.float3(
float.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[2], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.Float4 => values.Count >= 4 ? new Misaki.HighPerformance.Mathematics.float4(
float.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[2], System.Globalization.CultureInfo.InvariantCulture),
float.Parse(values[3], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.Int => values.Count > 0 ? int.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture) : 0,
ShaderPropertyType.Int2 => values.Count >= 2 ? new Misaki.HighPerformance.Mathematics.int2(
int.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.Int3 => values.Count >= 3 ? new Misaki.HighPerformance.Mathematics.int3(
int.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[2], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.Int4 => values.Count >= 4 ? new Misaki.HighPerformance.Mathematics.int4(
int.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[1], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[2], System.Globalization.CultureInfo.InvariantCulture),
int.Parse(values[3], System.Globalization.CultureInfo.InvariantCulture)) : default,
ShaderPropertyType.UInt => values.Count > 0 ? uint.Parse(values[0], System.Globalization.CultureInfo.InvariantCulture) : 0u,
ShaderPropertyType.Bool => values.Count > 0 && (values[0] == "1" || values[0].ToLower() == "true"),
_ => null
};
}
catch (Exception ex)
{
errors.Add(new DSLShaderError
{
message = $"Failed to parse property value: {ex.Message}",
line = 0,
column = 0
});
return null;
}
}
private static PipelineSemantic? ConvertPipeline(PipelineBlockModel? pipeline, List<DSLShaderError> errors)
{
if (pipeline == null || pipeline.Statements.Count == 0)
@@ -275,13 +361,13 @@ public class AntlrShaderCompiler
switch (entryType)
{
case "mesh" or "ms":
case "ms":
semantic.meshShader = shaderEntry;
break;
case "pixel" or "ps":
case "ps":
semantic.pixelShader = shaderEntry;
break;
case "task" or "ts":
case "as":
semantic.taskShader = shaderEntry;
break;
default: