Refactor APIContext, add AVX2 sin/cos codegen, FMA rewrite

- Refactored namespaces from .VectorAPI to .APIContext for clarity.
- Enhanced Avx2APIContext/IVectorAPIContext to support void returns.
- Added GenerateSinCosUtilityMethods for AVX2, emitting vectorized Sin/Cos/SinCos for float/double.
- Introduced HPCOptimizerRewriter for advanced SPMD type handling.
- Refactored HPCRewriter to use SemanticModel, support FMA pattern rewriting, and delegate SPMD logic.
- Updated AVX2Rewriter for new base and improved math mapping.
- Made UtilityTemplate generic and type-safe for sin/cos.
- Updated NoiseJob3D/NoiseJobVector for [HPCompute] attribute and partial struct.
- Fixed solution file project ordering and inclusion.
This commit is contained in:
2026-05-06 22:27:24 +09:00
parent fd2d60c8f1
commit b9537d91da
10 changed files with 439 additions and 161 deletions

View File

@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Avx2APIContext : IVectorAPIContext
{
@@ -23,7 +22,7 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
public Expression Call(string methodName, params string[] args)
{
return new Expression(this, $"{GetVectorType()}.{methodName}({string.Join(", ", args)})");
return new Expression(this, $"{methodName}({string.Join(", ", args)})");
}
public Expression Assign(Expression expr, string? varName = null, bool isNew = true)
@@ -38,11 +37,14 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
return new Expression(this, varName);
}
public Code Return(Expression expr)
public Code Return(Expression? expr)
{
var statement = $"return {expr.Code};";
_statements.Add(statement);
expr.Clear();
if (expr != null)
{
var statement = $"return {expr.Code};";
_statements.Add(statement);
expr.Clear();
}
var fullCode = new Code(_statements);
Reset();

View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Expression
{
@@ -139,7 +139,7 @@ namespace Misaki.HighPerformance.HPC.Generator.VectorAPI
Expression Call(string methodName, params string[] args);
Expression Assign(Expression expr, string? varName = null, bool isNew = true);
Code Return(Expression expr);
Code Return(Expression? expr);
Expression Create(string value);
Expression Zero<T>();

View File

@@ -1,6 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
using Misaki.HighPerformance.HPC.Generator.APIContext;
using System;
namespace Misaki.HighPerformance.HPC.Generator
@@ -14,8 +14,7 @@ namespace Misaki.HighPerformance.HPC.Generator
{
var api = new Avx2APIContext();
var sinFloat_standard = UtilityTemplate.SinFloat_Standard(api);
var sinFloat_fast = UtilityTemplate.SinFloat_Fast(api);
var sinCosMethods = UtilityTemplate.GenerateSinCosUtilityMethods(api, " ");
var source = @$"
using System;
@@ -27,95 +26,7 @@ namespace Misaki.HighPerformance.HPC
{{
public static class AVX2Utility
{{
[MethodImpl(MethodImplOptions.NoInlining)]
{sinFloat_standard.GetFullCode(" ")}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
{sinFloat_fast.GetFullCode(" ")}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Asin(Vector256<float> value)
{{
// asin(value) = pi/2 - acos(value)
var piOver2 = Vector256.Create(MathF.PI / 2.0f);
return Avx2.Subtract(piOver2, Acos(value));
}}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Acos(Vector256<float> value)
{{
// 0 <= value <= 1 : acos(value) = sqrt(1 - value) * (c0 + c1*value + c2*value^2 + c3*value^3)
// value < 0 : acos(value) = pi - acos(-value)
var x = Vector256.Abs(value);
var c0 = Vector256.Create(1.5707288f); // pi/2
var c1 = Vector256.Create(-0.2121144f);
var c2 = Vector256.Create(0.0742610f);
var c3 = Vector256.Create(-0.0187293f);
var term1 = Fma.MultiplyAdd(x, c3, c2);
var term2 = Fma.MultiplyAdd(x, term1, c1);
var poly = Fma.MultiplyAdd(x, term2, c0);
var sqrtTerm = Avx2.Sqrt(Avx2.Subtract(Vector256<float>.One, x));
var result = Avx2.Multiply(poly, sqrtTerm);
var pi = Vector256.Create(MathF.PI);
var isNegative = Avx2.CompareLessThan(value, Vector256<float>.Zero);
return Avx2.BlendVariable(pi, Avx2.Subtract(pi, result), isNegative);
}}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<float> Atan2(Vector256<float> y, Vector256<float> x)
{{
var absX = Vector256.Abs(x);
var absY = Vector256.Abs(y);
// 1. Determine the ratio (input to Atan)
// If |value| > |y|, we are in the ""shallow"" region, ratio = y/value
// If |y| > |value|, we are in the ""steep"" region, ratio = value/y (and we transform result)
var yGtX = Avx2.CompareGreaterThan(absY, absX);
// Select numerator and denominator to ensure ratio is always in [-1, 1]
var num = Avx2.BlendVariable(absX, absY, yGtX);
var den = Avx2.BlendVariable(absY, absX, yGtX);
var t = Avx2.Multiply(num, Avx2.Reciprocal(den)); // t is now in [0, 1]
var t2 = Avx2.Multiply(t, t);
// 2. Polynomial Approximation (Odd function: value * (c1 + c2*value^2))
var c1 = Vector256.Create(0.97239411f);
var c2 = Vector256.Create(-0.19194795f);
// (c1 + c2 * t2)
var poly = Fma.MultiplyAdd(c2, t2, c1);
// result = Avx2.Multiply(t, poly)
var result = Avx2.Multiply(t, poly);
// 3. Reconstruct the angle
// If we swapped value/y (yGtX), the identity is: atan(value/y) = PI/2 - atan(y/value)
var halfPi = Vector256.Create(1.570796327f);
result = Avx2.BlendVariable(halfPi - result, result, yGtX);
// 4. Adjust for Quadrants (Signs)
// If value < 0, we are in quadrants 2 or 3, so we need to add PI
var pi = Vector256.Create(3.141592654f);
var xLtZero = Avx2.CompareLessThan(x, Vector256<float>.Zero);
result = Avx2.BlendVariable(pi - result, result, xLtZero);
// If y < 0, the result should be negative (standard atan2 convention)
// NOTE: This sign flip strategy depends on exact polynomial range mapping,
// but typically just copy the sign of Y to the result.
var yLtZero = Avx2.CompareLessThan(y, Vector256<float>.Zero);
// If original Y was negative, negate the result
// (This works because our ratio logic effectively computed atan(|y|/|value|) above)
var negativeResult = Avx2.Subtract(Vector256<float>.Zero, result);
return Avx2.BlendVariable(negativeResult, result, yLtZero);
}}
{sinCosMethods}
}}
}}";
@@ -126,6 +37,11 @@ namespace Misaki.HighPerformance.HPC
internal class AVX2Rewriter : HPCRewriter
{
public AVX2Rewriter(SemanticModel semanticModel)
: base(semanticModel)
{
}
public override string Name => "AVX2";
public override string GetNesessaryUsing()
@@ -133,16 +49,33 @@ namespace Misaki.HighPerformance.HPC
return "using System.Runtime.Intrinsics;\nusing System.Runtime.Intrinsics.X86;";
}
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint)
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
{
throw new NotImplementedException();
}
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction)
{
switch (instruction)
{
case SIMDInstruction.Add:
break;
return new MathExpression
{
Expression = "Avx2",
Name = "Add"
};
case SIMDInstruction.Subtract:
break;
return new MathExpression
{
Expression = "Avx2",
Name = "Subtract"
};
case SIMDInstruction.Multiply:
break;
return new MathExpression
{
Expression = "Avx2",
Name = "Multiply"
};
case SIMDInstruction.MultiplyAdd:
return new MathExpression
{
@@ -167,10 +100,5 @@ namespace Misaki.HighPerformance.HPC
return default;
}
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
{
return;
}
}
}

View File

@@ -0,0 +1,99 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal class HPCOptimizerRewriter : CSharpSyntaxRewriter
{
private readonly Dictionary<string, string> _spmdTypes = new();
private readonly SemanticModel _semanticModel;
public HPCOptimizerRewriter(SemanticModel semanticModel)
{
_semanticModel = semanticModel;
}
private bool IsKnownHpcType(ITypeSymbol? type)
{
if (type == null)
{
return false;
}
// Check if it's WideLane, or one of the mapped TLane0 constraints
if (type.Name == "WideLane")
{
return true;
}
if (_spmdTypes.ContainsKey(type.Name))
{
return true;
}
return false;
}
protected string? GetHpcPrimitiveType(SyntaxNode originalNode)
{
var typeInfo = semanticModel.GetTypeInfo(originalNode);
var type = typeInfo.Type;
if (type == null)
{
return null;
}
if (string.Equals(type.Name, "WideLane") && type is INamedTypeSymbol namedType && namedType.IsGenericType)
{
// Returns "Single" (float) or "Double" (double)
return namedType.TypeArguments[0].ToDisplayString();
}
if (type is ITypeParameterSymbol typeParam)
{
// Inspect the `where TLane0 : ISPMDLane<TLane0, float>` constraints!
foreach (var constraint in typeParam.ConstraintTypes)
{
if (constraint.Name == "ISPMDLane" && constraint is INamedTypeSymbol constraintNamed && constraintNamed.IsGenericType)
{
// The second generic argument is the primitive format (float/double)
return constraintNamed.TypeArguments[1].ToDisplayString();
}
}
}
if (type.SpecialType == SpecialType.System_Single)
{
return "float";
}
if (type.SpecialType == SpecialType.System_Double)
{
return "double";
}
return null;
}
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
if (_spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
{
return SyntaxFactory.GenericName("Vector256")
.WithTypeArgumentList(
SyntaxFactory.TypeArgumentList(
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
SyntaxFactory.IdentifierName(primType))))
.WithTriviaFrom(node);
}
return base.VisitIdentifierName(node);
}
}
}

