Compare commits

...

4 Commits

Author SHA1 Message Date
e98ae96dd6 Refactor: switch to IR-based HPC codegen pipeline
Major rewrite replacing Roslyn syntax rewriters with an intermediate representation (IR) architecture. Adds IR nodes, analyzer, type resolver, and node rewriter base for optimization passes (e.g., FMA fusion). Refactors AVX2 backend to emit from IR and updates generator pipeline for analysis, optimization, and emission. Removes legacy rewriter classes. Adds polyfills for modern C# features and updates tests and project settings for latest language version. This enables advanced optimizations, easier backend targeting, and future extensibility.
2026-05-07 00:21:08 +09:00
b9537d91da Refactor APIContext, add AVX2 sin/cos codegen, FMA rewrite
- Refactored namespaces from .VectorAPI to .APIContext for clarity.
- Enhanced Avx2APIContext/IVectorAPIContext to support void returns.
- Added GenerateSinCosUtilityMethods for AVX2, emitting vectorized Sin/Cos/SinCos for float/double.
- Introduced HPCOptimizerRewriter for advanced SPMD type handling.
- Refactored HPCRewriter to use SemanticModel, support FMA pattern rewriting, and delegate SPMD logic.
- Updated AVX2Rewriter for new base and improved math mapping.
- Made UtilityTemplate generic and type-safe for sin/cos.
- Updated NoiseJob3D/NoiseJobVector for [HPCompute] attribute and partial struct.
- Fixed solution file project ordering and inclusion.
2026-05-06 22:27:24 +09:00
fd2d60c8f1 Refactor vector API codegen and WideLane conversions
- Introduce IVectorAPIContext abstraction and supporting types for vectorized code generation
- Add Avx2APIContext and UtilityTemplate for AVX2-specific code emission
- Dynamically generate AVX2 sine methods in AVX2Rewriter
- Refactor WideLane<TNumber> to use Unsafe.BitCast for all Vector conversions
- Update all WideLane operators and math methods to use Unsafe.BitCast
- Change MultiplyAdd parameter names for clarity
- Remove static indices field in favor of Vector<TNumber>.Indices
- Add implicit conversion from Vector<TNumber> to WideLane<TNumber>
- Update tests and program files for compatibility
2026-05-06 19:20:15 +09:00
c8f78f9d02 Refactor SPMD to HPC; add SIMD source generators
Major namespace migration from SPMD to HPC across all code, templates, and projects. Introduced Misaki.HighPerformance.HPC.Generator with Roslyn-based source generators for SIMD code (e.g., AVX2), including attribute and method generators. Renamed MultipleAdd to MultiplyAdd in all lanes and updated usages. Added AVX2 utility methods via codegen. Updated tests, benchmarks, and project references to use the new framework. Improved SIMD memory utilities and modernized project files. Removed legacy SPMD project from the solution.
2026-05-06 13:43:58 +09:00
48 changed files with 2639 additions and 196 deletions

View File

@@ -0,0 +1,119 @@
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Avx2APIContext : IVectorAPIContext
{
private readonly List<string> _statements = new();
private int _varCount = 0;
private string? _lastAssignedVariable;
public string? LastAssignedVariable => _lastAssignedVariable;
public string GetVectorType()
{
return "Vector256";
}
public string GetVectorType<T>()
{
return $"Vector256<{VectorAPIContext.GetTypeName<T>()}>";
}
public Expression Call(string methodName, params string[] args)
{
return new Expression(this, $"{methodName}({string.Join(", ", args)})");
}
public Expression Assign(Expression expr, string? varName = null, bool isNew = true)
{
varName ??= $"v{_varCount++}";
var statement = isNew ? $"var {varName} = {expr.Code};" : $"{varName} = {expr.Code};";
_statements.Add(statement);
_lastAssignedVariable = varName;
expr.Clear();
return new Expression(this, varName);
}
public Code Return(Expression? expr)
{
if (expr != null)
{
var statement = $"return {expr.Code};";
_statements.Add(statement);
expr.Clear();
}
var fullCode = new Code(_statements);
Reset();
return fullCode;
}
public Expression Create(string value)
{
return new Expression(this, $"{GetVectorType()}.Create({value})");
}
public Expression Zero<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Zero");
}
public Expression One<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.One");
}
public Expression Count<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Count");
}
public Expression Add(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Add({a}, {b})");
}
public Expression Multiply(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Multiply({a}, {b})");
}
public Expression Subtract(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Subtract({a}, {b})");
}
public Expression Divide(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Divide({a}, {b})");
}
public Expression MultiplyAdd(Expression left, Expression right, Expression addend)
{
return new Expression(this, $"Fma.MultiplyAdd({left}, {right}, {addend})");
}
public Expression Round(Expression value)
{
return new Expression(this, $"Avx2.RoundToNearestInteger({value})");
}
public Expression Abs(Expression value)
{
return new Expression(this, $"{GetVectorType()}.Abs({value})");
}
public void Reset()
{
_statements.Clear();
_varCount = 0;
}
}
}

View File

@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Expression
{
public IVectorAPIContext API
{
get;
}
public string Code
{
get; private set;
}
public Expression(IVectorAPIContext api, string code)
{
API = api;
Code = code;
}
public Expression Assign(string? varName = null, bool isNew = true)
{
return API.Assign(this, varName, isNew);
}
public void Clear()
{
Code = string.Empty;
}
public override string ToString()
{
return Code;
}
public static Expression operator +(Expression a, Expression b)
{
return a.API.Add(a, b);
}
public static Expression operator -(Expression a, Expression b)
{
return a.API.Subtract(a, b);
}
public static Expression operator *(Expression a, Expression b)
{
return a.API.Multiply(a, b);
}
public static Expression operator /(Expression a, Expression b)
{
return a.API.Divide(a, b);
}
}
internal record Code
{
private readonly string[] _statements;
public Code(IEnumerable<string> statements)
{
_statements = statements.ToArray();
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
foreach (var stmt in _statements)
{
sb.AppendLine(lineIndentation + stmt);
}
return sb.ToString();
}
}
internal record Method
{
public string Modifier
{
get;
}
public string ReturnType
{
get;
}
public string Name
{
get;
}
public string[] Parameters
{
get;
}
public Code Body
{
get;
}
public Method(string modifier, string returnType, string name, string[] parameters, Code body)
{
Modifier = modifier;
ReturnType = returnType;
Name = name;
Parameters = parameters;
Body = body;
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
sb.AppendLine(lineIndentation + $"{Modifier} {ReturnType} {Name}({string.Join(", ", Parameters)})");
sb.AppendLine(lineIndentation + "{");
sb.Append(Body.GetFullCode(lineIndentation + " "));
sb.AppendLine(lineIndentation + "}");
return sb.ToString();
}
}
internal interface IVectorAPIContext
{
string? LastAssignedVariable
{
get;
}
string GetVectorType();
string GetVectorType<T>();
Expression Call(string methodName, params string[] args);
Expression Assign(Expression expr, string? varName = null, bool isNew = true);
Code Return(Expression? expr);
Expression Create(string value);
Expression Zero<T>();
Expression One<T>();
Expression Count<T>();
Expression Add(Expression a, Expression b);
Expression Subtract(Expression a, Expression b);
Expression Multiply(Expression a, Expression b);
Expression Divide(Expression a, Expression b);
Expression MultiplyAdd(Expression left, Expression right, Expression addend);
Expression Round(Expression value);
Expression Abs(Expression value);
void Reset();
}
internal static class VectorAPIContext
{
public static string GetTypeName<T>()
{
return typeof(T) switch
{
_ when typeof(T) == typeof(float) => "float",
_ when typeof(T) == typeof(double) => "double",
_ when typeof(T) == typeof(byte) => "byte",
_ when typeof(T) == typeof(short) => "short",
_ when typeof(T) == typeof(int) => "int",
_ when typeof(T) == typeof(uint) => "uint",
_ when typeof(T) == typeof(long) => "long",
_ when typeof(T) == typeof(ulong) => "ulong",
_ => throw new NotSupportedException($"Type {typeof(T)} is not supported in vector operations.")
};
}
}
}

View File

@@ -0,0 +1,43 @@
using Microsoft.CodeAnalysis;
using Misaki.HighPerformance.HPC.Generator.APIContext;
namespace Misaki.HighPerformance.HPC.Generator
{
/// <summary>
/// Generates the <c>AVX2Utility</c> static class containing polynomial
/// approximations for transcendental functions (Sin, Cos, SinCos, etc.)
/// that have no built-in AVX2 hardware intrinsic.
///
/// <para>These methods are called by the <c>AVX2Backend</c> emitter when it
/// encounters <see cref="IR.HPCUnaryKind.Sin"/>, <see cref="IR.HPCUnaryKind.Cos"/>,
/// and similar IR nodes.</para>
/// </summary>
[Generator]
public class AVX2UtilityGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var api = new Avx2APIContext();
var sinCosMethods = UtilityTemplate.GenerateSinCosUtilityMethods(api, " ");
var source = @$"
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.HPC
{{
public static class AVX2Utility
{{
{sinCosMethods}
}}
}}";
ctx.AddSource("AVX2Utility.g.cs", source);
});
}
}
}

View File

