Improve performance and safety
This commit is contained in:
@@ -81,7 +81,6 @@ namespace {typeInfo.TypeSymbol.ContainingNamespace.ToDisplayString()}
|
||||
protected virtual void GenerateTypeStart()
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
//[global::System.Runtime.CompilerServices.SkipLocalsInit]
|
||||
public partial struct {typeInfo.TypeSymbol.Name} : global::System.IEquatable<{typeInfo.TypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>
|
||||
{{");
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine();
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
sourceBuilder.AppendLine(@$"
|
||||
public unsafe ref {typeInfo.ComponentTypeFullName} this[int index]
|
||||
{{
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
@@ -72,7 +73,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
}}
|
||||
}}");
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
sourceBuilder.AppendLine(@$"
|
||||
[global::System.Diagnostics.Conditional(""ENABLE_COLLECTION_CHECKS"")]
|
||||
private void RangeCheck(int index)
|
||||
{{
|
||||
|
||||
@@ -9,6 +9,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
private int _vectorBitsSize;
|
||||
private int _missingComponentsCount;
|
||||
private string _componentTypePrefix = null!;
|
||||
private bool CanUseVectorStorage => _missingComponentsCount == 0;
|
||||
|
||||
private readonly List<(string signature, List<string> assignment)> _constructorSignatures = new();
|
||||
|
||||
@@ -42,6 +43,20 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
};
|
||||
}
|
||||
|
||||
protected override void GenerateTypeStart()
|
||||
{
|
||||
if (CanUseVectorStorage)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
[global::System.Runtime.InteropServices.StructLayout(global::System.Runtime.InteropServices.LayoutKind.Explicit)]
|
||||
public partial struct {typeInfo.TypeSymbol.Name} : global::System.IEquatable<{typeInfo.TypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>
|
||||
{{");
|
||||
return;
|
||||
}
|
||||
|
||||
base.GenerateTypeStart();
|
||||
}
|
||||
|
||||
protected override void GenerateBody()
|
||||
{
|
||||
GenerateField();
|
||||
@@ -81,14 +96,32 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
|
||||
private void GenerateField()
|
||||
{
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
var componentType = typeInfo.ComponentTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
|
||||
|
||||
if (CanUseVectorStorage)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
public {typeInfo.ComponentTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} {s_vectorComponents[i]};");
|
||||
sourceBuilder.AppendLine($@"
|
||||
[global::System.Runtime.InteropServices.FieldOffset(0)]
|
||||
internal global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{typeInfo.ComponentTypeFullName}> __v;");
|
||||
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
[global::System.Runtime.InteropServices.FieldOffset({i * typeInfo.ComponentSize})]
|
||||
public {componentType} {s_vectorComponents[i]};");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
public {componentType} {s_vectorComponents[i]};");
|
||||
}
|
||||
}
|
||||
|
||||
sourceBuilder.AppendLine();
|
||||
sourceBuilder.AppendLine($@"
|
||||
sourceBuilder.AppendLine(@$"
|
||||
public unsafe ref {typeInfo.ComponentTypeFullName} this[int index]
|
||||
{{
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
@@ -99,7 +132,7 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
}}
|
||||
}}");
|
||||
|
||||
sourceBuilder.AppendLine($@"
|
||||
sourceBuilder.AppendLine(@$"
|
||||
[global::System.Diagnostics.Conditional(""ENABLE_COLLECTION_CHECKS"")]
|
||||
private void RangeCheck(int index)
|
||||
{{
|
||||
@@ -387,16 +420,59 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
var componentType = typeInfo.ComponentTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
|
||||
var asResult = $"As{typeSimpleName}()";
|
||||
|
||||
var canVectorizeBinaryArithmetic = CanUseVectorStorage;
|
||||
var canVectorizeDivide = CanUseVectorStorage && (typeInfo.ComponentTypeSymbol.SpecialType == SpecialType.System_Single || typeInfo.ComponentTypeSymbol.SpecialType == SpecialType.System_Double);
|
||||
|
||||
StartRegion("Arithmetic Operators");
|
||||
|
||||
// Add
|
||||
sourceBuilder.Append($@"
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator +({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Add(lhs.__v, rhs.__v)).{asResult};
|
||||
}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator +({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return new {typeName}({string.Join(", ", s_vectorComponents.Take(typeInfo.Row).Select(c => $"lhs.{c} + rhs.{c}"))});
|
||||
}}");
|
||||
}
|
||||
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator +({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Add(lhs.__v, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(rhs))).{asResult};
|
||||
}}
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator +({componentType} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Add(global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(lhs), rhs.__v)).{asResult};
|
||||
}}
|
||||
|
||||
#if NET10_0_OR_GREATER
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator +=({typeName} other)
|
||||
{{
|
||||
this.__v = global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Add(this.__v, other.__v);
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator +({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
@@ -413,23 +489,64 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator +=({typeName} other)
|
||||
{{");
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_vectorComponents[i]} += other.{s_vectorComponents[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
|
||||
// Subtract
|
||||
sourceBuilder.AppendLine($@"
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator -({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Subtract(lhs.__v, rhs.__v)).{asResult};
|
||||
}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator -({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return new {typeName}({string.Join(", ", s_vectorComponents.Take(typeInfo.Row).Select(c => $"lhs.{c} - rhs.{c}"))});
|
||||
}}");
|
||||
}
|
||||
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator -({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Subtract(lhs.__v, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(rhs))).{asResult};
|
||||
}}
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator -({componentType} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Subtract(global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(lhs), rhs.__v)).{asResult};
|
||||
}}
|
||||
|
||||
#if NET10_0_OR_GREATER
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator -=({typeName} other)
|
||||
{{
|
||||
this.__v = global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Subtract(this.__v, other.__v);
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator -({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
@@ -447,23 +564,65 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
public void operator -=({typeName} other)
|
||||
{{");
|
||||
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_vectorComponents[i]} -= other.{s_vectorComponents[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
|
||||
// Multiply
|
||||
sourceBuilder.AppendLine($@"
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator *({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Multiply(lhs.__v, rhs.__v)).{asResult};
|
||||
}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator *({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return new {typeName}({string.Join(", ", s_vectorComponents.Take(typeInfo.Row).Select(c => $"lhs.{c} * rhs.{c}"))});
|
||||
}}");
|
||||
}
|
||||
|
||||
if (canVectorizeBinaryArithmetic)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator *({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Multiply(lhs.__v, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(rhs))).{asResult};
|
||||
}}
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator *({componentType} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Multiply(global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(lhs), rhs.__v)).{asResult};
|
||||
}}
|
||||
|
||||
#if NET10_0_OR_GREATER
|
||||
// Use scaler here to let JIT handle the simd optimization since we can not do a in-place vectorlization manually.
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator *=({typeName} other)
|
||||
{{
|
||||
this.__v = global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Multiply(this.__v, other.__v);
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator *({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
@@ -482,23 +641,65 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
public void operator *=({typeName} other)
|
||||
{{");
|
||||
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_vectorComponents[i]} *= other.{s_vectorComponents[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
|
||||
// Divide
|
||||
sourceBuilder.AppendLine($@"
|
||||
if (canVectorizeDivide)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator /({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Divide(lhs.__v, rhs.__v)).{asResult};
|
||||
}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator /({typeName} lhs, {typeName} rhs)
|
||||
{{
|
||||
return new {typeName}({string.Join(", ", s_vectorComponents.Take(typeInfo.Row).Select(c => $"lhs.{c} / rhs.{c}"))});
|
||||
}}");
|
||||
}
|
||||
|
||||
if (canVectorizeDivide)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator /({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Divide(lhs.__v, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(rhs))).{asResult};
|
||||
}}
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator /({componentType} lhs, {typeName} rhs)
|
||||
{{
|
||||
return (global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Divide(global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Create(lhs), rhs.__v)).{asResult};
|
||||
}}
|
||||
|
||||
#if NET10_0_OR_GREATER
|
||||
// Use scaler here to let JIT handle the simd optimization since we can not do a in-place vectorlization manually.
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public void operator /=({typeName} other)
|
||||
{{
|
||||
this.__v = global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}.Divide(this.__v, other.__v);
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public static {typeName} operator /({typeName} lhs, {componentType} rhs)
|
||||
{{
|
||||
@@ -517,14 +718,15 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
public void operator /=({typeName} other)
|
||||
{{");
|
||||
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
for (var i = 0; i < typeInfo.Row; i++)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
this.{s_vectorComponents[i]} /= other.{s_vectorComponents[i]};");
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}
|
||||
sourceBuilder.AppendLine($@"
|
||||
}}
|
||||
#endif");
|
||||
}
|
||||
|
||||
// Modulus
|
||||
sourceBuilder.AppendLine($@"
|
||||
@@ -735,13 +937,14 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
public static partial class VectorInterop
|
||||
{{
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public unsafe static global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}> AsVector{_vectorBitsSize}(this {typeName} value)
|
||||
public unsafe static global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}> AsVector{_vectorBitsSize}(this in {typeName} value)
|
||||
{{");
|
||||
|
||||
if (typeInfo.Row == 4)
|
||||
if (CanUseVectorStorage)
|
||||
{
|
||||
sourceBuilder.Append($@"
|
||||
return global::System.Runtime.CompilerServices.Unsafe.BitCast<{typeName}, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}>>(value);");
|
||||
ref var v = ref global::System.Runtime.CompilerServices.Unsafe.AsRef(in value);
|
||||
return global::System.Runtime.CompilerServices.Unsafe.As<{typeName}, global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}>>(ref v);");
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -754,11 +957,23 @@ namespace Misaki.HighPerformance.Mathematics.CodeGen.Generators
|
||||
{INLINE_METHOD_ATTRIBUTE}
|
||||
public unsafe static {typeName} As{typeSimpleName}(this global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}> value)
|
||||
{{");
|
||||
sourceBuilder.AppendLine($@"
|
||||
if (CanUseVectorStorage)
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
var result = default({typeName});
|
||||
result.__v = value;
|
||||
return result;
|
||||
}}
|
||||
}}");
|
||||
}
|
||||
else
|
||||
{
|
||||
sourceBuilder.AppendLine($@"
|
||||
ref var address = ref global::System.Runtime.CompilerServices.Unsafe.As<global::System.Runtime.Intrinsics.Vector{_vectorBitsSize}<{componentType}>, byte>(ref value);
|
||||
return global::System.Runtime.CompilerServices.Unsafe.ReadUnaligned<{typeName}>(ref address);
|
||||
}}
|
||||
}}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateMathMethod()
|
||||
|
||||
Reference in New Issue
Block a user