Improve vector and matrix performance and add swizzle support to .net build-int VectorX type.
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection.Metadata;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{
|
||||
internal class MatrixGenerator : GeneratorBase
|
||||
{
|
||||
private readonly List<(string signature, string assignment)> _constructorSignatures = new();
|
||||
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
|
||||
|
||||
private string GetConversionFromTemplate(string template, int componentIndex)
|
||||
{
|
||||
@@ -15,6 +16,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
.Replace("{c}", s_matrixComponents[componentIndex]);
|
||||
}
|
||||
|
||||
protected override string? Validation()
|
||||
{
|
||||
if (typeInfo.ElementTypeSymbol == null)
|
||||
{
|
||||
return "You must specify 'elementType' in NumericTypeAttribute for matrix types.";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected override void GenerateBody()
|
||||
{
|
||||
GenerateField();
|
||||
@@ -142,7 +153,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{typeInfo.ElementTypeFullName} value",
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)").ToList()));
|
||||
|
||||
var tempSB = new StringBuilder();
|
||||
for (var r = 0; r < typeInfo.Row; r++)
|
||||
@@ -159,16 +170,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: tempSB.ToString(),
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})").ToList()));
|
||||
}
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{typeInfo.ComponentTypeFullName} value",
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => "value"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => "value").ToList()));
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"{typeInfo.ComponentTypeFullName} c{i}")),
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}").ToList()));
|
||||
|
||||
if (typeInfo.ConvertableTypes != null)
|
||||
{
|
||||
@@ -177,53 +188,34 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var targetTemplate = kv.Key;
|
||||
var targetTypes = kv.Value;
|
||||
|
||||
var tempSB = new StringBuilder();
|
||||
foreach (var type in targetTypes)
|
||||
{
|
||||
var assignments = new List<string>();
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
tempSB.Append(GetConversionFromTemplate(targetTemplate, i));
|
||||
if (i < typeInfo.Column - 1)
|
||||
{
|
||||
tempSB.Append(", ");
|
||||
}
|
||||
assignments.Add(GetConversionFromTemplate(targetTemplate, i));
|
||||
}
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} v",
|
||||
assignment: tempSB.ToString()));
|
||||
|
||||
tempSB.Clear();
|
||||
assignment: assignments));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (signature, assignment) in _constructorSignatures)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
public {typeInfo.TypeName}({signature})
|
||||
{{
|
||||
this = Create({assignment});
|
||||
}}");
|
||||
}
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeInfo.TypeName} Create({string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select((_, i) => $"{typeInfo.ComponentTypeFullName} {s_matrixComponents[i]}"))})
|
||||
{{
|
||||
global::System.Runtime.CompilerServices.Unsafe.SkipInit(out {typeInfo.TypeFullName} result);
|
||||
");
|
||||
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
result.{s_matrixComponents[i]} = {s_matrixComponents[i]};");
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
return result;
|
||||
public {typeInfo.TypeName}({signature})
|
||||
{{");
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_matrixComponents[i]} = {assignment[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateUnsafeMethod()
|
||||
@@ -467,7 +459,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeInfo.TypeFullName} {typeInfo.TypeName}({signature})
|
||||
{{
|
||||
return {typeInfo.TypeFullName}.Create({assignment});
|
||||
return new {typeInfo.TypeFullName}({string.Join(", ", assignment)});
|
||||
}}");
|
||||
}
|
||||
|
||||
@@ -766,11 +758,47 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var rhsVectorType = $"{typePrefix}{lhsCols}";
|
||||
var resultVectorType = $"{typePrefix}{lhsRows}";
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
var columnSizeBytes = lhsRows * typeInfo.ComponentSize;
|
||||
var vectorBits = columnSizeBytes > 16 ? 256 : 128;
|
||||
bool isFloatingPoint = typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Single||
|
||||
typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Double;
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {resultVectorType} mul({lhsType} m, {rhsVectorType} v)
|
||||
{{
|
||||
return {string.Join(" + ", Enumerable.Range(0, lhsCols).Select(c => $"m.{s_matrixComponents[c]} * v.{s_vectorComponents[c]}"))};
|
||||
{{");
|
||||
|
||||
for (int i = 0; i < lhsCols; i++)
|
||||
{
|
||||
var component = s_vectorComponents[i];
|
||||
sourceBuilder.Append($@"
|
||||
var v{component} = global::System.Runtime.Intrinsics.Vector{vectorBits}.Create(v.{component});");
|
||||
}
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
|
||||
var sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.c0.AsVector{vectorBits}(), vx);");
|
||||
|
||||
for (int i = 1; i < lhsCols; i++)
|
||||
{
|
||||
var component = s_vectorComponents[i];
|
||||
var col = s_matrixComponents[i];
|
||||
|
||||
if (isFloatingPoint)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.FusedMultiplyAdd(m.{col}.AsVector{vectorBits}(), v{component}, sum);");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Add(sum,
|
||||
global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.{col}.AsVector{vectorBits}(), v{component}));");
|
||||
}
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
return sum.As{typeInfo.ComponentTypeName}();
|
||||
}}");
|
||||
|
||||
// Vector-Matrix Multiplication: R-element vector * RxC = C-element vector
|
||||
|
||||
Reference in New Issue
Block a user