Refactor SPMD to HPC; add SIMD source generators

Major namespace migration from SPMD to HPC across all code, templates, and projects. Introduced Misaki.HighPerformance.HPC.Generator with Roslyn-based source generators for SIMD code (e.g., AVX2), including attribute and method generators. Renamed MultipleAdd to MultiplyAdd in all lanes and updated usages. Added AVX2 utility methods via codegen. Updated tests, benchmarks, and project references to use the new framework. Improved SIMD memory utilities and modernized project files. Removed legacy SPMD project from the solution.
This commit is contained in:
2026-05-06 13:43:58 +09:00
parent d3e497c7d8
commit c8f78f9d02
36 changed files with 895 additions and 130 deletions

View File

@@ -0,0 +1,164 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
namespace Misaki.HighPerformance.HPC.Generator
{
[Generator]
internal class AVX2UtilityGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var source = @"
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.HPC
{
public static class AVX2Utility
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Asin(Vector256<float> value)
{
// asin(value) = pi/2 - acos(value)
var piOver2 = Vector256.Create(MathF.PI / 2.0f);
return Avx2.Subtract(piOver2, Acos(value));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Acos(Vector256<float> value)
{
// 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3)
// value < 0 : acos(value) = pi - acos(-value)
var x = Vector256.Abs(value);
var c0 = Vector256.Create(1.5707288f); // pi/2
var c1 = Vector256.Create(-0.2121144f);
var c2 = Vector256.Create(0.0742610f);
var c3 = Vector256.Create(-0.0187293f);
var term1 = Fma.MultiplyAdd(x, c3, c2);
var term2 = Fma.MultiplyAdd(x, term1, c1);
var poly = Fma.MultiplyAdd(x, term2, c0);
var sqrtTerm = Avx2.Sqrt(Avx2.Subtract(Vector256<float>.One, x));
var result = Avx2.Multiply(poly, sqrtTerm);
var pi = Vector256.Create(MathF.PI);
var isNegative = Avx2.CompareLessThan(value, Vector256<float>.Zero);
return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Atan2(Vector256<float> y, Vector256<float> x)
{
var absX = Vector256.Abs(x);
var absY = Vector256.Abs(y);
// 1. Determine the ratio (input to Atan)
// If |value| > |y|, we are in the ""shallow"" region, ratio = y/value
// If |y| > |value|, we are in the ""steep"" region, ratio = value/y (and we transform result)
var yGtX = Avx2.CompareGreaterThan(absY, absX);
// Select numerator and denominator to ensure ratio is always in [-1, 1]
var num = Avx2.BlendVariable(absX, absY, yGtX);
var den = Avx2.BlendVariable(absY, absX, yGtX);
var t = Avx2.Multiply(num, Avx2.Reciprocal(den)); // t is now in [0, 1]
var t2 = Avx2.Multiply(t, t);
// 2. Polynomial Approximation (Odd function: value * (c1 + c2*value^2))
var c1 = Vector256.Create(0.97239411f);
var c2 = Vector256.Create(-0.19194795f);
// (c1 + c2 * t2)
var poly = Fma.MultiplyAdd(c2, t2, c1);
// result = Avx2.Multiply(t, poly)
var result = Avx2.Multiply(t, poly);
// 3. Reconstruct the angle
// If we swapped value/y (yGtX), the identity is: atan(value/y) = PI/2 - atan(y/value)
var halfPi = Vector256.Create(1.570796327f);
result = Avx2.BlendVariable(halfPi - result, result, yGtX);
// 4. Adjust for Quadrants (Signs)
// If value < 0, we are in quadrants 2 or 3, so we need to add PI
var pi = Vector256.Create(3.141592654f);
var xLtZero = Avx2.CompareLessThan(x, Vector256<float>.Zero);
result = Avx2.BlendVariable(pi - result, result, xLtZero);
// If y < 0, the result should be negative (standard atan2 convention)
// NOTE: This sign flip strategy depends on exact polynomial range mapping,
// but typically just copy the sign of Y to the result.
var yLtZero = Avx2.CompareLessThan(y, Vector256<float>.Zero);
// If original Y was negative, negate the result
// (This works because our ratio logic effectively computed atan(|y|/|value|) above)
var negativeResult = Avx2.Subtract(Vector256<float>.Zero, result);
return Avx2.BlendVariable(negativeResult, result, yLtZero);
}
}
}";
ctx.AddSource("AVX2Utility.g.cs", source);
});
}
}
internal class AVX2Rewriter : HPCRewriter
{
public override string Name => "AVX2";
public override string GetNesessaryUsing()
{
return "using System.Runtime.Intrinsics;\nusing System.Runtime.Intrinsics.X86;";
}
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint)
{
switch (instruction)
{
case SIMDInstruction.Add:
break;
case SIMDInstruction.Subtract:
break;
case SIMDInstruction.Multiply:
break;
case SIMDInstruction.MultiplyAdd:
return new MathExpression
{
Expression = "Fma",
Name = "MultiplyAdd"
};
case SIMDInstruction.Asin:
return new MathExpression
{
Expression = "AVX2Utility",
Name = "Asin"
};
case SIMDInstruction.Atan2:
return new MathExpression
{
Expression = "AVX2Utility",
Name = "Atan2"
};
default:
break;
}
return default;
}
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
{
return;
}
}
}