@@ -0,0 +1,461 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.IR;
using System;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.Analysis
{
/// <summary>
/// Walks a Roslyn <see cref="MethodDeclarationSyntax"/> and produces an
/// <see cref="HPCMethodIR"/> that is completely independent of Roslyn after
/// this phase completes.
///
/// <para>Uses <see cref="CSharpSyntaxWalker"/> (read-only walk) rather than
/// <see cref="CSharpSyntaxRewriter"/> so that the semantic model is never
/// consulted during rewriting — all queries happen here, once, upfront.</para>
/// </summary>
internal sealed class HPCAnalyzer : CSharpSyntaxWalker
{
// ── State ─────────────────────────────────────────────────────────────
private readonly HPCTypeResolver _typeResolver;
/// <summary>Statement stack — top is the current block being built.</summary>
private readonly Stack<List<HPCStmt>> _blocks = new();
// ── Construction ─────────────────────────────────────────────────────
public HPCAnalyzer(SemanticModel semanticModel)
{
_typeResolver = new HPCTypeResolver(semanticModel);
}
// ── Public API ───────────────────────────────────────────────────────
/// <summary>
/// Analyses <paramref name="method"/> and returns the completed IR.
/// </summary>
public HPCMethodIR Analyze(MethodDeclarationSyntax method, HPComputeMethodInfo info)
{
_typeResolver.RegisterConstraints(method);
_blocks.Clear();
// Push the root block
_blocks.Push(new List<HPCStmt>());
if (method.Body is not null)
Visit(method.Body);
else if (method.ExpressionBody is not null)
{
// Expression-bodied method: treat as `return <expr>;`
var expr = AnalyzeExpression(method.ExpressionBody.Expression);
var returnType = _typeResolver.Resolve(method.ExpressionBody.Expression)
?? HPCType.Float;
Emit(new HPCReturn(expr));
}
var body = _blocks.Pop().ToArray();
return new HPCMethodIR
{
Name = method.Identifier.Text,
OriginalName = method.Identifier.Text,
ReturnType = ResolveReturnType(method),
Parameters = BuildParameters(method),
Body = body,
TargetISA = info.InstructionSet,
Precision = info.Precision,
Mode = info.Mode,
ContainingNamespace = info.MethodSymbol.ContainingNamespace.ToDisplayString(),
ContainingTypeName = info.MethodSymbol.ContainingType.Name,
};
}
// ── Statement visitors ───────────────────────────────────────────────
public override void VisitBlock(BlockSyntax node)
{
// Visit each statement inside the block in order
foreach (var stmt in node.Statements)
Visit(stmt);
}
public override void VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
{
foreach (var declarator in node.Declaration.Variables)
{
if (declarator.Initializer is null)
continue;
var init = AnalyzeExpression(declarator.Initializer.Value);
var hpcType = _typeResolver.Resolve(declarator.Initializer.Value)
?? InferFromExpr(init);
Emit(new HPCVarDecl(declarator.Identifier.Text, hpcType, init));
}
}
public override void VisitExpressionStatement(ExpressionStatementSyntax node)
{
if (node.Expression is AssignmentExpressionSyntax assign)
{
var target = assign.Left.ToString(); // simple name or member access text
var value = AnalyzeExpression(assign.Right);
Emit(new HPCAssignment(target, value));
}
else
{
var expr = AnalyzeExpression(node.Expression);
Emit(new HPCExprStmt(expr));
}
}
public override void VisitReturnStatement(ReturnStatementSyntax node)
{
var value = node.Expression is null ? null : AnalyzeExpression(node.Expression);
Emit(new HPCReturn(value));
}
public override void VisitIfStatement(IfStatementSyntax node)
{
var condition = AnalyzeExpression(node.Condition);
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var thenBody = _blocks.Pop().ToArray();
HPCStmt[]? elseBody = null;
if (node.Else is not null)
{
_blocks.Push(new List<HPCStmt>());
Visit(node.Else.Statement);
elseBody = _blocks.Pop().ToArray();
}
Emit(new HPCIf(condition, thenBody, elseBody));
}
public override void VisitForStatement(ForStatementSyntax node)
{
// Only handles the simple canonical for (var i = init; i <cond>; i++) form.
if (node.Declaration?.Variables.Count != 1 ||
node.Condition is null ||
node.Incrementors.Count != 1)
{
// Fall back: treat body as pass-through
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var rawBody = _blocks.Pop().ToArray();
foreach (var s in rawBody)
Emit(s);
return;
}
var decl = node.Declaration.Variables[0];
var initExpr = decl.Initializer is not null
? AnalyzeExpression(decl.Initializer.Value)
: new HPCLiteral("0", HPCType.Int);
var iterDecl = new HPCVarDecl(decl.Identifier.Text, HPCType.Int, initExpr);
var cond = AnalyzeExpression(node.Condition);
var incr = AnalyzeExpression(node.Incrementors[0]);
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var body = _blocks.Pop().ToArray();
Emit(new HPCForLoop(iterDecl, cond, incr, body));
}
// ── Expression analysis ───────────────────────────────────────────────
private HPCExpr AnalyzeExpression(ExpressionSyntax expr) => expr switch
{
BinaryExpressionSyntax n => AnalyzeBinary(n),
PrefixUnaryExpressionSyntax n => AnalyzePrefixUnary(n),
PostfixUnaryExpressionSyntax n => AnalyzePostfixUnary(n),
InvocationExpressionSyntax n => AnalyzeInvocation(n),
MemberAccessExpressionSyntax n => AnalyzeMemberAccess(n),
LiteralExpressionSyntax n => AnalyzeLiteral(n),
IdentifierNameSyntax n => AnalyzeIdentifier(n),
ParenthesizedExpressionSyntax n => AnalyzeExpression(n.Expression),
CastExpressionSyntax n => AnalyzeCast(n),
_ => FallbackExpr(expr)
};
private HPCExpr AnalyzeBinary(BinaryExpressionSyntax node)
{
var left = AnalyzeExpression(node.Left);
var right = AnalyzeExpression(node.Right);
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(left);
var kind = node.Kind() switch
{
SyntaxKind.AddExpression => HPCBinaryKind.Add,
SyntaxKind.SubtractExpression => HPCBinaryKind.Subtract,
SyntaxKind.MultiplyExpression => HPCBinaryKind.Multiply,
SyntaxKind.DivideExpression => HPCBinaryKind.Divide,
SyntaxKind.ModuloExpression => HPCBinaryKind.Modulo,
SyntaxKind.BitwiseAndExpression => HPCBinaryKind.BitwiseAnd,
SyntaxKind.BitwiseOrExpression => HPCBinaryKind.BitwiseOr,
SyntaxKind.ExclusiveOrExpression => HPCBinaryKind.BitwiseXor,
SyntaxKind.LeftShiftExpression => HPCBinaryKind.ShiftLeft,
SyntaxKind.RightShiftExpression => HPCBinaryKind.ShiftRight,
SyntaxKind.EqualsExpression => HPCBinaryKind.Equal,
SyntaxKind.NotEqualsExpression => HPCBinaryKind.NotEqual,
SyntaxKind.LessThanExpression => HPCBinaryKind.LessThan,
SyntaxKind.LessThanOrEqualExpression => HPCBinaryKind.LessThanOrEqual,
SyntaxKind.GreaterThanExpression => HPCBinaryKind.GreaterThan,
SyntaxKind.GreaterThanOrEqualExpression => HPCBinaryKind.GreaterThanOrEqual,
_ => throw new NotSupportedException($"Binary operator not supported: {node.Kind()}")
};
return new HPCBinaryOp(kind, left, right, hpcType);
}
private HPCExpr AnalyzePrefixUnary(PrefixUnaryExpressionSyntax node)
{
var operand = AnalyzeExpression(node.Operand);
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(operand);
var kind = node.Kind() switch
{
SyntaxKind.UnaryMinusExpression => HPCUnaryKind.Negate,
SyntaxKind.BitwiseNotExpression => HPCUnaryKind.BitwiseNot,
_ => throw new NotSupportedException($"Prefix unary not supported: {node.Kind()}")
};
return new HPCUnaryOp(kind, operand, hpcType);
}
private HPCExpr AnalyzePostfixUnary(PostfixUnaryExpressionSyntax node)
{
// i++ / i-- — treat as i + 1 / i - 1 in an IR context (no mutation semantics)
var operand = AnalyzeExpression(node.Operand);
var hpcType = InferFromExpr(operand);
var one = new HPCLiteral("1", hpcType);
return node.Kind() == SyntaxKind.PostIncrementExpression
? new HPCBinaryOp(HPCBinaryKind.Add, operand, one, hpcType)
: new HPCBinaryOp(HPCBinaryKind.Subtract, operand, one, hpcType);
}
private HPCExpr AnalyzeInvocation(InvocationExpressionSyntax node)
{
// ── Instance method on SPMD lane: lane.Add(b), lane.MultiplyAdd(a,b,c) ──
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
{
var receiverType = _typeResolver.Resolve(memberAccess.Expression);
var methodName = memberAccess.Name.Identifier.Text;
if (receiverType != null)
{
// Unary math on the receiver: lane.Abs(), lane.Sqrt(), etc.
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
node.ArgumentList.Arguments.Count == 0)
{
var receiver = AnalyzeExpression(memberAccess.Expression);
return new HPCUnaryOp(unaryKind, receiver, receiverType);
}
// Multi-arg static-style methods called as instance methods via ISPMDLane
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
{
var args = new List<HPCExpr> { AnalyzeExpression(memberAccess.Expression) };
foreach (var arg in node.ArgumentList.Arguments)
args.Add(AnalyzeExpression(arg.Expression));
return new HPCIntrinsicCall(intrinsic, args.ToArray(), receiverType);
}
}
// Static call on the lane type: TSelf.Sin(v), TSelf.Load(ref p), etc.
if (memberAccess.Expression is IdentifierNameSyntax typeIdent &&
_typeResolver.IsSPMDTypeParam(typeIdent.Identifier.Text))
{
var elemType = _typeResolver.ResolveByName(typeIdent.Identifier.Text) ?? HPCType.Float;
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
node.ArgumentList.Arguments.Count == 1)
{
var arg = AnalyzeExpression(node.ArgumentList.Arguments[0].Expression);
return new HPCUnaryOp(unaryKind, arg, elemType);
}
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
{
var args = new List<HPCExpr>();
foreach (var arg in node.ArgumentList.Arguments)
args.Add(AnalyzeExpression(arg.Expression));
return new HPCIntrinsicCall(intrinsic, args.ToArray(), elemType);
}
}
}
// ── Fall through: unknown call, emit as pass-through ──────────────
return FallbackExpr(node);
}
private HPCExpr AnalyzeMemberAccess(MemberAccessExpressionSyntax node)
{
var receiverType = _typeResolver.Resolve(node.Expression);
var memberName = node.Name.Identifier.Text;
if (receiverType != null)
{
// LaneWidth → Count (property remap)
var mapped = memberName switch
{
"LaneWidth" => "Count",
_ => memberName
};
var receiver = AnalyzeExpression(node.Expression);
return new HPCPropertyAccess(receiver, mapped, receiverType);
}
return FallbackExpr(node);
}
private static HPCExpr AnalyzeLiteral(LiteralExpressionSyntax node)
{
var text = node.Token.ValueText;
var hpcType = node.Kind() switch
{
SyntaxKind.NumericLiteralExpression when node.Token.Value is float => HPCType.Float,
SyntaxKind.NumericLiteralExpression when node.Token.Value is double => HPCType.Double,
SyntaxKind.NumericLiteralExpression when node.Token.Value is int => HPCType.Int,
SyntaxKind.NumericLiteralExpression when node.Token.Value is long => HPCType.Long,
_ => HPCType.Float
};
return new HPCLiteral(text, hpcType);
}
private HPCExpr AnalyzeIdentifier(IdentifierNameSyntax node)
{
var name = node.Identifier.Text;
var hpcType = _typeResolver.Resolve(node)
?? _typeResolver.ResolveByName(name)
?? HPCType.Float;
return new HPCVarRef(name, hpcType);
}
private HPCExpr AnalyzeCast(CastExpressionSyntax node)
{
var operand = AnalyzeExpression(node.Expression);
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
return new HPCIntrinsicCall(HPCIntrinsic.Cast, [operand], hpcType);
}
private HPCExpr FallbackExpr(ExpressionSyntax node)
{
// Preserve anything we don't understand as a verbatim pass-through
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
return new HPCPassThroughCall(
node.ToString(),
Array.Empty<HPCExpr>(),
hpcType);
}
// ── Method mapping tables ─────────────────────────────────────────────
private static readonly Dictionary<string, HPCUnaryKind> s_unaryMethods = new()
{
["Abs"] = HPCUnaryKind.Abs,
["Sqrt"] = HPCUnaryKind.Sqrt,
["Floor"] = HPCUnaryKind.Floor,
["Ceil"] = HPCUnaryKind.Ceil,
["Round"] = HPCUnaryKind.Round,
["Trunc"] = HPCUnaryKind.Trunc,
["Frac"] = HPCUnaryKind.Frac,
["Sign"] = HPCUnaryKind.Sign,
["Saturate"] = HPCUnaryKind.Saturate,
["Rcp"] = HPCUnaryKind.Rcp,
["Rsqrt"] = HPCUnaryKind.Rsqrt,
["Sin"] = HPCUnaryKind.Sin,
["Cos"] = HPCUnaryKind.Cos,
["Tan"] = HPCUnaryKind.Tan,
["Asin"] = HPCUnaryKind.Asin,
["Acos"] = HPCUnaryKind.Acos,
["Atan"] = HPCUnaryKind.Atan,
["Exp"] = HPCUnaryKind.Exp,
["Exp2"] = HPCUnaryKind.Exp2,
["Log"] = HPCUnaryKind.Log,
["Log2"] = HPCUnaryKind.Log2,
};
private static readonly Dictionary<string, HPCIntrinsic> s_intrinsicMethods = new()
{
["MultiplyAdd"] = HPCIntrinsic.MultiplyAdd,
["MultiplySubtract"] = HPCIntrinsic.MultiplySubtract,
["Atan2"] = HPCIntrinsic.Atan2,
["Pow"] = HPCIntrinsic.Pow,
["SinCos"] = HPCIntrinsic.SinCos,
["Lerp"] = HPCIntrinsic.Lerp,
["Min"] = HPCIntrinsic.Min,
["Max"] = HPCIntrinsic.Max,
["Clamp"] = HPCIntrinsic.Clamp,
["Select"] = HPCIntrinsic.Select,
["CopySign"] = HPCIntrinsic.CopySign,
["ReduceAdd"] = HPCIntrinsic.ReduceAdd,
["ReduceMax"] = HPCIntrinsic.ReduceMax,
["ReduceMin"] = HPCIntrinsic.ReduceMin,
["Load"] = HPCIntrinsic.Load,
["MaskLoad"] = HPCIntrinsic.MaskLoad,
["Gather"] = HPCIntrinsic.Gather,
["MaskGather"] = HPCIntrinsic.MaskGather,
["Store"] = HPCIntrinsic.Store,
["MaskStore"] = HPCIntrinsic.MaskStore,
["Scatter"] = HPCIntrinsic.Scatter,
["MaskScatter"] = HPCIntrinsic.MaskScatter,
["CompressStore"] = HPCIntrinsic.CompressStore,
};
private static bool TryMapUnaryMethod(string name, out HPCUnaryKind kind)
=> s_unaryMethods.TryGetValue(name, out kind);
private static bool TryMapIntrinsicMethod(string name, out HPCIntrinsic intrinsic)
=> s_intrinsicMethods.TryGetValue(name, out intrinsic);
// ── Helpers ───────────────────────────────────────────────────────────
private void Emit(HPCStmt stmt) => _blocks.Peek().Add(stmt);
private static HPCType InferFromExpr(HPCExpr expr) => expr.Type;
private HPCType ResolveReturnType(MethodDeclarationSyntax method)
{
var typeText = method.ReturnType.ToString();
if (typeText == "void")
return new HPCType("void", IsFloatingPoint: false);
return _typeResolver.ResolveByName(typeText)
?? HPCType.FromElementName(typeText);
}
private HPCParameter[] BuildParameters(MethodDeclarationSyntax method)
{
var result = new List<HPCParameter>();
foreach (var param in method.ParameterList.Parameters)
{
// Skip generic SPMD type parameters (they become the vector type in the backend)
var paramTypeName = param.Type?.ToString() ?? "object";
if (_typeResolver.IsSPMDTypeParam(paramTypeName))
continue;
var hpcType = _typeResolver.ResolveByName(paramTypeName)
?? HPCType.FromElementName(paramTypeName);
bool isOut = false, isRef = false;
foreach (var mod in param.Modifiers)
{
if (mod.IsKind(SyntaxKind.OutKeyword))
isOut = true;
if (mod.IsKind(SyntaxKind.RefKeyword))
isRef = true;
}
result.Add(new HPCParameter(param.Identifier.Text, hpcType, isOut, isRef));
}
return result.ToArray();
}
}
}

View File

