Refactor vector API codegen and WideLane conversions

- Introduce IVectorAPIContext abstraction and supporting types for vectorized code generation
- Add Avx2APIContext and UtilityTemplate for AVX2-specific code emission
- Dynamically generate AVX2 sine methods in AVX2Rewriter
- Refactor WideLane<TNumber> to use Unsafe.BitCast for all Vector conversions
- Update all WideLane operators and math methods to use Unsafe.BitCast
- Change MultiplyAdd parameter names for clarity
- Remove static indices field in favor of Vector<TNumber>.Indices
- Add implicit conversion from Vector<TNumber> to WideLane<TNumber>
- Update tests and program files for compatibility
This commit is contained in:
2026-05-06 19:20:15 +09:00
parent c8f78f9d02
commit fd2d60c8f1
8 changed files with 439 additions and 84 deletions

View File

@@ -1,5 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
using System;
namespace Misaki.HighPerformance.HPC.Generator
@@ -11,28 +12,39 @@ namespace Misaki.HighPerformance.HPC.Generator
{
context.RegisterPostInitializationOutput(static ctx =>
{
var source = @"
var api = new Avx2APIContext();
var sinFloat_standard = UtilityTemplate.SinFloat_Standard(api);
var sinFloat_fast = UtilityTemplate.SinFloat_Fast(api);
var source = @$"
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.HPC
{
{{
public static class AVX2Utility
{
{{
[MethodImpl(MethodImplOptions.NoInlining)]
{sinFloat_standard.GetFullCode(" ")}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
{sinFloat_fast.GetFullCode(" ")}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Asin(Vector256<float> value)
{
{{
// asin(value) = pi/2 - acos(value)
var piOver2 = Vector256.Create(MathF.PI / 2.0f);
return Avx2.Subtract(piOver2, Acos(value));
}
}}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Acos(Vector256<float> value)
{
{{
// 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3)
// value < 0 : acos(value) = pi - acos(-value)
@@ -54,11 +66,11 @@ namespace Misaki.HighPerformance.HPC
var isNegative = Avx2.CompareLessThan(value, Vector256<float>.Zero);
return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative);
}
}}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Atan2(Vector256<float> y, Vector256<float> x)
{
{{
var absX = Vector256.Abs(x);
var absY = Vector256.Abs(y);
@@ -103,9 +115,9 @@ namespace Misaki.HighPerformance.HPC
// (This works because our ratio logic effectively computed atan(|y|/|value|) above)
var negativeResult = Avx2.Subtract(Vector256<float>.Zero, result);
return Avx2.BlendVariable(negativeResult, result, yLtZero);
}
}
}";
}}
}}
}}";
ctx.AddSource("AVX2Utility.g.cs", source);
});

View File

@@ -0,0 +1,58 @@
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
namespace Misaki.HighPerformance.HPC.Generator
{
internal static class UtilityTemplate
{
public static Method SinFloat_Standard(IVectorAPIContext api)
{
var body = api.Return(api.Call("Sin", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<float>(),
name: $"SinFloat_Standard",
parameters: new[] { $"{api.GetVectorType<float>()} value" },
body: body);
}
public static Method SinFloat_Fast(IVectorAPIContext api)
{
var invPi = api.Create("0.318309886f").Assign();
var x_sin = new Expression(api, "value").Assign();
var y_sin = api.Multiply(x_sin, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var half = api.Create("0.5f").Assign();
var two = api.Create("2.0f").Assign();
var k_even_sin = (api.Round(k_sin * half) * two).Assign();
var sign_sin = (api.One<float>() - two * api.Abs(k_sin - k_even_sin)).Assign();
var c1 = api.Create("3.14159265f").Assign();
var c3 = api.Create("-5.16771278f").Assign();
var c5 = api.Create("2.55016404f").Assign();
var c7 = api.Create("-0.59926453f").Assign();
var c9 = api.Create("0.08214589f").Assign();
var z2_sin = (z_sin * z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var body = api.Return(poly_sin * sign_sin);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<float>(),
name: $"SinFloat_Fast",
parameters: new[] { $"{api.GetVectorType<float>()} value" },
body: body);
}
}
}

View File

@@ -0,0 +1,117 @@
using System;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
{
internal class Avx2APIContext : IVectorAPIContext
{
private readonly List<string> _statements = new();
private int _varCount = 0;
private string? _lastAssignedVariable;
public string? LastAssignedVariable => _lastAssignedVariable;
public string GetVectorType()
{
return "Vector256";
}
public string GetVectorType<T>()
{
return $"Vector256<{VectorAPIContext.GetTypeName<T>()}>";
}
public Expression Call(string methodName, params string[] args)
{
return new Expression(this, $"{GetVectorType()}.{methodName}({string.Join(", ", args)})");
}
public Expression Assign(Expression expr, string? varName = null, bool isNew = true)
{
varName ??= $"v{_varCount++}";
var statement = isNew ? $"var {varName} = {expr.Code};" : $"{varName} = {expr.Code};";
_statements.Add(statement);
_lastAssignedVariable = varName;
expr.Clear();
return new Expression(this, varName);
}
public Code Return(Expression expr)
{
var statement = $"return {expr.Code};";
_statements.Add(statement);
expr.Clear();
var fullCode = new Code(_statements);
Reset();
return fullCode;
}
public Expression Create(string value)
{
return new Expression(this, $"{GetVectorType()}.Create({value})");
}
public Expression Zero<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Zero");
}
public Expression One<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.One");
}
public Expression Count<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Count");
}
public Expression Add(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Add({a}, {b})");
}
public Expression Multiply(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Multiply({a}, {b})");
}
public Expression Subtract(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Subtract({a}, {b})");
}
public Expression Divide(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Divide({a}, {b})");
}
public Expression MultiplyAdd(Expression left, Expression right, Expression addend)
{
return new Expression(this, $"Fma.MultiplyAdd({left}, {right}, {addend})");
}
public Expression Round(Expression value)
{
return new Expression(this, $"Avx2.RoundToNearestInteger({value})");
}
public Expression Abs(Expression value)
{
return new Expression(this, $"{GetVectorType()}.Abs({value})");
}
public void Reset()
{
_statements.Clear();
_varCount = 0;
}
}
}

