using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD;
public static unsafe class WideLane
{
internal static readonly uint* s_pShuffleTable512_32bit;
internal static readonly ulong* s_pShuffleTable512_64bit;
internal static readonly uint* s_pShuffleTable256_32bit;
internal static readonly ulong* s_pShuffleTable256_64bit;
internal static readonly uint* s_pShuffleTable128_32bit;
internal static readonly ulong* s_pShuffleTable128_64bit;
///
/// Gets whether WideLane is supported on the current hardware.
///
public static bool IsSupported => Vector.IsHardwareAccelerated;
static WideLane()
{
s_pShuffleTable512_32bit = ShuffleTableGenerator.ComputeShuffleTable512_32Bit();
s_pShuffleTable512_64bit = ShuffleTableGenerator.ComputeShuffleTable512_64Bit();
s_pShuffleTable256_32bit = ShuffleTableGenerator.ComputeShuffleTable256_32Bit();
s_pShuffleTable256_64bit = ShuffleTableGenerator.ComputeShuffleTable256_64Bit();
s_pShuffleTable128_32bit = ShuffleTableGenerator.ComputeShuffleTable128_32Bit();
s_pShuffleTable128_64bit = ShuffleTableGenerator.ComputeShuffleTable128_64Bit();
}
}
[StructLayout(LayoutKind.Sequential)]
public readonly unsafe partial struct WideLane : ISPMDLane, TNumber>
where TNumber : unmanaged, INumber, IBinaryNumber, IMinMaxValue, IBitwiseOperators
{
private static readonly Vector s_indices;
public readonly Vector value;
public static int LaneWidth
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => Vector.Count;
}
public static WideLane Zero
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane(Vector.Zero);
}
public static WideLane One
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane(Vector.One);
}
public static WideLane MinValue
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => Create(TNumber.MinValue);
}
public static WideLane MaxValue
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => Create(TNumber.MaxValue);
}
public static WideLane AllBitsSet
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => Create(TNumber.AllBitsSet);
}
public readonly TNumber this[int index]
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => value[index];
}
static WideLane()
{
var pValues = stackalloc TNumber[LaneWidth];
for (var i = 0; i < LaneWidth; i++)
{
pValues[i] = TNumber.CreateTruncating(i);
}
s_indices = Vector.Load(pValues);
}
public WideLane(Vector value)
{
this.value = value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector VectorFloor(Vector vector)
{
if (typeof(TNumber) == typeof(float))
{
ref var v = ref Unsafe.As, Vector>(ref vector);
var floored = Vector.Floor(v);
return Unsafe.As, Vector>(ref floored);
}
else if (typeof(TNumber) == typeof(double))
{
ref var v = ref Unsafe.As, Vector>(ref vector);
var floored = Vector.Floor(v);
return Unsafe.As, Vector>(ref floored);
}
return vector;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector VectorTruncate(Vector vector)
{
if (typeof(TNumber) == typeof(float))
{
ref var v = ref Unsafe.As, Vector>(ref vector);
var truncated = Vector.Truncate(v);
return Unsafe.As, Vector>(ref truncated);
}
else if (typeof(TNumber) == typeof(double))
{
ref var v = ref Unsafe.As, Vector>(ref vector);
var truncated = Vector.Truncate(v);
return Unsafe.As, Vector>(ref truncated);
}
return vector;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Create(TNumber value)
{
return new WideLane(Vector.Create(value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Create(params ReadOnlySpan values)
{
return new WideLane(Vector.Create(values));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Create(Vector value)
{
return new WideLane(value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Sequence(TNumber start, TNumber step)
{
if (LaneWidth == Vector512.Count)
{
var v = Vector512.CreateSequence(start, step);
return Unsafe.As, WideLane>(ref v);
}
else if (LaneWidth == Vector256.Count)
{
var v = Vector256.CreateSequence(start, step);
return Unsafe.As, WideLane>(ref v);
}
else if (LaneWidth == Vector128.Count)
{
var v = Vector128.CreateSequence(start, step);
return Unsafe.As, WideLane>(ref v);
}
else if (LaneWidth == Vector64.Count)
{
var v = Vector64.CreateSequence(start, step);
return Unsafe.As, WideLane>(ref v);
}
else
{
return new WideLane(Vector.Create(start) + (Vector.Create(step) * s_indices));
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Load(ref TNumber value)
{
return new WideLane(Vector.LoadUnsafe(ref value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Load(TNumber* pValue)
{
return new WideLane(Vector.Load(pValue));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane MaskLoad(WideLane mask, ref TNumber value)
{
var vector = Vector.LoadUnsafe(ref value);
return new WideLane(Vector.ConditionalSelect(mask.value, vector, Vector.Zero));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane MaskLoad(WideLane mask, TNumber* pValue)
{
return MaskLoad(mask, ref Unsafe.AsRef(pValue));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Gather(TNumber* pData, WideLane indices, int scale)
{
return Gather(ref Unsafe.AsRef(pData), indices, scale);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Gather(TNumber* pData, int* pIndices, int scale)
{
return Gather(ref Unsafe.AsRef(pData), ref Unsafe.AsRef(pIndices), scale);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Gather(ref TNumber baseAddress, WideLane indices, int scale)
{
Unsafe.SkipInit(out Vector result);
var pResult = (TNumber*)&result;
var pIndices = (TNumber*)&indices;
var count = Vector.Count;
for (var i = 0; i < count; i++)
{
var idx = int.CreateTruncating(pIndices[i]);
pResult[i] = Unsafe.Add(ref baseAddress, idx * scale / sizeof(TNumber));
}
return new WideLane(result);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Gather(ref TNumber baseAddress, ref int baseIndex, int scale)
{
Unsafe.SkipInit(out Vector result);
var pResult = (TNumber*)&result;
var count = Vector.Count;
for (var i = 0; i < count; i++)
{
pResult[i] = Unsafe.Add(ref baseAddress, Unsafe.Add(ref baseIndex, i) * scale / sizeof(TNumber));
}
return new WideLane(result);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void Store(ref TNumber destination)
{
value.StoreUnsafe(ref destination);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void Store(TNumber* pDestination)
{
value.Store(pDestination);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int CompressStore(WideLane mask, ref TNumber destination)
{
if (LaneWidth == Vector512.Count && Vector512.IsHardwareAccelerated)
{
if (sizeof(TNumber) == 4)
{
ref var vec = ref Unsafe.As, Vector512>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector512>(ref mask);
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 16) because each control vector has 16 elements
var shuffle = Vector512.Load(WideLane.s_pShuffleTable512_32bit + (moveMask * 16));
var compressed = Vector512.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
if (sizeof(TNumber) == 8)
{
ref var vec = ref Unsafe.As, Vector512>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector512>(ref mask);
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 8) because each control vector has 8 elements
var shuffle = Vector512.Load(WideLane.s_pShuffleTable512_64bit + (moveMask * 8));
var compressed = Vector512.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
}
else if (LaneWidth == Vector256.Count && Vector256.IsHardwareAccelerated)
{
if (sizeof(TNumber) == 4)
{
ref var vec = ref Unsafe.As, Vector256>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector256>(ref mask);
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 8) because each control vector has 8 elements
var shuffle = Vector256.Load(WideLane.s_pShuffleTable256_32bit + (moveMask * 8));
var compressed = Vector256.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
if (sizeof(TNumber) == 8)
{
ref var vec = ref Unsafe.As, Vector256>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector256>(ref mask);
// For 64-bit, ExtractMostSignificantBits only populates 4 bits (0-15)
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 4) because each control vector has 4 elements
var shuffle = Vector256.Load(WideLane.s_pShuffleTable256_64bit + (moveMask * 4));
var compressed = Vector256.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
}
else if (LaneWidth == Vector128.Count && Vector128.IsHardwareAccelerated)
{
if (sizeof(TNumber) == 4)
{
ref var vec = ref Unsafe.As, Vector128>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector128>(ref mask);
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 4) because each control vector has 4 elements
var shuffle = Vector128.Load(WideLane.s_pShuffleTable128_32bit + (moveMask * 4));
var compressed = Vector128.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
if (sizeof(TNumber) == 8)
{
ref var vec = ref Unsafe.As, Vector128>(ref Unsafe.AsRef(in this));
var m = Unsafe.As, Vector128>(ref mask);
var moveMask = m.ExtractMostSignificantBits();
// Offset is (moveMask * 2) because each control vector has 2 elements
var shuffle = Vector128.Load(WideLane.s_pShuffleTable128_64bit + (moveMask * 2));
var compressed = Vector128.Shuffle(vec, shuffle);
compressed.StoreUnsafe(ref Unsafe.As(ref destination));
return BitOperations.PopCount(moveMask);
}
}
// This is slow, but correct on ANY hardware.
// Check sign bit of the mask lane
var count = 0;
for (var i = 0; i < LaneWidth; i++)
{
if (mask.value[i] == TNumber.AllBitsSet)
{
Unsafe.Add(ref destination, count++) = value[i];
}
}
return count;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int CompressStore(WideLane mask, TNumber* pDestination)
{
return CompressStore(mask, ref Unsafe.AsRef(pDestination));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly Vector AsVector()
{
return value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly TNumber* GetUnsafePtr()
{
return (TNumber*)Unsafe.AsPointer(ref Unsafe.AsRef(in value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public TOther BitCast()
where TOther : ISPMDLane
where TOtherNumber : unmanaged, INumber, IBinaryNumber, IMinMaxValue, IBitwiseOperators
{
return Unsafe.BitCast, TOther>(this);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator +(WideLane a, WideLane b)
{
return new WideLane(a.value + b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator -(WideLane a, WideLane b)
{
return new WideLane(a.value - b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator *(WideLane a, WideLane b)
{
return new WideLane(a.value * b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator /(WideLane a, WideLane b)
{
return new WideLane(a.value / b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator %(WideLane a, WideLane b)
{
return new WideLane(a.value - VectorFloor(a.value / b.value) * b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator -(WideLane a)
{
return new WideLane(-a.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator &(WideLane a, WideLane b)
{
return new WideLane(a.value & b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator |(WideLane a, WideLane b)
{
return new WideLane(a.value | b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator ^(WideLane a, WideLane b)
{
return new WideLane(a.value ^ b.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator ~(WideLane a)
{
return new WideLane(~a.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator ==(WideLane a, WideLane b)
{
return Equal(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator !=(WideLane a, WideLane b)
{
return ~Equal(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator >(WideLane a, WideLane b)
{
return GreaterThan(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator >=(WideLane a, WideLane b)
{
return GreaterThanOrEqual(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator <(WideLane a, WideLane b)
{
return LessThan(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane operator <=(WideLane a, WideLane b)
{
return LessThanOrEqual(a, b);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static implicit operator WideLane(TNumber value)
{
return Create(value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Abs(WideLane value)
{
return new WideLane(Vector.Abs(value.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Floor(WideLane value)
{
if (typeof(TNumber) == typeof(float))
{
ref var v = ref Unsafe.As, Vector>(ref value);
var floored = Vector.Floor(v);
return new WideLane(Unsafe.As, Vector>(ref floored));
}
else if (typeof(TNumber) == typeof(double))
{
ref var v = ref Unsafe.As, Vector>(ref value);
var floored = Vector.Floor(v);
return new WideLane(Unsafe.As, Vector>(ref floored));
}
return value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Frac(WideLane value)
{
return new WideLane(value.value - VectorFloor(value.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Sqrt(WideLane value)
{
return new WideLane(Vector.SquareRoot(value.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Lerp(WideLane a, WideLane b, WideLane t)
{
return new WideLane(a.value + (b.value - a.value) * t.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane MultipleAdd(WideLane a, WideLane b, WideLane c)
{
if (typeof(TNumber) == typeof(float))
{
ref var va = ref Unsafe.As, Vector>(ref a);
ref var vb = ref Unsafe.As, Vector>(ref b);
ref var vc = ref Unsafe.As, Vector>(ref c);
var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane(Unsafe.As, Vector>(ref result));
}
else if (typeof(TNumber) == typeof(double))
{
ref var va = ref Unsafe.As, Vector>(ref a);
ref var vb = ref Unsafe.As, Vector>(ref b);
ref var vc = ref Unsafe.As, Vector>(ref c);
var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane(Unsafe.As, Vector>(ref result));
}
return new WideLane((a.value * b.value) + c.value);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Min(WideLane a, WideLane b)
{
return new WideLane(Vector.Min(a.value, b.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Max(WideLane a, WideLane b)
{
return new WideLane(Vector.Max(a.value, b.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Clamp(WideLane value, WideLane min, WideLane max)
{
return new WideLane(Vector.Clamp(value.value, min.value, max.value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Saturate(WideLane value)
{
return Clamp(value, Create(TNumber.Zero), Create(TNumber.One));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane Sin(WideLane value)
{
#if MHP_FASTMATH
var invPi = Create(TNumber.CreateTruncating(0.318309886f)); // 1 / PI
var x_sin = value;
var y_sin = x_sin * invPi;
var k_sin = Round(y_sin);
var z_sin = y_sin - k_sin;
var half = Create(TNumber.CreateTruncating(0.5f));
var two = Create(TNumber.CreateTruncating(2.0f));
var k_even_sin = Round(k_sin * half) * two;
var sign_sin = One - two * Abs(k_sin - k_even_sin);
var c1 = Create(TNumber.CreateTruncating(3.14159265f)); // PI
var c3 = Create(TNumber.CreateTruncating(-5.16771278f)); // -PI^3 / 6
var c5 = Create(TNumber.CreateTruncating(2.55016404f)); // PI^5 / 120
var c7 = Create(TNumber.CreateTruncating(-0.59926453f)); // -PI^7 / 5040
var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880
var z2_sin = z_sin * z_sin;
var poly_sin = MultipleAdd(z2_sin, c9, c7); // c7 + c9*z^2
poly_sin = MultipleAdd(z2_sin, poly_sin, c5); // c5 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c3); // c3 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c1); // c1 + z^2*(...)
poly_sin = z_sin * poly_sin; // z * (...)
return poly_sin * sign_sin;
#else
if (typeof(TNumber) == typeof(float))
{
ref var v = ref Unsafe.As