@@ -0,0 +1,133 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.Analysis
{
/// <summary>
/// Resolves HPC-specific types from Roslyn's semantic model.
/// Centralises all "what is the scalar element type of this expression?"
/// logic that was previously duplicated across <c>HPCRewriter</c> and
/// <c>HPCOptimizerRewriter</c>.
/// </summary>
internal sealed class HPCTypeResolver
{
private readonly SemanticModel _semanticModel;
/// <summary>
/// Maps generic type-parameter names (e.g. <c>"TLane0"</c>) to their
/// resolved scalar element types (e.g. <c>"float"</c>).
/// Populated by <see cref="RegisterConstraints"/>.
/// </summary>
private readonly Dictionary<string, string> _typeParamToPrimitive = new();
public HPCTypeResolver(SemanticModel semanticModel)
{
_semanticModel = semanticModel;
}
// ── Type-parameter registration ───────────────────────────────────────
/// <summary>
/// Scans the generic constraints on <paramref name="method"/> and
/// registers every type parameter constrained to
/// <c>ISPMDLane&lt;TSelf, TNumber&gt;</c>.
/// Must be called before <see cref="Resolve"/> is used on the method body.
/// </summary>
public void RegisterConstraints(MethodDeclarationSyntax method)
{
_typeParamToPrimitive.Clear();
foreach (var clause in method.ConstraintClauses)
{
var typeParamName = clause.Name.Identifier.Text;
foreach (var constraint in clause.Constraints)
{
if (constraint is TypeConstraintSyntax typeConstraint &&
typeConstraint.Type is GenericNameSyntax generic &&
generic.Identifier.Text == "ISPMDLane" &&
generic.TypeArgumentList.Arguments.Count == 2)
{
// ISPMDLane<TSelf, TNumber> — TNumber is the scalar element type
var primitiveTypeName = generic.TypeArgumentList.Arguments[1].ToString();
_typeParamToPrimitive[typeParamName] = primitiveTypeName;
}
}
}
}
// ── Expression type resolution ────────────────────────────────────────
/// <summary>
/// Returns the <see cref="HPCType"/> for <paramref name="node"/>, or
/// <c>null</c> if the node is not an HPC-typed expression (i.e. not a
/// <c>WideLane&lt;T&gt;</c>, a constrained SPMD type-parameter, or a
/// plain scalar float/double).
/// </summary>
public HPCType? Resolve(SyntaxNode node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
var type = typeInfo.Type;
if (type is null) return null;
// WideLane<float>, WideLane<double>, …
if (type.Name == "WideLane" &&
type is INamedTypeSymbol wideLane &&
wideLane.IsGenericType)
{
return HPCType.FromElementName(
wideLane.TypeArguments[0].ToDisplayString());
}
// Generic type parameter constrained to ISPMDLane<TSelf, TNumber>
if (type is ITypeParameterSymbol typeParam)
{
foreach (var constraint in typeParam.ConstraintTypes)
{
if (constraint.Name == "ISPMDLane" &&
constraint is INamedTypeSymbol namedConstraint &&
namedConstraint.IsGenericType)
{
return HPCType.FromElementName(
namedConstraint.TypeArguments[1].ToDisplayString());
}
}
// Registered from method constraints
if (_typeParamToPrimitive.TryGetValue(typeParam.Name, out var prim))
return HPCType.FromElementName(prim);
}
// Bare scalar (used when a scalar literal/variable is involved in a
// mixed scalar-vector expression)
if (type.SpecialType == SpecialType.System_Single) return HPCType.Float;
if (type.SpecialType == SpecialType.System_Double) return HPCType.Double;
if (type.SpecialType == SpecialType.System_Int32) return HPCType.Int;
if (type.SpecialType == SpecialType.System_UInt32) return HPCType.UInt;
if (type.SpecialType == SpecialType.System_Int64) return HPCType.Long;
return null;
}
/// <summary>
/// Resolves the type of the expression using the registered type-parameter
/// map as a fallback (for simple identifier references where the semantic
/// model yields a type-parameter symbol rather than a concrete type).
/// </summary>
public HPCType? ResolveByName(string typeName)
{
if (_typeParamToPrimitive.TryGetValue(typeName, out var prim))
return HPCType.FromElementName(prim);
return null;
}
/// <summary>Returns true if <paramref name="typeName"/> is a known SPMD type-param.</summary>
public bool IsSPMDTypeParam(string typeName) =>
_typeParamToPrimitive.ContainsKey(typeName);
/// <summary>Snapshot of current type-param → primitive mappings (for reference).</summary>
public IReadOnlyDictionary<string, string> TypeParamMap => _typeParamToPrimitive;
}
}

View File

@@ -0,0 +1,365 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.Backend
{
/// <summary>
/// Emits C# source code targeting AVX2 (256-bit vectors) with the bundled
/// AVX2 extensions: FMA3, F16C, and BMI1/2.
///
/// <para>
/// Vector type: <c>Vector256&lt;T&gt;</c><br/>
/// Intrinsic classes used: <c>Avx</c>, <c>Avx2</c>, <c>Fma</c>, <c>F16C</c>,
/// <c>Bmi1</c>, <c>Bmi2</c>, <c>Vector256</c>.
/// </para>
///
/// <para>
/// Math functions with no built-in AVX2 intrinsic (Sin, Cos, Asin, etc.) are
/// delegated to the <c>AVX2Utility</c> class that is generated separately by
/// <c>AVX2UtilityGenerator</c> via <c>IVectorAPIContext</c>.
/// </para>
/// </summary>
internal sealed class AVX2Backend : IHPCBackend
{
// ── IHPCBackend ───────────────────────────────────────────────────────
public string Name => "AVX2";
public string[] RequiredUsings => new[]
{
"using System.Runtime.CompilerServices;",
"using System.Runtime.Intrinsics;",
"using System.Runtime.Intrinsics.X86;",
};
public string EmitMethod(HPCMethodIR method)
{
var sb = new StringBuilder();
var indent = " ";
// ── Signature ──────────────────────────────────────────────────────
sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]");
sb.Append($"{indent}public static ");
sb.Append(EmitReturnType(method.ReturnType));
sb.Append(' ');
sb.Append($"{method.OriginalName}_{Name}");
sb.Append('(');
sb.Append(EmitParameterList(method));
sb.AppendLine(")");
// ── Body ───────────────────────────────────────────────────────────
sb.AppendLine($"{indent}{{");
foreach (var stmt in method.Body)
EmitStatement(sb, stmt, indent + " ");
sb.AppendLine($"{indent}}}");
return sb.ToString();
}
// ── Statement emission ────────────────────────────────────────────────
private void EmitStatement(StringBuilder sb, HPCStmt stmt, string indent)
{
switch (stmt)
{
case HPCVarDecl decl:
sb.AppendLine($"{indent}var {decl.Name} = {EmitExpr(decl.Initializer)};");
break;
case HPCAssignment assign:
sb.AppendLine($"{indent}{assign.Target} = {EmitExpr(assign.Value)};");
break;
case HPCExprStmt expr:
sb.AppendLine($"{indent}{EmitExpr(expr.Expression)};");
break;
case HPCReturn ret:
if (ret.Value is null)
sb.AppendLine($"{indent}return;");
else
sb.AppendLine($"{indent}return {EmitExpr(ret.Value)};");
break;
case HPCIf ifStmt:
EmitIf(sb, ifStmt, indent);
break;
case HPCForLoop forLoop:
EmitForLoop(sb, forLoop, indent);
break;
}
}
private void EmitIf(StringBuilder sb, HPCIf node, string indent)
{
sb.AppendLine($"{indent}if ({EmitExpr(node.Condition)})");
sb.AppendLine($"{indent}{{");
foreach (var s in node.ThenBody)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
if (node.ElseBody is { Length: > 0 })
{
sb.AppendLine($"{indent}else");
sb.AppendLine($"{indent}{{");
foreach (var s in node.ElseBody)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
}
}
private void EmitForLoop(StringBuilder sb, HPCForLoop node, string indent)
{
var init = $"var {node.Iterator.Name} = {EmitExpr(node.Iterator.Initializer)}";
var cond = EmitExpr(node.Condition);
var incr = EmitExpr(node.Increment);
sb.AppendLine($"{indent}for ({init}; {cond}; {incr})");
sb.AppendLine($"{indent}{{");
foreach (var s in node.Body)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
}
// ── Expression emission ───────────────────────────────────────────────
private string EmitExpr(HPCExpr expr) => expr switch
{
HPCVarRef v => v.Name,
HPCLiteral l => EmitLiteral(l),
HPCBinaryOp b => EmitBinary(b),
HPCUnaryOp u => EmitUnary(u),
HPCIntrinsicCall c => EmitIntrinsic(c),
HPCPropertyAccess p => EmitPropertyAccess(p),
HPCPassThroughCall pt => pt.MethodName, // verbatim
_ => throw new NotSupportedException($"Unknown IR expr: {expr.GetType().Name}")
};
private string EmitLiteral(HPCLiteral lit)
{
// Scalar literal → broadcast to all lanes
return $"Vector256.Create({lit.Value}{LiteralSuffix(lit.Type)})";
}
private string EmitBinary(HPCBinaryOp node)
{
var l = EmitExpr(node.Left);
var r = EmitExpr(node.Right);
return node.Kind switch
{
// Floating-point arithmetic uses Avx / Avx2
HPCBinaryKind.Add when node.Type.IsFloatingPoint => $"Avx.Add({l}, {r})",
HPCBinaryKind.Subtract when node.Type.IsFloatingPoint => $"Avx.Subtract({l}, {r})",
HPCBinaryKind.Multiply when node.Type.IsFloatingPoint => $"Avx.Multiply({l}, {r})",
HPCBinaryKind.Divide when node.Type.IsFloatingPoint => $"Avx.Divide({l}, {r})",
// Integer arithmetic uses Avx2
HPCBinaryKind.Add => $"Avx2.Add({l}, {r})",
HPCBinaryKind.Subtract => $"Avx2.Subtract({l}, {r})",
HPCBinaryKind.Multiply => $"Avx2.MultiplyLow({l}, {r})", // 32-bit int multiply
// Bitwise
HPCBinaryKind.BitwiseAnd => $"Avx2.And({l}, {r})",
HPCBinaryKind.BitwiseOr => $"Avx2.Or({l}, {r})",
HPCBinaryKind.BitwiseXor => $"Avx2.Xor({l}, {r})",
HPCBinaryKind.ShiftLeft => $"Avx2.ShiftLeftLogical({l}, {r})",
HPCBinaryKind.ShiftRight => $"Avx2.ShiftRightLogical({l}, {r})",
// Comparisons — emit as AVX compare returning a mask vector
HPCBinaryKind.Equal when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedEqualNonSignaling)",
HPCBinaryKind.NotEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedNotEqualNonSignaling)",
HPCBinaryKind.LessThan when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanNonSignaling)",
HPCBinaryKind.LessThanOrEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanOrEqualNonSignaling)",
HPCBinaryKind.GreaterThan when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanNonSignaling)",
HPCBinaryKind.GreaterThanOrEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanOrEqualNonSignaling)",
// Integer comparisons
HPCBinaryKind.Equal => $"Avx2.CompareEqual({l}, {r})",
HPCBinaryKind.GreaterThan => $"Avx2.CompareGreaterThan({l}, {r})",
HPCBinaryKind.LessThan => $"Avx2.CompareGreaterThan({r}, {l})", // reversed
HPCBinaryKind.GreaterThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({l}, {r}), Avx2.CompareEqual({l}, {r}))",
HPCBinaryKind.LessThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({r}, {l}), Avx2.CompareEqual({l}, {r}))",
HPCBinaryKind.NotEqual => $"Avx2.Xor(Avx2.CompareEqual({l}, {r}), Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
HPCBinaryKind.Modulo => throw new NotSupportedException("Modulo has no AVX2 intrinsic; consider reformulating."),
_ => throw new NotSupportedException($"Binary kind {node.Kind} not supported in AVX2 backend")
};
}
private string EmitUnary(HPCUnaryOp node)
{
var op = EmitExpr(node.Operand);
return node.Kind switch
{
HPCUnaryKind.Negate when node.Type.IsFloatingPoint
=> $"Avx.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
HPCUnaryKind.Negate
=> $"Avx2.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
HPCUnaryKind.BitwiseNot => $"Avx2.Xor({op}, Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
// Math — delegate to Vector256 helpers (available in .NET 7+)
HPCUnaryKind.Abs when node.Type.IsFloatingPoint => $"Vector256.Abs({op})",
HPCUnaryKind.Sqrt when node.Type.IsFloatingPoint => $"Avx.Sqrt({op})",
HPCUnaryKind.Floor when node.Type.IsFloatingPoint => $"Avx.Floor({op})",
HPCUnaryKind.Ceil when node.Type.IsFloatingPoint => $"Avx.Ceiling({op})",
HPCUnaryKind.Round when node.Type.IsFloatingPoint
=> $"Avx.RoundToNearestInteger({op})",
HPCUnaryKind.Trunc when node.Type.IsFloatingPoint
=> $"Avx.RoundToZero({op})",
// Reciprocal / Rsqrt (float only; approximate 14-bit variants)
HPCUnaryKind.Rcp when node.Type.ElementTypeName == "float"
=> $"Avx.Reciprocal({op})",
HPCUnaryKind.Rsqrt when node.Type.ElementTypeName == "float"
=> $"Avx.ReciprocalSqrt({op})",
// Transcendentals — routed to AVX2Utility (generated by UtilityTemplate)
HPCUnaryKind.Sin => $"AVX2Utility.Sin_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Cos => $"AVX2Utility.Cos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Asin => $"AVX2Utility.Asin({op})",
HPCUnaryKind.Atan => $"AVX2Utility.Atan({op})",
HPCUnaryKind.Log => $"AVX2Utility.Log_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Exp => $"AVX2Utility.Exp_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
// Frac = x - Floor(x)
HPCUnaryKind.Frac when node.Type.IsFloatingPoint
=> $"Avx.Subtract({op}, Avx.Floor({op}))",
// Sign: extract and normalise sign bit
HPCUnaryKind.Sign when node.Type.ElementTypeName == "float"
=> $"Avx.And(Avx.CompareEqual({op}, Vector256<float>.Zero) == Vector256<float>.Zero ? Vector256.Create(1.0f) : Vector256<float>.Zero, Avx.Or(Avx.And({op}, Vector256.Create(-0.0f)), Vector256.Create(1.0f)))",
// Saturate: clamp to [0,1]
HPCUnaryKind.Saturate when node.Type.IsFloatingPoint
=> $"Avx.Max(Avx.Min({op}, Vector256.Create(1.0{LiteralSuffix(node.Type)})), Vector256<{CsTypeName(node.Type)}>.Zero)",
_ => throw new NotSupportedException($"Unary kind {node.Kind} not supported in AVX2 backend (type: {node.Type})")
};
}
private string EmitIntrinsic(HPCIntrinsicCall node)
{
var args = node.Args;
string A(int i) => EmitExpr(args[i]);
return node.Intrinsic switch
{
// FMA — uses the FMA extension bundled with AVX2
HPCIntrinsic.MultiplyAdd => $"Fma.MultiplyAdd({A(0)}, {A(1)}, {A(2)})",
HPCIntrinsic.MultiplySubtract => $"Fma.MultiplySubtract({A(0)}, {A(1)}, {A(2)})",
// Math
HPCIntrinsic.Atan2 => $"AVX2Utility.Atan2({A(0)}, {A(1)})",
HPCIntrinsic.Pow => $"AVX2Utility.Pow({A(0)}, {A(1)})",
HPCIntrinsic.SinCos => $"AVX2Utility.SinCos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({A(0)}, out {A(1)}, out {A(2)})",
// Compound math
HPCIntrinsic.Lerp => $"Fma.MultiplyAdd(Avx.Subtract({A(1)}, {A(0)}), {A(2)}, {A(0)})",
HPCIntrinsic.Min when node.Type.IsFloatingPoint => $"Avx.Min({A(0)}, {A(1)})",
HPCIntrinsic.Max when node.Type.IsFloatingPoint => $"Avx.Max({A(0)}, {A(1)})",
HPCIntrinsic.Min => $"Avx2.Min({A(0)}, {A(1)})",
HPCIntrinsic.Max => $"Avx2.Max({A(0)}, {A(1)})",
HPCIntrinsic.Clamp when node.Type.IsFloatingPoint => $"Avx.Max(Avx.Min({A(0)}, {A(2)}), {A(1)})",
HPCIntrinsic.Clamp => $"Avx2.Max(Avx2.Min({A(0)}, {A(2)}), {A(1)})",
// Conditional select via bitwise blend
HPCIntrinsic.Select when node.Type.IsFloatingPoint
=> $"Avx.BlendVariable({A(2)}, {A(1)}, {A(0)})",
HPCIntrinsic.Select
=> $"Avx2.BlendVariable({A(2)}.As<{CsTypeName(node.Type)}, byte>(), {A(1)}.As<{CsTypeName(node.Type)}, byte>(), {A(0)}.As<{CsTypeName(node.Type)}, byte>()).As<byte, {CsTypeName(node.Type)}>()",
// CopySign: compose exponent+mantissa from A(0) and sign from A(1)
HPCIntrinsic.CopySign when node.Type.ElementTypeName == "float"
=> $"Avx.Or(Avx.And({A(0)}, Vector256.Create(0x7FFFFFFFu).AsSingle()), Avx.And({A(1)}, Vector256.Create(0x80000000u).AsSingle()))",
// Horizontal reductions
HPCIntrinsic.ReduceAdd => $"Vector256.Sum({A(0)})",
HPCIntrinsic.ReduceMax => $"Vector256.Max({A(0)})",
HPCIntrinsic.ReduceMin => $"Vector256.Min({A(0)})",
// Memory — pointer-based (emitted as ref-to-pointer pattern)
HPCIntrinsic.Load => $"Vector256.LoadUnsafe(ref {A(0)})",
HPCIntrinsic.MaskLoad => $"Avx2.MaskLoad(ref {A(0)}, {A(1)}.AsInt32())",
HPCIntrinsic.Store => $"Vector256.StoreUnsafe({A(0)}, ref {A(1)})",
HPCIntrinsic.MaskStore => $"Avx2.MaskStore(ref {A(1)}, {A(2)}.AsInt32(), {A(0)})",
HPCIntrinsic.Gather => $"Avx2.GatherVector256(ref {A(0)}, {A(1)}, {A(2)})",
HPCIntrinsic.MaskGather => $"Avx2.GatherMaskVector256({A(0)}, ref {A(1)}, {A(2)}, {A(3)}, {A(4)})",
HPCIntrinsic.CompressStore
=> $"/* CompressStore requires AVX-512VBMI2; use scalar fallback */ {A(0)}.CompressStore(ref {A(1)}, {A(2)})",
// Conversions
HPCIntrinsic.Cast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
HPCIntrinsic.BitCast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
_ => throw new NotSupportedException($"Intrinsic {node.Intrinsic} not implemented in AVX2 backend")
};
}
private static string EmitPropertyAccess(HPCPropertyAccess node) =>
$"{node.Target}.{node.PropertyName}";
// ── Signature helpers ─────────────────────────────────────────────────
private static string EmitReturnType(HPCType type) =>
type.ElementTypeName == "void"
? "void"
: $"Vector256<{type.ElementTypeName}>";
private static string EmitParameterList(HPCMethodIR method)
{
var parts = new System.Collections.Generic.List<string>();
foreach (var p in method.Parameters)
{
var prefix = (p.IsOut ? "out " : "") + (p.IsRef ? "ref " : "");
var typeName = p.Type.ElementTypeName == "void"
? "void"
: $"Vector256<{p.Type.ElementTypeName}>";
parts.Add($"{prefix}{typeName} {p.Name}");
}
return string.Join(", ", parts);
}
// ── Naming helpers ────────────────────────────────────────────────────
private static string CsTypeName(HPCType t) => t.ElementTypeName;
private static string CsTypeNameCap(HPCType t) => t.ElementTypeName switch
{
"float" or "Single" => "Single",
"double" or "Double" => "Double",
_ => t.ElementTypeName
};
private static string LiteralSuffix(HPCType t) => t.ElementTypeName switch
{
"float" or "Single" => "f",
"double" or "Double" => "d",
_ => ""
};
// ── Mode-aware suffix ─────────────────────────────────────────────────
// The backend does not hold state for the current method's MathMode by default;
// instead EmitMethod passes context through a field set per-call.
private MathMode _currentMode = MathMode.Standard;
public string EmitMethod(HPCMethodIR method, MathMode mode)
{
_currentMode = mode;
return EmitMethod(method);
}
private string MathModeSuffix() => _currentMode == MathMode.Fast ? "Fast" : "Standard";
}
}

