Refactor APIContext, add AVX2 sin/cos codegen, FMA rewrite
- Refactored namespaces from .VectorAPI to .APIContext for clarity. - Enhanced Avx2APIContext/IVectorAPIContext to support void returns. - Added GenerateSinCosUtilityMethods for AVX2, emitting vectorized Sin/Cos/SinCos for float/double. - Introduced HPCOptimizerRewriter for advanced SPMD type handling. - Refactored HPCRewriter to use SemanticModel, support FMA pattern rewriting, and delegate SPMD logic. - Updated AVX2Rewriter for new base and improved math mapping. - Made UtilityTemplate generic and type-safe for sin/cos. - Updated NoiseJob3D/NoiseJobVector for [HPCompute] attribute and partial struct. - Fixed solution file project ordering and inclusion.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
|
||||
namespace Misaki.HighPerformance.HPC.Generator.APIContext
|
||||
{
|
||||
internal class Avx2APIContext : IVectorAPIContext
|
||||
{
|
||||
@@ -23,7 +22,7 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
|
||||
|
||||
public Expression Call(string methodName, params string[] args)
|
||||
{
|
||||
return new Expression(this, $"{GetVectorType()}.{methodName}({string.Join(", ", args)})");
|
||||
return new Expression(this, $"{methodName}({string.Join(", ", args)})");
|
||||
}
|
||||
|
||||
public Expression Assign(Expression expr, string? varName = null, bool isNew = true)
|
||||
@@ -38,11 +37,14 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
|
||||
return new Expression(this, varName);
|
||||
}
|
||||
|
||||
public Code Return(Expression expr)
|
||||
public Code Return(Expression? expr)
|
||||
{
|
||||
var statement = $"return {expr.Code};";
|
||||
_statements.Add(statement);
|
||||
expr.Clear();
|
||||
if (expr != null)
|
||||
{
|
||||
var statement = $"return {expr.Code};";
|
||||
_statements.Add(statement);
|
||||
expr.Clear();
|
||||
}
|
||||
|
||||
var fullCode = new Code(_statements);
|
||||
Reset();
|
||||
@@ -3,7 +3,7 @@ using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
|
||||
namespace Misaki.HighPerformance.HPC.Generator.APIContext
|
||||
{
|
||||
internal class Expression
|
||||
{
|
||||
@@ -139,7 +139,7 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
|
||||
|
||||
Expression Call(string methodName, params string[] args);
|
||||
Expression Assign(Expression expr, string? varName = null, bool isNew = true);
|
||||
Code Return(Expression expr);
|
||||
Code Return(Expression? expr);
|
||||
|
||||
Expression Create(string value);
|
||||
Expression Zero<T>();
|
||||
@@ -1,6 +1,6 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
|
||||
using Misaki.HighPerformance.HPC.Generator.APIContext;
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
@@ -14,8 +14,7 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
var api = new Avx2APIContext();
|
||||
|
||||
var sinFloat_standard = UtilityTemplate.SinFloat_Standard(api);
|
||||
var sinFloat_fast = UtilityTemplate.SinFloat_Fast(api);
|
||||
var sinCosMethods = UtilityTemplate.GenerateSinCosUtilityMethods(api, " ");
|
||||
|
||||
var source = @$"
|
||||
using System;
|
||||
@@ -27,95 +26,7 @@ namespace Misaki.HighPerformance.HPC
|
||||
{{
|
||||
public static class AVX2Utility
|
||||
{{
|
||||
[MethodImpl(MethodImplOptions.NoInlining)]
|
||||
{sinFloat_standard.GetFullCode(" ")}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
{sinFloat_fast.GetFullCode(" ")}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Asin(Vector256<float> value)
|
||||
{{
|
||||
// asin(value) = pi/2 - acos(value)
|
||||
|
||||
var piOver2 = Vector256.Create(MathF.PI / 2.0f);
|
||||
return Avx2.Subtract(piOver2, Acos(value));
|
||||
}}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Acos(Vector256<float> value)
|
||||
{{
|
||||
// 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3)
|
||||
// value < 0 : acos(value) = pi - acos(-value)
|
||||
|
||||
var x = Vector256.Abs(value);
|
||||
|
||||
var c0 = Vector256.Create(1.5707288f); // pi/2
|
||||
var c1 = Vector256.Create(-0.2121144f);
|
||||
var c2 = Vector256.Create(0.0742610f);
|
||||
var c3 = Vector256.Create(-0.0187293f);
|
||||
|
||||
var term1 = Fma.MultiplyAdd(x, c3, c2);
|
||||
var term2 = Fma.MultiplyAdd(x, term1, c1);
|
||||
var poly = Fma.MultiplyAdd(x, term2, c0);
|
||||
|
||||
var sqrtTerm = Avx2.Sqrt(Avx2.Subtract(Vector256<float>.One, x));
|
||||
var result = Avx2.Multiply(poly, sqrtTerm);
|
||||
|
||||
var pi = Vector256.Create(MathF.PI);
|
||||
var isNegative = Avx2.CompareLessThan(value, Vector256<float>.Zero);
|
||||
|
||||
return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative);
|
||||
}}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static Vector256<float> Atan2(Vector256<float> y, Vector256<float> x)
|
||||
{{
|
||||
var absX = Vector256.Abs(x);
|
||||
var absY = Vector256.Abs(y);
|
||||
|
||||
// 1. Determine the ratio (input to Atan)
|
||||
// If |value| > |y|, we are in the ""shallow"" region, ratio = y/value
|
||||
// If |y| > |value|, we are in the ""steep"" region, ratio = value/y (and we transform result)
|
||||
var yGtX = Avx2.CompareGreaterThan(absY, absX);
|
||||
|
||||
// Select numerator and denominator to ensure ratio is always in [-1, 1]
|
||||
var num = Avx2.BlendVariable(absX, absY, yGtX);
|
||||
var den = Avx2.BlendVariable(absY, absX, yGtX);
|
||||
|
||||
var t = Avx2.Multiply(num, Avx2.Reciprocal(den)); // t is now in [0, 1]
|
||||
var t2 = Avx2.Multiply(t, t);
|
||||
|
||||
// 2. Polynomial Approximation (Odd function: value * (c1 + c2*value^2))
|
||||
var c1 = Vector256.Create(0.97239411f);
|
||||
var c2 = Vector256.Create(-0.19194795f);
|
||||
|
||||
// (c1 + c2 * t2)
|
||||
var poly = Fma.MultiplyAdd(c2, t2, c1);
|
||||
|
||||
// result = Avx2.Multiply(t, poly)
|
||||
var result = Avx2.Multiply(t, poly);
|
||||
|
||||
// 3. Reconstruct the angle
|
||||
// If we swapped value/y (yGtX), the identity is: atan(value/y) = PI/2 - atan(y/value)
|
||||
var halfPi = Vector256.Create(1.570796327f);
|
||||
result = Avx2.BlendVariable(halfPi - result, result, yGtX);
|
||||
|
||||
// 4. Adjust for Quadrants (Signs)
|
||||
// If value < 0, we are in quadrants 2 or 3, so we need to add PI
|
||||
var pi = Vector256.Create(3.141592654f);
|
||||
var xLtZero = Avx2.CompareLessThan(x, Vector256<float>.Zero);
|
||||
result = Avx2.BlendVariable(pi - result, result, xLtZero);
|
||||
|
||||
// If y < 0, the result should be negative (standard atan2 convention)
|
||||
// NOTE: This sign flip strategy depends on exact polynomial range mapping,
|
||||
// but typically just copy the sign of Y to the result.
|
||||
var yLtZero = Avx2.CompareLessThan(y, Vector256<float>.Zero);
|
||||
// If original Y was negative, negate the result
|
||||
// (This works because our ratio logic effectively computed atan(|y|/|value|) above)
|
||||
var negativeResult = Avx2.Subtract(Vector256<float>.Zero, result);
|
||||
return Avx2.BlendVariable(negativeResult, result, yLtZero);
|
||||
}}
|
||||
{sinCosMethods}
|
||||
}}
|
||||
}}";
|
||||
|
||||
@@ -126,6 +37,11 @@ namespace Misaki.HighPerformance.HPC
|
||||
|
||||
internal class AVX2Rewriter : HPCRewriter
|
||||
{
|
||||
public AVX2Rewriter(SemanticModel semanticModel)
|
||||
: base(semanticModel)
|
||||
{
|
||||
}
|
||||
|
||||
public override string Name => "AVX2";
|
||||
|
||||
public override string GetNesessaryUsing()
|
||||
@@ -133,16 +49,33 @@ namespace Misaki.HighPerformance.HPC
|
||||
return "using System.Runtime.Intrinsics;\nusing System.Runtime.Intrinsics.X86;";
|
||||
}
|
||||
|
||||
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint)
|
||||
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction)
|
||||
{
|
||||
switch (instruction)
|
||||
{
|
||||
case SIMDInstruction.Add:
|
||||
break;
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Add"
|
||||
};
|
||||
case SIMDInstruction.Subtract:
|
||||
break;
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Subtract"
|
||||
};
|
||||
case SIMDInstruction.Multiply:
|
||||
break;
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Multiply"
|
||||
};
|
||||
case SIMDInstruction.MultiplyAdd:
|
||||
return new MathExpression
|
||||
{
|
||||
@@ -167,10 +100,5 @@ namespace Misaki.HighPerformance.HPC
|
||||
|
||||
return default;
|
||||
}
|
||||
|
||||
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
99
Misaki.HighPerformance.HPC.Generator/HPCOptimizerRewriter.cs
Normal file
99
Misaki.HighPerformance.HPC.Generator/HPCOptimizerRewriter.cs
Normal file
@@ -0,0 +1,99 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal class HPCOptimizerRewriter : CSharpSyntaxRewriter
|
||||
{
|
||||
private readonly Dictionary<string, string> _spmdTypes = new();
|
||||
private readonly SemanticModel _semanticModel;
|
||||
|
||||
public HPCOptimizerRewriter(SemanticModel semanticModel)
|
||||
{
|
||||
_semanticModel = semanticModel;
|
||||
}
|
||||
|
||||
private bool IsKnownHpcType(ITypeSymbol? type)
|
||||
{
|
||||
if (type == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if it's WideLane, or one of the mapped TLane0 constraints
|
||||
if (type.Name == "WideLane")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (_spmdTypes.ContainsKey(type.Name))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
protected string? GetHpcPrimitiveType(SyntaxNode originalNode)
|
||||
{
|
||||
var typeInfo = semanticModel.GetTypeInfo(originalNode);
|
||||
var type = typeInfo.Type;
|
||||
|
||||
if (type == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (string.Equals(type.Name, "WideLane") && type is INamedTypeSymbol namedType && namedType.IsGenericType)
|
||||
{
|
||||
// Returns "Single" (float) or "Double" (double)
|
||||
return namedType.TypeArguments[0].ToDisplayString();
|
||||
}
|
||||
|
||||
if (type is ITypeParameterSymbol typeParam)
|
||||
{
|
||||
// Inspect the `where TLane0 : ISPMDLane<TLane0, float>` constraints!
|
||||
foreach (var constraint in typeParam.ConstraintTypes)
|
||||
{
|
||||
if (constraint.Name == "ISPMDLane" && constraint is INamedTypeSymbol constraintNamed && constraintNamed.IsGenericType)
|
||||
{
|
||||
// The second generic argument is the primitive format (float/double)
|
||||
return constraintNamed.TypeArguments[1].ToDisplayString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (type.SpecialType == SpecialType.System_Single)
|
||||
{
|
||||
return "float";
|
||||
}
|
||||
|
||||
if (type.SpecialType == SpecialType.System_Double)
|
||||
{
|
||||
return "double";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
|
||||
{
|
||||
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
|
||||
if (_spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(
|
||||
SyntaxFactory.TypeArgumentList(
|
||||
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
|
||||
SyntaxFactory.IdentifierName(primType))))
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitIdentifierName(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,16 +31,28 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public int[]? ArgumentOrder
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
}
|
||||
|
||||
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet)
|
||||
protected readonly SemanticModel semanticModel;
|
||||
|
||||
protected HPCRewriter(SemanticModel semanticModel)
|
||||
{
|
||||
this.semanticModel = semanticModel;
|
||||
}
|
||||
|
||||
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet, SemanticModel semanticModel)
|
||||
{
|
||||
var rewriters = new List<HPCRewriter>();
|
||||
|
||||
// TODO: Add more rewriters for different instruction sets
|
||||
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
|
||||
{
|
||||
rewriters.Add(new AVX2Rewriter());
|
||||
rewriters.Add(new AVX2Rewriter(semanticModel));
|
||||
}
|
||||
|
||||
return rewriters;
|
||||
@@ -61,8 +73,6 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
["Atan2"] = SIMDInstruction.Atan2,
|
||||
};
|
||||
|
||||
protected readonly Dictionary<string, string> spmdTypes = new();
|
||||
|
||||
public abstract string Name
|
||||
{
|
||||
get;
|
||||
@@ -159,28 +169,10 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
return base.VisitGenericName(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
|
||||
{
|
||||
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
|
||||
if (spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(
|
||||
SyntaxFactory.TypeArgumentList(
|
||||
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
|
||||
SyntaxFactory.IdentifierName(primType))))
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitIdentifierName(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
|
||||
{
|
||||
var isSpmdOrWideLane = false;
|
||||
var isFloatingPoint = false;
|
||||
|
||||
// 1. Check if the left-side expression is WideLane<...> or a tracked generic SPMD type
|
||||
if (node.Expression is GenericNameSyntax genericName &&
|
||||
genericName.Identifier.Text == "WideLane" &&
|
||||
genericName.TypeArgumentList.Arguments.Count == 1)
|
||||
@@ -188,13 +180,11 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
isSpmdOrWideLane = true;
|
||||
|
||||
var argTypeStr = genericName.TypeArgumentList.Arguments[0].ToString();
|
||||
isFloatingPoint = argTypeStr == "float" || argTypeStr == "double";
|
||||
}
|
||||
else if (node.Expression is IdentifierNameSyntax idName &&
|
||||
spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
isFloatingPoint = mappedPrimType == "float" || mappedPrimType == "double";
|
||||
}
|
||||
|
||||
if (isSpmdOrWideLane)
|
||||
@@ -213,7 +203,7 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
|
||||
if (s_remapMath.TryGetValue(node.Name.Identifier.Text, out var instruction))
|
||||
{
|
||||
var rewritResult = RewriteMathExpression(instruction, isFloatingPoint);
|
||||
var rewritResult = RewriteMathExpression(instruction);
|
||||
return SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
SyntaxFactory.IdentifierName(rewritResult.Expression),
|
||||
@@ -229,7 +219,7 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
|
||||
{
|
||||
bool isSpmdOrWideLane = false;
|
||||
var isSpmdOrWideLane = false;
|
||||
|
||||
if (memberAccess.Expression is GenericNameSyntax genericName
|
||||
&& genericName.Identifier.Text == "WideLane"
|
||||
@@ -268,7 +258,78 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
return base.VisitInvocationExpression(node);
|
||||
}
|
||||
|
||||
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint);
|
||||
public override SyntaxNode? VisitBinaryExpression(BinaryExpressionSyntax node)
|
||||
{
|
||||
var type = GetHpcPrimitiveType(node);
|
||||
var ifFloatingPoint = type == "float" || type == "double";
|
||||
|
||||
// Optimize (a * b) + c -> MultiplyAdd(a, b, c)
|
||||
if (node.IsKind(SyntaxKind.AddExpression))
|
||||
{
|
||||
var typeInfo = semanticModel.GetTypeInfo(node);
|
||||
|
||||
if (IsKnownHpcType(typeInfo.Type) && ifFloatingPoint)
|
||||
{
|
||||
if (node.Left.IsKind(SyntaxKind.MultiplyExpression))
|
||||
{
|
||||
var mulNode = (BinaryExpressionSyntax)node.Left;
|
||||
|
||||
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
|
||||
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
|
||||
var c = (ExpressionSyntax)Visit(node.Right)!;
|
||||
|
||||
// Assuming floating point by default for FMA, though you can expand this logic
|
||||
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
if (node.Right.IsKind(SyntaxKind.MultiplyExpression) && ifFloatingPoint)
|
||||
{
|
||||
var mulNode = (BinaryExpressionSyntax)node.Right;
|
||||
var c = (ExpressionSyntax)Visit(node.Left)!;
|
||||
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
|
||||
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
|
||||
|
||||
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitBinaryExpression(node);
|
||||
}
|
||||
|
||||
protected ExpressionSyntax InvokeMathRewrite(SIMDInstruction instruction, params ExpressionSyntax[] args)
|
||||
{
|
||||
var rewriteResult = RewriteMathExpression(instruction);
|
||||
|
||||
var finalArgs = new ArgumentSyntax[args.Length];
|
||||
|
||||
// Reorder arguments if the instruction set backend specifies an order
|
||||
if (rewriteResult.ArgumentOrder != null)
|
||||
{
|
||||
for (var i = 0; i < rewriteResult.ArgumentOrder.Length; i++)
|
||||
{
|
||||
finalArgs[i] = SyntaxFactory.Argument(args[rewriteResult.ArgumentOrder[i]]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i = 0; i < args.Length; i++)
|
||||
{
|
||||
finalArgs[i] = SyntaxFactory.Argument(args[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return SyntaxFactory.InvocationExpression(
|
||||
SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
SyntaxFactory.IdentifierName(rewriteResult.Expression),
|
||||
SyntaxFactory.IdentifierName(rewriteResult.Name)
|
||||
),
|
||||
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(finalArgs))
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction);
|
||||
protected abstract void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,11 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public SemanticModel SemanticModel
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public TargetInstructionSet InstructionSet
|
||||
{
|
||||
get; set;
|
||||
@@ -55,6 +60,7 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
|
||||
MethodSymbol = methodSymbol,
|
||||
SemanticModel = ctx.SemanticModel,
|
||||
InstructionSet = (TargetInstructionSet)attributes.ConstructorArguments[0].Value!,
|
||||
Precision = (FloatPrecision)attributes.ConstructorArguments[1].Value!,
|
||||
Mode = (MathMode)attributes.ConstructorArguments[2].Value!,
|
||||
@@ -82,7 +88,7 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
continue;
|
||||
}
|
||||
|
||||
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet);
|
||||
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet, info.SemanticModel);
|
||||
|
||||
foreach (var writer in rewriters)
|
||||
{
|
||||
|
||||
@@ -1,40 +1,46 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
|
||||
using Misaki.HighPerformance.HPC.Generator.APIContext;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal static class UtilityTemplate
|
||||
{
|
||||
public static Method SinFloat_Standard(IVectorAPIContext api)
|
||||
public static Method Sin_Standard<T>(IVectorAPIContext api)
|
||||
{
|
||||
var body = api.Return(api.Call("Sin", "value"));
|
||||
var body = api.Return(api.Call($"{api.GetVectorType()}.Sin", "value"));
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: api.GetVectorType<float>(),
|
||||
name: $"SinFloat_Standard",
|
||||
parameters: new[] { $"{api.GetVectorType<float>()} value" },
|
||||
returnType: api.GetVectorType<T>(),
|
||||
name: $"Sin_{typeof(T).Name}_Standard",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} value" },
|
||||
body: body);
|
||||
}
|
||||
|
||||
public static Method SinFloat_Fast(IVectorAPIContext api)
|
||||
public static Method Sin_Fast<T>(IVectorAPIContext api)
|
||||
{
|
||||
var invPi = api.Create("0.318309886f").Assign();
|
||||
var isFloat = typeof(T) == typeof(float);
|
||||
var typePrefix = isFloat ? "f" : "d";
|
||||
|
||||
var x_sin = new Expression(api, "value").Assign();
|
||||
var input = new Expression(api, "value");
|
||||
|
||||
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
|
||||
|
||||
var x_sin = input;
|
||||
var y_sin = api.Multiply(x_sin, invPi).Assign();
|
||||
var k_sin = api.Round(y_sin).Assign();
|
||||
var z_sin = api.Subtract(y_sin, k_sin).Assign();
|
||||
|
||||
var half = api.Create("0.5f").Assign();
|
||||
var two = api.Create("2.0f").Assign();
|
||||
var half = api.Create($"0.5{typePrefix}").Assign();
|
||||
var two = api.Create($"2.0{typePrefix}").Assign();
|
||||
|
||||
var k_even_sin = (api.Round(k_sin * half) * two).Assign();
|
||||
var sign_sin = (api.One<float>() - two * api.Abs(k_sin - k_even_sin)).Assign();
|
||||
var sign_sin = (api.One<T>() - two * api.Abs(k_sin - k_even_sin)).Assign();
|
||||
|
||||
var c1 = api.Create("3.14159265f").Assign();
|
||||
var c3 = api.Create("-5.16771278f").Assign();
|
||||
var c5 = api.Create("2.55016404f").Assign();
|
||||
var c7 = api.Create("-0.59926453f").Assign();
|
||||
var c9 = api.Create("0.08214589f").Assign();
|
||||
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
|
||||
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
|
||||
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
|
||||
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
|
||||
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
|
||||
|
||||
var z2_sin = (z_sin * z_sin).Assign();
|
||||
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
|
||||
@@ -49,10 +55,179 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: api.GetVectorType<float>(),
|
||||
name: $"SinFloat_Fast",
|
||||
parameters: new[] { $"{api.GetVectorType<float>()} value" },
|
||||
returnType: api.GetVectorType<T>(),
|
||||
name: $"Sin_{typeof(T).Name}_Fast",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
|
||||
body: body);
|
||||
}
|
||||
|
||||
public static Method Cos_Standard<T>(IVectorAPIContext api)
|
||||
{
|
||||
var body = api.Return(api.Call($"{api.GetVectorType()}.Cos", "value"));
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: api.GetVectorType<T>(),
|
||||
name: $"Cos_{typeof(T).Name}_Standard",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} value" },
|
||||
body: body);
|
||||
}
|
||||
|
||||
public static Method Cos_Fast<T>(IVectorAPIContext api)
|
||||
{
|
||||
var isFloat = typeof(T) == typeof(float);
|
||||
var typePrefix = isFloat ? "f" : "d";
|
||||
|
||||
var input = new Expression(api, "value");
|
||||
|
||||
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
|
||||
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
|
||||
|
||||
var x_cos = api.Add(input, halfPi).Assign();
|
||||
var y_cos = api.Multiply(x_cos, invPi).Assign();
|
||||
var k_cos = api.Round(y_cos).Assign();
|
||||
var z_cos = api.Subtract(y_cos, k_cos).Assign();
|
||||
|
||||
var half = api.Create($"0.5{typePrefix}").Assign();
|
||||
var two = api.Create($"2.0{typePrefix}").Assign();
|
||||
|
||||
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
|
||||
var sign_cos = api.Subtract(api.One<T>(), api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
|
||||
|
||||
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
|
||||
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
|
||||
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
|
||||
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
|
||||
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
|
||||
|
||||
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
|
||||
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
|
||||
|
||||
var poly_cos_name = api.LastAssignedVariable;
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
|
||||
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
|
||||
|
||||
var body = api.Return(poly_cos * sign_cos);
|
||||
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: api.GetVectorType<T>(),
|
||||
name: $"Cos_{typeof(T).Name}_Fast",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
|
||||
body: body);
|
||||
}
|
||||
|
||||
public static Method SinCos_Standard<T>(IVectorAPIContext api)
|
||||
{
|
||||
var sin_cos = api.Return(api.Call($"{api.GetVectorType()}.SinCos", "value"));
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: "void",
|
||||
name: $"SinCos_{typeof(T).Name}_Standard",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} value", $"out {api.GetVectorType<T>()} sin", $"out {api.GetVectorType<T>()} cos" },
|
||||
body: sin_cos);
|
||||
}
|
||||
|
||||
public static Method SinCos_Fast<T>(IVectorAPIContext api)
|
||||
{
|
||||
var isFloat = typeof(T) == typeof(float);
|
||||
var typePrefix = isFloat ? "f" : "d";
|
||||
|
||||
var input = new Expression(api, "value");
|
||||
var sinOut = new Expression(api, "sin");
|
||||
var cosOut = new Expression(api, "cos");
|
||||
|
||||
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
|
||||
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
|
||||
|
||||
var x_sin = input;
|
||||
var x_cos = api.Add(x_sin, halfPi).Assign();
|
||||
|
||||
var y_sin = api.Multiply(x_sin, invPi).Assign();
|
||||
var y_cos = api.Multiply(x_cos, invPi).Assign();
|
||||
|
||||
var k_sin = api.Round(y_sin).Assign();
|
||||
var k_cos = api.Round(y_cos).Assign();
|
||||
|
||||
var z_sin = api.Subtract(y_sin, k_sin).Assign();
|
||||
var z_cos = api.Subtract(y_cos, k_cos).Assign();
|
||||
|
||||
var half = api.Create($"0.5{typePrefix}").Assign();
|
||||
var two = api.Create($"2.0{typePrefix}").Assign();
|
||||
var one = api.One<T>();
|
||||
|
||||
var k_even_sin = api.Multiply(api.Round(api.Multiply(k_sin, half)), two).Assign();
|
||||
var sign_sin = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_sin, k_even_sin)))).Assign();
|
||||
|
||||
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
|
||||
var sign_cos = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
|
||||
|
||||
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
|
||||
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
|
||||
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
|
||||
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
|
||||
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
|
||||
|
||||
var z2_sin = api.Multiply(z_sin, z_sin).Assign();
|
||||
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
|
||||
|
||||
var poly_sin_name = api.LastAssignedVariable;
|
||||
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
|
||||
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
|
||||
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
|
||||
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
|
||||
|
||||
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
|
||||
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
|
||||
|
||||
var poly_cos_name = api.LastAssignedVariable;
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
|
||||
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
|
||||
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
|
||||
|
||||
sinOut = api.Multiply(poly_sin, sign_sin).Assign(sinOut.Code, false);
|
||||
cosOut = api.Multiply(poly_cos, sign_cos).Assign(cosOut.Code, false);
|
||||
|
||||
var body = api.Return(api.Create(""));
|
||||
|
||||
return new Method(
|
||||
modifier: "public static",
|
||||
returnType: "void",
|
||||
name: $"SinCos_{typeof(T).Name}_Fast",
|
||||
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}", $"out {api.GetVectorType<T>()} {sinOut.Code}", $"out {api.GetVectorType<T>()} {cosOut.Code}" },
|
||||
body: body);
|
||||
}
|
||||
|
||||
public static string GenerateSinCosUtilityMethods(IVectorAPIContext api, string identation)
|
||||
{
|
||||
var methods = new Method[]
|
||||
{
|
||||
Sin_Standard<float>(api),
|
||||
Sin_Fast<float>(api),
|
||||
Cos_Standard<float>(api),
|
||||
Cos_Fast<float>(api),
|
||||
SinCos_Standard<float>(api),
|
||||
SinCos_Fast<float>(api),
|
||||
Sin_Standard<double>(api),
|
||||
Sin_Fast<double>(api),
|
||||
Cos_Standard<double>(api),
|
||||
Cos_Fast<double>(api),
|
||||
SinCos_Standard<double>(api),
|
||||
SinCos_Fast<double>(api)
|
||||
};
|
||||
|
||||
var sb = new StringBuilder();
|
||||
var inlineAttr = identation + "[MethodImpl(MethodImplOptions.AggressiveInlining)]";
|
||||
|
||||
foreach (var method in methods)
|
||||
{
|
||||
sb.AppendLine(inlineAttr);
|
||||
sb.AppendLine(method.GetFullCode(identation));
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user