Improve vector and matrix performance and add swizzle support to .net build-int VectorX type.
This commit is contained in:
@@ -19,6 +19,12 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
this.typeInfo = typeInfo;
|
||||
sourceBuilder.Clear();
|
||||
|
||||
var message = Validation();
|
||||
if (message != null)
|
||||
{
|
||||
return message;
|
||||
}
|
||||
|
||||
Initialize();
|
||||
|
||||
GenerateHeader();
|
||||
@@ -45,6 +51,11 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
#endregion");
|
||||
}
|
||||
|
||||
protected virtual string? Validation()
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
protected virtual void Initialize()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection.Metadata;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{
|
||||
internal class MatrixGenerator : GeneratorBase
|
||||
{
|
||||
private readonly List<(string signature, string assignment)> _constructorSignatures = new();
|
||||
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
|
||||
|
||||
private string GetConversionFromTemplate(string template, int componentIndex)
|
||||
{
|
||||
@@ -15,6 +16,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
.Replace("{c}", s_matrixComponents[componentIndex]);
|
||||
}
|
||||
|
||||
protected override string? Validation()
|
||||
{
|
||||
if (typeInfo.ElementTypeSymbol == null)
|
||||
{
|
||||
return "You must specify 'elementType' in NumericTypeAttribute for matrix types.";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected override void GenerateBody()
|
||||
{
|
||||
GenerateField();
|
||||
@@ -142,7 +153,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{typeInfo.ElementTypeFullName} value",
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(_ => $"new {typeInfo.ComponentTypeFullName}(value)").ToList()));
|
||||
|
||||
var tempSB = new StringBuilder();
|
||||
for (var r = 0; r < typeInfo.Row; r++)
|
||||
@@ -159,16 +170,16 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: tempSB.ToString(),
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(c => $"new {typeInfo.ComponentTypeFullName}({string.Join(", ", Enumerable.Range(0, typeInfo.Row).Select(r => $"m{r}{c}"))})").ToList()));
|
||||
}
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{typeInfo.ComponentTypeFullName} value",
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => "value"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => "value").ToList()));
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"{typeInfo.ComponentTypeFullName} c{i}")),
|
||||
assignment: string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}"))));
|
||||
assignment: Enumerable.Range(0, typeInfo.Column).Select(i => $"c{i}").ToList()));
|
||||
|
||||
if (typeInfo.ConvertableTypes != null)
|
||||
{
|
||||
@@ -177,53 +188,34 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var targetTemplate = kv.Key;
|
||||
var targetTypes = kv.Value;
|
||||
|
||||
var tempSB = new StringBuilder();
|
||||
foreach (var type in targetTypes)
|
||||
{
|
||||
var assignments = new List<string>();
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
tempSB.Append(GetConversionFromTemplate(targetTemplate, i));
|
||||
if (i < typeInfo.Column - 1)
|
||||
{
|
||||
tempSB.Append(", ");
|
||||
}
|
||||
assignments.Add(GetConversionFromTemplate(targetTemplate, i));
|
||||
}
|
||||
|
||||
_constructorSignatures.Add((
|
||||
signature: $"{type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} v",
|
||||
assignment: tempSB.ToString()));
|
||||
|
||||
tempSB.Clear();
|
||||
assignment: assignments));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (signature, assignment) in _constructorSignatures)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
public {typeInfo.TypeName}({signature})
|
||||
{{
|
||||
this = Create({assignment});
|
||||
}}");
|
||||
}
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeInfo.TypeName} Create({string.Join(", ", Enumerable.Range(0, typeInfo.Column).Select((_, i) => $"{typeInfo.ComponentTypeFullName} {s_matrixComponents[i]}"))})
|
||||
{{
|
||||
global::System.Runtime.CompilerServices.Unsafe.SkipInit(out {typeInfo.TypeFullName} result);
|
||||
");
|
||||
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
result.{s_matrixComponents[i]} = {s_matrixComponents[i]};");
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
return result;
|
||||
public {typeInfo.TypeName}({signature})
|
||||
{{");
|
||||
for (var i = 0; i < typeInfo.Column; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_matrixComponents[i]} = {assignment[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateUnsafeMethod()
|
||||
@@ -467,7 +459,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeInfo.TypeFullName} {typeInfo.TypeName}({signature})
|
||||
{{
|
||||
return {typeInfo.TypeFullName}.Create({assignment});
|
||||
return new {typeInfo.TypeFullName}({string.Join(", ", assignment)});
|
||||
}}");
|
||||
}
|
||||
|
||||
@@ -766,11 +758,47 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var rhsVectorType = $"{typePrefix}{lhsCols}";
|
||||
var resultVectorType = $"{typePrefix}{lhsRows}";
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
var columnSizeBytes = lhsRows * typeInfo.ComponentSize;
|
||||
var vectorBits = columnSizeBytes > 16 ? 256 : 128;
|
||||
bool isFloatingPoint = typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Single||
|
||||
typeInfo.ElementTypeSymbol!.SpecialType == SpecialType.System_Double;
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {resultVectorType} mul({lhsType} m, {rhsVectorType} v)
|
||||
{{
|
||||
return {string.Join(" + ", Enumerable.Range(0, lhsCols).Select(c => $"m.{s_matrixComponents[c]} * v.{s_vectorComponents[c]}"))};
|
||||
{{");
|
||||
|
||||
for (int i = 0; i < lhsCols; i++)
|
||||
{
|
||||
var component = s_vectorComponents[i];
|
||||
sourceBuilder.Append($@"
|
||||
var v{component} = global::System.Runtime.Intrinsics.Vector{vectorBits}.Create(v.{component});");
|
||||
}
|
||||
|
||||
sourceBuilder.Append($@"
|
||||
|
||||
var sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.c0.AsVector{vectorBits}(), vx);");
|
||||
|
||||
for (int i = 1; i < lhsCols; i++)
|
||||
{
|
||||
var component = s_vectorComponents[i];
|
||||
var col = s_matrixComponents[i];
|
||||
|
||||
if (isFloatingPoint)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.FusedMultiplyAdd(m.{col}.AsVector{vectorBits}(), v{component}, sum);");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
sum = global::System.Runtime.Intrinsics.Vector{vectorBits}.Add(sum,
|
||||
global::System.Runtime.Intrinsics.Vector{vectorBits}.Multiply(m.{col}.AsVector{vectorBits}(), v{component}));");
|
||||
}
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
return sum.As{typeInfo.ComponentTypeName}();
|
||||
}}");
|
||||
|
||||
// Vector-Matrix Multiplication: R-element vector * RxC = C-element vector
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
internal class VectorGenerator : GeneratorBase
|
||||
{
|
||||
private int _vectorBitsSize;
|
||||
private int _missingComponents;
|
||||
private int _missingComponentsCount;
|
||||
private string _componentTypePrefix = null!;
|
||||
|
||||
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
|
||||
@@ -21,7 +21,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
protected override void Initialize()
|
||||
{
|
||||
var componentSize = typeInfo.ComponentSize;
|
||||
var typeSize = componentSize * typeInfo.Row * typeInfo.Column;
|
||||
var typeSize = componentSize * typeInfo.Row;
|
||||
var vectorBytesSize = typeSize switch
|
||||
{
|
||||
//<= 8 => 8,
|
||||
@@ -31,7 +31,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
};
|
||||
|
||||
_vectorBitsSize = vectorBytesSize * 8;
|
||||
_missingComponents = (vectorBytesSize - typeSize) / componentSize;
|
||||
_missingComponentsCount = (vectorBytesSize - typeSize) / componentSize;
|
||||
|
||||
_componentTypePrefix = typeInfo.ComponentTypeSymbol.SpecialType switch
|
||||
{
|
||||
@@ -224,6 +224,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var paramNames = new List<string>();
|
||||
var paramList = new List<string>();
|
||||
var assignments = new List<string>();
|
||||
|
||||
var fieldOffset = 0;
|
||||
|
||||
for (var i = 0; i < perm.Count; i++)
|
||||
@@ -426,7 +427,6 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator +=({typeName} other)
|
||||
{{");
|
||||
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
@@ -759,7 +759,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
else
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
return global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create({string.Join(", ", Enumerable.Range(0, typeInfo.Row + _missingComponents).Select(i => i < typeInfo.Row ? $"value.{s_vectorComponents[i]}" : "default"))});");
|
||||
return global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create({string.Join(", ", Enumerable.Range(0, typeInfo.Row + _missingComponentsCount).Select(i => i < typeInfo.Row ? $"value.{s_vectorComponents[i]}" : "default"))});");
|
||||
}
|
||||
sourceBuilder.Append($@"
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user