View File

@@ -31,16 +31,28 @@ namespace Misaki.HighPerformance.HPC.Generator
{
get; set;
}
public int[]? ArgumentOrder
{
get; set;
}
}
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet)
protected readonly SemanticModel semanticModel;
protected HPCRewriter(SemanticModel semanticModel)
{
this.semanticModel = semanticModel;
}
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet, SemanticModel semanticModel)
{
var rewriters = new List<HPCRewriter>();
// TODO: Add more rewriters for different instruction sets
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
{
rewriters.Add(new AVX2Rewriter());
rewriters.Add(new AVX2Rewriter(semanticModel));
}
return rewriters;
@@ -61,8 +73,6 @@ namespace Misaki.HighPerformance.HPC.Generator
["Atan2"] = SIMDInstruction.Atan2,
};
protected readonly Dictionary<string, string> spmdTypes = new();
public abstract string Name
{
get;
@@ -159,28 +169,10 @@ namespace Misaki.HighPerformance.HPC.Generator
return base.VisitGenericName(node);
}
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
if (spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
{
return SyntaxFactory.GenericName("Vector256")
.WithTypeArgumentList(
SyntaxFactory.TypeArgumentList(
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
SyntaxFactory.IdentifierName(primType))))
.WithTriviaFrom(node);
}
return base.VisitIdentifierName(node);
}
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
{
var isSpmdOrWideLane = false;
var isFloatingPoint = false;
// 1. Check if the left-side expression is WideLane<...> or a tracked generic SPMD type
if (node.Expression is GenericNameSyntax genericName &&
genericName.Identifier.Text == "WideLane" &&
genericName.TypeArgumentList.Arguments.Count == 1)
@@ -188,13 +180,11 @@ namespace Misaki.HighPerformance.HPC.Generator
isSpmdOrWideLane = true;
var argTypeStr = genericName.TypeArgumentList.Arguments[0].ToString();
isFloatingPoint = argTypeStr == "float" || argTypeStr == "double";
}
else if (node.Expression is IdentifierNameSyntax idName &&
spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
{
isSpmdOrWideLane = true;
isFloatingPoint = mappedPrimType == "float" || mappedPrimType == "double";
}
if (isSpmdOrWideLane)
@@ -213,7 +203,7 @@ namespace Misaki.HighPerformance.HPC.Generator
if (s_remapMath.TryGetValue(node.Name.Identifier.Text, out var instruction))
{
var rewritResult = RewriteMathExpression(instruction, isFloatingPoint);
var rewritResult = RewriteMathExpression(instruction);
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(rewritResult.Expression),
@@ -229,7 +219,7 @@ namespace Misaki.HighPerformance.HPC.Generator
{
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
{
bool isSpmdOrWideLane = false;
var isSpmdOrWideLane = false;
if (memberAccess.Expression is GenericNameSyntax genericName
&& genericName.Identifier.Text == "WideLane"
@@ -268,7 +258,78 @@ namespace Misaki.HighPerformance.HPC.Generator
return base.VisitInvocationExpression(node);
}
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction, bool isFloatingPoint);
public override SyntaxNode? VisitBinaryExpression(BinaryExpressionSyntax node)
{
var type = GetHpcPrimitiveType(node);
var ifFloatingPoint = type == "float" || type == "double";
// Optimize (a * b) + c -> MultiplyAdd(a, b, c)
if (node.IsKind(SyntaxKind.AddExpression))
{
var typeInfo = semanticModel.GetTypeInfo(node);
if (IsKnownHpcType(typeInfo.Type) && ifFloatingPoint)
{
if (node.Left.IsKind(SyntaxKind.MultiplyExpression))
{
var mulNode = (BinaryExpressionSyntax)node.Left;
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
var c = (ExpressionSyntax)Visit(node.Right)!;
// Assuming floating point by default for FMA, though you can expand this logic
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
}
if (node.Right.IsKind(SyntaxKind.MultiplyExpression) && ifFloatingPoint)
{
var mulNode = (BinaryExpressionSyntax)node.Right;
var c = (ExpressionSyntax)Visit(node.Left)!;
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
}
}
}
return base.VisitBinaryExpression(node);
}
protected ExpressionSyntax InvokeMathRewrite(SIMDInstruction instruction, params ExpressionSyntax[] args)
{
var rewriteResult = RewriteMathExpression(instruction);
var finalArgs = new ArgumentSyntax[args.Length];
// Reorder arguments if the instruction set backend specifies an order
if (rewriteResult.ArgumentOrder != null)
{
for (var i = 0; i < rewriteResult.ArgumentOrder.Length; i++)
{
finalArgs[i] = SyntaxFactory.Argument(args[rewriteResult.ArgumentOrder[i]]);
}
}
else
{
for (var i = 0; i < args.Length; i++)
{
finalArgs[i] = SyntaxFactory.Argument(args[i]);
}
}
return SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName(rewriteResult.Expression),
SyntaxFactory.IdentifierName(rewriteResult.Name)
),
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(finalArgs))
);
}
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction);
protected abstract void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs);
}
}

