From fd2d60c8f15b8be4241616250c0483644d079838 Mon Sep 17 00:00:00 2001 From: Misaki Date: Wed, 6 May 2026 19:20:15 +0900 Subject: [PATCH] Refactor vector API codegen and WideLane conversions - Introduce IVectorAPIContext abstraction and supporting types for vectorized code generation - Add Avx2APIContext and UtilityTemplate for AVX2-specific code emission - Dynamically generate AVX2 sine methods in AVX2Rewriter - Refactor WideLane to use Unsafe.BitCast for all Vector conversions - Update all WideLane operators and math methods to use Unsafe.BitCast - Change MultiplyAdd parameter names for clarity - Remove static indices field in favor of Vector.Indices - Add implicit conversion from Vector to WideLane - Update tests and program files for compatibility --- .../AVX2Rewriter.cs | 34 ++-- .../UtilityTemplate.cs | 58 ++++++ .../VectorAPI/Avx2APIContext.cs | 117 ++++++++++++ .../VectorAPI/IVectorAPIContext.cs | 179 ++++++++++++++++++ Misaki.HighPerformance.HPC/ISPMDLane.cs | 8 +- Misaki.HighPerformance.HPC/WideLane.cs | 119 ++++++------ Misaki.HighPerformance.Test/Program.cs | 2 + .../UnitTest/Jobs/SPMDTest.cs | 6 - 8 files changed, 439 insertions(+), 84 deletions(-) create mode 100644 Misaki.HighPerformance.HPC.Generator/UtilityTemplate.cs create mode 100644 Misaki.HighPerformance.HPC.Generator/VectorAPI/Avx2APIContext.cs create mode 100644 Misaki.HighPerformance.HPC.Generator/VectorAPI/IVectorAPIContext.cs diff --git a/Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs b/Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs index fd88dbf..e4415c7 100644 --- a/Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs +++ b/Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs @@ -1,5 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Misaki.HighPerformance.HPC.Generator.VectorAPI; using System; namespace Misaki.HighPerformance.HPC.Generator @@ -11,28 +12,39 @@ namespace Misaki.HighPerformance.HPC.Generator { context.RegisterPostInitializationOutput(static ctx => { - var source = @" + var api = new Avx2APIContext(); + + var sinFloat_standard = UtilityTemplate.SinFloat_Standard(api); + var sinFloat_fast = UtilityTemplate.SinFloat_Fast(api); + + var source = @$" using System; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; namespace Misaki.HighPerformance.HPC -{ +{{ public static class AVX2Utility - { + {{ + [MethodImpl(MethodImplOptions.NoInlining)] +{sinFloat_standard.GetFullCode(" ")} + + [MethodImpl(MethodImplOptions.AggressiveInlining)] +{sinFloat_fast.GetFullCode(" ")} + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector256 Asin(Vector256 value) - { + {{ // asin(value) = pi/2 - acos(value) var piOver2 = Vector256.Create(MathF.PI / 2.0f); return Avx2.Subtract(piOver2, Acos(value)); - } + }} [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector256 Acos(Vector256 value) - { + {{ // 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3) // value < 0 : acos(value) = pi - acos(-value) @@ -54,11 +66,11 @@ namespace Misaki.HighPerformance.HPC var isNegative = Avx2.CompareLessThan(value, Vector256.Zero); return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative); - } + }} [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector256 Atan2(Vector256 y, Vector256 x) - { + {{ var absX = Vector256.Abs(x); var absY = Vector256.Abs(y); @@ -103,9 +115,9 @@ namespace Misaki.HighPerformance.HPC // (This works because our ratio logic effectively computed atan(|y|/|value|) above) var negativeResult = Avx2.Subtract(Vector256.Zero, result); return Avx2.BlendVariable(negativeResult, result, yLtZero); - } - } -}"; + }} + }} +}}"; ctx.AddSource("AVX2Utility.g.cs", source); }); diff --git a/Misaki.HighPerformance.HPC.Generator/UtilityTemplate.cs b/Misaki.HighPerformance.HPC.Generator/UtilityTemplate.cs new file mode 100644 index 0000000..10ba89b --- /dev/null +++ b/Misaki.HighPerformance.HPC.Generator/UtilityTemplate.cs @@ -0,0 +1,58 @@ +using Misaki.HighPerformance.HPC.Generator.VectorAPI; + +namespace Misaki.HighPerformance.HPC.Generator +{ + internal static class UtilityTemplate + { + public static Method SinFloat_Standard(IVectorAPIContext api) + { + var body = api.Return(api.Call("Sin", "value")); + return new Method( + modifier: "public static", + returnType: api.GetVectorType(), + name: $"SinFloat_Standard", + parameters: new[] { $"{api.GetVectorType()} value" }, + body: body); + } + + public static Method SinFloat_Fast(IVectorAPIContext api) + { + var invPi = api.Create("0.318309886f").Assign(); + + var x_sin = new Expression(api, "value").Assign(); + var y_sin = api.Multiply(x_sin, invPi).Assign(); + var k_sin = api.Round(y_sin).Assign(); + var z_sin = api.Subtract(y_sin, k_sin).Assign(); + + var half = api.Create("0.5f").Assign(); + var two = api.Create("2.0f").Assign(); + + var k_even_sin = (api.Round(k_sin * half) * two).Assign(); + var sign_sin = (api.One() - two * api.Abs(k_sin - k_even_sin)).Assign(); + + var c1 = api.Create("3.14159265f").Assign(); + var c3 = api.Create("-5.16771278f").Assign(); + var c5 = api.Create("2.55016404f").Assign(); + var c7 = api.Create("-0.59926453f").Assign(); + var c9 = api.Create("0.08214589f").Assign(); + + var z2_sin = (z_sin * z_sin).Assign(); + var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign(); + + var poly_sin_name = api.LastAssignedVariable; + poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false); + poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false); + poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false); + poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false); + + var body = api.Return(poly_sin * sign_sin); + + return new Method( + modifier: "public static", + returnType: api.GetVectorType(), + name: $"SinFloat_Fast", + parameters: new[] { $"{api.GetVectorType()} value" }, + body: body); + } + } +} diff --git a/Misaki.HighPerformance.HPC.Generator/VectorAPI/Avx2APIContext.cs b/Misaki.HighPerformance.HPC.Generator/VectorAPI/Avx2APIContext.cs new file mode 100644 index 0000000..d68e8fc --- /dev/null +++ b/Misaki.HighPerformance.HPC.Generator/VectorAPI/Avx2APIContext.cs @@ -0,0 +1,117 @@ +using System; +using System.Collections.Generic; + +namespace Misaki.HighPerformance.HPC.Generator.VectorAPI +{ + internal class Avx2APIContext : IVectorAPIContext + { + private readonly List _statements = new(); + private int _varCount = 0; + private string? _lastAssignedVariable; + + public string? LastAssignedVariable => _lastAssignedVariable; + + public string GetVectorType() + { + return "Vector256"; + } + + public string GetVectorType() + { + return $"Vector256<{VectorAPIContext.GetTypeName()}>"; + } + + public Expression Call(string methodName, params string[] args) + { + return new Expression(this, $"{GetVectorType()}.{methodName}({string.Join(", ", args)})"); + } + + public Expression Assign(Expression expr, string? varName = null, bool isNew = true) + { + varName ??= $"v{_varCount++}"; + + var statement = isNew ? $"var {varName} = {expr.Code};" : $"{varName} = {expr.Code};"; + _statements.Add(statement); + _lastAssignedVariable = varName; + + expr.Clear(); + return new Expression(this, varName); + } + + public Code Return(Expression expr) + { + var statement = $"return {expr.Code};"; + _statements.Add(statement); + expr.Clear(); + + var fullCode = new Code(_statements); + Reset(); + + return fullCode; + } + + + public Expression Create(string value) + { + return new Expression(this, $"{GetVectorType()}.Create({value})"); + } + + public Expression Zero() + { + return new Expression(this, $"{GetVectorType()}.Zero"); + } + + public Expression One() + { + return new Expression(this, $"{GetVectorType()}.One"); + } + + public Expression Count() + { + return new Expression(this, $"{GetVectorType()}.Count"); + } + + + public Expression Add(Expression a, Expression b) + { + return new Expression(this, $"Avx2.Add({a}, {b})"); + } + + public Expression Multiply(Expression a, Expression b) + { + return new Expression(this, $"Avx2.Multiply({a}, {b})"); + } + + public Expression Subtract(Expression a, Expression b) + { + return new Expression(this, $"Avx2.Subtract({a}, {b})"); + } + + public Expression Divide(Expression a, Expression b) + { + return new Expression(this, $"Avx2.Divide({a}, {b})"); + } + + public Expression MultiplyAdd(Expression left, Expression right, Expression addend) + { + return new Expression(this, $"Fma.MultiplyAdd({left}, {right}, {addend})"); + } + + + public Expression Round(Expression value) + { + return new Expression(this, $"Avx2.RoundToNearestInteger({value})"); + } + + public Expression Abs(Expression value) + { + return new Expression(this, $"{GetVectorType()}.Abs({value})"); + } + + public void Reset() + { + _statements.Clear(); + _varCount = 0; + } + } +} diff --git a/Misaki.HighPerformance.HPC.Generator/VectorAPI/IVectorAPIContext.cs b/Misaki.HighPerformance.HPC.Generator/VectorAPI/IVectorAPIContext.cs new file mode 100644 index 0000000..37eb1a8 --- /dev/null +++ b/Misaki.HighPerformance.HPC.Generator/VectorAPI/IVectorAPIContext.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Misaki.HighPerformance.HPC.Generator.VectorAPI +{ + internal class Expression + { + public IVectorAPIContext API + { + get; + } + + public string Code + { + get; private set; + } + + public Expression(IVectorAPIContext api, string code) + { + API = api; + Code = code; + } + + public Expression Assign(string? varName = null, bool isNew = true) + { + return API.Assign(this, varName, isNew); + } + + public void Clear() + { + Code = string.Empty; + } + + public override string ToString() + { + return Code; + } + + public static Expression operator +(Expression a, Expression b) + { + return a.API.Add(a, b); + } + + public static Expression operator -(Expression a, Expression b) + { + return a.API.Subtract(a, b); + } + + public static Expression operator *(Expression a, Expression b) + { + return a.API.Multiply(a, b); + } + + public static Expression operator /(Expression a, Expression b) + { + return a.API.Divide(a, b); + } + } + + internal record Code + { + private readonly string[] _statements; + + public Code(IEnumerable statements) + { + _statements = statements.ToArray(); + } + + public string GetFullCode(string lineIndentation) + { + var sb = new StringBuilder(); + foreach (var stmt in _statements) + { + sb.AppendLine(lineIndentation + stmt); + } + + return sb.ToString(); + } + } + + internal record Method + { + public string Modifier + { + get; + } + + public string ReturnType + { + get; + } + + public string Name + { + get; + } + + public string[] Parameters + { + get; + } + + public Code Body + { + get; + } + + public Method(string modifier, string returnType, string name, string[] parameters, Code body) + { + Modifier = modifier; + ReturnType = returnType; + Name = name; + Parameters = parameters; + Body = body; + } + + public string GetFullCode(string lineIndentation) + { + var sb = new StringBuilder(); + sb.AppendLine(lineIndentation + $"{Modifier} {ReturnType} {Name}({string.Join(", ", Parameters)})"); + sb.AppendLine(lineIndentation + "{"); + sb.Append(Body.GetFullCode(lineIndentation + " ")); + sb.AppendLine(lineIndentation + "}"); + return sb.ToString(); + } + } + + internal interface IVectorAPIContext + { + string? LastAssignedVariable + { + get; + } + + string GetVectorType(); + string GetVectorType(); + + Expression Call(string methodName, params string[] args); + Expression Assign(Expression expr, string? varName = null, bool isNew = true); + Code Return(Expression expr); + + Expression Create(string value); + Expression Zero(); + Expression One(); + Expression Count(); + + Expression Add(Expression a, Expression b); + Expression Subtract(Expression a, Expression b); + Expression Multiply(Expression a, Expression b); + Expression Divide(Expression a, Expression b); + Expression MultiplyAdd(Expression left, Expression right, Expression addend); + + Expression Round(Expression value); + Expression Abs(Expression value); + + void Reset(); + } + + internal static class VectorAPIContext + { + public static string GetTypeName() + { + return typeof(T) switch + { + _ when typeof(T) == typeof(float) => "float", + _ when typeof(T) == typeof(double) => "double", + _ when typeof(T) == typeof(byte) => "byte", + _ when typeof(T) == typeof(short) => "short", + _ when typeof(T) == typeof(int) => "int", + _ when typeof(T) == typeof(uint) => "uint", + _ when typeof(T) == typeof(long) => "long", + _ when typeof(T) == typeof(ulong) => "ulong", + _ => throw new NotSupportedException($"Type {typeof(T)} is not supported in vector operations.") + }; + } + } +} diff --git a/Misaki.HighPerformance.HPC/ISPMDLane.cs b/Misaki.HighPerformance.HPC/ISPMDLane.cs index a456d99..b380920 100644 --- a/Misaki.HighPerformance.HPC/ISPMDLane.cs +++ b/Misaki.HighPerformance.HPC/ISPMDLane.cs @@ -416,14 +416,14 @@ public unsafe interface ISPMDLane : ISPMDLane, IEquatable /// /// Computes a * b + c element-wise. /// - /// The first multiplier. - /// The second multiplier. - /// The addend. + /// The first multiplier. + /// The second multiplier. + /// The addend. /// The result of the fused multiply-add operation. /// /// Float and double implementations should use fused multiply-add instructions when available for both accuracy and performance. /// - static abstract TSelf MultiplyAdd(TSelf a, TSelf b, TSelf c); + static abstract TSelf MultiplyAdd(TSelf left, TSelf right, TSelf addend); /// /// Returns the minimum of the two lane values element-wise. /// diff --git a/Misaki.HighPerformance.HPC/WideLane.cs b/Misaki.HighPerformance.HPC/WideLane.cs index c603ffe..f985c8f 100644 --- a/Misaki.HighPerformance.HPC/WideLane.cs +++ b/Misaki.HighPerformance.HPC/WideLane.cs @@ -40,8 +40,6 @@ public static unsafe class WideLane public readonly unsafe partial struct WideLane : ISPMDLane, TNumber> where TNumber : unmanaged, INumber, IBinaryNumber, IMinMaxValue, IBitwiseOperators { - private static readonly Vector s_indices; - public readonly Vector value; public static int LaneWidth @@ -53,13 +51,13 @@ public readonly unsafe partial struct WideLane : ISPMDLane Zero { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => new WideLane(Vector.Zero); + get => Unsafe.BitCast, WideLane>(Vector.Zero); } public static WideLane One { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => new WideLane(Vector.One); + get => Unsafe.BitCast, WideLane>(Vector.One); } public static WideLane MinValue @@ -86,17 +84,6 @@ public readonly unsafe partial struct WideLane : ISPMDLane value[index]; } - static WideLane() - { - var pValues = stackalloc TNumber[LaneWidth]; - for (var i = 0; i < LaneWidth; i++) - { - pValues[i] = TNumber.CreateTruncating(i); - } - - s_indices = Vector.Load(pValues); - } - public WideLane(Vector value) { this.value = value; @@ -145,19 +132,19 @@ public readonly unsafe partial struct WideLane : ISPMDLane Create(TNumber value) { - return new WideLane(Vector.Create(value)); + return Unsafe.BitCast, WideLane>(Vector.Create(value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Create(params ReadOnlySpan values) { - return new WideLane(Vector.Create(values)); + return Unsafe.BitCast, WideLane>(Vector.Create(values)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Create(Vector value) { - return new WideLane(value); + return Unsafe.BitCast, WideLane>(value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -185,20 +172,20 @@ public readonly unsafe partial struct WideLane : ISPMDLane(Vector.Create(start) + (Vector.Create(step) * s_indices)); + return Unsafe.BitCast, WideLane>(Vector.Create(start) + (Vector.Create(step) * Vector.Indices)); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Load(ref TNumber value) { - return new WideLane(Vector.LoadUnsafe(ref value)); + return Unsafe.BitCast, WideLane>(Vector.LoadUnsafe(ref value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Load(TNumber* pValue) { - return new WideLane(Vector.Load(pValue)); + return Unsafe.BitCast, WideLane>(Vector.Load(pValue)); } @@ -302,7 +289,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane(result); + return Unsafe.BitCast, WideLane>(result); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -352,7 +339,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane(result); + return Unsafe.BitCast, WideLane>(result); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -419,7 +406,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane(result); + return Unsafe.BitCast, WideLane>(result); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -473,7 +460,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane(result); + return Unsafe.BitCast, WideLane>(result); } @@ -777,61 +764,61 @@ public readonly unsafe partial struct WideLane : ISPMDLane operator +(WideLane a, WideLane b) { - return new WideLane(a.value + b.value); + return Unsafe.BitCast, WideLane>(a.value + b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator -(WideLane a, WideLane b) { - return new WideLane(a.value - b.value); + return Unsafe.BitCast, WideLane>(a.value - b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator *(WideLane a, WideLane b) { - return new WideLane(a.value * b.value); + return Unsafe.BitCast, WideLane>(a.value * b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator /(WideLane a, WideLane b) { - return new WideLane(a.value / b.value); + return Unsafe.BitCast, WideLane>(a.value / b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator %(WideLane a, WideLane b) { - return new WideLane(a.value - VectorFloor(a.value / b.value) * b.value); + return Unsafe.BitCast, WideLane>(a.value - VectorFloor(a.value / b.value) * b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator -(WideLane a) { - return new WideLane(-a.value); + return Unsafe.BitCast, WideLane>(-a.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator &(WideLane a, WideLane b) { - return new WideLane(a.value & b.value); + return Unsafe.BitCast, WideLane>(a.value & b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator |(WideLane a, WideLane b) { - return new WideLane(a.value | b.value); + return Unsafe.BitCast, WideLane>(a.value | b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator ^(WideLane a, WideLane b) { - return new WideLane(a.value ^ b.value); + return Unsafe.BitCast, WideLane>(a.value ^ b.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane operator ~(WideLane a) { - return new WideLane(~a.value); + return Unsafe.BitCast, WideLane>(~a.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -881,7 +868,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane Abs(WideLane value) { - return new WideLane(Vector.Abs(value.value)); + return Unsafe.BitCast, WideLane>(Vector.Abs(value.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -897,7 +884,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane, Vector>(value); var floored = Vector.Floor(v); - return new WideLane(Unsafe.BitCast, Vector>(floored)); + return Unsafe.BitCast, WideLane>(Unsafe.BitCast, Vector>(floored)); } return value; @@ -906,60 +893,60 @@ public readonly unsafe partial struct WideLane : ISPMDLane Frac(WideLane value) { - return new WideLane(value.value - VectorFloor(value.value)); + return Unsafe.BitCast, WideLane>(value.value - VectorFloor(value.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Sqrt(WideLane value) { - return new WideLane(Vector.SquareRoot(value.value)); + return Unsafe.BitCast, WideLane>(Vector.SquareRoot(value.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Lerp(WideLane a, WideLane b, WideLane t) { - return new WideLane(a.value + (b.value - a.value) * t.value); + return Unsafe.BitCast, WideLane>(a.value + (b.value - a.value) * t.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static WideLane MultiplyAdd(WideLane a, WideLane b, WideLane c) + public static WideLane MultiplyAdd(WideLane left, WideLane right, WideLane addend) { if (typeof(TNumber) == typeof(float)) { - var va = Unsafe.BitCast, Vector>(a); - var vb = Unsafe.BitCast, Vector>(b); - var vc = Unsafe.BitCast, Vector>(c); + var va = Unsafe.BitCast, Vector>(left); + var vb = Unsafe.BitCast, Vector>(right); + var vc = Unsafe.BitCast, Vector>(addend); var result = Vector.FusedMultiplyAdd(va, vb, vc); return Unsafe.BitCast, WideLane>(result); } else if (typeof(TNumber) == typeof(double)) { - var va = Unsafe.BitCast, Vector>(a); - var vb = Unsafe.BitCast, Vector>(b); - var vc = Unsafe.BitCast, Vector>(c); + var va = Unsafe.BitCast, Vector>(left); + var vb = Unsafe.BitCast, Vector>(right); + var vc = Unsafe.BitCast, Vector>(addend); var result = Vector.FusedMultiplyAdd(va, vb, vc); return Unsafe.BitCast, WideLane>(result); } - return new WideLane((a.value * b.value) + c.value); + return Unsafe.BitCast, WideLane>((left.value * right.value) + addend.value); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Min(WideLane a, WideLane b) { - return new WideLane(Vector.Min(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.Min(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Max(WideLane a, WideLane b) { - return new WideLane(Vector.Max(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.Max(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Clamp(WideLane value, WideLane min, WideLane max) { - return new WideLane(Vector.Clamp(value.value, min.value, max.value)); + return Unsafe.BitCast, WideLane>(Vector.Clamp(value.value, min.value, max.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1010,7 +997,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane, Vector>(value); var result = Vector.Sin(v); - return new WideLane(Unsafe.BitCast, Vector>(result)); + return Unsafe.BitCast, WideLane>(result); } return value; @@ -1060,7 +1047,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane, Vector>(value); var result = Vector.Cos(v); - return new WideLane(Unsafe.BitCast, Vector>(result)); + return Unsafe.BitCast, WideLane>(result); } return value; @@ -1144,8 +1131,8 @@ public readonly unsafe partial struct WideLane : ISPMDLane, Vector>(value); var (sinResult, cosResult) = Vector.SinCos(v); - sin = new WideLane(Unsafe.BitCast, Vector>(sinResult)); - cos = new WideLane(Unsafe.BitCast, Vector>(cosResult)); + sin = Unsafe.BitCast, WideLane>(sinResult); + cos = Unsafe.BitCast, WideLane>(cosResult); } else { @@ -1296,7 +1283,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane, Vector>(value); var result = Vector.Exp(v); - return new WideLane(Unsafe.BitCast, Vector>(result)); + return Unsafe.BitCast, WideLane>(result); } return value; @@ -1418,7 +1405,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane CopySign(WideLane magnitude, WideLane sign) { - return new WideLane(Vector.CopySign(magnitude.value, sign.value)); + return Unsafe.BitCast, WideLane>(Vector.CopySign(magnitude.value, sign.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1538,7 +1525,7 @@ public readonly unsafe partial struct WideLane : ISPMDLane Select(WideLane conditionMask, WideLane ifTrue, WideLane ifFalse) { - return new WideLane(Vector.ConditionalSelect( + return Unsafe.BitCast, WideLane>(Vector.ConditionalSelect( conditionMask.value, ifTrue.value, ifFalse.value)); @@ -1553,31 +1540,31 @@ public readonly unsafe partial struct WideLane : ISPMDLane GreaterThan(WideLane a, WideLane b) { - return new WideLane(Vector.GreaterThan(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.GreaterThan(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane GreaterThanOrEqual(WideLane a, WideLane b) { - return new WideLane(Vector.GreaterThanOrEqual(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.GreaterThanOrEqual(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane LessThan(WideLane a, WideLane b) { - return new WideLane(Vector.LessThan(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.LessThan(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane LessThanOrEqual(WideLane a, WideLane b) { - return new WideLane(Vector.LessThanOrEqual(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.LessThanOrEqual(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] public static WideLane Equal(WideLane a, WideLane b) { - return new WideLane(Vector.Equals(a.value, b.value)); + return Unsafe.BitCast, WideLane>(Vector.Equals(a.value, b.value)); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1623,4 +1610,10 @@ public readonly unsafe partial struct WideLane : ISPMDLane(Vector v) + { + return Unsafe.BitCast, WideLane>(v); + } } diff --git a/Misaki.HighPerformance.Test/Program.cs b/Misaki.HighPerformance.Test/Program.cs index b84ade9..ed60605 100644 --- a/Misaki.HighPerformance.Test/Program.cs +++ b/Misaki.HighPerformance.Test/Program.cs @@ -5,6 +5,8 @@ using Misaki.HighPerformance.Test.Benchmark; using Misaki.HighPerformance.Test.UnitTest; using Misaki.HighPerformance.Test.UnitTest.Jobs; using System.Buffers; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; //BenchmarkRunner.Run(); diff --git a/Misaki.HighPerformance.Test/UnitTest/Jobs/SPMDTest.cs b/Misaki.HighPerformance.Test/UnitTest/Jobs/SPMDTest.cs index 13b8db0..e556557 100644 --- a/Misaki.HighPerformance.Test/UnitTest/Jobs/SPMDTest.cs +++ b/Misaki.HighPerformance.Test/UnitTest/Jobs/SPMDTest.cs @@ -118,12 +118,6 @@ internal struct DistanceJob : IJobSPMD [TestClass] public partial class SPMDTest { - [HPCompute(TargetInstructionSet.AVX2)] - private static WideLane Test(WideLane a, WideLane b, WideLane c) - { - return WideLane.MultiplyAdd(a, b, c); - } - [HPCompute(TargetInstructionSet.AVX2)] private static (TFloat, TFloat) Test_SPMD(TFloat a, TFloat b, TFloat c) where TFloat : unmanaged, ISPMDLane