View File

@@ -0,0 +1,274 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Linq;
namespace Misaki.HighPerformance.HPC.Generator
{
internal enum SIMDInstruction
{
Add,
Subtract,
Multiply,
MultiplyAdd,
Asin,
Atan2,
}
internal abstract class HPCRewriter : CSharpSyntaxRewriter
{
protected struct MathExpression
{
public string Expression
{
get; set;
}
public string Name
{
get; set;
}
}
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet)
{
var rewriters = new List<HPCRewriter>();
// TODO: Add more rewriters for different instruction sets
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
{
rewriters.Add(new AVX2Rewriter());
}
return rewriters;
}
private static readonly Dictionary<string, string> s_remapProperties = new()
{
["LaneWidth"] = "Count",
};
private static readonly Dictionary<string, SIMDInstruction> s_remapMath = new()
{
["Add"] = SIMDInstruction.Add,
["Subtract"] = SIMDInstruction.Subtract,
["Multiply"] = SIMDInstruction.Multiply,
["MultiplyAdd"] = SIMDInstruction.MultiplyAdd,
["Asin"] = SIMDInstruction.Asin,
["Atan2"] = SIMDInstruction.Atan2,
};
protected readonly Dictionary<string, string> spmdTypes = new();
public abstract string Name
{
get;
}
public virtual string GetNesessaryUsing()
{
return string.Empty;
}
public override SyntaxNode? VisitAttributeList(AttributeListSyntax node)
{
var filteredAttributes = SyntaxFactory.SeparatedList(
node.Attributes.Where(a => !a.Name.ToString().Contains("HPCompute"))
);
if (filteredAttributes.Count == 0)
{
return null;
}
return node.WithAttributes(filteredAttributes).WithTriviaFrom(node);
}
public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
{
var typesToRemove = new HashSet<string>();
// 1. Analyze constraints to identify ISPMDLane generics
foreach (var clause in node.ConstraintClauses)
{
var typeNameStr = clause.Name.Identifier.Text;
foreach (var constraint in clause.Constraints.OfType<TypeConstraintSyntax>())
{
if (constraint.Type is GenericNameSyntax genericType &&
genericType.Identifier.Text == "ISPMDLane" &&
genericType.TypeArgumentList.Arguments.Count == 2)
{
var primType = genericType.TypeArgumentList.Arguments[1].ToString();
spmdTypes[typeNameStr] = primType;
typesToRemove.Add(typeNameStr);
}
}
}
var methodToVisit = node;
// 2. Strip type parameter and constraints BEFORE visiting so VisitIdentifierName doesn't touch them
if (typesToRemove.Count > 0)
{
// Remove from <TLane0, ...> generics list
if (methodToVisit.TypeParameterList != null)
{
var newParams = methodToVisit.TypeParameterList.Parameters
.Where(p => !typesToRemove.Contains(p.Identifier.Text))
.ToList();
if (newParams.Any())
{
methodToVisit = methodToVisit.WithTypeParameterList(
SyntaxFactory.TypeParameterList(SyntaxFactory.SeparatedList(newParams))
);
}
else
{
methodToVisit = methodToVisit.WithTypeParameterList(null); // Removes angle brackets entirely
}
}
// Remove the matching `where TLane0 : ...` clause
var newConstraints = methodToVisit.ConstraintClauses
.Where(c => !typesToRemove.Contains(c.Name.Identifier.Text))
.ToList();
methodToVisit = methodToVisit.WithConstraintClauses(
SyntaxFactory.List(newConstraints)
);
}
// 3. Fallback to base to rewrite method arguments, return types, and body via our updated visitors
return base.VisitMethodDeclaration(methodToVisit);
}
public override SyntaxNode? VisitGenericName(GenericNameSyntax node)
{
if (node.Identifier.Text == "WideLane" &&
node.TypeArgumentList.Arguments.Count == 1)
{
return SyntaxFactory.GenericName("Vector256")
.WithTypeArgumentList(node.TypeArgumentList)
.WithTriviaFrom(node);
}
return base.VisitGenericName(node);
}
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
if (spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
{
return SyntaxFactory.GenericName("Vector256")
.WithTypeArgumentList(
SyntaxFactory.TypeArgumentList(
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
SyntaxFactory.IdentifierName(primType))))
.WithTriviaFrom(node);
}
return base.VisitIdentifierName(node);
}
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
{
var isSpmdOrWideLane = false;
var isFloatingPoint = false;
// 1. Check if the left-side expression is WideLane<...> or a tracked generic SPMD type
if (node.Expression is GenericNameSyntax genericName &&
genericName.Identifier.Text == "WideLane" &&
genericName.TypeArgumentList.Arguments.Count == 1)
{
isSpmdOrWideLane = true;
var argTypeStr = genericName.TypeArgumentList.Arguments[0].ToString();
isFloatingPoint = argTypeStr == "float" || argTypeStr == "double";
}
else if (node.Expression is IdentifierNameSyntax idName &&
spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
{
isSpmdOrWideLane = true;
isFloatingPoint = mappedPrimType == "float" || mappedPrimType == "double";
}
if (isSpmdOrWideLane)
{
if (s_remapProperties.TryGetValue(node.Name.Identifier.Text, out var remappedName))
{
// Keep the evaluated left-hand side (TLane0 -> Vector256<float>) but change the property
var rewrittenExpression = (ExpressionSyntax)Visit(node.Expression);
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
rewrittenExpression,
SyntaxFactory.IdentifierName(remappedName)
).WithTriviaFrom(node);
}
if (s_remapMath.TryGetValue(node.Name.Identifier.Text, out var instruction))
{
var rewritResult = RewriteMathExpression(instruction, isFloatingPoint);
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(rewritResult.Expression),
SyntaxFactory.IdentifierName(rewritResult.Name)
).WithTriviaFrom(node);
}
}
return base.VisitMemberAccessExpression(node);
}
public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
{
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
{
bool isSpmdOrWideLane = false;
if (memberAccess.Expression is GenericNameSyntax genericName
&& genericName.Identifier.Text == "WideLane"
&& genericName.TypeArgumentList.Arguments.Count == 1)
{
isSpmdOrWideLane = true;
}
else if (memberAccess.Expression is IdentifierNameSyntax idName
&& spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
{
isSpmdOrWideLane = true;
}
if (isSpmdOrWideLane)
{
var args = node.ArgumentList.Arguments;
var argList = new ArgumentSyntax[args.Count];
for (var i = 0; i < args.Count; i++)
{
argList[i] = (ArgumentSyntax)Visit(args[i]);
}
if (s_remapMath.TryGetValue(memberAccess.Name.Identifier.Text, out var instruction))
{
RewriteMathArguments(instruction, argList);
var arguments = SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(argList));
var newExpression = (ExpressionSyntax)Visit(memberAccess);
return SyntaxFactory.InvocationExpression(newExpression, arguments)
.WithTriviaFrom(node);
}
}
}
return base.VisitInvocationExpression(node);
}
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint);
protected abstract void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs);
}
}

