Refactor SPMD to HPC; add SIMD source generators
Major namespace migration from SPMD to HPC across all code, templates, and projects. Introduced Misaki.HighPerformance.HPC.Generator with Roslyn-based source generators for SIMD code (e.g., AVX2), including attribute and method generators. Renamed MultipleAdd to MultiplyAdd in all lanes and updated usages. Added AVX2 utility methods via codegen. Updated tests, benchmarks, and project references to use the new framework. Improved SIMD memory utilities and modernized project files. Removed legacy SPMD project from the solution.
This commit is contained in:
164
Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs
Normal file
164
Misaki.HighPerformance.HPC.Generator/AVX2Rewriter.cs
Normal file
@@ -0,0 +1,164 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
[Generator]
|
||||
internal class AVX2UtilityGenerator : IIncrementalGenerator
|
||||
{
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
context.RegisterPostInitializationOutput(static ctx =>
|
||||
{
|
||||
var source = @"
|
||||
using System;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Intrinsics;
|
||||
using System.Runtime.Intrinsics.X86;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC
|
||||
{
|
||||
public static class AVX2Utility
|
||||
{
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Asin(Vector256<float> value)
|
||||
{
|
||||
// asin(value) = pi/2 - acos(value)
|
||||
|
||||
var piOver2 = Vector256.Create(MathF.PI / 2.0f);
|
||||
return Avx2.Subtract(piOver2, Acos(value));
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Acos(Vector256<float> value)
|
||||
{
|
||||
// 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3)
|
||||
// value < 0 : acos(value) = pi - acos(-value)
|
||||
|
||||
var x = Vector256.Abs(value);
|
||||
|
||||
var c0 = Vector256.Create(1.5707288f); // pi/2
|
||||
var c1 = Vector256.Create(-0.2121144f);
|
||||
var c2 = Vector256.Create(0.0742610f);
|
||||
var c3 = Vector256.Create(-0.0187293f);
|
||||
|
||||
var term1 = Fma.MultiplyAdd(x, c3, c2);
|
||||
var term2 = Fma.MultiplyAdd(x, term1, c1);
|
||||
var poly = Fma.MultiplyAdd(x, term2, c0);
|
||||
|
||||
var sqrtTerm = Avx2.Sqrt(Avx2.Subtract(Vector256<float>.One, x));
|
||||
var result = Avx2.Multiply(poly, sqrtTerm);
|
||||
|
||||
var pi = Vector256.Create(MathF.PI);
|
||||
var isNegative = Avx2.CompareLessThan(value, Vector256<float>.Zero);
|
||||
|
||||
return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Atan2(Vector256<float> y, Vector256<float> x)
|
||||
{
|
||||
var absX = Vector256.Abs(x);
|
||||
var absY = Vector256.Abs(y);
|
||||
|
||||
// 1. Determine the ratio (input to Atan)
|
||||
// If |value| > |y|, we are in the ""shallow"" region, ratio = y/value
|
||||
// If |y| > |value|, we are in the ""steep"" region, ratio = value/y (and we transform result)
|
||||
var yGtX = Avx2.CompareGreaterThan(absY, absX);
|
||||
|
||||
// Select numerator and denominator to ensure ratio is always in [-1, 1]
|
||||
var num = Avx2.BlendVariable(absX, absY, yGtX);
|
||||
var den = Avx2.BlendVariable(absY, absX, yGtX);
|
||||
|
||||
var t = Avx2.Multiply(num, Avx2.Reciprocal(den)); // t is now in [0, 1]
|
||||
var t2 = Avx2.Multiply(t, t);
|
||||
|
||||
// 2. Polynomial Approximation (Odd function: value * (c1 + c2*value^2))
|
||||
var c1 = Vector256.Create(0.97239411f);
|
||||
var c2 = Vector256.Create(-0.19194795f);
|
||||
|
||||
// (c1 + c2 * t2)
|
||||
var poly = Fma.MultiplyAdd(c2, t2, c1);
|
||||
|
||||
// result = Avx2.Multiply(t, poly)
|
||||
var result = Avx2.Multiply(t, poly);
|
||||
|
||||
// 3. Reconstruct the angle
|
||||
// If we swapped value/y (yGtX), the identity is: atan(value/y) = PI/2 - atan(y/value)
|
||||
var halfPi = Vector256.Create(1.570796327f);
|
||||
result = Avx2.BlendVariable(halfPi - result, result, yGtX);
|
||||
|
||||
// 4. Adjust for Quadrants (Signs)
|
||||
// If value < 0, we are in quadrants 2 or 3, so we need to add PI
|
||||
var pi = Vector256.Create(3.141592654f);
|
||||
var xLtZero = Avx2.CompareLessThan(x, Vector256<float>.Zero);
|
||||
result = Avx2.BlendVariable(pi - result, result, xLtZero);
|
||||
|
||||
// If y < 0, the result should be negative (standard atan2 convention)
|
||||
// NOTE: This sign flip strategy depends on exact polynomial range mapping,
|
||||
// but typically just copy the sign of Y to the result.
|
||||
var yLtZero = Avx2.CompareLessThan(y, Vector256<float>.Zero);
|
||||
// If original Y was negative, negate the result
|
||||
// (This works because our ratio logic effectively computed atan(|y|/|value|) above)
|
||||
var negativeResult = Avx2.Subtract(Vector256<float>.Zero, result);
|
||||
return Avx2.BlendVariable(negativeResult, result, yLtZero);
|
||||
}
|
||||
}
|
||||
}";
|
||||
|
||||
ctx.AddSource("AVX2Utility.g.cs", source);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
internal class AVX2Rewriter : HPCRewriter
|
||||
{
|
||||
public override string Name => "AVX2";
|
||||
|
||||
public override string GetNesessaryUsing()
|
||||
{
|
||||
return "using System.Runtime.Intrinsics;\nusing System.Runtime.Intrinsics.X86;";
|
||||
}
|
||||
|
||||
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint)
|
||||
{
|
||||
switch (instruction)
|
||||
{
|
||||
case SIMDInstruction.Add:
|
||||
break;
|
||||
case SIMDInstruction.Subtract:
|
||||
break;
|
||||
case SIMDInstruction.Multiply:
|
||||
break;
|
||||
case SIMDInstruction.MultiplyAdd:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Fma",
|
||||
Name = "MultiplyAdd"
|
||||
};
|
||||
case SIMDInstruction.Asin:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "AVX2Utility",
|
||||
Name = "Asin"
|
||||
};
|
||||
case SIMDInstruction.Atan2:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "AVX2Utility",
|
||||
Name = "Atan2"
|
||||
};
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return default;
|
||||
}
|
||||
|
||||
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
274
Misaki.HighPerformance.HPC.Generator/HPCRewriter.cs
Normal file
274
Misaki.HighPerformance.HPC.Generator/HPCRewriter.cs
Normal file
@@ -0,0 +1,274 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal enum SIMDInstruction
|
||||
{
|
||||
Add,
|
||||
Subtract,
|
||||
Multiply,
|
||||
MultiplyAdd,
|
||||
|
||||
Asin,
|
||||
Atan2,
|
||||
}
|
||||
|
||||
internal abstract class HPCRewriter : CSharpSyntaxRewriter
|
||||
{
|
||||
protected struct MathExpression
|
||||
{
|
||||
public string Expression
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public string Name
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
}
|
||||
|
||||
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet)
|
||||
{
|
||||
var rewriters = new List<HPCRewriter>();
|
||||
|
||||
// TODO: Add more rewriters for different instruction sets
|
||||
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
|
||||
{
|
||||
rewriters.Add(new AVX2Rewriter());
|
||||
}
|
||||
|
||||
return rewriters;
|
||||
}
|
||||
|
||||
private static readonly Dictionary<string, string> s_remapProperties = new()
|
||||
{
|
||||
["LaneWidth"] = "Count",
|
||||
};
|
||||
|
||||
private static readonly Dictionary<string, SIMDInstruction> s_remapMath = new()
|
||||
{
|
||||
["Add"] = SIMDInstruction.Add,
|
||||
["Subtract"] = SIMDInstruction.Subtract,
|
||||
["Multiply"] = SIMDInstruction.Multiply,
|
||||
["MultiplyAdd"] = SIMDInstruction.MultiplyAdd,
|
||||
["Asin"] = SIMDInstruction.Asin,
|
||||
["Atan2"] = SIMDInstruction.Atan2,
|
||||
};
|
||||
|
||||
protected readonly Dictionary<string, string> spmdTypes = new();
|
||||
|
||||
public abstract string Name
|
||||
{
|
||||
get;
|
||||
}
|
||||
|
||||
public virtual string GetNesessaryUsing()
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitAttributeList(AttributeListSyntax node)
|
||||
{
|
||||
var filteredAttributes = SyntaxFactory.SeparatedList(
|
||||
node.Attributes.Where(a => !a.Name.ToString().Contains("HPCompute"))
|
||||
);
|
||||
|
||||
if (filteredAttributes.Count == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return node.WithAttributes(filteredAttributes).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
|
||||
{
|
||||
var typesToRemove = new HashSet<string>();
|
||||
|
||||
// 1. Analyze constraints to identify ISPMDLane generics
|
||||
foreach (var clause in node.ConstraintClauses)
|
||||
{
|
||||
var typeNameStr = clause.Name.Identifier.Text;
|
||||
foreach (var constraint in clause.Constraints.OfType<TypeConstraintSyntax>())
|
||||
{
|
||||
if (constraint.Type is GenericNameSyntax genericType &&
|
||||
genericType.Identifier.Text == "ISPMDLane" &&
|
||||
genericType.TypeArgumentList.Arguments.Count == 2)
|
||||
{
|
||||
var primType = genericType.TypeArgumentList.Arguments[1].ToString();
|
||||
spmdTypes[typeNameStr] = primType;
|
||||
typesToRemove.Add(typeNameStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var methodToVisit = node;
|
||||
|
||||
// 2. Strip type parameter and constraints BEFORE visiting so VisitIdentifierName doesn't touch them
|
||||
if (typesToRemove.Count > 0)
|
||||
{
|
||||
// Remove from <TLane0, ...> generics list
|
||||
if (methodToVisit.TypeParameterList != null)
|
||||
{
|
||||
var newParams = methodToVisit.TypeParameterList.Parameters
|
||||
.Where(p => !typesToRemove.Contains(p.Identifier.Text))
|
||||
.ToList();
|
||||
|
||||
if (newParams.Any())
|
||||
{
|
||||
methodToVisit = methodToVisit.WithTypeParameterList(
|
||||
SyntaxFactory.TypeParameterList(SyntaxFactory.SeparatedList(newParams))
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
methodToVisit = methodToVisit.WithTypeParameterList(null); // Removes angle brackets entirely
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the matching `where TLane0 : ...` clause
|
||||
var newConstraints = methodToVisit.ConstraintClauses
|
||||
.Where(c => !typesToRemove.Contains(c.Name.Identifier.Text))
|
||||
.ToList();
|
||||
|
||||
methodToVisit = methodToVisit.WithConstraintClauses(
|
||||
SyntaxFactory.List(newConstraints)
|
||||
);
|
||||
}
|
||||
|
||||
// 3. Fallback to base to rewrite method arguments, return types, and body via our updated visitors
|
||||
return base.VisitMethodDeclaration(methodToVisit);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitGenericName(GenericNameSyntax node)
|
||||
{
|
||||
if (node.Identifier.Text == "WideLane" &&
|
||||
node.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(node.TypeArgumentList)
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitGenericName(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
|
||||
{
|
||||
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
|
||||
if (spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(
|
||||
SyntaxFactory.TypeArgumentList(
|
||||
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
|
||||
SyntaxFactory.IdentifierName(primType))))
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitIdentifierName(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
|
||||
{
|
||||
var isSpmdOrWideLane = false;
|
||||
var isFloatingPoint = false;
|
||||
|
||||
// 1. Check if the left-side expression is WideLane<...> or a tracked generic SPMD type
|
||||
if (node.Expression is GenericNameSyntax genericName &&
|
||||
genericName.Identifier.Text == "WideLane" &&
|
||||
genericName.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
|
||||
var argTypeStr = genericName.TypeArgumentList.Arguments[0].ToString();
|
||||
isFloatingPoint = argTypeStr == "float" || argTypeStr == "double";
|
||||
}
|
||||
else if (node.Expression is IdentifierNameSyntax idName &&
|
||||
spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
isFloatingPoint = mappedPrimType == "float" || mappedPrimType == "double";
|
||||
}
|
||||
|
||||
if (isSpmdOrWideLane)
|
||||
{
|
||||
if (s_remapProperties.TryGetValue(node.Name.Identifier.Text, out var remappedName))
|
||||
{
|
||||
// Keep the evaluated left-hand side (TLane0 -> Vector256<float>) but change the property
|
||||
var rewrittenExpression = (ExpressionSyntax)Visit(node.Expression);
|
||||
|
||||
return SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
rewrittenExpression,
|
||||
SyntaxFactory.IdentifierName(remappedName)
|
||||
).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
if (s_remapMath.TryGetValue(node.Name.Identifier.Text, out var instruction))
|
||||
{
|
||||
var rewritResult = RewriteMathExpression(instruction, isFloatingPoint);
|
||||
return SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
SyntaxFactory.IdentifierName(rewritResult.Expression),
|
||||
SyntaxFactory.IdentifierName(rewritResult.Name)
|
||||
).WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitMemberAccessExpression(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
|
||||
{
|
||||
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
|
||||
{
|
||||
bool isSpmdOrWideLane = false;
|
||||
|
||||
if (memberAccess.Expression is GenericNameSyntax genericName
|
||||
&& genericName.Identifier.Text == "WideLane"
|
||||
&& genericName.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
}
|
||||
else if (memberAccess.Expression is IdentifierNameSyntax idName
|
||||
&& spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
}
|
||||
|
||||
if (isSpmdOrWideLane)
|
||||
{
|
||||
var args = node.ArgumentList.Arguments;
|
||||
var argList = new ArgumentSyntax[args.Count];
|
||||
|
||||
for (var i = 0; i < args.Count; i++)
|
||||
{
|
||||
argList[i] = (ArgumentSyntax)Visit(args[i]);
|
||||
}
|
||||
|
||||
if (s_remapMath.TryGetValue(memberAccess.Name.Identifier.Text, out var instruction))
|
||||
{
|
||||
RewriteMathArguments(instruction, argList);
|
||||
var arguments = SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(argList));
|
||||
|
||||
var newExpression = (ExpressionSyntax)Visit(memberAccess);
|
||||
return SyntaxFactory.InvocationExpression(newExpression, arguments)
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitInvocationExpression(node);
|
||||
}
|
||||
|
||||
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint);
|
||||
protected abstract void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal enum FloatPrecision
|
||||
{
|
||||
Standard = 0,
|
||||
High = 1,
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
internal enum MathMode
|
||||
{
|
||||
Standard = 0,
|
||||
Fast = 1,
|
||||
}
|
||||
|
||||
[Flags]
|
||||
internal enum TargetInstructionSet
|
||||
{
|
||||
None = 0,
|
||||
SSE2 = 1 << 0,
|
||||
SSE4 = 1 << 1,
|
||||
AVX = 1 << 2,
|
||||
AVX2 = 1 << 3,
|
||||
AVX512 = 1 << 4,
|
||||
}
|
||||
|
||||
[Generator]
|
||||
public class HPComputeAttributeGenerator : IIncrementalGenerator
|
||||
{
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
context.RegisterPostInitializationOutput(static ctx =>
|
||||
{
|
||||
var source = @$"
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC
|
||||
{{
|
||||
public enum FloatPrecision
|
||||
{{
|
||||
/// <summary>
|
||||
/// Compute with an accuracy of 3.5 ULPs (Units in the Last Place). This is the default precision level for floating-point operations.
|
||||
/// </summary>
|
||||
Standard = {(int)FloatPrecision.Standard},
|
||||
/// <summary>
|
||||
/// Compute with an accuracy of 1 ULP. This level may use more aggressive optimizations that can lead to faster computations but with reduced precision.
|
||||
/// </summary>
|
||||
High = {(int)FloatPrecision.High},
|
||||
/// <summary>
|
||||
/// Compute with an accuracy that equals or lower than 3.5 ULPs. This level may use the most aggressive optimizations, potentially sacrificing precision for maximum performance.
|
||||
/// </summary>
|
||||
Low = {(int)FloatPrecision.Low},
|
||||
}}
|
||||
|
||||
public enum MathMode
|
||||
{{
|
||||
/// <summary>
|
||||
/// Use the default math mode, which balances performance and accuracy. This mode may allow certain optimizations that can lead to faster computations while maintaining reasonable precision.
|
||||
/// </summary>
|
||||
Standard = {(int)MathMode.Standard},
|
||||
/// <summary>
|
||||
/// Use a fast math mode, which prioritizes performance over accuracy. This mode assumes there are no special cases (like NaNs or infinities) and may allow for more aggressive optimizations.
|
||||
/// </summary>
|
||||
Fast = {(int)MathMode.Fast},
|
||||
}}
|
||||
|
||||
[Flags]
|
||||
public enum TargetInstructionSet
|
||||
{{
|
||||
None = {(int)TargetInstructionSet.None},
|
||||
/// <summary>
|
||||
/// Streaming SIMD Extensions 2.
|
||||
/// </summary>
|
||||
SSE2 = {(int)TargetInstructionSet.SSE2},
|
||||
/// <summary>
|
||||
/// Streaming SIMD Extensions 4.2.
|
||||
/// </summary>
|
||||
SSE4 = {(int)TargetInstructionSet.SSE4},
|
||||
/// <summary>
|
||||
/// Advanced Vector Extensions.
|
||||
/// </summary>
|
||||
AVX = {(int)TargetInstructionSet.AVX},
|
||||
/// <summary>
|
||||
/// Advanced Vector Extensions 2. Includes FMA, F16C and BMI1/2.
|
||||
/// </summary>
|
||||
AVX2 = {(int)TargetInstructionSet.AVX2},
|
||||
/// <summary>
|
||||
/// Advanced Vector Extensions 512.
|
||||
/// </summary>
|
||||
AVX512 = {(int)TargetInstructionSet.AVX512},
|
||||
}}
|
||||
|
||||
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
|
||||
public sealed class HPComputeAttribute : Attribute
|
||||
{{
|
||||
public HPComputeAttribute(TargetInstructionSet instructionSet, FloatPrecision precision = FloatPrecision.Standard, MathMode mode = MathMode.Standard)
|
||||
{{
|
||||
}}
|
||||
}}
|
||||
}}";
|
||||
ctx.AddSource("HPComputeAttribute.g.cs", source);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
109
Misaki.HighPerformance.HPC.Generator/HPComputeGenerator.cs
Normal file
109
Misaki.HighPerformance.HPC.Generator/HPComputeGenerator.cs
Normal file
@@ -0,0 +1,109 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using System;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal class HPComputeMethodInfo
|
||||
{
|
||||
public MethodDeclarationSyntax MethodDeclaration
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public IMethodSymbol MethodSymbol
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public TargetInstructionSet InstructionSet
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public FloatPrecision Precision
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public MathMode Mode
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
}
|
||||
|
||||
[Generator]
|
||||
public class HPComputeGenerator : IIncrementalGenerator
|
||||
{
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
var methodDeclarations = context.SyntaxProvider
|
||||
.ForAttributeWithMetadataName(
|
||||
"Misaki.HighPerformance.HPC.HPComputeAttribute",
|
||||
static (n, ct) => n is MethodDeclarationSyntax,
|
||||
static (ctx, ct) =>
|
||||
{
|
||||
var attributes = ctx.Attributes.FirstOrDefault(a => a.AttributeClass?.ToDisplayString() == "Misaki.HighPerformance.HPC.HPComputeAttribute");
|
||||
if (attributes != null && ctx.TargetSymbol is IMethodSymbol methodSymbol)
|
||||
{
|
||||
return new HPComputeMethodInfo
|
||||
{
|
||||
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
|
||||
MethodSymbol = methodSymbol,
|
||||
InstructionSet = (TargetInstructionSet)attributes.ConstructorArguments[0].Value!,
|
||||
Precision = (FloatPrecision)attributes.ConstructorArguments[1].Value!,
|
||||
Mode = (MathMode)attributes.ConstructorArguments[2].Value!,
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
})
|
||||
.Collect();
|
||||
|
||||
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethod);
|
||||
}
|
||||
|
||||
private void GenerateHPCMethod(SourceProductionContext context, ImmutableArray<HPComputeMethodInfo?> array)
|
||||
{
|
||||
if (array.IsEmpty)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
foreach (var info in array)
|
||||
{
|
||||
if (info == null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet);
|
||||
|
||||
foreach (var writer in rewriters)
|
||||
{
|
||||
var rewrittenMethod = (MethodDeclarationSyntax)writer.Visit(info.MethodDeclaration);
|
||||
var newMethod = rewrittenMethod
|
||||
.WithIdentifier(SyntaxFactory.Identifier($"{info.MethodDeclaration.Identifier.Text}_{writer.Name}"));
|
||||
|
||||
var source = $@"
|
||||
using Misaki.HighPerformance.HPC;
|
||||
{writer.GetNesessaryUsing()}
|
||||
|
||||
namespace {info.MethodSymbol.ContainingNamespace.ToDisplayString()}
|
||||
{{
|
||||
partial class {info.MethodSymbol.ContainingType.Name}
|
||||
{{
|
||||
{newMethod.NormalizeWhitespace().ToFullString()}
|
||||
}}
|
||||
}}";
|
||||
context.AddSource($"{info.MethodSymbol.ContainingType.Name}_{info.MethodDeclaration.Identifier.Text}_{writer.Name}.g.cs", SourceText.From(source, Encoding.UTF8));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
<Nullable>enable</Nullable>
|
||||
<EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
|
||||
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
|
||||
<LangVersion>9.0</LangVersion>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="5.3.0">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.3.0" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
Reference in New Issue
Block a user