View File

@@ -0,0 +1,23 @@
using Misaki.HighPerformance.HPC.Generator.IR;
namespace Misaki.HighPerformance.HPC.Generator.Backend
{
/// <summary>
/// Emits C# source code for one target instruction-set architecture from a
/// fully analysed and optimised <see cref="HPCMethodIR"/>.
/// </summary>
internal interface IHPCBackend
{
/// <summary>Short identifier used in generated file/method names (e.g. "AVX2").</summary>
string Name { get; }
/// <summary>Using-directives the emitted code requires.</summary>
string[] RequiredUsings { get; }
/// <summary>
/// Emits the full body of a specialised method and returns the C# source
/// as a string (method signature + braces included).
/// </summary>
string EmitMethod(HPCMethodIR method);
}
}

View File

@@ -0,0 +1,108 @@
using Microsoft.CodeAnalysis;
using System;
namespace Misaki.HighPerformance.HPC.Generator
{
internal enum FloatPrecision
{
Standard = 0,
High = 1,
Low = 2,
}
internal enum MathMode
{
Standard = 0,
Fast = 1,
}
[Flags]
internal enum TargetInstructionSet
{
None = 0,
SSE2 = 1 << 0,
SSE4 = 1 << 1,
AVX = 1 << 2,
AVX2 = 1 << 3,
AVX512 = 1 << 4,
}
[Generator]
public class HPComputeAttributeGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var source = @$"
using System;
namespace Misaki.HighPerformance.HPC
{{
public enum FloatPrecision
{{
/// <summary>
/// Compute with an accuracy of 3.5 ULPs (Units in the Last Place). This is the default precision level for floating-point operations.
/// </summary>
Standard = {(int)FloatPrecision.Standard},
/// <summary>
/// Compute with an accuracy of 1 ULP. This level may use more aggressive optimizations that can lead to faster computations but with reduced precision.
/// </summary>
High = {(int)FloatPrecision.High},
/// <summary>
/// Compute with an accuracy that equals or lower than 3.5 ULPs. This level may use the most aggressive optimizations, potentially sacrificing precision for maximum performance.
/// </summary>
Low = {(int)FloatPrecision.Low},
}}
public enum MathMode
{{
/// <summary>
/// Use the default math mode, which balances performance and accuracy. This mode may allow certain optimizations that can lead to faster computations while maintaining reasonable precision.
/// </summary>
Standard = {(int)MathMode.Standard},
/// <summary>
/// Use a fast math mode, which prioritizes performance over accuracy. This mode assumes there are no special cases (like NaNs or infinities) and may allow for more aggressive optimizations.
/// </summary>
Fast = {(int)MathMode.Fast},
}}
[Flags]
public enum TargetInstructionSet
{{
None = {(int)TargetInstructionSet.None},
/// <summary>
/// Streaming SIMD Extensions 2.
/// </summary>
SSE2 = {(int)TargetInstructionSet.SSE2},
/// <summary>
/// Streaming SIMD Extensions 4.2.
/// </summary>
SSE4 = {(int)TargetInstructionSet.SSE4},
/// <summary>
/// Advanced Vector Extensions.
/// </summary>
AVX = {(int)TargetInstructionSet.AVX},
/// <summary>
/// Advanced Vector Extensions 2. Includes FMA, F16C and BMI1/2.
/// </summary>
AVX2 = {(int)TargetInstructionSet.AVX2},
/// <summary>
/// Advanced Vector Extensions 512.
/// </summary>
AVX512 = {(int)TargetInstructionSet.AVX512},
}}
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public sealed class HPComputeAttribute : Attribute
{{
public HPComputeAttribute(TargetInstructionSet instructionSet, FloatPrecision precision = FloatPrecision.Standard, MathMode mode = MathMode.Standard)
{{
}}
}}
}}";
ctx.AddSource("HPComputeAttribute.g.cs", source);
});
}
}
}

View File

@@ -0,0 +1,177 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.Analysis;
using Misaki.HighPerformance.HPC.Generator.Backend;
using Misaki.HighPerformance.HPC.Generator.Optimization;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal class HPComputeMethodInfo
{
public MethodDeclarationSyntax MethodDeclaration { get; set; } = null!;
public IMethodSymbol MethodSymbol { get; set; } = null!;
public SemanticModel SemanticModel { get; set; } = null!;
public TargetInstructionSet InstructionSet { get; set; }
public FloatPrecision Precision { get; set; }
public MathMode Mode { get; set; }
}
[Generator]
public class HPComputeGenerator : IIncrementalGenerator
{
// ── Backends (one singleton per ISA, stateless between methods) ───────
private static readonly AVX2Backend s_avx2Backend = new();
// ── Optimization passes ───────────────────────────────────────────────
/// <summary>
/// Returns the ordered list of optimisation passes for the given method.
/// Passes are cheap objects; creating them per-method is intentional so
/// future stateful passes (e.g. CSE) can be added without concurrency issues.
/// </summary>
private static IEnumerable<IHPCOptimizationPass> GetPasses(HPComputeMethodInfo info)
{
// FMA fusion only makes sense on ISAs that have it
if (info.InstructionSet.HasFlag(TargetInstructionSet.AVX2))
yield return new FMAFusionPass();
}
// ── IIncrementalGenerator ─────────────────────────────────────────────
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var methodDeclarations = context.SyntaxProvider
.ForAttributeWithMetadataName(
"Misaki.HighPerformance.HPC.HPComputeAttribute",
static (n, _) => n is MethodDeclarationSyntax,
static (ctx, _) =>
{
var attribute = ctx.Attributes.FirstOrDefault(
a => a.AttributeClass?.ToDisplayString() ==
"Misaki.HighPerformance.HPC.HPComputeAttribute");
if (attribute is null || ctx.TargetSymbol is not IMethodSymbol methodSymbol)
return null;
return new HPComputeMethodInfo
{
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
MethodSymbol = methodSymbol,
SemanticModel = ctx.SemanticModel,
InstructionSet = (TargetInstructionSet)attribute.ConstructorArguments[0].Value!,
Precision = (FloatPrecision)attribute.ConstructorArguments[1].Value!,
Mode = (MathMode)attribute.ConstructorArguments[2].Value!,
};
})
.Collect();
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethods);
}
// ── Core pipeline ─────────────────────────────────────────────────────
private static void GenerateHPCMethods(
SourceProductionContext context,
ImmutableArray<HPComputeMethodInfo?> array)
{
if (array.IsEmpty) return;
foreach (var info in array)
{
if (info is null) continue;
try
{
GenerateSingleMethod(context, info);
}
catch (Exception ex)
{
// Surface analyzer errors as Roslyn diagnostics so the user
// sees them in the IDE rather than a silent empty output.
context.ReportDiagnostic(Diagnostic.Create(
new DiagnosticDescriptor(
id: "HPC0001",
title: "HPC code generation failed",
messageFormat: "Failed to generate HPC variant for '{0}': {1}",
category: "HPCGenerator",
defaultSeverity: DiagnosticSeverity.Error,
isEnabledByDefault: true),
info.MethodDeclaration.GetLocation(),
info.MethodDeclaration.Identifier.Text,
ex.Message));
}
}
}
private static void GenerateSingleMethod(
SourceProductionContext context,
HPComputeMethodInfo info)
{
// ── Phase 1: Analyse ──────────────────────────────────────────────
var analyzer = new HPCAnalyzer(info.SemanticModel);
var ir = analyzer.Analyze(info.MethodDeclaration, info);
// ── Phase 2: Optimise ─────────────────────────────────────────────
var optimizedIR = ir;
foreach (var pass in GetPasses(info))
optimizedIR = pass.Transform(optimizedIR);
// ── Phase 3: Emit per-target ──────────────────────────────────────
foreach (var (backend, isa) in GetBackends(info.InstructionSet))
{
var methodSource = backend.EmitMethod(optimizedIR);
var fullSource = WrapSource(methodSource, optimizedIR, backend);
context.AddSource(
hintName: $"{ir.ContainingTypeName}_{ir.OriginalName}_{backend.Name}.g.cs",
source: fullSource);
}
}
// ── Backend selection ─────────────────────────────────────────────────
private static IEnumerable<(IHPCBackend backend, TargetInstructionSet isa)>
GetBackends(TargetInstructionSet instructionSet)
{
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
yield return (s_avx2Backend, TargetInstructionSet.AVX2);
// Future: SSE4, AVX512, NEON — add here without touching anything else
}
// ── Source wrapping ───────────────────────────────────────────────────
private static string WrapSource(
string methodBody,
IR.HPCMethodIR ir,
IHPCBackend backend)
{
var sb = new StringBuilder();
sb.AppendLine("// <auto-generated/>");
sb.AppendLine("#nullable enable");
foreach (var u in backend.RequiredUsings)
sb.AppendLine(u);
sb.AppendLine("using Misaki.HighPerformance.HPC;");
sb.AppendLine();
sb.AppendLine($"namespace {ir.ContainingNamespace}");
sb.AppendLine("{");
sb.AppendLine($" partial class {ir.ContainingTypeName}");
sb.AppendLine(" {");
sb.AppendLine(methodBody);
sb.AppendLine(" }");
sb.AppendLine("}");
return sb.ToString();
}
}
}

View File