View File

@@ -0,0 +1,108 @@
using Microsoft.CodeAnalysis;
using System;
namespace Misaki.HighPerformance.HPC.Generator
{
internal enum FloatPrecision
{
Standard = 0,
High = 1,
Low = 2,
}
internal enum MathMode
{
Standard = 0,
Fast = 1,
}
[Flags]
internal enum TargetInstructionSet
{
None = 0,
SSE2 = 1 << 0,
SSE4 = 1 << 1,
AVX = 1 << 2,
AVX2 = 1 << 3,
AVX512 = 1 << 4,
}
[Generator]
public class HPComputeAttributeGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var source = @$"
using System;
namespace Misaki.HighPerformance.HPC
{{
public enum FloatPrecision
{{
/// <summary>
/// Compute with an accuracy of 3.5 ULPs (Units in the Last Place). This is the default precision level for floating-point operations.
/// </summary>
Standard = {(int)FloatPrecision.Standard},
/// <summary>
/// Compute with an accuracy of 1 ULP. This level may use more aggressive optimizations that can lead to faster computations but with reduced precision.
/// </summary>
High = {(int)FloatPrecision.High},
/// <summary>
/// Compute with an accuracy that equals or lower than 3.5 ULPs. This level may use the most aggressive optimizations, potentially sacrificing precision for maximum performance.
/// </summary>
Low = {(int)FloatPrecision.Low},
}}
public enum MathMode
{{
/// <summary>
/// Use the default math mode, which balances performance and accuracy. This mode may allow certain optimizations that can lead to faster computations while maintaining reasonable precision.
/// </summary>
Standard = {(int)MathMode.Standard},
/// <summary>
/// Use a fast math mode, which prioritizes performance over accuracy. This mode assumes there are no special cases (like NaNs or infinities) and may allow for more aggressive optimizations.
/// </summary>
Fast = {(int)MathMode.Fast},
}}
[Flags]
public enum TargetInstructionSet
{{
None = {(int)TargetInstructionSet.None},
/// <summary>
/// Streaming SIMD Extensions 2.
/// </summary>
SSE2 = {(int)TargetInstructionSet.SSE2},
/// <summary>
/// Streaming SIMD Extensions 4.2.
/// </summary>
SSE4 = {(int)TargetInstructionSet.SSE4},
/// <summary>
/// Advanced Vector Extensions.
/// </summary>
AVX = {(int)TargetInstructionSet.AVX},
/// <summary>
/// Advanced Vector Extensions 2. Includes FMA, F16C and BMI1/2.
/// </summary>
AVX2 = {(int)TargetInstructionSet.AVX2},
/// <summary>
/// Advanced Vector Extensions 512.
/// </summary>
AVX512 = {(int)TargetInstructionSet.AVX512},
}}
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public sealed class HPComputeAttribute : Attribute
{{
public HPComputeAttribute(TargetInstructionSet instructionSet, FloatPrecision precision = FloatPrecision.Standard, MathMode mode = MathMode.Standard)
{{
}}
}}
}}";
ctx.AddSource("HPComputeAttribute.g.cs", source);
});
}
}
}

