Improve vector and matrix performance and add swizzle support to .net build-int VectorX type.

This commit is contained in:
2025-12-17 16:55:28 +09:00
parent ef2a3a37bd
commit a1ad0bd2da
15 changed files with 2960 additions and 269 deletions

View File

@@ -19,6 +19,12 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
this.typeInfo = typeInfo;
sourceBuilder.Clear();
var message = Validation();
if (message != null)
{
return message;
}
Initialize();
GenerateHeader();
@@ -45,6 +51,11 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
#endregion");
}
protected virtual string? Validation()
{
return null;
}
protected virtual void Initialize()
{
}

View File

@@ -1,13 +1,14 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
using System.Text;
namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
{
internal class MatrixGenerator : GeneratorBase
{
private readonly List<(string signature, string assignment)> _constructorSignatures = new();
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
private string GetConversionFromTemplate(string template, int componentIndex)
{
@@ -15,6 +16,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
.Replace("{c}", s_matrixComponents[componentIndex]);
}
protected override string? Validation()
{
if (typeInfo.ElementTypeSymbol == null)
{
return "You must specify 'elementType' in NumericTypeAttribute for matrix types.";
}
return null;
}
protected override void GenerateBody()
{
GenerateField();
@@ -142,7 +153,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
_constructorSignatures.Add((
signature: $"{typeInfo.ElementTypeFullName} value",
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)"))));
assignment: Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)").ToList()));
var tempSB = new StringBuilder();
for (var r = 0; r < typeInfo.Row; r++)
@@ -159,16 +170,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
_constructorSignatures.Add((
signature: tempSB.ToString(),
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})"))));
assignment: Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})").ToList()));
}
_constructorSignatures.Add((
signature: $"{typeInfo.ComponentTypeFullName} value",
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => "value"))));
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => "value").ToList()));
_constructorSignatures.Add((
signature: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"{typeInfo.ComponentTypeFullName} c{i}")),
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}"))));
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}").ToList()));
if (typeInfo.ConvertableTypes != null)
{
@@ -177,53 +188,34 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
var targetTemplate = kv.Key;
var targetTypes = kv.Value;
var tempSB = new StringBuilder();
foreach (var type in targetTypes)
{
var assignments = new List<string>();
for (var i = 0; i < typeInfo.Column; i++)
{
tempSB.Append(GetConversionFromTemplate(targetTemplate, i));
if (i < typeInfo.Column - 1)
{
tempSB.Append(", ");
}
assignments.Add(GetConversionFromTemplate(targetTemplate, i));
}
_constructorSignatures.Add((
signature: $"{type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} v",
assignment: tempSB.ToString()));
tempSB.Clear();
assignment: assignments));
}
}
}
foreach (var (signature, assignment) in _constructorSignatures)
{
sourceBuilder.AppendLine($@"
public {typeInfo.TypeName}({signature})
{{
this = Create({assignment});
}}");
}
sourceBuilder.Append($@"
{INLINE_METHOD_ATTRIBUTE}
public static {typeInfo.TypeName} Create({string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select((_, i) => $"{typeInfo.ComponentTypeFullName} {s_matrixComponents[i]}"))})
{{
global::System.Runtime.CompilerServices.Unsafe.SkipInit(out {typeInfo.TypeFullName} result);
");
for (var i = 0; i < typeInfo.Column; i++)
{
sourceBuilder.Append($@"
result.{s_matrixComponents[i]} = {s_matrixComponents[i]};");
}
sourceBuilder.AppendLine($@"
return result;
public {typeInfo.TypeName}({signature})
{{");
for (var i = 0; i < typeInfo.Column; i++)
{
sourceBuilder.Append($@"
this.{s_matrixComponents[i]} = {assignment[i]};");
}
sourceBuilder.AppendLine($@"
}}");
}
}
private void GenerateUnsafeMethod()
@@ -467,7 +459,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
{INLINE_METHOD_ATTRIBUTE}
public static {typeInfo.TypeFullName} {typeInfo.TypeName}({signature})
{{
return {typeInfo.TypeFullName}.Create({assignment});
return new {typeInfo.TypeFullName}({string.Join(", ", assignment)});
}}");
}
@@ -766,11 +758,47 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
var rhsVectorType = $"{typePrefix}{lhsCols}";
var resultVectorType = $"{typePrefix}{lhsRows}";
sourceBuilder.AppendLine($@"
var columnSizeBytes = lhsRows * typeInfo.ComponentSize;
var vectorBits = columnSizeBytes > 16 ? 256 : 128;
bool isFloatingPoint = typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Single||
typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Double;
sourceBuilder.Append($@"
{INLINE_METHOD_ATTRIBUTE}
public static {resultVectorType} mul({lhsType} m, {rhsVectorType} v)
{{
return {string.Join(" + ", Enumerable.Range(0, lhsCols).Select(c => $"m.{s_matrixComponents[c]} * v.{s_vectorComponents[c]}"))};
{{");
for (int i = 0; i < lhsCols; i++)
{
var component = s_vectorComponents[i];
sourceBuilder.Append($@"
var v{component} = global::System.Runtime.Intrinsics.Vector{vectorBits}.Create(v.{component});");
}
sourceBuilder.Append($@"
var sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.c0.AsVector{vectorBits}(), vx);");
for (int i = 1; i < lhsCols; i++)
{
var component = s_vectorComponents[i];
var col = s_matrixComponents[i];
if (isFloatingPoint)
{
sourceBuilder.Append($@"
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.FusedMultiplyAdd(m.{col}.AsVector{vectorBits}(), v{component}, sum);");
}
else
{
sourceBuilder.Append($@"
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Add(sum,
global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.{col}.AsVector{vectorBits}(), v{component}));");
}
}
sourceBuilder.AppendLine($@"
return sum.As{typeInfo.ComponentTypeName}();
}}");
// Vector-Matrix Multiplication: R-element vector * RxC = C-element vector

View File

@@ -7,7 +7,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
internal class VectorGenerator : GeneratorBase
{
private int _vectorBitsSize;
private int _missingComponents;
private int _missingComponentsCount;
private string _componentTypePrefix = null!;
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
@@ -21,7 +21,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
protected override void Initialize()
{
var componentSize = typeInfo.ComponentSize;
var typeSize = componentSize * typeInfo.Row * typeInfo.Column;
var typeSize = componentSize * typeInfo.Row;
var vectorBytesSize = typeSize switch
{
//<= 8 => 8,
@@ -31,7 +31,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
};
_vectorBitsSize = vectorBytesSize * 8;
_missingComponents = (vectorBytesSize - typeSize) / componentSize;
_missingComponentsCount = (vectorBytesSize - typeSize) / componentSize;
_componentTypePrefix = typeInfo.ComponentTypeSymbol.SpecialType switch
{
@@ -224,6 +224,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
var paramNames = new List<string>();
var paramList = new List<string>();
var assignments = new List<string>();
var fieldOffset = 0;
for (var i = 0; i < perm.Count; i++)
@@ -426,7 +427,6 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
{INLINE_METHOD_ATTRIBUTE}
public void operator +=({typeName} other)
{{");
for (var i = 0; i < typeInfo.Row; i++)
{
sourceBuilder.Append($@"
@@ -759,7 +759,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
else
{
sourceBuilder.Append($@"
return global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create({string.Join(", ", Enumerable.Range(0, typeInfo.Row + _missingComponents).Select(i => i < typeInfo.Row ? $"value.{s_vectorComponents[i]}" : "default"))});");
return global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create({string.Join(", ", Enumerable.Range(0, typeInfo.Row + _missingComponentsCount).Select(i => i < typeInfo.Row ? $"value.{s_vectorComponents[i]}" : "default"))});");
}
sourceBuilder.Append($@"
}}