@@ -0,0 +1,181 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Linq;
namespace Misaki.HighPerformance.HPC.Generator.IR
{
/// <summary>
/// Base class for passes that transform the IR tree.
/// Uses the <em>immutable rewrite</em> pattern: each Visit method returns
/// either the original node (if nothing changed) or a new <c>with</c>-expression
/// copy, leaving the input tree untouched.
/// </summary>
internal abstract class HPCNodeRewriter
{
// ── Entry points ─────────────────────────────────────────────────────
public virtual HPCExpr RewriteExpr(HPCExpr expr) => expr switch
{
HPCBinaryOp n => RewriteBinaryOp(n),
HPCUnaryOp n => RewriteUnaryOp(n),
HPCIntrinsicCall n => RewriteIntrinsicCall(n),
HPCPassThroughCall n => RewritePassThroughCall(n),
HPCPropertyAccess n => RewritePropertyAccess(n),
HPCVarRef n => RewriteVarRef(n),
HPCLiteral n => RewriteLiteral(n),
_ => expr
};
public virtual HPCStmt RewriteStmt(HPCStmt stmt) => stmt switch
{
HPCVarDecl n => RewriteVarDecl(n),
HPCAssignment n => RewriteAssignment(n),
HPCExprStmt n => RewriteExprStmt(n),
HPCReturn n => RewriteReturn(n),
HPCIf n => RewriteIf(n),
HPCForLoop n => RewriteForLoop(n),
_ => stmt
};
// ── Expression rewrites (override to modify specific patterns) ───────
protected virtual HPCExpr RewriteBinaryOp(HPCBinaryOp node)
{
var left = RewriteExpr(node.Left);
var right = RewriteExpr(node.Right);
return ReferenceEquals(left, node.Left) && ReferenceEquals(right, node.Right)
? node
: node with { Left = left, Right = right };
}
protected virtual HPCExpr RewriteUnaryOp(HPCUnaryOp node)
{
var operand = RewriteExpr(node.Operand);
return ReferenceEquals(operand, node.Operand)
? node
: node with { Operand = operand };
}
protected virtual HPCExpr RewriteIntrinsicCall(HPCIntrinsicCall node)
{
var args = RewriteArgs(node.Args);
return ReferenceEquals(args, node.Args)
? node
: node with { Args = args };
}
protected virtual HPCExpr RewritePassThroughCall(HPCPassThroughCall node)
{
var args = RewriteArgs(node.Args);
return ReferenceEquals(args, node.Args)
? node
: node with { Args = args };
}
protected virtual HPCExpr RewritePropertyAccess(HPCPropertyAccess node)
{
var target = RewriteExpr(node.Target);
return ReferenceEquals(target, node.Target)
? node
: node with { Target = target };
}
protected virtual HPCExpr RewriteVarRef(HPCVarRef node) => node;
protected virtual HPCExpr RewriteLiteral(HPCLiteral node) => node;
// ── Statement rewrites ───────────────────────────────────────────────
protected virtual HPCStmt RewriteVarDecl(HPCVarDecl node)
{
var init = RewriteExpr(node.Initializer);
return ReferenceEquals(init, node.Initializer)
? node
: node with { Initializer = init };
}
protected virtual HPCStmt RewriteAssignment(HPCAssignment node)
{
var value = RewriteExpr(node.Value);
return ReferenceEquals(value, node.Value)
? node
: node with { Value = value };
}
protected virtual HPCStmt RewriteExprStmt(HPCExprStmt node)
{
var expr = RewriteExpr(node.Expression);
return ReferenceEquals(expr, node.Expression)
? node
: node with { Expression = expr };
}
protected virtual HPCStmt RewriteReturn(HPCReturn node)
{
if (node.Value is null) return node;
var value = RewriteExpr(node.Value);
return ReferenceEquals(value, node.Value)
? node
: node with { Value = value };
}
protected virtual HPCStmt RewriteIf(HPCIf node)
{
var cond = RewriteExpr(node.Condition);
var thenBody = RewriteBody(node.ThenBody);
var elseBody = node.ElseBody is null ? null : RewriteBody(node.ElseBody);
return ReferenceEquals(cond, node.Condition) &&
ReferenceEquals(thenBody, node.ThenBody) &&
ReferenceEquals(elseBody, node.ElseBody)
? node
: node with { Condition = cond, ThenBody = thenBody, ElseBody = elseBody };
}
protected virtual HPCStmt RewriteForLoop(HPCForLoop node)
{
var iter = (HPCVarDecl)RewriteVarDecl(node.Iterator);
var cond = RewriteExpr(node.Condition);
var incr = RewriteExpr(node.Increment);
var body = RewriteBody(node.Body);
return ReferenceEquals(iter, node.Iterator) &&
ReferenceEquals(cond, node.Condition) &&
ReferenceEquals(incr, node.Increment) &&
ReferenceEquals(body, node.Body)
? node
: node with { Iterator = iter, Condition = cond, Increment = incr, Body = body };
}
// ── Helpers ──────────────────────────────────────────────────────────
protected HPCStmt[] RewriteBody(HPCStmt[] body)
{
HPCStmt[]? result = null;
for (int i = 0; i < body.Length; i++)
{
var original = body[i];
var rewritten = RewriteStmt(original);
if (!ReferenceEquals(original, rewritten))
{
result ??= body.ToArray();
result[i] = rewritten;
}
}
return result ?? body;
}
private HPCExpr[] RewriteArgs(HPCExpr[] args)
{
HPCExpr[]? result = null;
for (int i = 0; i < args.Length; i++)
{
var original = args[i];
var rewritten = RewriteExpr(original);
if (!ReferenceEquals(original, rewritten))
{
result ??= args.ToArray();
result[i] = rewritten;
}
}
return result ?? args;
}
}
}

View File

@@ -0,0 +1,205 @@
using System;
namespace Misaki.HighPerformance.HPC.Generator.IR
{
// ─────────────────────────────────────────────────────────────────────────
// Type system
// ─────────────────────────────────────────────────────────────────────────
/// <summary>
/// A resolved primitive type inside the HPC IR (always the element scalar type).
/// E.g. "float", "double", "int". Never a vector type name — the IR is
/// element-typecentric; the backend decides the concrete vector width.
/// </summary>
internal sealed record HPCType(string ElementTypeName, bool IsFloatingPoint)
{
// Convenience singletons for the most common cases
public static readonly HPCType Float = new("float", IsFloatingPoint: true);
public static readonly HPCType Double = new("double", IsFloatingPoint: true);
public static readonly HPCType Int = new("int", IsFloatingPoint: false);
public static readonly HPCType UInt = new("uint", IsFloatingPoint: false);
public static readonly HPCType Long = new("long", IsFloatingPoint: false);
public static HPCType FromElementName(string name) => name switch
{
"float" or "Single" => Float,
"double" or "Double" => Double,
"int" or "Int32" => Int,
"uint" or "UInt32" => UInt,
"long" or "Int64" => Long,
_ => new(name, IsFloatingPoint: false)
};
public override string ToString() => ElementTypeName;
}
// ─────────────────────────────────────────────────────────────────────────
// Expression nodes
// ─────────────────────────────────────────────────────────────────────────
/// <summary>Base type for all HPC IR expression nodes.</summary>
internal abstract record HPCExpr(HPCType Type);
/// <summary>A reference to a local variable or parameter by name.</summary>
internal sealed record HPCVarRef(string Name, HPCType Type) : HPCExpr(Type)
{
public override string ToString() => Name;
}
/// <summary>A scalar literal value broadcast to all lanes.</summary>
internal sealed record HPCLiteral(string Value, HPCType Type) : HPCExpr(Type)
{
public override string ToString() => Value;
}
// ── Binary ops ───────────────────────────────────────────────────────────
internal enum HPCBinaryKind
{
Add, Subtract, Multiply, Divide, Modulo,
BitwiseAnd, BitwiseOr, BitwiseXor, ShiftLeft, ShiftRight,
Equal, NotEqual,
LessThan, LessThanOrEqual,
GreaterThan, GreaterThanOrEqual,
}
internal sealed record HPCBinaryOp(
HPCBinaryKind Kind,
HPCExpr Left,
HPCExpr Right,
HPCType Type) : HPCExpr(Type);
// ── Unary ops ────────────────────────────────────────────────────────────
internal enum HPCUnaryKind
{
Negate, BitwiseNot,
// Math functions that map to a single ISPMDLane static method
Abs, Sqrt, Floor, Ceil, Round, Trunc, Frac, Sign, Saturate,
Rcp, Rsqrt,
Sin, Cos, Tan, Asin, Acos, Atan,
Exp, Exp2, Log, Log2,
}
internal sealed record HPCUnaryOp(
HPCUnaryKind Kind,
HPCExpr Operand,
HPCType Type) : HPCExpr(Type);
// ── Intrinsic calls (multi-argument) ────────────────────────────────────
internal enum HPCIntrinsic
{
// Arithmetic
MultiplyAdd, // a * b + c (FMA)
MultiplySubtract, // a * b - c
// Math
Atan2, Pow,
SinCos, // out sin, out cos simultaneously
Lerp, Min, Max, Clamp, Select, CopySign,
// Reduction
ReduceAdd, ReduceMax, ReduceMin,
// Memory
Load, Store,
MaskLoad, MaskStore,
Gather, MaskGather,
Scatter, MaskScatter,
CompressStore,
// Conversion
Cast, BitCast,
}
internal sealed record HPCIntrinsicCall(
HPCIntrinsic Intrinsic,
HPCExpr[] Args,
HPCType Type) : HPCExpr(Type);
/// <summary>
/// A method call that the HPC pipeline does not recognise as an intrinsic —
/// emitted verbatim so user code can still call arbitrary helpers.
/// </summary>
internal sealed record HPCPassThroughCall(
string MethodName,
HPCExpr[] Args,
HPCType Type) : HPCExpr(Type);
/// <summary>Member-property access, e.g. <c>lane.LaneWidth</c> → <c>Count</c>.</summary>
internal sealed record HPCPropertyAccess(
HPCExpr Target,
string PropertyName,
HPCType Type) : HPCExpr(Type);
// ─────────────────────────────────────────────────────────────────────────
// Statement nodes
// ─────────────────────────────────────────────────────────────────────────
/// <summary>Base type for all HPC IR statement nodes.</summary>
internal abstract record HPCStmt;
/// <summary>Local variable declaration with mandatory initialiser.</summary>
internal sealed record HPCVarDecl(
string Name,
HPCType Type,
HPCExpr Initializer) : HPCStmt;
/// <summary>Assignment to an existing variable (including out-parameters).</summary>
internal sealed record HPCAssignment(string Target, HPCExpr Value) : HPCStmt;
/// <summary>Expression evaluated purely for its side-effect (e.g. Store calls).</summary>
internal sealed record HPCExprStmt(HPCExpr Expression) : HPCStmt;
/// <summary>Return statement.</summary>
internal sealed record HPCReturn(HPCExpr? Value) : HPCStmt;
/// <summary>
/// Conditional branch. In a vectorised context the backend may lower this
/// to a predicated Select or a masked execution block.
/// </summary>
internal sealed record HPCIf(
HPCExpr Condition,
HPCStmt[] ThenBody,
HPCStmt[]? ElseBody) : HPCStmt;
/// <summary>Simple counted for-loop.</summary>
internal sealed record HPCForLoop(
HPCVarDecl Iterator,
HPCExpr Condition,
HPCExpr Increment,
HPCStmt[] Body) : HPCStmt;
// ─────────────────────────────────────────────────────────────────────────
// Method-level IR
// ─────────────────────────────────────────────────────────────────────────
internal sealed record HPCParameter(
string Name,
HPCType Type,
bool IsOut,
bool IsRef);
/// <summary>
/// The complete IR representation of a single <c>[HPCompute]</c>-annotated method,
/// fully detached from Roslyn syntax after the analysis phase.
/// </summary>
internal sealed record HPCMethodIR
{
public required string Name { get; init; }
public required HPCType ReturnType { get; init; }
public required HPCParameter[] Parameters { get; init; }
public required HPCStmt[] Body { get; init; }
// Compilation metadata from the attribute
public required TargetInstructionSet TargetISA { get; init; }
public required FloatPrecision Precision { get; init; }
public required MathMode Mode { get; init; }
// Emission metadata
public required string ContainingNamespace { get; init; }
public required string ContainingTypeName { get; init; }
public required string OriginalName { get; init; }
}
}

View File

@@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable>
<EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<LangVersion>latest</LangVersion>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="5.3.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.3.0" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,65 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Linq;
namespace Misaki.HighPerformance.HPC.Generator.Optimization
{
/// <summary>
/// Detects <c>(a * b) + c</c> and <c>c + (a * b)</c> binary expression patterns
/// in the IR and replaces them with <see cref="HPCIntrinsic.MultiplyAdd"/> calls,
/// enabling the backend to emit a single FMA instruction instead of two.
///
/// <para>This pass only runs when the target ISA supports FMA (e.g. AVX2 which
/// includes FMA3 via the FMA extension, and AVX-512F). The pipeline checks
/// <see cref="HPCMethodIR.TargetISA"/> before adding the pass.</para>
/// </summary>
internal sealed class FMAFusionPass : HPCNodeRewriter, IHPCOptimizationPass
{
public string Name => "FMA Fusion";
public HPCMethodIR Transform(HPCMethodIR method)
{
var newBody = RewriteBody(method.Body);
return ReferenceEquals(newBody, method.Body)
? method
: method with { Body = newBody };
}
protected override HPCExpr RewriteBinaryOp(HPCBinaryOp node)
{
if (node.Kind == HPCBinaryKind.Add)
{
// (a * b) + c → FMA(a, b, c)
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplyAdd,
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
node.Type);
}
// c + (a * b) → FMA(a, b, c)
if (node.Right is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulRight)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplyAdd,
[RewriteExpr(mulRight.Left), RewriteExpr(mulRight.Right), RewriteExpr(node.Left)],
node.Type);
}
}
if (node.Kind == HPCBinaryKind.Subtract)
{
// (a * b) - c → FMA(a, b, -c) ... or MultiplySubtract if available
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplySubtract,
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
node.Type);
}
}
return base.RewriteBinaryOp(node);
}
}
}

View File