View File

@@ -21,6 +21,11 @@ namespace Misaki.HighPerformance.HPC.Generator
get; set;
} = null!;
public SemanticModel SemanticModel
{
get; set;
} = null!;
public TargetInstructionSet InstructionSet
{
get; set;
@@ -55,6 +60,7 @@ namespace Misaki.HighPerformance.HPC.Generator
{
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
MethodSymbol = methodSymbol,
SemanticModel = ctx.SemanticModel,
InstructionSet = (TargetInstructionSet)attributes.ConstructorArguments[0].Value!,
Precision = (FloatPrecision)attributes.ConstructorArguments[1].Value!,
Mode = (MathMode)attributes.ConstructorArguments[2].Value!,
@@ -82,7 +88,7 @@ namespace Misaki.HighPerformance.HPC.Generator
continue;
}
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet);
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet, info.SemanticModel);
foreach (var writer in rewriters)
{

View File

@@ -1,40 +1,46 @@
using Misaki.HighPerformance.HPC.Generator.VectorAPI;
using Misaki.HighPerformance.HPC.Generator.APIContext;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal static class UtilityTemplate
{
public static Method SinFloat_Standard(IVectorAPIContext api)
public static Method Sin_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call("Sin", "value"));
var body = api.Return(api.Call($"{api.GetVectorType()}.Sin", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<float>(),
name: $"SinFloat_Standard",
parameters: new[] { $"{api.GetVectorType<float>()} value" },
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method SinFloat_Fast(IVectorAPIContext api)
public static Method Sin_Fast<T>(IVectorAPIContext api)
{
var invPi = api.Create("0.318309886f").Assign();
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var x_sin = new Expression(api, "value").Assign();
var input = new Expression(api, "value");
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var y_sin = api.Multiply(x_sin, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var half = api.Create("0.5f").Assign();
var two = api.Create("2.0f").Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_sin = (api.Round(k_sin * half) * two).Assign();
var sign_sin = (api.One<float>() - two * api.Abs(k_sin - k_even_sin)).Assign();
var sign_sin = (api.One<T>() - two * api.Abs(k_sin - k_even_sin)).Assign();
var c1 = api.Create("3.14159265f").Assign();
var c3 = api.Create("-5.16771278f").Assign();
var c5 = api.Create("2.55016404f").Assign();
var c7 = api.Create("-0.59926453f").Assign();
var c9 = api.Create("0.08214589f").Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = (z_sin * z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
@@ -49,10 +55,179 @@ namespace Misaki.HighPerformance.HPC.Generator
return new Method(
modifier: "public static",
returnType: api.GetVectorType<float>(),
name: $"SinFloat_Fast",
parameters: new[] { $"{api.GetVectorType<float>()} value" },
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method Cos_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call($"{api.GetVectorType()}.Cos", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method Cos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_cos = api.Add(input, halfPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(api.One<T>(), api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
var body = api.Return(poly_cos * sign_cos);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method SinCos_Standard<T>(IVectorAPIContext api)
{
var sin_cos = api.Return(api.Call($"{api.GetVectorType()}.SinCos", "value"));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value", $"out {api.GetVectorType<T>()} sin", $"out {api.GetVectorType<T>()} cos" },
body: sin_cos);
}
public static Method SinCos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var sinOut = new Expression(api, "sin");
var cosOut = new Expression(api, "cos");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var x_cos = api.Add(x_sin, halfPi).Assign();
var y_sin = api.Multiply(x_sin, invPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var one = api.One<T>();
var k_even_sin = api.Multiply(api.Round(api.Multiply(k_sin, half)), two).Assign();
var sign_sin = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_sin, k_even_sin)))).Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = api.Multiply(z_sin, z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
sinOut = api.Multiply(poly_sin, sign_sin).Assign(sinOut.Code, false);
cosOut = api.Multiply(poly_cos, sign_cos).Assign(cosOut.Code, false);
var body = api.Return(api.Create(""));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}", $"out {api.GetVectorType<T>()} {sinOut.Code}", $"out {api.GetVectorType<T>()} {cosOut.Code}" },
body: body);
}
public static string GenerateSinCosUtilityMethods(IVectorAPIContext api, string identation)
{
var methods = new Method[]
{
Sin_Standard<float>(api),
Sin_Fast<float>(api),
Cos_Standard<float>(api),
Cos_Fast<float>(api),
SinCos_Standard<float>(api),
SinCos_Fast<float>(api),
Sin_Standard<double>(api),
Sin_Fast<double>(api),
Cos_Standard<double>(api),
Cos_Fast<double>(api),
SinCos_Standard<double>(api),
SinCos_Fast<double>(api)
};
var sb = new StringBuilder();
var inlineAttr = identation + "[MethodImpl(MethodImplOptions.AggressiveInlining)]";
foreach (var method in methods)
{
sb.AppendLine(inlineAttr);
sb.AppendLine(method.GetFullCode(identation));
}
return sb.ToString();
}
}
}