View File

@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
{
internal class Expression
{
public IVectorAPIContext API
{
get;
}
public string Code
{
get; private set;
}
public Expression(IVectorAPIContext api, string code)
{
API = api;
Code = code;
}
public Expression Assign(string? varName = null, bool isNew = true)
{
return API.Assign(this, varName, isNew);
}
public void Clear()
{
Code = string.Empty;
}
public override string ToString()
{
return Code;
}
public static Expression operator +(Expression a, Expression b)
{
return a.API.Add(a, b);
}
public static Expression operator -(Expression a, Expression b)
{
return a.API.Subtract(a, b);
}
public static Expression operator *(Expression a, Expression b)
{
return a.API.Multiply(a, b);
}
public static Expression operator /(Expression a, Expression b)
{
return a.API.Divide(a, b);
}
}
internal record Code
{
private readonly string[] _statements;
public Code(IEnumerable<string> statements)
{
_statements = statements.ToArray();
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
foreach (var stmt in _statements)
{
sb.AppendLine(lineIndentation + stmt);
}
return sb.ToString();
}
}
internal record Method
{
public string Modifier
{
get;
}
public string ReturnType
{
get;
}
public string Name
{
get;
}
public string[] Parameters
{
get;
}
public Code Body
{
get;
}
public Method(string modifier, string returnType, string name, string[] parameters, Code body)
{
Modifier = modifier;
ReturnType = returnType;
Name = name;
Parameters = parameters;
Body = body;
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
sb.AppendLine(lineIndentation + $"{Modifier} {ReturnType} {Name}({string.Join(", ", Parameters)})");
sb.AppendLine(lineIndentation + "{");
sb.Append(Body.GetFullCode(lineIndentation + " "));
sb.AppendLine(lineIndentation + "}");
return sb.ToString();
}
}
internal interface IVectorAPIContext
{
string? LastAssignedVariable
{
get;
}
string GetVectorType();
string GetVectorType<T>();
Expression Call(string methodName, params string[] args);
Expression Assign(Expression expr, string? varName = null, bool isNew = true);
Code Return(Expression expr);
Expression Create(string value);
Expression Zero<T>();
Expression One<T>();
Expression Count<T>();
Expression Add(Expression a, Expression b);
Expression Subtract(Expression a, Expression b);
Expression Multiply(Expression a, Expression b);
Expression Divide(Expression a, Expression b);
Expression MultiplyAdd(Expression left, Expression right, Expression addend);
Expression Round(Expression value);
Expression Abs(Expression value);
void Reset();
}
internal static class VectorAPIContext
{
public static string GetTypeName<T>()
{
return typeof(T) switch
{
_ when typeof(T) == typeof(float) => "float",
_ when typeof(T) == typeof(double) => "double",
_ when typeof(T) == typeof(byte) => "byte",
_ when typeof(T) == typeof(short) => "short",
_ when typeof(T) == typeof(int) => "int",
_ when typeof(T) == typeof(uint) => "uint",
_ when typeof(T) == typeof(long) => "long",
_ when typeof(T) == typeof(ulong) => "ulong",
_ => throw new NotSupportedException($"Type {typeof(T)} is not supported in vector operations.")
};
}
}
}