@@ -0,0 +1,20 @@
using Misaki.HighPerformance.HPC.Generator.IR;
namespace Misaki.HighPerformance.HPC.Generator.Optimization
{
/// <summary>
/// A single optimization pass that transforms an <see cref="HPCMethodIR"/>
/// into another <see cref="HPCMethodIR"/>. Passes must be pure functions —
/// they must not mutate the input and should return the original object
/// unchanged when nothing was modified (enabling cheap "no-change" detection
/// in the pipeline).
/// </summary>
internal interface IHPCOptimizationPass
{
/// <summary>Human-readable name for diagnostics.</summary>
string Name { get; }
/// <summary>Transforms the IR. Returns the same object if unchanged.</summary>
HPCMethodIR Transform(HPCMethodIR method);
}
}

View File

@@ -0,0 +1,31 @@
// Polyfills needed for modern C# features on netstandard2.0 target.
// These types are built into .NET 5+ runtimes but must be declared manually
// when targeting netstandard2.0, which source generators must do.
namespace System.Runtime.CompilerServices
{
// Enables `init` accessors and `record` types (C# 9)
internal static class IsExternalInit { }
// Enables `required` members (C# 11)
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct |
AttributeTargets.Field | AttributeTargets.Property,
AllowMultiple = false, Inherited = false)]
internal sealed class RequiredMemberAttribute : Attribute { }
[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
internal sealed class CompilerFeatureRequiredAttribute : Attribute
{
public CompilerFeatureRequiredAttribute(string featureName)
=> FeatureName = featureName;
public string FeatureName { get; }
public bool IsOptional { get; init; }
}
}
namespace System.Diagnostics.CodeAnalysis
{
// Enables `required` constructor attribution (C# 11)
[AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)]
internal sealed class SetsRequiredMembersAttribute : Attribute { }
}

View File

@@ -0,0 +1,233 @@
using Misaki.HighPerformance.HPC.Generator.APIContext;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal static class UtilityTemplate
{
public static Method Sin_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call($"{api.GetVectorType()}.Sin", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method Sin_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var y_sin = api.Multiply(x_sin, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_sin = (api.Round(k_sin * half) * two).Assign();
var sign_sin = (api.One<T>() - two * api.Abs(k_sin - k_even_sin)).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = (z_sin * z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var body = api.Return(poly_sin * sign_sin);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method Cos_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call($"{api.GetVectorType()}.Cos", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method Cos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_cos = api.Add(input, halfPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(api.One<T>(), api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
var body = api.Return(poly_cos * sign_cos);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method SinCos_Standard<T>(IVectorAPIContext api)
{
var sin_cos = api.Return(api.Call($"{api.GetVectorType()}.SinCos", "value"));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value", $"out {api.GetVectorType<T>()} sin", $"out {api.GetVectorType<T>()} cos" },
body: sin_cos);
}
public static Method SinCos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var sinOut = new Expression(api, "sin");
var cosOut = new Expression(api, "cos");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var x_cos = api.Add(x_sin, halfPi).Assign();
var y_sin = api.Multiply(x_sin, invPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var one = api.One<T>();
var k_even_sin = api.Multiply(api.Round(api.Multiply(k_sin, half)), two).Assign();
var sign_sin = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_sin, k_even_sin)))).Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = api.Multiply(z_sin, z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
sinOut = api.Multiply(poly_sin, sign_sin).Assign(sinOut.Code, false);
cosOut = api.Multiply(poly_cos, sign_cos).Assign(cosOut.Code, false);
var body = api.Return(api.Create(""));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}", $"out {api.GetVectorType<T>()} {sinOut.Code}", $"out {api.GetVectorType<T>()} {cosOut.Code}" },
body: body);
}
public static string GenerateSinCosUtilityMethods(IVectorAPIContext api, string identation)
{
var methods = new Method[]
{
Sin_Standard<float>(api),
Sin_Fast<float>(api),
Cos_Standard<float>(api),
Cos_Fast<float>(api),
SinCos_Standard<float>(api),
SinCos_Fast<float>(api),
Sin_Standard<double>(api),
Sin_Fast<double>(api),
Cos_Standard<double>(api),
Cos_Fast<double>(api),
SinCos_Standard<double>(api),
SinCos_Fast<double>(api)
};
var sb = new StringBuilder();
var inlineAttr = identation + "[MethodImpl(MethodImplOptions.AggressiveInlining)]";
foreach (var method in methods)
{
sb.AppendLine(inlineAttr);
sb.AppendLine(method.GetFullCode(identation));
}
return sb.ToString();
}
}
}

View File

@@ -1,7 +1,7 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
/// <summary> /// <summary>
/// Common marker interface for SPMD lane types. /// Common marker interface for SPMD lane types.
@@ -416,14 +416,14 @@ public unsafe interface ISPMDLane<TSelf, TNumber> : ISPMDLane, IEquatable<TSelf>
/// <summary> /// <summary>
/// Computes a * b + c element-wise. /// Computes a * b + c element-wise.
/// </summary> /// </summary>
/// <param name="a">The first multiplier.</param> /// <param name="left">The first multiplier.</param>
/// <param name="b">The second multiplier.</param> /// <param name="right">The second multiplier.</param>
/// <param name="c">The addend.</param> /// <param name="addend">The addend.</param>
/// <returns>The result of the fused multiply-add operation.</returns> /// <returns>The result of the fused multiply-add operation.</returns>
/// <remarks> /// <remarks>
/// Float and double implementations should use fused multiply-add instructions when available for both accuracy and performance. /// Float and double implementations should use fused multiply-add instructions when available for both accuracy and performance.
/// </remarks> /// </remarks>
static abstract TSelf MultipleAdd(TSelf a, TSelf b, TSelf c); static abstract TSelf MultiplyAdd(TSelf left, TSelf right, TSelf addend);
/// <summary> /// <summary>
/// Returns the minimum of the two lane values element-wise. /// Returns the minimum of the two lane values element-wise.
/// </summary> /// </summary>

View File

@@ -35,7 +35,7 @@
<ItemGroup> <ItemGroup>
<Content Include="**\*.cs" Exclude="obj\**;bin\**"> <Content Include="**\*.cs" Exclude="obj\**;bin\**">
<Pack>true</Pack> <Pack>true</Pack>
<PackagePath>contentFiles\cs\any\Misaki.HighPerformance.Mathematics.SPMD\</PackagePath> <PackagePath>contentFiles\cs\any\Misaki.HighPerformance.HPC\</PackagePath>
<PackageCopyToOutput>false</PackageCopyToOutput> <PackageCopyToOutput>false</PackageCopyToOutput>
<BuildAction>Compile</BuildAction> <BuildAction>Compile</BuildAction>
</Content> </Content>

View File

