Refactor: switch to IR-based HPC codegen pipeline
Major rewrite replacing Roslyn syntax rewriters with an intermediate representation (IR) architecture. Adds IR nodes, analyzer, type resolver, and node rewriter base for optimization passes (e.g., FMA fusion). Refactors AVX2 backend to emit from IR and updates generator pipeline for analysis, optimization, and emission. Removes legacy rewriter classes. Adds polyfills for modern C# features and updates tests and project settings for latest language version. This enables advanced optimizations, easier backend targeting, and future extensibility.
This commit is contained in:
@@ -1,12 +1,19 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Misaki.HighPerformance.HPC.Generator.APIContext;
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
/// <summary>
|
||||
/// Generates the <c>AVX2Utility</c> static class containing polynomial
|
||||
/// approximations for transcendental functions (Sin, Cos, SinCos, etc.)
|
||||
/// that have no built-in AVX2 hardware intrinsic.
|
||||
///
|
||||
/// <para>These methods are called by the <c>AVX2Backend</c> emitter when it
|
||||
/// encounters <see cref="IR.HPCUnaryKind.Sin"/>, <see cref="IR.HPCUnaryKind.Cos"/>,
|
||||
/// and similar IR nodes.</para>
|
||||
/// </summary>
|
||||
[Generator]
|
||||
internal class AVX2UtilityGenerator : IIncrementalGenerator
|
||||
public class AVX2UtilityGenerator : IIncrementalGenerator
|
||||
{
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
@@ -29,76 +36,8 @@ namespace Misaki.HighPerformance.HPC
|
||||
{sinCosMethods}
|
||||
}}
|
||||
}}";
|
||||
|
||||
ctx.AddSource("AVX2Utility.g.cs", source);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
internal class AVX2Rewriter : HPCRewriter
|
||||
{
|
||||
public AVX2Rewriter(SemanticModel semanticModel)
|
||||
: base(semanticModel)
|
||||
{
|
||||
}
|
||||
|
||||
public override string Name => "AVX2";
|
||||
|
||||
public override string GetNesessaryUsing()
|
||||
{
|
||||
return "using System.Runtime.Intrinsics;\nusing System.Runtime.Intrinsics.X86;";
|
||||
}
|
||||
|
||||
protected override void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
protected override MathExpression RewriteMathExpression(SIMDInstruction instruction)
|
||||
{
|
||||
switch (instruction)
|
||||
{
|
||||
case SIMDInstruction.Add:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Add"
|
||||
};
|
||||
case SIMDInstruction.Subtract:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Subtract"
|
||||
};
|
||||
case SIMDInstruction.Multiply:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Avx2",
|
||||
Name = "Multiply"
|
||||
};
|
||||
case SIMDInstruction.MultiplyAdd:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "Fma",
|
||||
Name = "MultiplyAdd"
|
||||
};
|
||||
case SIMDInstruction.Asin:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "AVX2Utility",
|
||||
Name = "Asin"
|
||||
};
|
||||
case SIMDInstruction.Atan2:
|
||||
return new MathExpression
|
||||
{
|
||||
Expression = "AVX2Utility",
|
||||
Name = "Atan2"
|
||||
};
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return default;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
461
Misaki.HighPerformance.HPC.Generator/Analysis/HPCAnalyzer.cs
Normal file
461
Misaki.HighPerformance.HPC.Generator/Analysis/HPCAnalyzer.cs
Normal file
@@ -0,0 +1,461 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Analysis
|
||||
{
|
||||
/// <summary>
|
||||
/// Walks a Roslyn <see cref="MethodDeclarationSyntax"/> and produces an
|
||||
/// <see cref="HPCMethodIR"/> that is completely independent of Roslyn after
|
||||
/// this phase completes.
|
||||
///
|
||||
/// <para>Uses <see cref="CSharpSyntaxWalker"/> (read-only walk) rather than
|
||||
/// <see cref="CSharpSyntaxRewriter"/> so that the semantic model is never
|
||||
/// consulted during rewriting — all queries happen here, once, upfront.</para>
|
||||
/// </summary>
|
||||
internal sealed class HPCAnalyzer : CSharpSyntaxWalker
|
||||
{
|
||||
// ── State ─────────────────────────────────────────────────────────────
|
||||
|
||||
private readonly HPCTypeResolver _typeResolver;
|
||||
|
||||
/// <summary>Statement stack — top is the current block being built.</summary>
|
||||
private readonly Stack<List<HPCStmt>> _blocks = new();
|
||||
|
||||
// ── Construction ─────────────────────────────────────────────────────
|
||||
|
||||
public HPCAnalyzer(SemanticModel semanticModel)
|
||||
{
|
||||
_typeResolver = new HPCTypeResolver(semanticModel);
|
||||
}
|
||||
|
||||
// ── Public API ───────────────────────────────────────────────────────
|
||||
|
||||
/// <summary>
|
||||
/// Analyses <paramref name="method"/> and returns the completed IR.
|
||||
/// </summary>
|
||||
public HPCMethodIR Analyze(MethodDeclarationSyntax method, HPComputeMethodInfo info)
|
||||
{
|
||||
_typeResolver.RegisterConstraints(method);
|
||||
_blocks.Clear();
|
||||
|
||||
// Push the root block
|
||||
_blocks.Push(new List<HPCStmt>());
|
||||
|
||||
if (method.Body is not null)
|
||||
Visit(method.Body);
|
||||
else if (method.ExpressionBody is not null)
|
||||
{
|
||||
// Expression-bodied method: treat as `return <expr>;`
|
||||
var expr = AnalyzeExpression(method.ExpressionBody.Expression);
|
||||
var returnType = _typeResolver.Resolve(method.ExpressionBody.Expression)
|
||||
?? HPCType.Float;
|
||||
Emit(new HPCReturn(expr));
|
||||
}
|
||||
|
||||
var body = _blocks.Pop().ToArray();
|
||||
|
||||
return new HPCMethodIR
|
||||
{
|
||||
Name = method.Identifier.Text,
|
||||
OriginalName = method.Identifier.Text,
|
||||
ReturnType = ResolveReturnType(method),
|
||||
Parameters = BuildParameters(method),
|
||||
Body = body,
|
||||
TargetISA = info.InstructionSet,
|
||||
Precision = info.Precision,
|
||||
Mode = info.Mode,
|
||||
ContainingNamespace = info.MethodSymbol.ContainingNamespace.ToDisplayString(),
|
||||
ContainingTypeName = info.MethodSymbol.ContainingType.Name,
|
||||
};
|
||||
}
|
||||
|
||||
// ── Statement visitors ───────────────────────────────────────────────
|
||||
|
||||
public override void VisitBlock(BlockSyntax node)
|
||||
{
|
||||
// Visit each statement inside the block in order
|
||||
foreach (var stmt in node.Statements)
|
||||
Visit(stmt);
|
||||
}
|
||||
|
||||
public override void VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
|
||||
{
|
||||
foreach (var declarator in node.Declaration.Variables)
|
||||
{
|
||||
if (declarator.Initializer is null)
|
||||
continue;
|
||||
|
||||
var init = AnalyzeExpression(declarator.Initializer.Value);
|
||||
var hpcType = _typeResolver.Resolve(declarator.Initializer.Value)
|
||||
?? InferFromExpr(init);
|
||||
|
||||
Emit(new HPCVarDecl(declarator.Identifier.Text, hpcType, init));
|
||||
}
|
||||
}
|
||||
|
||||
public override void VisitExpressionStatement(ExpressionStatementSyntax node)
|
||||
{
|
||||
if (node.Expression is AssignmentExpressionSyntax assign)
|
||||
{
|
||||
var target = assign.Left.ToString(); // simple name or member access text
|
||||
var value = AnalyzeExpression(assign.Right);
|
||||
Emit(new HPCAssignment(target, value));
|
||||
}
|
||||
else
|
||||
{
|
||||
var expr = AnalyzeExpression(node.Expression);
|
||||
Emit(new HPCExprStmt(expr));
|
||||
}
|
||||
}
|
||||
|
||||
public override void VisitReturnStatement(ReturnStatementSyntax node)
|
||||
{
|
||||
var value = node.Expression is null ? null : AnalyzeExpression(node.Expression);
|
||||
Emit(new HPCReturn(value));
|
||||
}
|
||||
|
||||
public override void VisitIfStatement(IfStatementSyntax node)
|
||||
{
|
||||
var condition = AnalyzeExpression(node.Condition);
|
||||
|
||||
_blocks.Push(new List<HPCStmt>());
|
||||
Visit(node.Statement);
|
||||
var thenBody = _blocks.Pop().ToArray();
|
||||
|
||||
HPCStmt[]? elseBody = null;
|
||||
if (node.Else is not null)
|
||||
{
|
||||
_blocks.Push(new List<HPCStmt>());
|
||||
Visit(node.Else.Statement);
|
||||
elseBody = _blocks.Pop().ToArray();
|
||||
}
|
||||
|
||||
Emit(new HPCIf(condition, thenBody, elseBody));
|
||||
}
|
||||
|
||||
public override void VisitForStatement(ForStatementSyntax node)
|
||||
{
|
||||
// Only handles the simple canonical for (var i = init; i <cond>; i++) form.
|
||||
if (node.Declaration?.Variables.Count != 1 ||
|
||||
node.Condition is null ||
|
||||
node.Incrementors.Count != 1)
|
||||
{
|
||||
// Fall back: treat body as pass-through
|
||||
_blocks.Push(new List<HPCStmt>());
|
||||
Visit(node.Statement);
|
||||
var rawBody = _blocks.Pop().ToArray();
|
||||
foreach (var s in rawBody)
|
||||
Emit(s);
|
||||
return;
|
||||
}
|
||||
|
||||
var decl = node.Declaration.Variables[0];
|
||||
var initExpr = decl.Initializer is not null
|
||||
? AnalyzeExpression(decl.Initializer.Value)
|
||||
: new HPCLiteral("0", HPCType.Int);
|
||||
var iterDecl = new HPCVarDecl(decl.Identifier.Text, HPCType.Int, initExpr);
|
||||
var cond = AnalyzeExpression(node.Condition);
|
||||
var incr = AnalyzeExpression(node.Incrementors[0]);
|
||||
|
||||
_blocks.Push(new List<HPCStmt>());
|
||||
Visit(node.Statement);
|
||||
var body = _blocks.Pop().ToArray();
|
||||
|
||||
Emit(new HPCForLoop(iterDecl, cond, incr, body));
|
||||
}
|
||||
|
||||
// ── Expression analysis ───────────────────────────────────────────────
|
||||
|
||||
private HPCExpr AnalyzeExpression(ExpressionSyntax expr) => expr switch
|
||||
{
|
||||
BinaryExpressionSyntax n => AnalyzeBinary(n),
|
||||
PrefixUnaryExpressionSyntax n => AnalyzePrefixUnary(n),
|
||||
PostfixUnaryExpressionSyntax n => AnalyzePostfixUnary(n),
|
||||
InvocationExpressionSyntax n => AnalyzeInvocation(n),
|
||||
MemberAccessExpressionSyntax n => AnalyzeMemberAccess(n),
|
||||
LiteralExpressionSyntax n => AnalyzeLiteral(n),
|
||||
IdentifierNameSyntax n => AnalyzeIdentifier(n),
|
||||
ParenthesizedExpressionSyntax n => AnalyzeExpression(n.Expression),
|
||||
CastExpressionSyntax n => AnalyzeCast(n),
|
||||
_ => FallbackExpr(expr)
|
||||
};
|
||||
|
||||
private HPCExpr AnalyzeBinary(BinaryExpressionSyntax node)
|
||||
{
|
||||
var left = AnalyzeExpression(node.Left);
|
||||
var right = AnalyzeExpression(node.Right);
|
||||
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(left);
|
||||
|
||||
var kind = node.Kind() switch
|
||||
{
|
||||
SyntaxKind.AddExpression => HPCBinaryKind.Add,
|
||||
SyntaxKind.SubtractExpression => HPCBinaryKind.Subtract,
|
||||
SyntaxKind.MultiplyExpression => HPCBinaryKind.Multiply,
|
||||
SyntaxKind.DivideExpression => HPCBinaryKind.Divide,
|
||||
SyntaxKind.ModuloExpression => HPCBinaryKind.Modulo,
|
||||
SyntaxKind.BitwiseAndExpression => HPCBinaryKind.BitwiseAnd,
|
||||
SyntaxKind.BitwiseOrExpression => HPCBinaryKind.BitwiseOr,
|
||||
SyntaxKind.ExclusiveOrExpression => HPCBinaryKind.BitwiseXor,
|
||||
SyntaxKind.LeftShiftExpression => HPCBinaryKind.ShiftLeft,
|
||||
SyntaxKind.RightShiftExpression => HPCBinaryKind.ShiftRight,
|
||||
SyntaxKind.EqualsExpression => HPCBinaryKind.Equal,
|
||||
SyntaxKind.NotEqualsExpression => HPCBinaryKind.NotEqual,
|
||||
SyntaxKind.LessThanExpression => HPCBinaryKind.LessThan,
|
||||
SyntaxKind.LessThanOrEqualExpression => HPCBinaryKind.LessThanOrEqual,
|
||||
SyntaxKind.GreaterThanExpression => HPCBinaryKind.GreaterThan,
|
||||
SyntaxKind.GreaterThanOrEqualExpression => HPCBinaryKind.GreaterThanOrEqual,
|
||||
_ => throw new NotSupportedException($"Binary operator not supported: {node.Kind()}")
|
||||
};
|
||||
|
||||
return new HPCBinaryOp(kind, left, right, hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzePrefixUnary(PrefixUnaryExpressionSyntax node)
|
||||
{
|
||||
var operand = AnalyzeExpression(node.Operand);
|
||||
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(operand);
|
||||
|
||||
var kind = node.Kind() switch
|
||||
{
|
||||
SyntaxKind.UnaryMinusExpression => HPCUnaryKind.Negate,
|
||||
SyntaxKind.BitwiseNotExpression => HPCUnaryKind.BitwiseNot,
|
||||
_ => throw new NotSupportedException($"Prefix unary not supported: {node.Kind()}")
|
||||
};
|
||||
|
||||
return new HPCUnaryOp(kind, operand, hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzePostfixUnary(PostfixUnaryExpressionSyntax node)
|
||||
{
|
||||
// i++ / i-- — treat as i + 1 / i - 1 in an IR context (no mutation semantics)
|
||||
var operand = AnalyzeExpression(node.Operand);
|
||||
var hpcType = InferFromExpr(operand);
|
||||
var one = new HPCLiteral("1", hpcType);
|
||||
|
||||
return node.Kind() == SyntaxKind.PostIncrementExpression
|
||||
? new HPCBinaryOp(HPCBinaryKind.Add, operand, one, hpcType)
|
||||
: new HPCBinaryOp(HPCBinaryKind.Subtract, operand, one, hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzeInvocation(InvocationExpressionSyntax node)
|
||||
{
|
||||
// ── Instance method on SPMD lane: lane.Add(b), lane.MultiplyAdd(a,b,c) ──
|
||||
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
|
||||
{
|
||||
var receiverType = _typeResolver.Resolve(memberAccess.Expression);
|
||||
var methodName = memberAccess.Name.Identifier.Text;
|
||||
|
||||
if (receiverType != null)
|
||||
{
|
||||
// Unary math on the receiver: lane.Abs(), lane.Sqrt(), etc.
|
||||
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
|
||||
node.ArgumentList.Arguments.Count == 0)
|
||||
{
|
||||
var receiver = AnalyzeExpression(memberAccess.Expression);
|
||||
return new HPCUnaryOp(unaryKind, receiver, receiverType);
|
||||
}
|
||||
|
||||
// Multi-arg static-style methods called as instance methods via ISPMDLane
|
||||
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
|
||||
{
|
||||
var args = new List<HPCExpr> { AnalyzeExpression(memberAccess.Expression) };
|
||||
foreach (var arg in node.ArgumentList.Arguments)
|
||||
args.Add(AnalyzeExpression(arg.Expression));
|
||||
return new HPCIntrinsicCall(intrinsic, args.ToArray(), receiverType);
|
||||
}
|
||||
}
|
||||
|
||||
// Static call on the lane type: TSelf.Sin(v), TSelf.Load(ref p), etc.
|
||||
if (memberAccess.Expression is IdentifierNameSyntax typeIdent &&
|
||||
_typeResolver.IsSPMDTypeParam(typeIdent.Identifier.Text))
|
||||
{
|
||||
var elemType = _typeResolver.ResolveByName(typeIdent.Identifier.Text) ?? HPCType.Float;
|
||||
|
||||
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
|
||||
node.ArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
var arg = AnalyzeExpression(node.ArgumentList.Arguments[0].Expression);
|
||||
return new HPCUnaryOp(unaryKind, arg, elemType);
|
||||
}
|
||||
|
||||
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
|
||||
{
|
||||
var args = new List<HPCExpr>();
|
||||
foreach (var arg in node.ArgumentList.Arguments)
|
||||
args.Add(AnalyzeExpression(arg.Expression));
|
||||
return new HPCIntrinsicCall(intrinsic, args.ToArray(), elemType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Fall through: unknown call, emit as pass-through ──────────────
|
||||
return FallbackExpr(node);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzeMemberAccess(MemberAccessExpressionSyntax node)
|
||||
{
|
||||
var receiverType = _typeResolver.Resolve(node.Expression);
|
||||
var memberName = node.Name.Identifier.Text;
|
||||
|
||||
if (receiverType != null)
|
||||
{
|
||||
// LaneWidth → Count (property remap)
|
||||
var mapped = memberName switch
|
||||
{
|
||||
"LaneWidth" => "Count",
|
||||
_ => memberName
|
||||
};
|
||||
var receiver = AnalyzeExpression(node.Expression);
|
||||
return new HPCPropertyAccess(receiver, mapped, receiverType);
|
||||
}
|
||||
|
||||
return FallbackExpr(node);
|
||||
}
|
||||
|
||||
private static HPCExpr AnalyzeLiteral(LiteralExpressionSyntax node)
|
||||
{
|
||||
var text = node.Token.ValueText;
|
||||
var hpcType = node.Kind() switch
|
||||
{
|
||||
SyntaxKind.NumericLiteralExpression when node.Token.Value is float => HPCType.Float,
|
||||
SyntaxKind.NumericLiteralExpression when node.Token.Value is double => HPCType.Double,
|
||||
SyntaxKind.NumericLiteralExpression when node.Token.Value is int => HPCType.Int,
|
||||
SyntaxKind.NumericLiteralExpression when node.Token.Value is long => HPCType.Long,
|
||||
_ => HPCType.Float
|
||||
};
|
||||
return new HPCLiteral(text, hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzeIdentifier(IdentifierNameSyntax node)
|
||||
{
|
||||
var name = node.Identifier.Text;
|
||||
var hpcType = _typeResolver.Resolve(node)
|
||||
?? _typeResolver.ResolveByName(name)
|
||||
?? HPCType.Float;
|
||||
return new HPCVarRef(name, hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr AnalyzeCast(CastExpressionSyntax node)
|
||||
{
|
||||
var operand = AnalyzeExpression(node.Expression);
|
||||
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
|
||||
return new HPCIntrinsicCall(HPCIntrinsic.Cast, [operand], hpcType);
|
||||
}
|
||||
|
||||
private HPCExpr FallbackExpr(ExpressionSyntax node)
|
||||
{
|
||||
// Preserve anything we don't understand as a verbatim pass-through
|
||||
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
|
||||
return new HPCPassThroughCall(
|
||||
node.ToString(),
|
||||
Array.Empty<HPCExpr>(),
|
||||
hpcType);
|
||||
}
|
||||
|
||||
// ── Method mapping tables ─────────────────────────────────────────────
|
||||
|
||||
private static readonly Dictionary<string, HPCUnaryKind> s_unaryMethods = new()
|
||||
{
|
||||
["Abs"] = HPCUnaryKind.Abs,
|
||||
["Sqrt"] = HPCUnaryKind.Sqrt,
|
||||
["Floor"] = HPCUnaryKind.Floor,
|
||||
["Ceil"] = HPCUnaryKind.Ceil,
|
||||
["Round"] = HPCUnaryKind.Round,
|
||||
["Trunc"] = HPCUnaryKind.Trunc,
|
||||
["Frac"] = HPCUnaryKind.Frac,
|
||||
["Sign"] = HPCUnaryKind.Sign,
|
||||
["Saturate"] = HPCUnaryKind.Saturate,
|
||||
["Rcp"] = HPCUnaryKind.Rcp,
|
||||
["Rsqrt"] = HPCUnaryKind.Rsqrt,
|
||||
["Sin"] = HPCUnaryKind.Sin,
|
||||
["Cos"] = HPCUnaryKind.Cos,
|
||||
["Tan"] = HPCUnaryKind.Tan,
|
||||
["Asin"] = HPCUnaryKind.Asin,
|
||||
["Acos"] = HPCUnaryKind.Acos,
|
||||
["Atan"] = HPCUnaryKind.Atan,
|
||||
["Exp"] = HPCUnaryKind.Exp,
|
||||
["Exp2"] = HPCUnaryKind.Exp2,
|
||||
["Log"] = HPCUnaryKind.Log,
|
||||
["Log2"] = HPCUnaryKind.Log2,
|
||||
};
|
||||
|
||||
private static readonly Dictionary<string, HPCIntrinsic> s_intrinsicMethods = new()
|
||||
{
|
||||
["MultiplyAdd"] = HPCIntrinsic.MultiplyAdd,
|
||||
["MultiplySubtract"] = HPCIntrinsic.MultiplySubtract,
|
||||
["Atan2"] = HPCIntrinsic.Atan2,
|
||||
["Pow"] = HPCIntrinsic.Pow,
|
||||
["SinCos"] = HPCIntrinsic.SinCos,
|
||||
["Lerp"] = HPCIntrinsic.Lerp,
|
||||
["Min"] = HPCIntrinsic.Min,
|
||||
["Max"] = HPCIntrinsic.Max,
|
||||
["Clamp"] = HPCIntrinsic.Clamp,
|
||||
["Select"] = HPCIntrinsic.Select,
|
||||
["CopySign"] = HPCIntrinsic.CopySign,
|
||||
["ReduceAdd"] = HPCIntrinsic.ReduceAdd,
|
||||
["ReduceMax"] = HPCIntrinsic.ReduceMax,
|
||||
["ReduceMin"] = HPCIntrinsic.ReduceMin,
|
||||
["Load"] = HPCIntrinsic.Load,
|
||||
["MaskLoad"] = HPCIntrinsic.MaskLoad,
|
||||
["Gather"] = HPCIntrinsic.Gather,
|
||||
["MaskGather"] = HPCIntrinsic.MaskGather,
|
||||
["Store"] = HPCIntrinsic.Store,
|
||||
["MaskStore"] = HPCIntrinsic.MaskStore,
|
||||
["Scatter"] = HPCIntrinsic.Scatter,
|
||||
["MaskScatter"] = HPCIntrinsic.MaskScatter,
|
||||
["CompressStore"] = HPCIntrinsic.CompressStore,
|
||||
};
|
||||
|
||||
private static bool TryMapUnaryMethod(string name, out HPCUnaryKind kind)
|
||||
=> s_unaryMethods.TryGetValue(name, out kind);
|
||||
|
||||
private static bool TryMapIntrinsicMethod(string name, out HPCIntrinsic intrinsic)
|
||||
=> s_intrinsicMethods.TryGetValue(name, out intrinsic);
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
private void Emit(HPCStmt stmt) => _blocks.Peek().Add(stmt);
|
||||
|
||||
private static HPCType InferFromExpr(HPCExpr expr) => expr.Type;
|
||||
|
||||
private HPCType ResolveReturnType(MethodDeclarationSyntax method)
|
||||
{
|
||||
var typeText = method.ReturnType.ToString();
|
||||
if (typeText == "void")
|
||||
return new HPCType("void", IsFloatingPoint: false);
|
||||
return _typeResolver.ResolveByName(typeText)
|
||||
?? HPCType.FromElementName(typeText);
|
||||
}
|
||||
|
||||
private HPCParameter[] BuildParameters(MethodDeclarationSyntax method)
|
||||
{
|
||||
var result = new List<HPCParameter>();
|
||||
foreach (var param in method.ParameterList.Parameters)
|
||||
{
|
||||
// Skip generic SPMD type parameters (they become the vector type in the backend)
|
||||
var paramTypeName = param.Type?.ToString() ?? "object";
|
||||
if (_typeResolver.IsSPMDTypeParam(paramTypeName))
|
||||
continue;
|
||||
|
||||
var hpcType = _typeResolver.ResolveByName(paramTypeName)
|
||||
?? HPCType.FromElementName(paramTypeName);
|
||||
|
||||
bool isOut = false, isRef = false;
|
||||
foreach (var mod in param.Modifiers)
|
||||
{
|
||||
if (mod.IsKind(SyntaxKind.OutKeyword))
|
||||
isOut = true;
|
||||
if (mod.IsKind(SyntaxKind.RefKeyword))
|
||||
isRef = true;
|
||||
}
|
||||
|
||||
result.Add(new HPCParameter(param.Identifier.Text, hpcType, isOut, isRef));
|
||||
}
|
||||
return result.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
133
Misaki.HighPerformance.HPC.Generator/Analysis/HPCTypeResolver.cs
Normal file
133
Misaki.HighPerformance.HPC.Generator/Analysis/HPCTypeResolver.cs
Normal file
@@ -0,0 +1,133 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Analysis
|
||||
{
|
||||
/// <summary>
|
||||
/// Resolves HPC-specific types from Roslyn's semantic model.
|
||||
/// Centralises all "what is the scalar element type of this expression?"
|
||||
/// logic that was previously duplicated across <c>HPCRewriter</c> and
|
||||
/// <c>HPCOptimizerRewriter</c>.
|
||||
/// </summary>
|
||||
internal sealed class HPCTypeResolver
|
||||
{
|
||||
private readonly SemanticModel _semanticModel;
|
||||
|
||||
/// <summary>
|
||||
/// Maps generic type-parameter names (e.g. <c>"TLane0"</c>) to their
|
||||
/// resolved scalar element types (e.g. <c>"float"</c>).
|
||||
/// Populated by <see cref="RegisterConstraints"/>.
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, string> _typeParamToPrimitive = new();
|
||||
|
||||
public HPCTypeResolver(SemanticModel semanticModel)
|
||||
{
|
||||
_semanticModel = semanticModel;
|
||||
}
|
||||
|
||||
// ── Type-parameter registration ───────────────────────────────────────
|
||||
|
||||
/// <summary>
|
||||
/// Scans the generic constraints on <paramref name="method"/> and
|
||||
/// registers every type parameter constrained to
|
||||
/// <c>ISPMDLane<TSelf, TNumber></c>.
|
||||
/// Must be called before <see cref="Resolve"/> is used on the method body.
|
||||
/// </summary>
|
||||
public void RegisterConstraints(MethodDeclarationSyntax method)
|
||||
{
|
||||
_typeParamToPrimitive.Clear();
|
||||
|
||||
foreach (var clause in method.ConstraintClauses)
|
||||
{
|
||||
var typeParamName = clause.Name.Identifier.Text;
|
||||
|
||||
foreach (var constraint in clause.Constraints)
|
||||
{
|
||||
if (constraint is TypeConstraintSyntax typeConstraint &&
|
||||
typeConstraint.Type is GenericNameSyntax generic &&
|
||||
generic.Identifier.Text == "ISPMDLane" &&
|
||||
generic.TypeArgumentList.Arguments.Count == 2)
|
||||
{
|
||||
// ISPMDLane<TSelf, TNumber> — TNumber is the scalar element type
|
||||
var primitiveTypeName = generic.TypeArgumentList.Arguments[1].ToString();
|
||||
_typeParamToPrimitive[typeParamName] = primitiveTypeName;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Expression type resolution ────────────────────────────────────────
|
||||
|
||||
/// <summary>
|
||||
/// Returns the <see cref="HPCType"/> for <paramref name="node"/>, or
|
||||
/// <c>null</c> if the node is not an HPC-typed expression (i.e. not a
|
||||
/// <c>WideLane<T></c>, a constrained SPMD type-parameter, or a
|
||||
/// plain scalar float/double).
|
||||
/// </summary>
|
||||
public HPCType? Resolve(SyntaxNode node)
|
||||
{
|
||||
var typeInfo = _semanticModel.GetTypeInfo(node);
|
||||
var type = typeInfo.Type;
|
||||
if (type is null) return null;
|
||||
|
||||
// WideLane<float>, WideLane<double>, …
|
||||
if (type.Name == "WideLane" &&
|
||||
type is INamedTypeSymbol wideLane &&
|
||||
wideLane.IsGenericType)
|
||||
{
|
||||
return HPCType.FromElementName(
|
||||
wideLane.TypeArguments[0].ToDisplayString());
|
||||
}
|
||||
|
||||
// Generic type parameter constrained to ISPMDLane<TSelf, TNumber>
|
||||
if (type is ITypeParameterSymbol typeParam)
|
||||
{
|
||||
foreach (var constraint in typeParam.ConstraintTypes)
|
||||
{
|
||||
if (constraint.Name == "ISPMDLane" &&
|
||||
constraint is INamedTypeSymbol namedConstraint &&
|
||||
namedConstraint.IsGenericType)
|
||||
{
|
||||
return HPCType.FromElementName(
|
||||
namedConstraint.TypeArguments[1].ToDisplayString());
|
||||
}
|
||||
}
|
||||
|
||||
// Registered from method constraints
|
||||
if (_typeParamToPrimitive.TryGetValue(typeParam.Name, out var prim))
|
||||
return HPCType.FromElementName(prim);
|
||||
}
|
||||
|
||||
// Bare scalar (used when a scalar literal/variable is involved in a
|
||||
// mixed scalar-vector expression)
|
||||
if (type.SpecialType == SpecialType.System_Single) return HPCType.Float;
|
||||
if (type.SpecialType == SpecialType.System_Double) return HPCType.Double;
|
||||
if (type.SpecialType == SpecialType.System_Int32) return HPCType.Int;
|
||||
if (type.SpecialType == SpecialType.System_UInt32) return HPCType.UInt;
|
||||
if (type.SpecialType == SpecialType.System_Int64) return HPCType.Long;
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Resolves the type of the expression using the registered type-parameter
|
||||
/// map as a fallback (for simple identifier references where the semantic
|
||||
/// model yields a type-parameter symbol rather than a concrete type).
|
||||
/// </summary>
|
||||
public HPCType? ResolveByName(string typeName)
|
||||
{
|
||||
if (_typeParamToPrimitive.TryGetValue(typeName, out var prim))
|
||||
return HPCType.FromElementName(prim);
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>Returns true if <paramref name="typeName"/> is a known SPMD type-param.</summary>
|
||||
public bool IsSPMDTypeParam(string typeName) =>
|
||||
_typeParamToPrimitive.ContainsKey(typeName);
|
||||
|
||||
/// <summary>Snapshot of current type-param → primitive mappings (for reference).</summary>
|
||||
public IReadOnlyDictionary<string, string> TypeParamMap => _typeParamToPrimitive;
|
||||
}
|
||||
}
|
||||
365
Misaki.HighPerformance.HPC.Generator/Backend/AVX2Backend.cs
Normal file
365
Misaki.HighPerformance.HPC.Generator/Backend/AVX2Backend.cs
Normal file
@@ -0,0 +1,365 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
using System;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Backend
|
||||
{
|
||||
/// <summary>
|
||||
/// Emits C# source code targeting AVX2 (256-bit vectors) with the bundled
|
||||
/// AVX2 extensions: FMA3, F16C, and BMI1/2.
|
||||
///
|
||||
/// <para>
|
||||
/// Vector type: <c>Vector256<T></c><br/>
|
||||
/// Intrinsic classes used: <c>Avx</c>, <c>Avx2</c>, <c>Fma</c>, <c>F16C</c>,
|
||||
/// <c>Bmi1</c>, <c>Bmi2</c>, <c>Vector256</c>.
|
||||
/// </para>
|
||||
///
|
||||
/// <para>
|
||||
/// Math functions with no built-in AVX2 intrinsic (Sin, Cos, Asin, etc.) are
|
||||
/// delegated to the <c>AVX2Utility</c> class that is generated separately by
|
||||
/// <c>AVX2UtilityGenerator</c> via <c>IVectorAPIContext</c>.
|
||||
/// </para>
|
||||
/// </summary>
|
||||
internal sealed class AVX2Backend : IHPCBackend
|
||||
{
|
||||
// ── IHPCBackend ───────────────────────────────────────────────────────
|
||||
|
||||
public string Name => "AVX2";
|
||||
|
||||
public string[] RequiredUsings => new[]
|
||||
{
|
||||
"using System.Runtime.CompilerServices;",
|
||||
"using System.Runtime.Intrinsics;",
|
||||
"using System.Runtime.Intrinsics.X86;",
|
||||
};
|
||||
|
||||
public string EmitMethod(HPCMethodIR method)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
var indent = " ";
|
||||
|
||||
// ── Signature ──────────────────────────────────────────────────────
|
||||
sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]");
|
||||
sb.Append($"{indent}public static ");
|
||||
sb.Append(EmitReturnType(method.ReturnType));
|
||||
sb.Append(' ');
|
||||
sb.Append($"{method.OriginalName}_{Name}");
|
||||
sb.Append('(');
|
||||
sb.Append(EmitParameterList(method));
|
||||
sb.AppendLine(")");
|
||||
|
||||
// ── Body ───────────────────────────────────────────────────────────
|
||||
sb.AppendLine($"{indent}{{");
|
||||
foreach (var stmt in method.Body)
|
||||
EmitStatement(sb, stmt, indent + " ");
|
||||
sb.AppendLine($"{indent}}}");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
// ── Statement emission ────────────────────────────────────────────────
|
||||
|
||||
private void EmitStatement(StringBuilder sb, HPCStmt stmt, string indent)
|
||||
{
|
||||
switch (stmt)
|
||||
{
|
||||
case HPCVarDecl decl:
|
||||
sb.AppendLine($"{indent}var {decl.Name} = {EmitExpr(decl.Initializer)};");
|
||||
break;
|
||||
|
||||
case HPCAssignment assign:
|
||||
sb.AppendLine($"{indent}{assign.Target} = {EmitExpr(assign.Value)};");
|
||||
break;
|
||||
|
||||
case HPCExprStmt expr:
|
||||
sb.AppendLine($"{indent}{EmitExpr(expr.Expression)};");
|
||||
break;
|
||||
|
||||
case HPCReturn ret:
|
||||
if (ret.Value is null)
|
||||
sb.AppendLine($"{indent}return;");
|
||||
else
|
||||
sb.AppendLine($"{indent}return {EmitExpr(ret.Value)};");
|
||||
break;
|
||||
|
||||
case HPCIf ifStmt:
|
||||
EmitIf(sb, ifStmt, indent);
|
||||
break;
|
||||
|
||||
case HPCForLoop forLoop:
|
||||
EmitForLoop(sb, forLoop, indent);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private void EmitIf(StringBuilder sb, HPCIf node, string indent)
|
||||
{
|
||||
sb.AppendLine($"{indent}if ({EmitExpr(node.Condition)})");
|
||||
sb.AppendLine($"{indent}{{");
|
||||
foreach (var s in node.ThenBody)
|
||||
EmitStatement(sb, s, indent + " ");
|
||||
sb.AppendLine($"{indent}}}");
|
||||
|
||||
if (node.ElseBody is { Length: > 0 })
|
||||
{
|
||||
sb.AppendLine($"{indent}else");
|
||||
sb.AppendLine($"{indent}{{");
|
||||
foreach (var s in node.ElseBody)
|
||||
EmitStatement(sb, s, indent + " ");
|
||||
sb.AppendLine($"{indent}}}");
|
||||
}
|
||||
}
|
||||
|
||||
private void EmitForLoop(StringBuilder sb, HPCForLoop node, string indent)
|
||||
{
|
||||
var init = $"var {node.Iterator.Name} = {EmitExpr(node.Iterator.Initializer)}";
|
||||
var cond = EmitExpr(node.Condition);
|
||||
var incr = EmitExpr(node.Increment);
|
||||
sb.AppendLine($"{indent}for ({init}; {cond}; {incr})");
|
||||
sb.AppendLine($"{indent}{{");
|
||||
foreach (var s in node.Body)
|
||||
EmitStatement(sb, s, indent + " ");
|
||||
sb.AppendLine($"{indent}}}");
|
||||
}
|
||||
|
||||
// ── Expression emission ───────────────────────────────────────────────
|
||||
|
||||
private string EmitExpr(HPCExpr expr) => expr switch
|
||||
{
|
||||
HPCVarRef v => v.Name,
|
||||
HPCLiteral l => EmitLiteral(l),
|
||||
HPCBinaryOp b => EmitBinary(b),
|
||||
HPCUnaryOp u => EmitUnary(u),
|
||||
HPCIntrinsicCall c => EmitIntrinsic(c),
|
||||
HPCPropertyAccess p => EmitPropertyAccess(p),
|
||||
HPCPassThroughCall pt => pt.MethodName, // verbatim
|
||||
_ => throw new NotSupportedException($"Unknown IR expr: {expr.GetType().Name}")
|
||||
};
|
||||
|
||||
private string EmitLiteral(HPCLiteral lit)
|
||||
{
|
||||
// Scalar literal → broadcast to all lanes
|
||||
return $"Vector256.Create({lit.Value}{LiteralSuffix(lit.Type)})";
|
||||
}
|
||||
|
||||
private string EmitBinary(HPCBinaryOp node)
|
||||
{
|
||||
var l = EmitExpr(node.Left);
|
||||
var r = EmitExpr(node.Right);
|
||||
|
||||
return node.Kind switch
|
||||
{
|
||||
// Floating-point arithmetic uses Avx / Avx2
|
||||
HPCBinaryKind.Add when node.Type.IsFloatingPoint => $"Avx.Add({l}, {r})",
|
||||
HPCBinaryKind.Subtract when node.Type.IsFloatingPoint => $"Avx.Subtract({l}, {r})",
|
||||
HPCBinaryKind.Multiply when node.Type.IsFloatingPoint => $"Avx.Multiply({l}, {r})",
|
||||
HPCBinaryKind.Divide when node.Type.IsFloatingPoint => $"Avx.Divide({l}, {r})",
|
||||
|
||||
// Integer arithmetic uses Avx2
|
||||
HPCBinaryKind.Add => $"Avx2.Add({l}, {r})",
|
||||
HPCBinaryKind.Subtract => $"Avx2.Subtract({l}, {r})",
|
||||
HPCBinaryKind.Multiply => $"Avx2.MultiplyLow({l}, {r})", // 32-bit int multiply
|
||||
|
||||
// Bitwise
|
||||
HPCBinaryKind.BitwiseAnd => $"Avx2.And({l}, {r})",
|
||||
HPCBinaryKind.BitwiseOr => $"Avx2.Or({l}, {r})",
|
||||
HPCBinaryKind.BitwiseXor => $"Avx2.Xor({l}, {r})",
|
||||
HPCBinaryKind.ShiftLeft => $"Avx2.ShiftLeftLogical({l}, {r})",
|
||||
HPCBinaryKind.ShiftRight => $"Avx2.ShiftRightLogical({l}, {r})",
|
||||
|
||||
// Comparisons — emit as AVX compare returning a mask vector
|
||||
HPCBinaryKind.Equal when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedEqualNonSignaling)",
|
||||
HPCBinaryKind.NotEqual when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedNotEqualNonSignaling)",
|
||||
HPCBinaryKind.LessThan when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanNonSignaling)",
|
||||
HPCBinaryKind.LessThanOrEqual when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanOrEqualNonSignaling)",
|
||||
HPCBinaryKind.GreaterThan when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanNonSignaling)",
|
||||
HPCBinaryKind.GreaterThanOrEqual when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanOrEqualNonSignaling)",
|
||||
|
||||
// Integer comparisons
|
||||
HPCBinaryKind.Equal => $"Avx2.CompareEqual({l}, {r})",
|
||||
HPCBinaryKind.GreaterThan => $"Avx2.CompareGreaterThan({l}, {r})",
|
||||
HPCBinaryKind.LessThan => $"Avx2.CompareGreaterThan({r}, {l})", // reversed
|
||||
HPCBinaryKind.GreaterThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({l}, {r}), Avx2.CompareEqual({l}, {r}))",
|
||||
HPCBinaryKind.LessThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({r}, {l}), Avx2.CompareEqual({l}, {r}))",
|
||||
HPCBinaryKind.NotEqual => $"Avx2.Xor(Avx2.CompareEqual({l}, {r}), Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
|
||||
|
||||
HPCBinaryKind.Modulo => throw new NotSupportedException("Modulo has no AVX2 intrinsic; consider reformulating."),
|
||||
|
||||
_ => throw new NotSupportedException($"Binary kind {node.Kind} not supported in AVX2 backend")
|
||||
};
|
||||
}
|
||||
|
||||
private string EmitUnary(HPCUnaryOp node)
|
||||
{
|
||||
var op = EmitExpr(node.Operand);
|
||||
return node.Kind switch
|
||||
{
|
||||
HPCUnaryKind.Negate when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
|
||||
HPCUnaryKind.Negate
|
||||
=> $"Avx2.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
|
||||
HPCUnaryKind.BitwiseNot => $"Avx2.Xor({op}, Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
|
||||
|
||||
// Math — delegate to Vector256 helpers (available in .NET 7+)
|
||||
HPCUnaryKind.Abs when node.Type.IsFloatingPoint => $"Vector256.Abs({op})",
|
||||
HPCUnaryKind.Sqrt when node.Type.IsFloatingPoint => $"Avx.Sqrt({op})",
|
||||
HPCUnaryKind.Floor when node.Type.IsFloatingPoint => $"Avx.Floor({op})",
|
||||
HPCUnaryKind.Ceil when node.Type.IsFloatingPoint => $"Avx.Ceiling({op})",
|
||||
HPCUnaryKind.Round when node.Type.IsFloatingPoint
|
||||
=> $"Avx.RoundToNearestInteger({op})",
|
||||
HPCUnaryKind.Trunc when node.Type.IsFloatingPoint
|
||||
=> $"Avx.RoundToZero({op})",
|
||||
|
||||
// Reciprocal / Rsqrt (float only; approximate 14-bit variants)
|
||||
HPCUnaryKind.Rcp when node.Type.ElementTypeName == "float"
|
||||
=> $"Avx.Reciprocal({op})",
|
||||
HPCUnaryKind.Rsqrt when node.Type.ElementTypeName == "float"
|
||||
=> $"Avx.ReciprocalSqrt({op})",
|
||||
|
||||
// Transcendentals — routed to AVX2Utility (generated by UtilityTemplate)
|
||||
HPCUnaryKind.Sin => $"AVX2Utility.Sin_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
|
||||
HPCUnaryKind.Cos => $"AVX2Utility.Cos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
|
||||
HPCUnaryKind.Asin => $"AVX2Utility.Asin({op})",
|
||||
HPCUnaryKind.Atan => $"AVX2Utility.Atan({op})",
|
||||
HPCUnaryKind.Log => $"AVX2Utility.Log_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
|
||||
HPCUnaryKind.Exp => $"AVX2Utility.Exp_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
|
||||
|
||||
// Frac = x - Floor(x)
|
||||
HPCUnaryKind.Frac when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Subtract({op}, Avx.Floor({op}))",
|
||||
|
||||
// Sign: extract and normalise sign bit
|
||||
HPCUnaryKind.Sign when node.Type.ElementTypeName == "float"
|
||||
=> $"Avx.And(Avx.CompareEqual({op}, Vector256<float>.Zero) == Vector256<float>.Zero ? Vector256.Create(1.0f) : Vector256<float>.Zero, Avx.Or(Avx.And({op}, Vector256.Create(-0.0f)), Vector256.Create(1.0f)))",
|
||||
|
||||
// Saturate: clamp to [0,1]
|
||||
HPCUnaryKind.Saturate when node.Type.IsFloatingPoint
|
||||
=> $"Avx.Max(Avx.Min({op}, Vector256.Create(1.0{LiteralSuffix(node.Type)})), Vector256<{CsTypeName(node.Type)}>.Zero)",
|
||||
|
||||
_ => throw new NotSupportedException($"Unary kind {node.Kind} not supported in AVX2 backend (type: {node.Type})")
|
||||
};
|
||||
}
|
||||
|
||||
private string EmitIntrinsic(HPCIntrinsicCall node)
|
||||
{
|
||||
var args = node.Args;
|
||||
string A(int i) => EmitExpr(args[i]);
|
||||
|
||||
return node.Intrinsic switch
|
||||
{
|
||||
// FMA — uses the FMA extension bundled with AVX2
|
||||
HPCIntrinsic.MultiplyAdd => $"Fma.MultiplyAdd({A(0)}, {A(1)}, {A(2)})",
|
||||
HPCIntrinsic.MultiplySubtract => $"Fma.MultiplySubtract({A(0)}, {A(1)}, {A(2)})",
|
||||
|
||||
// Math
|
||||
HPCIntrinsic.Atan2 => $"AVX2Utility.Atan2({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.Pow => $"AVX2Utility.Pow({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.SinCos => $"AVX2Utility.SinCos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({A(0)}, out {A(1)}, out {A(2)})",
|
||||
|
||||
// Compound math
|
||||
HPCIntrinsic.Lerp => $"Fma.MultiplyAdd(Avx.Subtract({A(1)}, {A(0)}), {A(2)}, {A(0)})",
|
||||
HPCIntrinsic.Min when node.Type.IsFloatingPoint => $"Avx.Min({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.Max when node.Type.IsFloatingPoint => $"Avx.Max({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.Min => $"Avx2.Min({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.Max => $"Avx2.Max({A(0)}, {A(1)})",
|
||||
HPCIntrinsic.Clamp when node.Type.IsFloatingPoint => $"Avx.Max(Avx.Min({A(0)}, {A(2)}), {A(1)})",
|
||||
HPCIntrinsic.Clamp => $"Avx2.Max(Avx2.Min({A(0)}, {A(2)}), {A(1)})",
|
||||
|
||||
// Conditional select via bitwise blend
|
||||
HPCIntrinsic.Select when node.Type.IsFloatingPoint
|
||||
=> $"Avx.BlendVariable({A(2)}, {A(1)}, {A(0)})",
|
||||
HPCIntrinsic.Select
|
||||
=> $"Avx2.BlendVariable({A(2)}.As<{CsTypeName(node.Type)}, byte>(), {A(1)}.As<{CsTypeName(node.Type)}, byte>(), {A(0)}.As<{CsTypeName(node.Type)}, byte>()).As<byte, {CsTypeName(node.Type)}>()",
|
||||
|
||||
// CopySign: compose exponent+mantissa from A(0) and sign from A(1)
|
||||
HPCIntrinsic.CopySign when node.Type.ElementTypeName == "float"
|
||||
=> $"Avx.Or(Avx.And({A(0)}, Vector256.Create(0x7FFFFFFFu).AsSingle()), Avx.And({A(1)}, Vector256.Create(0x80000000u).AsSingle()))",
|
||||
|
||||
// Horizontal reductions
|
||||
HPCIntrinsic.ReduceAdd => $"Vector256.Sum({A(0)})",
|
||||
HPCIntrinsic.ReduceMax => $"Vector256.Max({A(0)})",
|
||||
HPCIntrinsic.ReduceMin => $"Vector256.Min({A(0)})",
|
||||
|
||||
// Memory — pointer-based (emitted as ref-to-pointer pattern)
|
||||
HPCIntrinsic.Load => $"Vector256.LoadUnsafe(ref {A(0)})",
|
||||
HPCIntrinsic.MaskLoad => $"Avx2.MaskLoad(ref {A(0)}, {A(1)}.AsInt32())",
|
||||
HPCIntrinsic.Store => $"Vector256.StoreUnsafe({A(0)}, ref {A(1)})",
|
||||
HPCIntrinsic.MaskStore => $"Avx2.MaskStore(ref {A(1)}, {A(2)}.AsInt32(), {A(0)})",
|
||||
|
||||
HPCIntrinsic.Gather => $"Avx2.GatherVector256(ref {A(0)}, {A(1)}, {A(2)})",
|
||||
HPCIntrinsic.MaskGather => $"Avx2.GatherMaskVector256({A(0)}, ref {A(1)}, {A(2)}, {A(3)}, {A(4)})",
|
||||
|
||||
HPCIntrinsic.CompressStore
|
||||
=> $"/* CompressStore requires AVX-512VBMI2; use scalar fallback */ {A(0)}.CompressStore(ref {A(1)}, {A(2)})",
|
||||
|
||||
// Conversions
|
||||
HPCIntrinsic.Cast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
|
||||
HPCIntrinsic.BitCast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
|
||||
|
||||
_ => throw new NotSupportedException($"Intrinsic {node.Intrinsic} not implemented in AVX2 backend")
|
||||
};
|
||||
}
|
||||
|
||||
private static string EmitPropertyAccess(HPCPropertyAccess node) =>
|
||||
$"{node.Target}.{node.PropertyName}";
|
||||
|
||||
// ── Signature helpers ─────────────────────────────────────────────────
|
||||
|
||||
private static string EmitReturnType(HPCType type) =>
|
||||
type.ElementTypeName == "void"
|
||||
? "void"
|
||||
: $"Vector256<{type.ElementTypeName}>";
|
||||
|
||||
private static string EmitParameterList(HPCMethodIR method)
|
||||
{
|
||||
var parts = new System.Collections.Generic.List<string>();
|
||||
foreach (var p in method.Parameters)
|
||||
{
|
||||
var prefix = (p.IsOut ? "out " : "") + (p.IsRef ? "ref " : "");
|
||||
var typeName = p.Type.ElementTypeName == "void"
|
||||
? "void"
|
||||
: $"Vector256<{p.Type.ElementTypeName}>";
|
||||
parts.Add($"{prefix}{typeName} {p.Name}");
|
||||
}
|
||||
return string.Join(", ", parts);
|
||||
}
|
||||
|
||||
// ── Naming helpers ────────────────────────────────────────────────────
|
||||
|
||||
private static string CsTypeName(HPCType t) => t.ElementTypeName;
|
||||
|
||||
private static string CsTypeNameCap(HPCType t) => t.ElementTypeName switch
|
||||
{
|
||||
"float" or "Single" => "Single",
|
||||
"double" or "Double" => "Double",
|
||||
_ => t.ElementTypeName
|
||||
};
|
||||
|
||||
private static string LiteralSuffix(HPCType t) => t.ElementTypeName switch
|
||||
{
|
||||
"float" or "Single" => "f",
|
||||
"double" or "Double" => "d",
|
||||
_ => ""
|
||||
};
|
||||
|
||||
// ── Mode-aware suffix ─────────────────────────────────────────────────
|
||||
|
||||
// The backend does not hold state for the current method's MathMode by default;
|
||||
// instead EmitMethod passes context through a field set per-call.
|
||||
private MathMode _currentMode = MathMode.Standard;
|
||||
|
||||
public string EmitMethod(HPCMethodIR method, MathMode mode)
|
||||
{
|
||||
_currentMode = mode;
|
||||
return EmitMethod(method);
|
||||
}
|
||||
|
||||
private string MathModeSuffix() => _currentMode == MathMode.Fast ? "Fast" : "Standard";
|
||||
}
|
||||
}
|
||||
23
Misaki.HighPerformance.HPC.Generator/Backend/IHPCBackend.cs
Normal file
23
Misaki.HighPerformance.HPC.Generator/Backend/IHPCBackend.cs
Normal file
@@ -0,0 +1,23 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Backend
|
||||
{
|
||||
/// <summary>
|
||||
/// Emits C# source code for one target instruction-set architecture from a
|
||||
/// fully analysed and optimised <see cref="HPCMethodIR"/>.
|
||||
/// </summary>
|
||||
internal interface IHPCBackend
|
||||
{
|
||||
/// <summary>Short identifier used in generated file/method names (e.g. "AVX2").</summary>
|
||||
string Name { get; }
|
||||
|
||||
/// <summary>Using-directives the emitted code requires.</summary>
|
||||
string[] RequiredUsings { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Emits the full body of a specialised method and returns the C# source
|
||||
/// as a string (method signature + braces included).
|
||||
/// </summary>
|
||||
string EmitMethod(HPCMethodIR method);
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal class HPCOptimizerRewriter : CSharpSyntaxRewriter
|
||||
{
|
||||
private readonly Dictionary<string, string> _spmdTypes = new();
|
||||
private readonly SemanticModel _semanticModel;
|
||||
|
||||
public HPCOptimizerRewriter(SemanticModel semanticModel)
|
||||
{
|
||||
_semanticModel = semanticModel;
|
||||
}
|
||||
|
||||
private bool IsKnownHpcType(ITypeSymbol? type)
|
||||
{
|
||||
if (type == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if it's WideLane, or one of the mapped TLane0 constraints
|
||||
if (type.Name == "WideLane")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (_spmdTypes.ContainsKey(type.Name))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
protected string? GetHpcPrimitiveType(SyntaxNode originalNode)
|
||||
{
|
||||
var typeInfo = semanticModel.GetTypeInfo(originalNode);
|
||||
var type = typeInfo.Type;
|
||||
|
||||
if (type == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (string.Equals(type.Name, "WideLane") && type is INamedTypeSymbol namedType && namedType.IsGenericType)
|
||||
{
|
||||
// Returns "Single" (float) or "Double" (double)
|
||||
return namedType.TypeArguments[0].ToDisplayString();
|
||||
}
|
||||
|
||||
if (type is ITypeParameterSymbol typeParam)
|
||||
{
|
||||
// Inspect the `where TLane0 : ISPMDLane<TLane0, float>` constraints!
|
||||
foreach (var constraint in typeParam.ConstraintTypes)
|
||||
{
|
||||
if (constraint.Name == "ISPMDLane" && constraint is INamedTypeSymbol constraintNamed && constraintNamed.IsGenericType)
|
||||
{
|
||||
// The second generic argument is the primitive format (float/double)
|
||||
return constraintNamed.TypeArguments[1].ToDisplayString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (type.SpecialType == SpecialType.System_Single)
|
||||
{
|
||||
return "float";
|
||||
}
|
||||
|
||||
if (type.SpecialType == SpecialType.System_Double)
|
||||
{
|
||||
return "double";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
|
||||
{
|
||||
// Rewrites signature types and generic types from `TLane0` to `Vector256<float>`
|
||||
if (_spmdTypes.TryGetValue(node.Identifier.Text, out var primType))
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(
|
||||
SyntaxFactory.TypeArgumentList(
|
||||
SyntaxFactory.SingletonSeparatedList<TypeSyntax>(
|
||||
SyntaxFactory.IdentifierName(primType))))
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitIdentifierName(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal enum SIMDInstruction
|
||||
{
|
||||
Add,
|
||||
Subtract,
|
||||
Multiply,
|
||||
MultiplyAdd,
|
||||
|
||||
Asin,
|
||||
Atan2,
|
||||
}
|
||||
|
||||
internal abstract class HPCRewriter : CSharpSyntaxRewriter
|
||||
{
|
||||
protected struct MathExpression
|
||||
{
|
||||
public string Expression
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public string Name
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public int[]? ArgumentOrder
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
}
|
||||
|
||||
protected readonly SemanticModel semanticModel;
|
||||
|
||||
protected HPCRewriter(SemanticModel semanticModel)
|
||||
{
|
||||
this.semanticModel = semanticModel;
|
||||
}
|
||||
|
||||
public static IReadOnlyCollection<HPCRewriter> GetRewriter(TargetInstructionSet instructionSet, SemanticModel semanticModel)
|
||||
{
|
||||
var rewriters = new List<HPCRewriter>();
|
||||
|
||||
// TODO: Add more rewriters for different instruction sets
|
||||
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
|
||||
{
|
||||
rewriters.Add(new AVX2Rewriter(semanticModel));
|
||||
}
|
||||
|
||||
return rewriters;
|
||||
}
|
||||
|
||||
private static readonly Dictionary<string, string> s_remapProperties = new()
|
||||
{
|
||||
["LaneWidth"] = "Count",
|
||||
};
|
||||
|
||||
private static readonly Dictionary<string, SIMDInstruction> s_remapMath = new()
|
||||
{
|
||||
["Add"] = SIMDInstruction.Add,
|
||||
["Subtract"] = SIMDInstruction.Subtract,
|
||||
["Multiply"] = SIMDInstruction.Multiply,
|
||||
["MultiplyAdd"] = SIMDInstruction.MultiplyAdd,
|
||||
["Asin"] = SIMDInstruction.Asin,
|
||||
["Atan2"] = SIMDInstruction.Atan2,
|
||||
};
|
||||
|
||||
public abstract string Name
|
||||
{
|
||||
get;
|
||||
}
|
||||
|
||||
public virtual string GetNesessaryUsing()
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitAttributeList(AttributeListSyntax node)
|
||||
{
|
||||
var filteredAttributes = SyntaxFactory.SeparatedList(
|
||||
node.Attributes.Where(a => !a.Name.ToString().Contains("HPCompute"))
|
||||
);
|
||||
|
||||
if (filteredAttributes.Count == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
return node.WithAttributes(filteredAttributes).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMethodDeclaration(MethodDeclarationSyntax node)
|
||||
{
|
||||
var typesToRemove = new HashSet<string>();
|
||||
|
||||
// 1. Analyze constraints to identify ISPMDLane generics
|
||||
foreach (var clause in node.ConstraintClauses)
|
||||
{
|
||||
var typeNameStr = clause.Name.Identifier.Text;
|
||||
foreach (var constraint in clause.Constraints.OfType<TypeConstraintSyntax>())
|
||||
{
|
||||
if (constraint.Type is GenericNameSyntax genericType &&
|
||||
genericType.Identifier.Text == "ISPMDLane" &&
|
||||
genericType.TypeArgumentList.Arguments.Count == 2)
|
||||
{
|
||||
var primType = genericType.TypeArgumentList.Arguments[1].ToString();
|
||||
spmdTypes[typeNameStr] = primType;
|
||||
typesToRemove.Add(typeNameStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var methodToVisit = node;
|
||||
|
||||
// 2. Strip type parameter and constraints BEFORE visiting so VisitIdentifierName doesn't touch them
|
||||
if (typesToRemove.Count > 0)
|
||||
{
|
||||
// Remove from <TLane0, ...> generics list
|
||||
if (methodToVisit.TypeParameterList != null)
|
||||
{
|
||||
var newParams = methodToVisit.TypeParameterList.Parameters
|
||||
.Where(p => !typesToRemove.Contains(p.Identifier.Text))
|
||||
.ToList();
|
||||
|
||||
if (newParams.Any())
|
||||
{
|
||||
methodToVisit = methodToVisit.WithTypeParameterList(
|
||||
SyntaxFactory.TypeParameterList(SyntaxFactory.SeparatedList(newParams))
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
methodToVisit = methodToVisit.WithTypeParameterList(null); // Removes angle brackets entirely
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the matching `where TLane0 : ...` clause
|
||||
var newConstraints = methodToVisit.ConstraintClauses
|
||||
.Where(c => !typesToRemove.Contains(c.Name.Identifier.Text))
|
||||
.ToList();
|
||||
|
||||
methodToVisit = methodToVisit.WithConstraintClauses(
|
||||
SyntaxFactory.List(newConstraints)
|
||||
);
|
||||
}
|
||||
|
||||
// 3. Fallback to base to rewrite method arguments, return types, and body via our updated visitors
|
||||
return base.VisitMethodDeclaration(methodToVisit);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitGenericName(GenericNameSyntax node)
|
||||
{
|
||||
if (node.Identifier.Text == "WideLane" &&
|
||||
node.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
return SyntaxFactory.GenericName("Vector256")
|
||||
.WithTypeArgumentList(node.TypeArgumentList)
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
return base.VisitGenericName(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
|
||||
{
|
||||
var isSpmdOrWideLane = false;
|
||||
|
||||
if (node.Expression is GenericNameSyntax genericName &&
|
||||
genericName.Identifier.Text == "WideLane" &&
|
||||
genericName.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
|
||||
var argTypeStr = genericName.TypeArgumentList.Arguments[0].ToString();
|
||||
}
|
||||
else if (node.Expression is IdentifierNameSyntax idName &&
|
||||
spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
}
|
||||
|
||||
if (isSpmdOrWideLane)
|
||||
{
|
||||
if (s_remapProperties.TryGetValue(node.Name.Identifier.Text, out var remappedName))
|
||||
{
|
||||
// Keep the evaluated left-hand side (TLane0 -> Vector256<float>) but change the property
|
||||
var rewrittenExpression = (ExpressionSyntax)Visit(node.Expression);
|
||||
|
||||
return SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
rewrittenExpression,
|
||||
SyntaxFactory.IdentifierName(remappedName)
|
||||
).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
if (s_remapMath.TryGetValue(node.Name.Identifier.Text, out var instruction))
|
||||
{
|
||||
var rewritResult = RewriteMathExpression(instruction);
|
||||
return SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
SyntaxFactory.IdentifierName(rewritResult.Expression),
|
||||
SyntaxFactory.IdentifierName(rewritResult.Name)
|
||||
).WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitMemberAccessExpression(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
|
||||
{
|
||||
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
|
||||
{
|
||||
var isSpmdOrWideLane = false;
|
||||
|
||||
if (memberAccess.Expression is GenericNameSyntax genericName
|
||||
&& genericName.Identifier.Text == "WideLane"
|
||||
&& genericName.TypeArgumentList.Arguments.Count == 1)
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
}
|
||||
else if (memberAccess.Expression is IdentifierNameSyntax idName
|
||||
&& spmdTypes.TryGetValue(idName.Identifier.Text, out var mappedPrimType))
|
||||
{
|
||||
isSpmdOrWideLane = true;
|
||||
}
|
||||
|
||||
if (isSpmdOrWideLane)
|
||||
{
|
||||
var args = node.ArgumentList.Arguments;
|
||||
var argList = new ArgumentSyntax[args.Count];
|
||||
|
||||
for (var i = 0; i < args.Count; i++)
|
||||
{
|
||||
argList[i] = (ArgumentSyntax)Visit(args[i]);
|
||||
}
|
||||
|
||||
if (s_remapMath.TryGetValue(memberAccess.Name.Identifier.Text, out var instruction))
|
||||
{
|
||||
RewriteMathArguments(instruction, argList);
|
||||
var arguments = SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(argList));
|
||||
|
||||
var newExpression = (ExpressionSyntax)Visit(memberAccess);
|
||||
return SyntaxFactory.InvocationExpression(newExpression, arguments)
|
||||
.WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitInvocationExpression(node);
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitBinaryExpression(BinaryExpressionSyntax node)
|
||||
{
|
||||
var type = GetHpcPrimitiveType(node);
|
||||
var ifFloatingPoint = type == "float" || type == "double";
|
||||
|
||||
// Optimize (a * b) + c -> MultiplyAdd(a, b, c)
|
||||
if (node.IsKind(SyntaxKind.AddExpression))
|
||||
{
|
||||
var typeInfo = semanticModel.GetTypeInfo(node);
|
||||
|
||||
if (IsKnownHpcType(typeInfo.Type) && ifFloatingPoint)
|
||||
{
|
||||
if (node.Left.IsKind(SyntaxKind.MultiplyExpression))
|
||||
{
|
||||
var mulNode = (BinaryExpressionSyntax)node.Left;
|
||||
|
||||
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
|
||||
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
|
||||
var c = (ExpressionSyntax)Visit(node.Right)!;
|
||||
|
||||
// Assuming floating point by default for FMA, though you can expand this logic
|
||||
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
|
||||
}
|
||||
|
||||
if (node.Right.IsKind(SyntaxKind.MultiplyExpression) && ifFloatingPoint)
|
||||
{
|
||||
var mulNode = (BinaryExpressionSyntax)node.Right;
|
||||
var c = (ExpressionSyntax)Visit(node.Left)!;
|
||||
var a = (ExpressionSyntax)Visit(mulNode.Left)!;
|
||||
var b = (ExpressionSyntax)Visit(mulNode.Right)!;
|
||||
|
||||
return InvokeMathRewrite(SIMDInstruction.MultiplyAdd, a, b, c).WithTriviaFrom(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitBinaryExpression(node);
|
||||
}
|
||||
|
||||
protected ExpressionSyntax InvokeMathRewrite(SIMDInstruction instruction, params ExpressionSyntax[] args)
|
||||
{
|
||||
var rewriteResult = RewriteMathExpression(instruction);
|
||||
|
||||
var finalArgs = new ArgumentSyntax[args.Length];
|
||||
|
||||
// Reorder arguments if the instruction set backend specifies an order
|
||||
if (rewriteResult.ArgumentOrder != null)
|
||||
{
|
||||
for (var i = 0; i < rewriteResult.ArgumentOrder.Length; i++)
|
||||
{
|
||||
finalArgs[i] = SyntaxFactory.Argument(args[rewriteResult.ArgumentOrder[i]]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var i = 0; i < args.Length; i++)
|
||||
{
|
||||
finalArgs[i] = SyntaxFactory.Argument(args[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return SyntaxFactory.InvocationExpression(
|
||||
SyntaxFactory.MemberAccessExpression(
|
||||
SyntaxKind.SimpleMemberAccessExpression,
|
||||
SyntaxFactory.IdentifierName(rewriteResult.Expression),
|
||||
SyntaxFactory.IdentifierName(rewriteResult.Name)
|
||||
),
|
||||
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(finalArgs))
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract MathExpression RewriteMathExpression(SIMDInstruction instruction);
|
||||
protected abstract void RewriteMathArguments(SIMDInstruction instruction, Span<ArgumentSyntax> originalArgs);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using Misaki.HighPerformance.HPC.Generator.Analysis;
|
||||
using Misaki.HighPerformance.HPC.Generator.Backend;
|
||||
using Misaki.HighPerformance.HPC.Generator.Optimization;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
@@ -11,105 +13,165 @@ namespace Misaki.HighPerformance.HPC.Generator
|
||||
{
|
||||
internal class HPComputeMethodInfo
|
||||
{
|
||||
public MethodDeclarationSyntax MethodDeclaration
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public IMethodSymbol MethodSymbol
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public SemanticModel SemanticModel
|
||||
{
|
||||
get; set;
|
||||
} = null!;
|
||||
|
||||
public TargetInstructionSet InstructionSet
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public FloatPrecision Precision
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
|
||||
public MathMode Mode
|
||||
{
|
||||
get; set;
|
||||
}
|
||||
public MethodDeclarationSyntax MethodDeclaration { get; set; } = null!;
|
||||
public IMethodSymbol MethodSymbol { get; set; } = null!;
|
||||
public SemanticModel SemanticModel { get; set; } = null!;
|
||||
public TargetInstructionSet InstructionSet { get; set; }
|
||||
public FloatPrecision Precision { get; set; }
|
||||
public MathMode Mode { get; set; }
|
||||
}
|
||||
|
||||
[Generator]
|
||||
public class HPComputeGenerator : IIncrementalGenerator
|
||||
{
|
||||
// ── Backends (one singleton per ISA, stateless between methods) ───────
|
||||
|
||||
private static readonly AVX2Backend s_avx2Backend = new();
|
||||
|
||||
// ── Optimization passes ───────────────────────────────────────────────
|
||||
|
||||
/// <summary>
|
||||
/// Returns the ordered list of optimisation passes for the given method.
|
||||
/// Passes are cheap objects; creating them per-method is intentional so
|
||||
/// future stateful passes (e.g. CSE) can be added without concurrency issues.
|
||||
/// </summary>
|
||||
private static IEnumerable<IHPCOptimizationPass> GetPasses(HPComputeMethodInfo info)
|
||||
{
|
||||
// FMA fusion only makes sense on ISAs that have it
|
||||
if (info.InstructionSet.HasFlag(TargetInstructionSet.AVX2))
|
||||
yield return new FMAFusionPass();
|
||||
}
|
||||
|
||||
// ── IIncrementalGenerator ─────────────────────────────────────────────
|
||||
|
||||
public void Initialize(IncrementalGeneratorInitializationContext context)
|
||||
{
|
||||
var methodDeclarations = context.SyntaxProvider
|
||||
.ForAttributeWithMetadataName(
|
||||
"Misaki.HighPerformance.HPC.HPComputeAttribute",
|
||||
static (n, ct) => n is MethodDeclarationSyntax,
|
||||
static (ctx, ct) =>
|
||||
static (n, _) => n is MethodDeclarationSyntax,
|
||||
static (ctx, _) =>
|
||||
{
|
||||
var attributes = ctx.Attributes.FirstOrDefault(a => a.AttributeClass?.ToDisplayString() == "Misaki.HighPerformance.HPC.HPComputeAttribute");
|
||||
if (attributes != null && ctx.TargetSymbol is IMethodSymbol methodSymbol)
|
||||
{
|
||||
return new HPComputeMethodInfo
|
||||
{
|
||||
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
|
||||
MethodSymbol = methodSymbol,
|
||||
SemanticModel = ctx.SemanticModel,
|
||||
InstructionSet = (TargetInstructionSet)attributes.ConstructorArguments[0].Value!,
|
||||
Precision = (FloatPrecision)attributes.ConstructorArguments[1].Value!,
|
||||
Mode = (MathMode)attributes.ConstructorArguments[2].Value!,
|
||||
};
|
||||
}
|
||||
var attribute = ctx.Attributes.FirstOrDefault(
|
||||
a => a.AttributeClass?.ToDisplayString() ==
|
||||
"Misaki.HighPerformance.HPC.HPComputeAttribute");
|
||||
|
||||
return null;
|
||||
if (attribute is null || ctx.TargetSymbol is not IMethodSymbol methodSymbol)
|
||||
return null;
|
||||
|
||||
return new HPComputeMethodInfo
|
||||
{
|
||||
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
|
||||
MethodSymbol = methodSymbol,
|
||||
SemanticModel = ctx.SemanticModel,
|
||||
InstructionSet = (TargetInstructionSet)attribute.ConstructorArguments[0].Value!,
|
||||
Precision = (FloatPrecision)attribute.ConstructorArguments[1].Value!,
|
||||
Mode = (MathMode)attribute.ConstructorArguments[2].Value!,
|
||||
};
|
||||
})
|
||||
.Collect();
|
||||
|
||||
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethod);
|
||||
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethods);
|
||||
}
|
||||
|
||||
private void GenerateHPCMethod(SourceProductionContext context, ImmutableArray<HPComputeMethodInfo?> array)
|
||||
// ── Core pipeline ─────────────────────────────────────────────────────
|
||||
|
||||
private static void GenerateHPCMethods(
|
||||
SourceProductionContext context,
|
||||
ImmutableArray<HPComputeMethodInfo?> array)
|
||||
{
|
||||
if (array.IsEmpty)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (array.IsEmpty) return;
|
||||
|
||||
foreach (var info in array)
|
||||
{
|
||||
if (info == null)
|
||||
if (info is null) continue;
|
||||
|
||||
try
|
||||
{
|
||||
continue;
|
||||
GenerateSingleMethod(context, info);
|
||||
}
|
||||
|
||||
var rewriters = HPCRewriter.GetRewriter(info.InstructionSet, info.SemanticModel);
|
||||
|
||||
foreach (var writer in rewriters)
|
||||
catch (Exception ex)
|
||||
{
|
||||
var rewrittenMethod = (MethodDeclarationSyntax)writer.Visit(info.MethodDeclaration);
|
||||
var newMethod = rewrittenMethod
|
||||
.WithIdentifier(SyntaxFactory.Identifier($"{info.MethodDeclaration.Identifier.Text}_{writer.Name}"));
|
||||
|
||||
var source = $@"
|
||||
using Misaki.HighPerformance.HPC;
|
||||
{writer.GetNesessaryUsing()}
|
||||
|
||||
namespace {info.MethodSymbol.ContainingNamespace.ToDisplayString()}
|
||||
{{
|
||||
partial class {info.MethodSymbol.ContainingType.Name}
|
||||
{{
|
||||
{newMethod.NormalizeWhitespace().ToFullString()}
|
||||
}}
|
||||
}}";
|
||||
context.AddSource($"{info.MethodSymbol.ContainingType.Name}_{info.MethodDeclaration.Identifier.Text}_{writer.Name}.g.cs", SourceText.From(source, Encoding.UTF8));
|
||||
// Surface analyzer errors as Roslyn diagnostics so the user
|
||||
// sees them in the IDE rather than a silent empty output.
|
||||
context.ReportDiagnostic(Diagnostic.Create(
|
||||
new DiagnosticDescriptor(
|
||||
id: "HPC0001",
|
||||
title: "HPC code generation failed",
|
||||
messageFormat: "Failed to generate HPC variant for '{0}': {1}",
|
||||
category: "HPCGenerator",
|
||||
defaultSeverity: DiagnosticSeverity.Error,
|
||||
isEnabledByDefault: true),
|
||||
info.MethodDeclaration.GetLocation(),
|
||||
info.MethodDeclaration.Identifier.Text,
|
||||
ex.Message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void GenerateSingleMethod(
|
||||
SourceProductionContext context,
|
||||
HPComputeMethodInfo info)
|
||||
{
|
||||
// ── Phase 1: Analyse ──────────────────────────────────────────────
|
||||
var analyzer = new HPCAnalyzer(info.SemanticModel);
|
||||
var ir = analyzer.Analyze(info.MethodDeclaration, info);
|
||||
|
||||
// ── Phase 2: Optimise ─────────────────────────────────────────────
|
||||
var optimizedIR = ir;
|
||||
foreach (var pass in GetPasses(info))
|
||||
optimizedIR = pass.Transform(optimizedIR);
|
||||
|
||||
// ── Phase 3: Emit per-target ──────────────────────────────────────
|
||||
foreach (var (backend, isa) in GetBackends(info.InstructionSet))
|
||||
{
|
||||
var methodSource = backend.EmitMethod(optimizedIR);
|
||||
var fullSource = WrapSource(methodSource, optimizedIR, backend);
|
||||
|
||||
context.AddSource(
|
||||
hintName: $"{ir.ContainingTypeName}_{ir.OriginalName}_{backend.Name}.g.cs",
|
||||
source: fullSource);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Backend selection ─────────────────────────────────────────────────
|
||||
|
||||
private static IEnumerable<(IHPCBackend backend, TargetInstructionSet isa)>
|
||||
GetBackends(TargetInstructionSet instructionSet)
|
||||
{
|
||||
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
|
||||
yield return (s_avx2Backend, TargetInstructionSet.AVX2);
|
||||
|
||||
// Future: SSE4, AVX512, NEON — add here without touching anything else
|
||||
}
|
||||
|
||||
// ── Source wrapping ───────────────────────────────────────────────────
|
||||
|
||||
private static string WrapSource(
|
||||
string methodBody,
|
||||
IR.HPCMethodIR ir,
|
||||
IHPCBackend backend)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
|
||||
sb.AppendLine("// <auto-generated/>");
|
||||
sb.AppendLine("#nullable enable");
|
||||
|
||||
foreach (var u in backend.RequiredUsings)
|
||||
sb.AppendLine(u);
|
||||
|
||||
sb.AppendLine("using Misaki.HighPerformance.HPC;");
|
||||
sb.AppendLine();
|
||||
|
||||
sb.AppendLine($"namespace {ir.ContainingNamespace}");
|
||||
sb.AppendLine("{");
|
||||
sb.AppendLine($" partial class {ir.ContainingTypeName}");
|
||||
sb.AppendLine(" {");
|
||||
sb.AppendLine(methodBody);
|
||||
sb.AppendLine(" }");
|
||||
sb.AppendLine("}");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
181
Misaki.HighPerformance.HPC.Generator/IR/HPCNodeRewriter.cs
Normal file
181
Misaki.HighPerformance.HPC.Generator/IR/HPCNodeRewriter.cs
Normal file
@@ -0,0 +1,181 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
using System.Linq;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.IR
|
||||
{
|
||||
/// <summary>
|
||||
/// Base class for passes that transform the IR tree.
|
||||
/// Uses the <em>immutable rewrite</em> pattern: each Visit method returns
|
||||
/// either the original node (if nothing changed) or a new <c>with</c>-expression
|
||||
/// copy, leaving the input tree untouched.
|
||||
/// </summary>
|
||||
internal abstract class HPCNodeRewriter
|
||||
{
|
||||
// ── Entry points ─────────────────────────────────────────────────────
|
||||
|
||||
public virtual HPCExpr RewriteExpr(HPCExpr expr) => expr switch
|
||||
{
|
||||
HPCBinaryOp n => RewriteBinaryOp(n),
|
||||
HPCUnaryOp n => RewriteUnaryOp(n),
|
||||
HPCIntrinsicCall n => RewriteIntrinsicCall(n),
|
||||
HPCPassThroughCall n => RewritePassThroughCall(n),
|
||||
HPCPropertyAccess n => RewritePropertyAccess(n),
|
||||
HPCVarRef n => RewriteVarRef(n),
|
||||
HPCLiteral n => RewriteLiteral(n),
|
||||
_ => expr
|
||||
};
|
||||
|
||||
public virtual HPCStmt RewriteStmt(HPCStmt stmt) => stmt switch
|
||||
{
|
||||
HPCVarDecl n => RewriteVarDecl(n),
|
||||
HPCAssignment n => RewriteAssignment(n),
|
||||
HPCExprStmt n => RewriteExprStmt(n),
|
||||
HPCReturn n => RewriteReturn(n),
|
||||
HPCIf n => RewriteIf(n),
|
||||
HPCForLoop n => RewriteForLoop(n),
|
||||
_ => stmt
|
||||
};
|
||||
|
||||
// ── Expression rewrites (override to modify specific patterns) ───────
|
||||
|
||||
protected virtual HPCExpr RewriteBinaryOp(HPCBinaryOp node)
|
||||
{
|
||||
var left = RewriteExpr(node.Left);
|
||||
var right = RewriteExpr(node.Right);
|
||||
return ReferenceEquals(left, node.Left) && ReferenceEquals(right, node.Right)
|
||||
? node
|
||||
: node with { Left = left, Right = right };
|
||||
}
|
||||
|
||||
protected virtual HPCExpr RewriteUnaryOp(HPCUnaryOp node)
|
||||
{
|
||||
var operand = RewriteExpr(node.Operand);
|
||||
return ReferenceEquals(operand, node.Operand)
|
||||
? node
|
||||
: node with { Operand = operand };
|
||||
}
|
||||
|
||||
protected virtual HPCExpr RewriteIntrinsicCall(HPCIntrinsicCall node)
|
||||
{
|
||||
var args = RewriteArgs(node.Args);
|
||||
return ReferenceEquals(args, node.Args)
|
||||
? node
|
||||
: node with { Args = args };
|
||||
}
|
||||
|
||||
protected virtual HPCExpr RewritePassThroughCall(HPCPassThroughCall node)
|
||||
{
|
||||
var args = RewriteArgs(node.Args);
|
||||
return ReferenceEquals(args, node.Args)
|
||||
? node
|
||||
: node with { Args = args };
|
||||
}
|
||||
|
||||
protected virtual HPCExpr RewritePropertyAccess(HPCPropertyAccess node)
|
||||
{
|
||||
var target = RewriteExpr(node.Target);
|
||||
return ReferenceEquals(target, node.Target)
|
||||
? node
|
||||
: node with { Target = target };
|
||||
}
|
||||
|
||||
protected virtual HPCExpr RewriteVarRef(HPCVarRef node) => node;
|
||||
|
||||
protected virtual HPCExpr RewriteLiteral(HPCLiteral node) => node;
|
||||
|
||||
// ── Statement rewrites ───────────────────────────────────────────────
|
||||
|
||||
protected virtual HPCStmt RewriteVarDecl(HPCVarDecl node)
|
||||
{
|
||||
var init = RewriteExpr(node.Initializer);
|
||||
return ReferenceEquals(init, node.Initializer)
|
||||
? node
|
||||
: node with { Initializer = init };
|
||||
}
|
||||
|
||||
protected virtual HPCStmt RewriteAssignment(HPCAssignment node)
|
||||
{
|
||||
var value = RewriteExpr(node.Value);
|
||||
return ReferenceEquals(value, node.Value)
|
||||
? node
|
||||
: node with { Value = value };
|
||||
}
|
||||
|
||||
protected virtual HPCStmt RewriteExprStmt(HPCExprStmt node)
|
||||
{
|
||||
var expr = RewriteExpr(node.Expression);
|
||||
return ReferenceEquals(expr, node.Expression)
|
||||
? node
|
||||
: node with { Expression = expr };
|
||||
}
|
||||
|
||||
protected virtual HPCStmt RewriteReturn(HPCReturn node)
|
||||
{
|
||||
if (node.Value is null) return node;
|
||||
var value = RewriteExpr(node.Value);
|
||||
return ReferenceEquals(value, node.Value)
|
||||
? node
|
||||
: node with { Value = value };
|
||||
}
|
||||
|
||||
protected virtual HPCStmt RewriteIf(HPCIf node)
|
||||
{
|
||||
var cond = RewriteExpr(node.Condition);
|
||||
var thenBody = RewriteBody(node.ThenBody);
|
||||
var elseBody = node.ElseBody is null ? null : RewriteBody(node.ElseBody);
|
||||
return ReferenceEquals(cond, node.Condition) &&
|
||||
ReferenceEquals(thenBody, node.ThenBody) &&
|
||||
ReferenceEquals(elseBody, node.ElseBody)
|
||||
? node
|
||||
: node with { Condition = cond, ThenBody = thenBody, ElseBody = elseBody };
|
||||
}
|
||||
|
||||
protected virtual HPCStmt RewriteForLoop(HPCForLoop node)
|
||||
{
|
||||
var iter = (HPCVarDecl)RewriteVarDecl(node.Iterator);
|
||||
var cond = RewriteExpr(node.Condition);
|
||||
var incr = RewriteExpr(node.Increment);
|
||||
var body = RewriteBody(node.Body);
|
||||
return ReferenceEquals(iter, node.Iterator) &&
|
||||
ReferenceEquals(cond, node.Condition) &&
|
||||
ReferenceEquals(incr, node.Increment) &&
|
||||
ReferenceEquals(body, node.Body)
|
||||
? node
|
||||
: node with { Iterator = iter, Condition = cond, Increment = incr, Body = body };
|
||||
}
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
protected HPCStmt[] RewriteBody(HPCStmt[] body)
|
||||
{
|
||||
HPCStmt[]? result = null;
|
||||
for (int i = 0; i < body.Length; i++)
|
||||
{
|
||||
var original = body[i];
|
||||
var rewritten = RewriteStmt(original);
|
||||
if (!ReferenceEquals(original, rewritten))
|
||||
{
|
||||
result ??= body.ToArray();
|
||||
result[i] = rewritten;
|
||||
}
|
||||
}
|
||||
return result ?? body;
|
||||
}
|
||||
|
||||
private HPCExpr[] RewriteArgs(HPCExpr[] args)
|
||||
{
|
||||
HPCExpr[]? result = null;
|
||||
for (int i = 0; i < args.Length; i++)
|
||||
{
|
||||
var original = args[i];
|
||||
var rewritten = RewriteExpr(original);
|
||||
if (!ReferenceEquals(original, rewritten))
|
||||
{
|
||||
result ??= args.ToArray();
|
||||
result[i] = rewritten;
|
||||
}
|
||||
}
|
||||
return result ?? args;
|
||||
}
|
||||
}
|
||||
}
|
||||
205
Misaki.HighPerformance.HPC.Generator/IR/HPCNodes.cs
Normal file
205
Misaki.HighPerformance.HPC.Generator/IR/HPCNodes.cs
Normal file
@@ -0,0 +1,205 @@
|
||||
using System;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.IR
|
||||
{
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
// Type system
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// <summary>
|
||||
/// A resolved primitive type inside the HPC IR (always the element scalar type).
|
||||
/// E.g. "float", "double", "int". Never a vector type name — the IR is
|
||||
/// element-type–centric; the backend decides the concrete vector width.
|
||||
/// </summary>
|
||||
internal sealed record HPCType(string ElementTypeName, bool IsFloatingPoint)
|
||||
{
|
||||
// Convenience singletons for the most common cases
|
||||
public static readonly HPCType Float = new("float", IsFloatingPoint: true);
|
||||
public static readonly HPCType Double = new("double", IsFloatingPoint: true);
|
||||
public static readonly HPCType Int = new("int", IsFloatingPoint: false);
|
||||
public static readonly HPCType UInt = new("uint", IsFloatingPoint: false);
|
||||
public static readonly HPCType Long = new("long", IsFloatingPoint: false);
|
||||
|
||||
public static HPCType FromElementName(string name) => name switch
|
||||
{
|
||||
"float" or "Single" => Float,
|
||||
"double" or "Double" => Double,
|
||||
"int" or "Int32" => Int,
|
||||
"uint" or "UInt32" => UInt,
|
||||
"long" or "Int64" => Long,
|
||||
_ => new(name, IsFloatingPoint: false)
|
||||
};
|
||||
|
||||
public override string ToString() => ElementTypeName;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
// Expression nodes
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// <summary>Base type for all HPC IR expression nodes.</summary>
|
||||
internal abstract record HPCExpr(HPCType Type);
|
||||
|
||||
/// <summary>A reference to a local variable or parameter by name.</summary>
|
||||
internal sealed record HPCVarRef(string Name, HPCType Type) : HPCExpr(Type)
|
||||
{
|
||||
public override string ToString() => Name;
|
||||
}
|
||||
|
||||
/// <summary>A scalar literal value broadcast to all lanes.</summary>
|
||||
internal sealed record HPCLiteral(string Value, HPCType Type) : HPCExpr(Type)
|
||||
{
|
||||
public override string ToString() => Value;
|
||||
}
|
||||
|
||||
// ── Binary ops ───────────────────────────────────────────────────────────
|
||||
|
||||
internal enum HPCBinaryKind
|
||||
{
|
||||
Add, Subtract, Multiply, Divide, Modulo,
|
||||
BitwiseAnd, BitwiseOr, BitwiseXor, ShiftLeft, ShiftRight,
|
||||
Equal, NotEqual,
|
||||
LessThan, LessThanOrEqual,
|
||||
GreaterThan, GreaterThanOrEqual,
|
||||
}
|
||||
|
||||
internal sealed record HPCBinaryOp(
|
||||
HPCBinaryKind Kind,
|
||||
HPCExpr Left,
|
||||
HPCExpr Right,
|
||||
HPCType Type) : HPCExpr(Type);
|
||||
|
||||
// ── Unary ops ────────────────────────────────────────────────────────────
|
||||
|
||||
internal enum HPCUnaryKind
|
||||
{
|
||||
Negate, BitwiseNot,
|
||||
// Math functions that map to a single ISPMDLane static method
|
||||
Abs, Sqrt, Floor, Ceil, Round, Trunc, Frac, Sign, Saturate,
|
||||
Rcp, Rsqrt,
|
||||
Sin, Cos, Tan, Asin, Acos, Atan,
|
||||
Exp, Exp2, Log, Log2,
|
||||
}
|
||||
|
||||
internal sealed record HPCUnaryOp(
|
||||
HPCUnaryKind Kind,
|
||||
HPCExpr Operand,
|
||||
HPCType Type) : HPCExpr(Type);
|
||||
|
||||
// ── Intrinsic calls (multi-argument) ────────────────────────────────────
|
||||
|
||||
internal enum HPCIntrinsic
|
||||
{
|
||||
// Arithmetic
|
||||
MultiplyAdd, // a * b + c (FMA)
|
||||
MultiplySubtract, // a * b - c
|
||||
|
||||
// Math
|
||||
Atan2, Pow,
|
||||
SinCos, // out sin, out cos simultaneously
|
||||
Lerp, Min, Max, Clamp, Select, CopySign,
|
||||
|
||||
// Reduction
|
||||
ReduceAdd, ReduceMax, ReduceMin,
|
||||
|
||||
// Memory
|
||||
Load, Store,
|
||||
MaskLoad, MaskStore,
|
||||
Gather, MaskGather,
|
||||
Scatter, MaskScatter,
|
||||
CompressStore,
|
||||
|
||||
// Conversion
|
||||
Cast, BitCast,
|
||||
}
|
||||
|
||||
internal sealed record HPCIntrinsicCall(
|
||||
HPCIntrinsic Intrinsic,
|
||||
HPCExpr[] Args,
|
||||
HPCType Type) : HPCExpr(Type);
|
||||
|
||||
/// <summary>
|
||||
/// A method call that the HPC pipeline does not recognise as an intrinsic —
|
||||
/// emitted verbatim so user code can still call arbitrary helpers.
|
||||
/// </summary>
|
||||
internal sealed record HPCPassThroughCall(
|
||||
string MethodName,
|
||||
HPCExpr[] Args,
|
||||
HPCType Type) : HPCExpr(Type);
|
||||
|
||||
/// <summary>Member-property access, e.g. <c>lane.LaneWidth</c> → <c>Count</c>.</summary>
|
||||
internal sealed record HPCPropertyAccess(
|
||||
HPCExpr Target,
|
||||
string PropertyName,
|
||||
HPCType Type) : HPCExpr(Type);
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
// Statement nodes
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// <summary>Base type for all HPC IR statement nodes.</summary>
|
||||
internal abstract record HPCStmt;
|
||||
|
||||
/// <summary>Local variable declaration with mandatory initialiser.</summary>
|
||||
internal sealed record HPCVarDecl(
|
||||
string Name,
|
||||
HPCType Type,
|
||||
HPCExpr Initializer) : HPCStmt;
|
||||
|
||||
/// <summary>Assignment to an existing variable (including out-parameters).</summary>
|
||||
internal sealed record HPCAssignment(string Target, HPCExpr Value) : HPCStmt;
|
||||
|
||||
/// <summary>Expression evaluated purely for its side-effect (e.g. Store calls).</summary>
|
||||
internal sealed record HPCExprStmt(HPCExpr Expression) : HPCStmt;
|
||||
|
||||
/// <summary>Return statement.</summary>
|
||||
internal sealed record HPCReturn(HPCExpr? Value) : HPCStmt;
|
||||
|
||||
/// <summary>
|
||||
/// Conditional branch. In a vectorised context the backend may lower this
|
||||
/// to a predicated Select or a masked execution block.
|
||||
/// </summary>
|
||||
internal sealed record HPCIf(
|
||||
HPCExpr Condition,
|
||||
HPCStmt[] ThenBody,
|
||||
HPCStmt[]? ElseBody) : HPCStmt;
|
||||
|
||||
/// <summary>Simple counted for-loop.</summary>
|
||||
internal sealed record HPCForLoop(
|
||||
HPCVarDecl Iterator,
|
||||
HPCExpr Condition,
|
||||
HPCExpr Increment,
|
||||
HPCStmt[] Body) : HPCStmt;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
// Method-level IR
|
||||
// ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
internal sealed record HPCParameter(
|
||||
string Name,
|
||||
HPCType Type,
|
||||
bool IsOut,
|
||||
bool IsRef);
|
||||
|
||||
/// <summary>
|
||||
/// The complete IR representation of a single <c>[HPCompute]</c>-annotated method,
|
||||
/// fully detached from Roslyn syntax after the analysis phase.
|
||||
/// </summary>
|
||||
internal sealed record HPCMethodIR
|
||||
{
|
||||
public required string Name { get; init; }
|
||||
public required HPCType ReturnType { get; init; }
|
||||
public required HPCParameter[] Parameters { get; init; }
|
||||
public required HPCStmt[] Body { get; init; }
|
||||
|
||||
// Compilation metadata from the attribute
|
||||
public required TargetInstructionSet TargetISA { get; init; }
|
||||
public required FloatPrecision Precision { get; init; }
|
||||
public required MathMode Mode { get; init; }
|
||||
|
||||
// Emission metadata
|
||||
public required string ContainingNamespace { get; init; }
|
||||
public required string ContainingTypeName { get; init; }
|
||||
public required string OriginalName { get; init; }
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
<Nullable>enable</Nullable>
|
||||
<EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
|
||||
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
|
||||
<LangVersion>9.0</LangVersion>
|
||||
<LangVersion>latest</LangVersion>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
using System.Linq;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Optimization
|
||||
{
|
||||
/// <summary>
|
||||
/// Detects <c>(a * b) + c</c> and <c>c + (a * b)</c> binary expression patterns
|
||||
/// in the IR and replaces them with <see cref="HPCIntrinsic.MultiplyAdd"/> calls,
|
||||
/// enabling the backend to emit a single FMA instruction instead of two.
|
||||
///
|
||||
/// <para>This pass only runs when the target ISA supports FMA (e.g. AVX2 which
|
||||
/// includes FMA3 via the FMA extension, and AVX-512F). The pipeline checks
|
||||
/// <see cref="HPCMethodIR.TargetISA"/> before adding the pass.</para>
|
||||
/// </summary>
|
||||
internal sealed class FMAFusionPass : HPCNodeRewriter, IHPCOptimizationPass
|
||||
{
|
||||
public string Name => "FMA Fusion";
|
||||
|
||||
public HPCMethodIR Transform(HPCMethodIR method)
|
||||
{
|
||||
var newBody = RewriteBody(method.Body);
|
||||
return ReferenceEquals(newBody, method.Body)
|
||||
? method
|
||||
: method with { Body = newBody };
|
||||
}
|
||||
|
||||
protected override HPCExpr RewriteBinaryOp(HPCBinaryOp node)
|
||||
{
|
||||
if (node.Kind == HPCBinaryKind.Add)
|
||||
{
|
||||
// (a * b) + c → FMA(a, b, c)
|
||||
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
|
||||
{
|
||||
return new HPCIntrinsicCall(
|
||||
HPCIntrinsic.MultiplyAdd,
|
||||
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
|
||||
node.Type);
|
||||
}
|
||||
|
||||
// c + (a * b) → FMA(a, b, c)
|
||||
if (node.Right is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulRight)
|
||||
{
|
||||
return new HPCIntrinsicCall(
|
||||
HPCIntrinsic.MultiplyAdd,
|
||||
[RewriteExpr(mulRight.Left), RewriteExpr(mulRight.Right), RewriteExpr(node.Left)],
|
||||
node.Type);
|
||||
}
|
||||
}
|
||||
|
||||
if (node.Kind == HPCBinaryKind.Subtract)
|
||||
{
|
||||
// (a * b) - c → FMA(a, b, -c) ... or MultiplySubtract if available
|
||||
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
|
||||
{
|
||||
return new HPCIntrinsicCall(
|
||||
HPCIntrinsic.MultiplySubtract,
|
||||
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
|
||||
node.Type);
|
||||
}
|
||||
}
|
||||
|
||||
return base.RewriteBinaryOp(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
using Misaki.HighPerformance.HPC.Generator.IR;
|
||||
|
||||
namespace Misaki.HighPerformance.HPC.Generator.Optimization
|
||||
{
|
||||
/// <summary>
|
||||
/// A single optimization pass that transforms an <see cref="HPCMethodIR"/>
|
||||
/// into another <see cref="HPCMethodIR"/>. Passes must be pure functions —
|
||||
/// they must not mutate the input and should return the original object
|
||||
/// unchanged when nothing was modified (enabling cheap "no-change" detection
|
||||
/// in the pipeline).
|
||||
/// </summary>
|
||||
internal interface IHPCOptimizationPass
|
||||
{
|
||||
/// <summary>Human-readable name for diagnostics.</summary>
|
||||
string Name { get; }
|
||||
|
||||
/// <summary>Transforms the IR. Returns the same object if unchanged.</summary>
|
||||
HPCMethodIR Transform(HPCMethodIR method);
|
||||
}
|
||||
}
|
||||
31
Misaki.HighPerformance.HPC.Generator/Polyfills.cs
Normal file
31
Misaki.HighPerformance.HPC.Generator/Polyfills.cs
Normal file
@@ -0,0 +1,31 @@
|
||||
// Polyfills needed for modern C# features on netstandard2.0 target.
|
||||
// These types are built into .NET 5+ runtimes but must be declared manually
|
||||
// when targeting netstandard2.0, which source generators must do.
|
||||
|
||||
namespace System.Runtime.CompilerServices
|
||||
{
|
||||
// Enables `init` accessors and `record` types (C# 9)
|
||||
internal static class IsExternalInit { }
|
||||
|
||||
// Enables `required` members (C# 11)
|
||||
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct |
|
||||
AttributeTargets.Field | AttributeTargets.Property,
|
||||
AllowMultiple = false, Inherited = false)]
|
||||
internal sealed class RequiredMemberAttribute : Attribute { }
|
||||
|
||||
[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
|
||||
internal sealed class CompilerFeatureRequiredAttribute : Attribute
|
||||
{
|
||||
public CompilerFeatureRequiredAttribute(string featureName)
|
||||
=> FeatureName = featureName;
|
||||
public string FeatureName { get; }
|
||||
public bool IsOptional { get; init; }
|
||||
}
|
||||
}
|
||||
|
||||
namespace System.Diagnostics.CodeAnalysis
|
||||
{
|
||||
// Enables `required` constructor attribution (C# 11)
|
||||
[AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)]
|
||||
internal sealed class SetsRequiredMembersAttribute : Attribute { }
|
||||
}
|
||||
Reference in New Issue
Block a user