View File

@@ -0,0 +1,109 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal class HPComputeMethodInfo
{
public MethodDeclarationSyntax MethodDeclaration
{
get; set;
} = null!;
public IMethodSymbol MethodSymbol
{
get; set;
} = null!;
public TargetInstructionSet InstructionSet
{
get; set;
}
public FloatPrecision Precision
{
get; set;
}
public MathMode Mode
{
get; set;
}
}
[Generator]
public class HPComputeGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var methodDeclarations = context.SyntaxProvider
.ForAttributeWithMetadataName(
"Misaki.HighPerformance.HPC.HPComputeAttribute",
static (n, ct) => n is MethodDeclarationSyntax,
static (ctx, ct) =>
{
var attributes = ctx.Attributes.FirstOrDefault(a => a.AttributeClass?.ToDisplayString() == "Misaki.HighPerformance.HPC.HPComputeAttribute");
if (attributes != null && ctx.TargetSymbol is IMethodSymbol methodSymbol)
{
return new HPComputeMethodInfo
{
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
MethodSymbol = methodSymbol,
InstructionSet = (TargetInstructionSet)attributes.ConstructorArguments[0].Value!,
Precision = (FloatPrecision)attributes.ConstructorArguments[1].Value!,
Mode = (MathMode)attributes.ConstructorArguments[2].Value!,
};
}
return null;
})
.Collect();
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethod);
}
private void GenerateHPCMethod(SourceProductionContext context, ImmutableArray<HPComputeMethodInfo?> array)
{
if (array.IsEmpty)
{
return;
}
foreach (var info in array)
{
if (info == null)
{
continue;
}
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet);
foreach (var writer in rewriters)
{
var rewrittenMethod = (MethodDeclarationSyntax)writer.Visit(info.MethodDeclaration);
var newMethod = rewrittenMethod
.WithIdentifier(SyntaxFactory.Identifier($"{info.MethodDeclaration.Identifier.Text}_{writer.Name}"));
var source = $@"
using Misaki.HighPerformance.HPC;
{writer.GetNesessaryUsing()}
namespace {info.MethodSymbol.ContainingNamespace.ToDisplayString()}
{{
partial class {info.MethodSymbol.ContainingType.Name}
{{
{newMethod.NormalizeWhitespace().ToFullString()}
}}
}}";
context.AddSource($"{info.MethodSymbol.ContainingType.Name}_{info.MethodDeclaration.Identifier.Text}_{writer.Name}.g.cs", SourceText.From(source, Encoding.UTF8));
}
}
}
}
}

View File

@@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable>
<EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<LangVersion>9.0</LangVersion>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="5.3.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.3.0" />
</ItemGroup>
</Project>