@@ -2,7 +2,7 @@ using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
internal static unsafe class SPMDUtility internal static unsafe class SPMDUtility
{ {

View File

@@ -3,7 +3,7 @@ using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct ScalarLane<TNumber> : ISPMDLane<ScalarLane<TNumber>, TNumber> public readonly unsafe struct ScalarLane<TNumber> : ISPMDLane<ScalarLane<TNumber>, TNumber>
@@ -446,7 +446,7 @@ public readonly unsafe struct ScalarLane<TNumber> : ISPMDLane<ScalarLane<TNumber
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ScalarLane<TNumber> MultipleAdd(ScalarLane<TNumber> a, ScalarLane<TNumber> b, ScalarLane<TNumber> c) public static ScalarLane<TNumber> MultiplyAdd(ScalarLane<TNumber> a, ScalarLane<TNumber> b, ScalarLane<TNumber> c)
{ {
return new ScalarLane<TNumber>(TNumber.MultiplyAddEstimate(a.value, b.value, c.value)); return new ScalarLane<TNumber>(TNumber.MultiplyAddEstimate(a.value, b.value, c.value));
} }

View File

@@ -1,6 +1,6 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
internal static unsafe class ShuffleTableGenerator internal static unsafe class ShuffleTableGenerator
{ {

View File

@@ -1,7 +1,7 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
/// <summary> /// <summary>
/// A job interface for Single Program Multiple Data (SPMD) execution, allowing for efficient parallel processing of data across multiple lanes. /// A job interface for Single Program Multiple Data (SPMD) execution, allowing for efficient parallel processing of data across multiple lanes.

View File

@@ -7,7 +7,7 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
const string TLane = "TLane"; const string TLane = "TLane";

View File

@@ -7,10 +7,9 @@ using System.Diagnostics.CodeAnalysis;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public static unsafe partial class MathV public static unsafe partial class MathV
{ {

View File

@@ -15,7 +15,7 @@ using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
const string TLane = "TLane"; const string TLane = "TLane";
const string TNumber = "TNumber"; const string TNumber = "TNumber";

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector2<TLane, TNumber> : IEquatable<Vector2<TLane, TNumber>> public unsafe struct Vector2<TLane, TNumber> : IEquatable<Vector2<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector3<TLane, TNumber> : IEquatable<Vector3<TLane, TNumber>> public unsafe struct Vector3<TLane, TNumber> : IEquatable<Vector3<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector4<TLane, TNumber> : IEquatable<Vector4<TLane, TNumber>> public unsafe struct Vector4<TLane, TNumber> : IEquatable<Vector4<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -1,7 +1,7 @@
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber> public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber>
where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber> where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber>

View File

@@ -7,7 +7,7 @@
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
var conversions = new CastRoute[] var conversions = new CastRoute[]
{ {

View File

@@ -5,7 +5,7 @@ using System.Runtime.InteropServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public static unsafe class WideLane public static unsafe class WideLane
{ {
@@ -40,8 +40,6 @@ public static unsafe class WideLane
public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber> public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber>
where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber> where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber>
{ {
private static readonly Vector<TNumber> s_indices;
public readonly Vector<TNumber> value; public readonly Vector<TNumber> value;
public static int LaneWidth public static int LaneWidth
@@ -53,13 +51,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
public static WideLane<TNumber> Zero public static WideLane<TNumber> Zero
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane<TNumber>(Vector<TNumber>.Zero); get => Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector<TNumber>.Zero);
} }
public static WideLane<TNumber> One public static WideLane<TNumber> One
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane<TNumber>(Vector<TNumber>.One); get => Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector<TNumber>.One);
} }
public static WideLane<TNumber> MinValue public static WideLane<TNumber> MinValue
@@ -86,17 +84,6 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
get => value[index]; get => value[index];
} }
static WideLane()
{
var pValues = stackalloc TNumber[LaneWidth];
for (var i = 0; i < LaneWidth; i++)
{
pValues[i] = TNumber.CreateTruncating(i);
}
s_indices = Vector.Load(pValues);
}
public WideLane(Vector<TNumber> value) public WideLane(Vector<TNumber> value)
{ {
this.value = value; this.value = value;
@@ -145,19 +132,19 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(TNumber value) public static WideLane<TNumber> Create(TNumber value)
{ {
return new WideLane<TNumber>(Vector.Create(value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(params ReadOnlySpan<TNumber> values) public static WideLane<TNumber> Create(params ReadOnlySpan<TNumber> values)
{ {
return new WideLane<TNumber>(Vector.Create(values)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(values));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(Vector<TNumber> value) public static WideLane<TNumber> Create(Vector<TNumber> value)
{ {
return new WideLane<TNumber>(value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -185,20 +172,20 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
} }
else else
{ {
return new WideLane<TNumber>(Vector.Create(start) + (Vector.Create(step) * s_indices)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(start) + (Vector.Create(step) * Vector<TNumber>.Indices));
} }
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Load(ref TNumber value) public static WideLane<TNumber> Load(ref TNumber value)
{ {
return new WideLane<TNumber>(Vector.LoadUnsafe(ref value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LoadUnsafe(ref value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Load(TNumber* pValue) public static WideLane<TNumber> Load(TNumber* pValue)
{ {
return new WideLane<TNumber>(Vector.Load(pValue)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Load(pValue));
} }
@@ -302,7 +289,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (idx * scale)); pResult[i] = *(TNumber*)((byte*)pData + (idx * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -352,7 +339,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale)); pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -419,7 +406,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (idx * scale)); pResult[i] = *(TNumber*)((byte*)pData + (idx * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -473,7 +460,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale)); pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
@@ -777,61 +764,61 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator +(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator +(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value + b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value + b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator -(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator -(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value - b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value - b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator *(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator *(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value * b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value * b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator /(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator /(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value / b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value / b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator %(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator %(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value - VectorFloor(a.value / b.value) * b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value - VectorFloor(a.value / b.value) * b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator -(WideLane<TNumber> a) public static WideLane<TNumber> operator -(WideLane<TNumber> a)
{ {
return new WideLane<TNumber>(-a.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(-a.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator &(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator &(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value & b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value & b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator |(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator |(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value | b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value | b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator ^(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator ^(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value ^ b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value ^ b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator ~(WideLane<TNumber> a) public static WideLane<TNumber> operator ~(WideLane<TNumber> a)
{ {
return new WideLane<TNumber>(~a.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(~a.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -881,7 +868,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Abs(WideLane<TNumber> value) public static WideLane<TNumber> Abs(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(Vector.Abs(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Abs(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -891,13 +878,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var floored = Vector.Floor(v); var floored = Vector.Floor(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(floored)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(floored);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var floored = Vector.Floor(v); var floored = Vector.Floor(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(floored)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(floored));
} }
return value; return value;
@@ -906,60 +893,60 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Frac(WideLane<TNumber> value) public static WideLane<TNumber> Frac(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(value.value - VectorFloor(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(value.value - VectorFloor(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Sqrt(WideLane<TNumber> value) public static WideLane<TNumber> Sqrt(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(Vector.SquareRoot(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.SquareRoot(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Lerp(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> t) public static WideLane<TNumber> Lerp(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> t)
{ {
return new WideLane<TNumber>(a.value + (b.value - a.value) * t.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value + (b.value - a.value) * t.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> MultipleAdd(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> c) public static WideLane<TNumber> MultiplyAdd(WideLane<TNumber> left, WideLane<TNumber> right, WideLane<TNumber> addend)
{ {
if (typeof(TNumber) == typeof(float)) if (typeof(TNumber) == typeof(float))
{ {
var va = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(a); var va = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(left);
var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(b); var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(right);
var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(c); var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(addend);
var result = Vector.FusedMultiplyAdd(va, vb, vc); var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var va = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(a); var va = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(left);
var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(b); var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(right);
var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(c); var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(addend);
var result = Vector.FusedMultiplyAdd(va, vb, vc); var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return new WideLane<TNumber>((a.value * b.value) + c.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>((left.value * right.value) + addend.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Min(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Min(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Min(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Min(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Max(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Max(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Max(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Max(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Clamp(WideLane<TNumber> value, WideLane<TNumber> min, WideLane<TNumber> max) public static WideLane<TNumber> Clamp(WideLane<TNumber> value, WideLane<TNumber> min, WideLane<TNumber> max)
{ {
return new WideLane<TNumber>(Vector.Clamp(value.value, min.value, max.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Clamp(value.value, min.value, max.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -992,10 +979,10 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880 var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880
var z2_sin = z_sin * z_sin; var z2_sin = z_sin * z_sin;
var poly_sin = MultipleAdd(z2_sin, c9, c7); // c7 + c9*z^2 var poly_sin = MultiplyAdd(z2_sin, c9, c7); // c7 + c9*z^2
poly_sin = MultipleAdd(z2_sin, poly_sin, c5); // c5 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c5); // c5 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c3); // c3 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c3); // c3 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c1); // c1 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c1); // c1 + z^2*(...)
poly_sin = z_sin * poly_sin; // z * (...) poly_sin = z_sin * poly_sin; // z * (...)
return poly_sin * sign_sin; return poly_sin * sign_sin;
@@ -1004,13 +991,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Sin(v); var result = Vector.Sin(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Sin(v); var result = Vector.Sin(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1042,10 +1029,10 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880 var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880
var z2_cos = z_cos * z_cos; var z2_cos = z_cos * z_cos;
var poly_cos = MultipleAdd(z2_cos, c9, c7); var poly_cos = MultiplyAdd(z2_cos, c9, c7);
poly_cos = MultipleAdd(z2_cos, poly_cos, c5); poly_cos = MultiplyAdd(z2_cos, poly_cos, c5);
poly_cos = MultipleAdd(z2_cos, poly_cos, c3); poly_cos = MultiplyAdd(z2_cos, poly_cos, c3);
poly_cos = MultipleAdd(z2_cos, poly_cos, c1); poly_cos = MultiplyAdd(z2_cos, poly_cos, c1);
poly_cos = z_cos * poly_cos; poly_cos = z_cos * poly_cos;
return poly_cos * sign_cos; return poly_cos * sign_cos;
@@ -1054,13 +1041,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Cos(v); var result = Vector.Cos(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Cos(v); var result = Vector.Cos(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1117,17 +1104,17 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880 var c9 = Create(TNumber.CreateTruncating(0.08214589f)); // PI^9 / 362880
var z2_sin = z_sin * z_sin; var z2_sin = z_sin * z_sin;
var poly_sin = MultipleAdd(z2_sin, c9, c7); // c7 + c9*z^2 var poly_sin = MultiplyAdd(z2_sin, c9, c7); // c7 + c9*z^2
poly_sin = MultipleAdd(z2_sin, poly_sin, c5); // c5 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c5); // c5 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c3); // c3 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c3); // c3 + z^2*(...)
poly_sin = MultipleAdd(z2_sin, poly_sin, c1); // c1 + z^2*(...) poly_sin = MultiplyAdd(z2_sin, poly_sin, c1); // c1 + z^2*(...)
poly_sin = z_sin * poly_sin; // z * (...) poly_sin = z_sin * poly_sin; // z * (...)
var z2_cos = z_cos * z_cos; var z2_cos = z_cos * z_cos;
var poly_cos = MultipleAdd(z2_cos, c9, c7); var poly_cos = MultiplyAdd(z2_cos, c9, c7);
poly_cos = MultipleAdd(z2_cos, poly_cos, c5); poly_cos = MultiplyAdd(z2_cos, poly_cos, c5);
poly_cos = MultipleAdd(z2_cos, poly_cos, c3); poly_cos = MultiplyAdd(z2_cos, poly_cos, c3);
poly_cos = MultipleAdd(z2_cos, poly_cos, c1); poly_cos = MultiplyAdd(z2_cos, poly_cos, c1);
poly_cos = z_cos * poly_cos; poly_cos = z_cos * poly_cos;
sin = poly_sin * sign_sin; sin = poly_sin * sign_sin;
@@ -1137,15 +1124,15 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var (sinResult, cosResult) = Vector.SinCos(v); var (sinResult, cosResult) = Vector.SinCos(v);
sin = new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(sinResult)); sin = Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(sinResult));
cos = new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(cosResult)); cos = Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(cosResult));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var (sinResult, cosResult) = Vector.SinCos(v); var (sinResult, cosResult) = Vector.SinCos(v);
sin = new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(sinResult)); sin = Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(sinResult);
cos = new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(cosResult)); cos = Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(cosResult);
} }
else else
{ {
@@ -1175,9 +1162,9 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var vc2 = Create(TNumber.CreateTruncating(0.1333923995)); // 2/15 var vc2 = Create(TNumber.CreateTruncating(0.1333923995)); // 2/15
// x2 * (c1 + c2 * x2) // x2 * (c1 + c2 * x2)
var poly = MultipleAdd(x2, vc2, vc1); var poly = MultiplyAdd(x2, vc2, vc1);
// value * (1 + x2 * poly) // value * (1 + x2 * poly)
return MultipleAdd(x, MultipleAdd(x2, poly, One), Zero); return MultiplyAdd(x, MultiplyAdd(x2, poly, One), Zero);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1202,9 +1189,9 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c2 = Create(TNumber.CreateTruncating(0.0742610f)); var c2 = Create(TNumber.CreateTruncating(0.0742610f));
var c3 = Create(TNumber.CreateTruncating(-0.0187293f)); var c3 = Create(TNumber.CreateTruncating(-0.0187293f));
var term1 = MultipleAdd(x, c3, c2); var term1 = MultiplyAdd(x, c3, c2);
var term2 = MultipleAdd(x, term1, c1); var term2 = MultiplyAdd(x, term1, c1);
var poly = MultipleAdd(x, term2, c0); var poly = MultiplyAdd(x, term2, c0);
var sqrtTerm = Sqrt(One - x); var sqrtTerm = Sqrt(One - x);
var result = poly * sqrtTerm; var result = poly * sqrtTerm;
@@ -1224,7 +1211,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c2 = Create(TNumber.CreateTruncating(-0.19194795f)); var c2 = Create(TNumber.CreateTruncating(-0.19194795f));
var x2 = value * value; var x2 = value * value;
var poly = MultipleAdd(x2, c2, c1); var poly = MultiplyAdd(x2, c2, c1);
return value * poly; return value * poly;
} }
@@ -1251,7 +1238,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
var c2 = Create(TNumber.CreateTruncating(-0.19194795f)); var c2 = Create(TNumber.CreateTruncating(-0.19194795f));
// (c1 + c2 * t2) // (c1 + c2 * t2)
var poly = MultipleAdd(c2, t2, c1); var poly = MultiplyAdd(c2, t2, c1);
// result = t * poly // result = t * poly
var result = t * poly; var result = t * poly;
@@ -1290,13 +1277,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Exp(v); var result = Vector.Exp(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Exp(v); var result = Vector.Exp(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1315,13 +1302,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Log(v); var result = Vector.Log(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Log(v); var result = Vector.Log(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1334,13 +1321,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Log2(v); var result = Vector.Log2(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Log2(v); var result = Vector.Log2(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1353,13 +1340,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Ceiling(v); var result = Vector.Ceiling(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Ceiling(v); var result = Vector.Ceiling(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1372,13 +1359,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Round(v); var result = Vector.Round(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Round(v); var result = Vector.Round(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1391,13 +1378,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Truncate(v); var result = Vector.Truncate(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Truncate(v); var result = Vector.Truncate(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1418,7 +1405,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> CopySign(WideLane<TNumber> magnitude, WideLane<TNumber> sign) public static WideLane<TNumber> CopySign(WideLane<TNumber> magnitude, WideLane<TNumber> sign)
{ {
return new WideLane<TNumber>(Vector.CopySign(magnitude.value, sign.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.CopySign(magnitude.value, sign.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1538,40 +1525,46 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Select(WideLane<TNumber> conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse) public static WideLane<TNumber> Select(WideLane<TNumber> conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse)
{ {
return new WideLane<TNumber>(Vector.ConditionalSelect( return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.ConditionalSelect(
conditionMask.value, conditionMask.value,
ifTrue.value, ifTrue.value,
ifFalse.value)); ifFalse.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Select(byte conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse)
{
throw new NotImplementedException();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> GreaterThan(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> GreaterThan(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.GreaterThan(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.GreaterThan(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> GreaterThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> GreaterThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.GreaterThanOrEqual(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.GreaterThanOrEqual(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> LessThan(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> LessThan(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.LessThan(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LessThan(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> LessThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> LessThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.LessThanOrEqual(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LessThanOrEqual(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Equal(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Equal(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Equals(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Equals(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1617,4 +1610,10 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
return value.ToString(); return value.ToString();
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static implicit operator WideLane<TNumber>(Vector<TNumber> v)
{
return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(v);
}
} }

View File

@@ -405,20 +405,54 @@ public static unsafe partial class MemoryUtility
} }
var i = 0; var i = 0;
if (Vector.IsHardwareAccelerated && a.Length >= Vector<byte>.Count) if (Vector512.IsHardwareAccelerated && a.Length >= Vector512<byte>.Count)
{ {
ref var ptrA = ref MemoryMarshal.GetReference(a); ref var ptrA = ref MemoryMarshal.GetReference(a);
ref var ptrB = ref MemoryMarshal.GetReference(b); ref var ptrB = ref MemoryMarshal.GetReference(b);
var limit = a.Length - Vector<byte>.Count; var limit = a.Length - Vector512<byte>.Count;
for (; i <= limit; i += Vector<byte>.Count) for (; i <= limit; i += Vector512<byte>.Count)
{ {
var vecA = Vector.LoadUnsafe(ref ptrA, (nuint)i); var vecA = Vector512.LoadUnsafe(ref ptrA, (nuint)i);
var vecB = Vector.LoadUnsafe(ref ptrB, (nuint)i); var vecB = Vector512.LoadUnsafe(ref ptrB, (nuint)i);
var mask = Vector.Equals(vecA, Vector<byte>.Zero); var mask = Vector512.Equals(vecA, Vector512<byte>.Zero);
var result = Vector.ConditionalSelect(mask, vecB, vecA); var result = Vector512.ConditionalSelect(mask, vecB, vecA);
result.StoreUnsafe(ref ptrA, (nuint)i);
}
}
else if (Vector256.IsHardwareAccelerated && a.Length >= Vector256<byte>.Count)
{
ref var ptrA = ref MemoryMarshal.GetReference(a);
ref var ptrB = ref MemoryMarshal.GetReference(b);
var limit = a.Length - Vector256<byte>.Count;
for (; i <= limit; i += Vector256<byte>.Count)
{
var vecA = Vector256.LoadUnsafe(ref ptrA, (nuint)i);
var vecB = Vector256.LoadUnsafe(ref ptrB, (nuint)i);
var mask = Vector256.Equals(vecA, Vector256<byte>.Zero);
var result = Vector256.ConditionalSelect(mask, vecB, vecA);
result.StoreUnsafe(ref ptrA, (nuint)i);
}
}
else if (Vector128.IsHardwareAccelerated && a.Length >= Vector128<byte>.Count)
{
ref var ptrA = ref MemoryMarshal.GetReference(a);
ref var ptrB = ref MemoryMarshal.GetReference(b);
var limit = a.Length - Vector128<byte>.Count;
for (; i <= limit; i += Vector128<byte>.Count)
{
var vecA = Vector128.LoadUnsafe(ref ptrA, (nuint)i);
var vecB = Vector128.LoadUnsafe(ref ptrB, (nuint)i);
var mask = Vector128.Equals(vecA, Vector128<byte>.Zero);
var result = Vector128.ConditionalSelect(mask, vecB, vecA);
result.StoreUnsafe(ref ptrA, (nuint)i); result.StoreUnsafe(ref ptrA, (nuint)i);
} }
} }
@@ -440,18 +474,48 @@ public static unsafe partial class MemoryUtility
nuint i = 0u; nuint i = 0u;
if (Vector.IsHardwareAccelerated && length >= (nuint)Vector<byte>.Count) if (Vector512.IsHardwareAccelerated && length >= (nuint)Vector512<byte>.Count)
{ {
var vectorSize = (nuint)Vector<byte>.Count; var vectorSize = (nuint)Vector512<byte>.Count;
var limit = length - vectorSize; var limit = length - vectorSize;
for (; i <= limit; i += vectorSize) for (; i <= limit; i += vectorSize)
{ {
var vecA = Vector.Load(ptrA + i); var vecA = Vector512.Load(ptrA + i);
var vecB = Vector.Load(ptrB + i); var vecB = Vector512.Load(ptrB + i);
var mask = Vector.Equals(vecA, Vector<byte>.Zero); var mask = Vector512.Equals(vecA, Vector512<byte>.Zero);
var result = Vector.ConditionalSelect(mask, vecB, vecA); var result = Vector512.ConditionalSelect(mask, vecB, vecA);
result.Store(ptrA + i);
}
}
else if (Vector256.IsHardwareAccelerated && length >= (nuint)Vector256<byte>.Count)
{
var vectorSize = (nuint)Vector256<byte>.Count;
var limit = length - vectorSize;
for (; i <= limit; i += vectorSize)
{
var vecA = Vector256.Load(ptrA + i);
var vecB = Vector256.Load(ptrB + i);
var mask = Vector256.Equals(vecA, Vector256<byte>.Zero);
var result = Vector256.ConditionalSelect(mask, vecB, vecA);
result.Store(ptrA + i);
}
}
else if (Vector128.IsHardwareAccelerated && length >= (nuint)Vector128<byte>.Count)
{
var vectorSize = (nuint)Vector128<byte>.Count;
var limit = length - vectorSize;
for (; i <= limit; i += vectorSize)
{
var vecA = Vector128.Load(ptrA + i);
var vecB = Vector128.Load(ptrB + i);
var mask = Vector128.Equals(vecA, Vector128<byte>.Zero);
var result = Vector128.ConditionalSelect(mask, vecB, vecA);
result.Store(ptrA + i); result.Store(ptrA + i);
} }
} }

View File

@@ -3,17 +3,17 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework> <TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules> <EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks> <AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<LangVersion>9.0</LangVersion> <LangVersion>9.0</LangVersion>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="4.14.0"> <PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="5.3.0">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference> </PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.14.0" /> <PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.3.0" />
</ItemGroup> </ItemGroup>
</Project> </Project>

View File

@@ -1,9 +1,8 @@
using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Engines; using BenchmarkDotNet.Engines;
using Misaki.HighPerformance.HPC;
using Misaki.HighPerformance.Image; using Misaki.HighPerformance.Image;
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using Misaki.HighPerformance.Mathematics;
using Misaki.HighPerformance.Mathematics.SPMD;
using SkiaSharp; using SkiaSharp;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
@@ -23,6 +22,7 @@ internal unsafe struct MipLevel
public float roughness; public float roughness;
} }
[HPCompute(TargetInstructionSet.AVX2)]
internal unsafe struct GGXMipGenerationJobSPMD : IJobSPMD<float, int> internal unsafe struct GGXMipGenerationJobSPMD : IJobSPMD<float, int>
{ {
public ImageResultFloat image; public ImageResultFloat image;
@@ -47,7 +47,7 @@ internal unsafe struct GGXMipGenerationJobSPMD : IJobSPMD<float, int>
var phi = 2.0f * PI * Xi.x; var phi = 2.0f * PI * Xi.x;
var cosTheta = TFloat.Sqrt((1.0f - Xi.y) / TFloat.MultipleAdd(a * a - 1.0f, Xi.y, 1.0f)); var cosTheta = TFloat.Sqrt((1.0f - Xi.y) / TFloat.MultiplyAdd(a * a - 1.0f, Xi.y, 1.0f));
var sinTheta = TFloat.Sqrt(1.0f - cosTheta * cosTheta); var sinTheta = TFloat.Sqrt(1.0f - cosTheta * cosTheta);
// Spherical to Cartesian coordinates (Halfway vector) // Spherical to Cartesian coordinates (Halfway vector)
@@ -198,7 +198,7 @@ internal unsafe struct GGXMipGenerationJobSPMD<TFloat, TInt> : IJobParallelFor
var phi = 2.0f * PI * Xi.x; var phi = 2.0f * PI * Xi.x;
var cosTheta = TFloat.Sqrt((1.0f - Xi.y) / TFloat.MultipleAdd(a * a - 1.0f, Xi.y, 1.0f)); var cosTheta = TFloat.Sqrt((1.0f - Xi.y) / TFloat.MultiplyAdd(a * a - 1.0f, Xi.y, 1.0f));
var sinTheta = TFloat.Sqrt(1.0f - cosTheta * cosTheta); var sinTheta = TFloat.Sqrt(1.0f - cosTheta * cosTheta);
// Spherical to Cartesian coordinates (Halfway vector) // Spherical to Cartesian coordinates (Halfway vector)

View File

@@ -1,6 +1,6 @@
using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Attributes;
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using Misaki.HighPerformance.Mathematics.SPMD; using Misaki.HighPerformance.HPC;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Test.Benchmark; namespace Misaki.HighPerformance.Test.Benchmark;

View File

@@ -1,9 +1,11 @@
using Misaki.HighPerformance.HPC;
using Misaki.HighPerformance.Mathematics; using Misaki.HighPerformance.Mathematics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using static Misaki.HighPerformance.Mathematics.math; using static Misaki.HighPerformance.Mathematics.math;
namespace Misaki.HighPerformance.Test.Jobs; namespace Misaki.HighPerformance.Test.Jobs;
[HPCompute(TargetInstructionSet.AVX2)]
public static partial class noise public static partial class noise
{ {
// Modulo 289 without a division (only multiplications) // Modulo 289 without a division (only multiplications)

View File

@@ -1,6 +1,6 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using Misaki.HighPerformance.Mathematics; using Misaki.HighPerformance.Mathematics;
using Misaki.HighPerformance.Mathematics.SPMD; using Misaki.HighPerformance.HPC;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
@@ -77,13 +77,14 @@ internal unsafe struct NoiseJobVector : IJobParallel
} }
} }
internal unsafe struct NoiseJobMath : IJobParallel internal unsafe partial struct NoiseJobMath : IJobParallel
{ {
public float* buffers; public float* buffers;
public int width; public int width;
public int height; public int height;
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[HPCompute(TargetInstructionSet.AVX2)]
private static float2 GradientNoiseDirect(float2 uv) private static float2 GradientNoiseDirect(float2 uv)
{ {
uv = noise.mod289(uv); uv = noise.mod289(uv);

View File

@@ -31,7 +31,8 @@
<ProjectReference Include="..\Misaki.HighPerformance.Image\Misaki.HighPerformance.Image.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance.Image\Misaki.HighPerformance.Image.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance.Jobs\Misaki.HighPerformance.Jobs.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance.Jobs\Misaki.HighPerformance.Jobs.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance.LowLevel\Misaki.HighPerformance.LowLevel.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance.LowLevel\Misaki.HighPerformance.LowLevel.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance.Mathematics.SPMD\Misaki.HighPerformance.Mathematics.SPMD.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance.HPC.Generator\Misaki.HighPerformance.HPC.Generator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\Misaki.HighPerformance.HPC\Misaki.HighPerformance.HPC.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance.Mathematics\Misaki.HighPerformance.Mathematics.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance.Mathematics\Misaki.HighPerformance.Mathematics.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance\Misaki.HighPerformance.csproj" /> <ProjectReference Include="..\Misaki.HighPerformance\Misaki.HighPerformance.csproj" />
<ProjectReference Include="..\Misaki.HighPerformance.Analyzer\Misaki.HighPerformance.Analyzer\Misaki.HighPerformance.Analyzer.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" /> <ProjectReference Include="..\Misaki.HighPerformance.Analyzer\Misaki.HighPerformance.Analyzer\Misaki.HighPerformance.Analyzer.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />

View File

@@ -1,56 +1,59 @@
using Misaki.HighPerformance.HPC;
using Misaki.HighPerformance.LowLevel.Buffer; using Misaki.HighPerformance.LowLevel.Buffer;
using Misaki.HighPerformance.LowLevel.Collections; using Misaki.HighPerformance.LowLevel.Collections;
using Misaki.HighPerformance.Test.Benchmark; using Misaki.HighPerformance.Test.Benchmark;
using Misaki.HighPerformance.Test.UnitTest; using Misaki.HighPerformance.Test.UnitTest;
using Misaki.HighPerformance.Test.UnitTest.Jobs; using Misaki.HighPerformance.Test.UnitTest.Jobs;
using System.Buffers; using System.Buffers;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
//BenchmarkRunner.Run<GGXMipGenerationBenchmark>(); //BenchmarkRunner.Run<GGXMipGenerationBenchmark>();
//const int count = 16; const int count = 16;
//var bench = new GGXMipGenerationBenchmark(); var bench = new GGXMipGenerationBenchmark();
//bench.Setup(); bench.Setup();
//for (var i = 0; i < count; i++) for (var i = 0; i < count; i++)
//{ {
// bench.JobGGX(); bench.JobGGX();
//} }
//var sw = System.Diagnostics.Stopwatch.StartNew(); var sw = System.Diagnostics.Stopwatch.StartNew();
//for (var i = 0; i < count; i++) for (var i = 0; i < count; i++)
//{ {
// bench.JobGGX(); bench.JobGGX();
//} }
//sw.Stop(); sw.Stop();
//var avgTime = sw.Elapsed.TotalMilliseconds / count; var avgTime = sw.Elapsed.TotalMilliseconds / count;
//Console.WriteLine($"GGX Mip Generation (Inline): {avgTime} ms"); Console.WriteLine($"GGX Mip Generation (Inline): {avgTime} ms");
//bench.Cleanup(); bench.Cleanup();
//GlobalSetup.GlobalInitialize(null!); //GlobalSetup.GlobalInitialize(null!);
//TestJobSystem.Initialize(null!); //TestJobSystem.Initialize(null!);
AllocationManager.Initialize(); //AllocationManager.Initialize();
Console.WriteLine(0); //Console.WriteLine(0);
for (var i = 0; i < 64; i++) //for (var i = 0; i < 64; i++)
{ //{
var size = Random.Shared.Next(2048, 8192 * 2); // var size = Random.Shared.Next(2048, 8192 * 2);
var arr = new UnsafeArray<Guid>(size, AllocationHandle.TLSF); // AllocationHandle.FreeList // var arr = new UnsafeArray<Guid>(size, AllocationHandle.TLSF); // AllocationHandle.FreeList
arr.Dispose(); // arr.Dispose();
} //}
Thread.Sleep(1000); //Thread.Sleep(1000);
Console.WriteLine(1); //Console.WriteLine(1);
for (var i = 0; i < 64; i++) //for (var i = 0; i < 64; i++)
{ //{
var size = Random.Shared.Next(2048, 8192 * 2); // var size = Random.Shared.Next(2048, 8192 * 2);
var arr = new UnsafeArray<Guid>(size, AllocationHandle.TLSF); // AllocationHandle.FreeList // var arr = new UnsafeArray<Guid>(size, AllocationHandle.TLSF); // AllocationHandle.FreeList
arr.Dispose(); // arr.Dispose();
} //}
AllocationManager.Dispose(); //AllocationManager.Dispose();
Console.Read(); //Console.Read();

View File

@@ -1,5 +1,6 @@
using Misaki.HighPerformance.Mathematics.SPMD; using Misaki.HighPerformance.HPC;
using System.Numerics; using System.Numerics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Test.UnitTest.Jobs; namespace Misaki.HighPerformance.Test.UnitTest.Jobs;

View File

@@ -1,6 +1,6 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using Misaki.HighPerformance.Mathematics; using Misaki.HighPerformance.Mathematics;
using Misaki.HighPerformance.Mathematics.SPMD; using Misaki.HighPerformance.HPC;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Test.UnitTest.Jobs; namespace Misaki.HighPerformance.Test.UnitTest.Jobs;
@@ -116,8 +116,14 @@ internal struct DistanceJob : IJobSPMD<float>
} }
[TestClass] [TestClass]
public class SPMDTest public partial class SPMDTest
{ {
[HPCompute(TargetInstructionSet.AVX2)]
private static void Test_SPMD(float a, out float sin, out float cos)
{
math.sincos(a, out sin, out cos);
}
[TestMethod] [TestMethod]
public unsafe void TestSPMDVectorDot() public unsafe void TestSPMDVectorDot()
{ {

View File

@@ -2,7 +2,7 @@ using Misaki.HighPerformance.Jobs;
using Misaki.HighPerformance.LowLevel.Buffer; using Misaki.HighPerformance.LowLevel.Buffer;
using Misaki.HighPerformance.LowLevel.Collections; using Misaki.HighPerformance.LowLevel.Collections;
using Misaki.HighPerformance.LowLevel.Utilities; using Misaki.HighPerformance.LowLevel.Utilities;
using Misaki.HighPerformance.Mathematics.SPMD; using Misaki.HighPerformance.HPC;
using Misaki.HighPerformance.Test.Jobs; using Misaki.HighPerformance.Test.Jobs;
namespace Misaki.HighPerformance.Test.UnitTest.Jobs; namespace Misaki.HighPerformance.Test.UnitTest.Jobs;

View File

@@ -4,11 +4,16 @@
<Project Path="Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer.Package/Misaki.HighPerformance.Analyzer.Package.csproj" /> <Project Path="Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer.Package/Misaki.HighPerformance.Analyzer.Package.csproj" />
<Project Path="Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer.csproj" /> <Project Path="Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer/Misaki.HighPerformance.Analyzer.csproj" />
</Folder> </Folder>
<Project Path="Misaki.HighPerformance.HPC.Generator/Misaki.HighPerformance.HPC.Generator.csproj" Id="2b8a9c0d-ce6d-4064-8bcb-517001f631d3">
<Build Solution="Release|*" Project="false" />
</Project>
<Project Path="Misaki.HighPerformance.HPC/Misaki.HighPerformance.HPC.csproj">
<Build Solution="Release|*" Project="false" />
</Project>
<Project Path="Misaki.HighPerformance.Image/Misaki.HighPerformance.Image.csproj" /> <Project Path="Misaki.HighPerformance.Image/Misaki.HighPerformance.Image.csproj" />
<Project Path="Misaki.HighPerformance.Jobs/Misaki.HighPerformance.Jobs.csproj" /> <Project Path="Misaki.HighPerformance.Jobs/Misaki.HighPerformance.Jobs.csproj" />
<Project Path="Misaki.HighPerformance.LowLevel/Misaki.HighPerformance.LowLevel.csproj" /> <Project Path="Misaki.HighPerformance.LowLevel/Misaki.HighPerformance.LowLevel.csproj" />
<Project Path="Misaki.HighPerformance.Mathematics.CodeGen/Misaki.HighPerformance.Mathematics.CodeGen.csproj" /> <Project Path="Misaki.HighPerformance.Mathematics.CodeGen/Misaki.HighPerformance.Mathematics.CodeGen.csproj" />
<Project Path="Misaki.HighPerformance.Mathematics.SPMD/Misaki.HighPerformance.Mathematics.SPMD.csproj" />
<Project Path="Misaki.HighPerformance.Mathematics/Misaki.HighPerformance.Mathematics.csproj" /> <Project Path="Misaki.HighPerformance.Mathematics/Misaki.HighPerformance.Mathematics.csproj" />
<Project Path="Misaki.HighPerformance.Test/Misaki.HighPerformance.Test.csproj" /> <Project Path="Misaki.HighPerformance.Test/Misaki.HighPerformance.Test.csproj" />
<Project Path="Misaki.HighPerformance/Misaki.HighPerformance.csproj" /> <Project Path="Misaki.HighPerformance/Misaki.HighPerformance.csproj" />