4 Commits

Author SHA1 Message Date
e98ae96dd6 Refactor: switch to IR-based HPC codegen pipeline
Major rewrite replacing Roslyn syntax rewriters with an intermediate representation (IR) architecture. Adds IR nodes, analyzer, type resolver, and node rewriter base for optimization passes (e.g., FMA fusion). Refactors AVX2 backend to emit from IR and updates generator pipeline for analysis, optimization, and emission. Removes legacy rewriter classes. Adds polyfills for modern C# features and updates tests and project settings for latest language version. This enables advanced optimizations, easier backend targeting, and future extensibility.
2026-05-07 00:21:08 +09:00
b9537d91da Refactor APIContext, add AVX2 sin/cos codegen, FMA rewrite
- Refactored namespaces from .VectorAPI to .APIContext for clarity.
- Enhanced Avx2APIContext/IVectorAPIContext to support void returns.
- Added GenerateSinCosUtilityMethods for AVX2, emitting vectorized Sin/Cos/SinCos for float/double.
- Introduced HPCOptimizerRewriter for advanced SPMD type handling.
- Refactored HPCRewriter to use SemanticModel, support FMA pattern rewriting, and delegate SPMD logic.
- Updated AVX2Rewriter for new base and improved math mapping.
- Made UtilityTemplate generic and type-safe for sin/cos.
- Updated NoiseJob3D/NoiseJobVector for [HPCompute] attribute and partial struct.
- Fixed solution file project ordering and inclusion.
2026-05-06 22:27:24 +09:00
fd2d60c8f1 Refactor vector API codegen and WideLane conversions
- Introduce IVectorAPIContext abstraction and supporting types for vectorized code generation
- Add Avx2APIContext and UtilityTemplate for AVX2-specific code emission
- Dynamically generate AVX2 sine methods in AVX2Rewriter
- Refactor WideLane<TNumber> to use Unsafe.BitCast for all Vector conversions
- Update all WideLane operators and math methods to use Unsafe.BitCast
- Change MultiplyAdd parameter names for clarity
- Remove static indices field in favor of Vector<TNumber>.Indices
- Add implicit conversion from Vector<TNumber> to WideLane<TNumber>
- Update tests and program files for compatibility
2026-05-06 19:20:15 +09:00
c8f78f9d02 Refactor SPMD to HPC; add SIMD source generators
Major namespace migration from SPMD to HPC across all code, templates, and projects. Introduced Misaki.HighPerformance.HPC.Generator with Roslyn-based source generators for SIMD code (e.g., AVX2), including attribute and method generators. Renamed MultipleAdd to MultiplyAdd in all lanes and updated usages. Added AVX2 utility methods via codegen. Updated tests, benchmarks, and project references to use the new framework. Improved SIMD memory utilities and modernized project files. Removed legacy SPMD project from the solution.
2026-05-06 13:43:58 +09:00
500 changed files with 5011 additions and 495698 deletions

View File

@@ -9,9 +9,6 @@ on:
jobs: jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
working-directory: ./src
steps: steps:
- name: Checkout repository - name: Checkout repository
@@ -23,8 +20,7 @@ jobs:
dotnet-version: 10.0.x dotnet-version: 10.0.x
- name: Restore dependencies - name: Restore dependencies
run: | run: dotnet restore
dotnet restore
- name: Run tests - name: Run tests
# Run all test projects in the repository. If any test fails, this step will exit non-zero # Run all test projects in the repository. If any test fails, this step will exit non-zero

2
.gitignore vendored
View File

@@ -12,8 +12,6 @@
.code-review-graph/ .code-review-graph/
.github/instructions/ .github/instructions/
docfx/
# User-specific files (MonoDevelop/Xamarin Studio) # User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs *.userprefs

View File

@@ -0,0 +1,119 @@
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Avx2APIContext : IVectorAPIContext
{
private readonly List<string> _statements = new();
private int _varCount = 0;
private string? _lastAssignedVariable;
public string? LastAssignedVariable => _lastAssignedVariable;
public string GetVectorType()
{
return "Vector256";
}
public string GetVectorType<T>()
{
return $"Vector256<{VectorAPIContext.GetTypeName<T>()}>";
}
public Expression Call(string methodName, params string[] args)
{
return new Expression(this, $"{methodName}({string.Join(", ", args)})");
}
public Expression Assign(Expression expr, string? varName = null, bool isNew = true)
{
varName ??= $"v{_varCount++}";
var statement = isNew ? $"var {varName} = {expr.Code};" : $"{varName} = {expr.Code};";
_statements.Add(statement);
_lastAssignedVariable = varName;
expr.Clear();
return new Expression(this, varName);
}
public Code Return(Expression? expr)
{
if (expr != null)
{
var statement = $"return {expr.Code};";
_statements.Add(statement);
expr.Clear();
}
var fullCode = new Code(_statements);
Reset();
return fullCode;
}
public Expression Create(string value)
{
return new Expression(this, $"{GetVectorType()}.Create({value})");
}
public Expression Zero<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Zero");
}
public Expression One<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.One");
}
public Expression Count<T>()
{
return new Expression(this, $"{GetVectorType<T>()}.Count");
}
public Expression Add(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Add({a}, {b})");
}
public Expression Multiply(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Multiply({a}, {b})");
}
public Expression Subtract(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Subtract({a}, {b})");
}
public Expression Divide(Expression a, Expression b)
{
return new Expression(this, $"Avx2.Divide({a}, {b})");
}
public Expression MultiplyAdd(Expression left, Expression right, Expression addend)
{
return new Expression(this, $"Fma.MultiplyAdd({left}, {right}, {addend})");
}
public Expression Round(Expression value)
{
return new Expression(this, $"Avx2.RoundToNearestInteger({value})");
}
public Expression Abs(Expression value)
{
return new Expression(this, $"{GetVectorType()}.Abs({value})");
}
public void Reset()
{
_statements.Clear();
_varCount = 0;
}
}
}

View File

@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.APIContext
{
internal class Expression
{
public IVectorAPIContext API
{
get;
}
public string Code
{
get; private set;
}
public Expression(IVectorAPIContext api, string code)
{
API = api;
Code = code;
}
public Expression Assign(string? varName = null, bool isNew = true)
{
return API.Assign(this, varName, isNew);
}
public void Clear()
{
Code = string.Empty;
}
public override string ToString()
{
return Code;
}
public static Expression operator +(Expression a, Expression b)
{
return a.API.Add(a, b);
}
public static Expression operator -(Expression a, Expression b)
{
return a.API.Subtract(a, b);
}
public static Expression operator *(Expression a, Expression b)
{
return a.API.Multiply(a, b);
}
public static Expression operator /(Expression a, Expression b)
{
return a.API.Divide(a, b);
}
}
internal record Code
{
private readonly string[] _statements;
public Code(IEnumerable<string> statements)
{
_statements = statements.ToArray();
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
foreach (var stmt in _statements)
{
sb.AppendLine(lineIndentation + stmt);
}
return sb.ToString();
}
}
internal record Method
{
public string Modifier
{
get;
}
public string ReturnType
{
get;
}
public string Name
{
get;
}
public string[] Parameters
{
get;
}
public Code Body
{
get;
}
public Method(string modifier, string returnType, string name, string[] parameters, Code body)
{
Modifier = modifier;
ReturnType = returnType;
Name = name;
Parameters = parameters;
Body = body;
}
public string GetFullCode(string lineIndentation)
{
var sb = new StringBuilder();
sb.AppendLine(lineIndentation + $"{Modifier} {ReturnType} {Name}({string.Join(", ", Parameters)})");
sb.AppendLine(lineIndentation + "{");
sb.Append(Body.GetFullCode(lineIndentation + " "));
sb.AppendLine(lineIndentation + "}");
return sb.ToString();
}
}
internal interface IVectorAPIContext
{
string? LastAssignedVariable
{
get;
}
string GetVectorType();
string GetVectorType<T>();
Expression Call(string methodName, params string[] args);
Expression Assign(Expression expr, string? varName = null, bool isNew = true);
Code Return(Expression? expr);
Expression Create(string value);
Expression Zero<T>();
Expression One<T>();
Expression Count<T>();
Expression Add(Expression a, Expression b);
Expression Subtract(Expression a, Expression b);
Expression Multiply(Expression a, Expression b);
Expression Divide(Expression a, Expression b);
Expression MultiplyAdd(Expression left, Expression right, Expression addend);
Expression Round(Expression value);
Expression Abs(Expression value);
void Reset();
}
internal static class VectorAPIContext
{
public static string GetTypeName<T>()
{
return typeof(T) switch
{
_ when typeof(T) == typeof(float) => "float",
_ when typeof(T) == typeof(double) => "double",
_ when typeof(T) == typeof(byte) => "byte",
_ when typeof(T) == typeof(short) => "short",
_ when typeof(T) == typeof(int) => "int",
_ when typeof(T) == typeof(uint) => "uint",
_ when typeof(T) == typeof(long) => "long",
_ when typeof(T) == typeof(ulong) => "ulong",
_ => throw new NotSupportedException($"Type {typeof(T)} is not supported in vector operations.")
};
}
}
}

View File

@@ -0,0 +1,43 @@
using Microsoft.CodeAnalysis;
using Misaki.HighPerformance.HPC.Generator.APIContext;
namespace Misaki.HighPerformance.HPC.Generator
{
/// <summary>
/// Generates the <c>AVX2Utility</c> static class containing polynomial
/// approximations for transcendental functions (Sin, Cos, SinCos, etc.)
/// that have no built-in AVX2 hardware intrinsic.
///
/// <para>These methods are called by the <c>AVX2Backend</c> emitter when it
/// encounters <see cref="IR.HPCUnaryKind.Sin"/>, <see cref="IR.HPCUnaryKind.Cos"/>,
/// and similar IR nodes.</para>
/// </summary>
[Generator]
public class AVX2UtilityGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var api = new Avx2APIContext();
var sinCosMethods = UtilityTemplate.GenerateSinCosUtilityMethods(api, " ");
var source = @$"
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.HPC
{{
public static class AVX2Utility
{{
{sinCosMethods}
}}
}}";
ctx.AddSource("AVX2Utility.g.cs", source);
});
}
}
}

View File

@@ -0,0 +1,461 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.IR;
using System;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.Analysis
{
/// <summary>
/// Walks a Roslyn <see cref="MethodDeclarationSyntax"/> and produces an
/// <see cref="HPCMethodIR"/> that is completely independent of Roslyn after
/// this phase completes.
///
/// <para>Uses <see cref="CSharpSyntaxWalker"/> (read-only walk) rather than
/// <see cref="CSharpSyntaxRewriter"/> so that the semantic model is never
/// consulted during rewriting — all queries happen here, once, upfront.</para>
/// </summary>
internal sealed class HPCAnalyzer : CSharpSyntaxWalker
{
// ── State ─────────────────────────────────────────────────────────────
private readonly HPCTypeResolver _typeResolver;
/// <summary>Statement stack — top is the current block being built.</summary>
private readonly Stack<List<HPCStmt>> _blocks = new();
// ── Construction ─────────────────────────────────────────────────────
public HPCAnalyzer(SemanticModel semanticModel)
{
_typeResolver = new HPCTypeResolver(semanticModel);
}
// ── Public API ───────────────────────────────────────────────────────
/// <summary>
/// Analyses <paramref name="method"/> and returns the completed IR.
/// </summary>
public HPCMethodIR Analyze(MethodDeclarationSyntax method, HPComputeMethodInfo info)
{
_typeResolver.RegisterConstraints(method);
_blocks.Clear();
// Push the root block
_blocks.Push(new List<HPCStmt>());
if (method.Body is not null)
Visit(method.Body);
else if (method.ExpressionBody is not null)
{
// Expression-bodied method: treat as `return <expr>;`
var expr = AnalyzeExpression(method.ExpressionBody.Expression);
var returnType = _typeResolver.Resolve(method.ExpressionBody.Expression)
?? HPCType.Float;
Emit(new HPCReturn(expr));
}
var body = _blocks.Pop().ToArray();
return new HPCMethodIR
{
Name = method.Identifier.Text,
OriginalName = method.Identifier.Text,
ReturnType = ResolveReturnType(method),
Parameters = BuildParameters(method),
Body = body,
TargetISA = info.InstructionSet,
Precision = info.Precision,
Mode = info.Mode,
ContainingNamespace = info.MethodSymbol.ContainingNamespace.ToDisplayString(),
ContainingTypeName = info.MethodSymbol.ContainingType.Name,
};
}
// ── Statement visitors ───────────────────────────────────────────────
public override void VisitBlock(BlockSyntax node)
{
// Visit each statement inside the block in order
foreach (var stmt in node.Statements)
Visit(stmt);
}
public override void VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
{
foreach (var declarator in node.Declaration.Variables)
{
if (declarator.Initializer is null)
continue;
var init = AnalyzeExpression(declarator.Initializer.Value);
var hpcType = _typeResolver.Resolve(declarator.Initializer.Value)
?? InferFromExpr(init);
Emit(new HPCVarDecl(declarator.Identifier.Text, hpcType, init));
}
}
public override void VisitExpressionStatement(ExpressionStatementSyntax node)
{
if (node.Expression is AssignmentExpressionSyntax assign)
{
var target = assign.Left.ToString(); // simple name or member access text
var value = AnalyzeExpression(assign.Right);
Emit(new HPCAssignment(target, value));
}
else
{
var expr = AnalyzeExpression(node.Expression);
Emit(new HPCExprStmt(expr));
}
}
public override void VisitReturnStatement(ReturnStatementSyntax node)
{
var value = node.Expression is null ? null : AnalyzeExpression(node.Expression);
Emit(new HPCReturn(value));
}
public override void VisitIfStatement(IfStatementSyntax node)
{
var condition = AnalyzeExpression(node.Condition);
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var thenBody = _blocks.Pop().ToArray();
HPCStmt[]? elseBody = null;
if (node.Else is not null)
{
_blocks.Push(new List<HPCStmt>());
Visit(node.Else.Statement);
elseBody = _blocks.Pop().ToArray();
}
Emit(new HPCIf(condition, thenBody, elseBody));
}
public override void VisitForStatement(ForStatementSyntax node)
{
// Only handles the simple canonical for (var i = init; i <cond>; i++) form.
if (node.Declaration?.Variables.Count != 1 ||
node.Condition is null ||
node.Incrementors.Count != 1)
{
// Fall back: treat body as pass-through
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var rawBody = _blocks.Pop().ToArray();
foreach (var s in rawBody)
Emit(s);
return;
}
var decl = node.Declaration.Variables[0];
var initExpr = decl.Initializer is not null
? AnalyzeExpression(decl.Initializer.Value)
: new HPCLiteral("0", HPCType.Int);
var iterDecl = new HPCVarDecl(decl.Identifier.Text, HPCType.Int, initExpr);
var cond = AnalyzeExpression(node.Condition);
var incr = AnalyzeExpression(node.Incrementors[0]);
_blocks.Push(new List<HPCStmt>());
Visit(node.Statement);
var body = _blocks.Pop().ToArray();
Emit(new HPCForLoop(iterDecl, cond, incr, body));
}
// ── Expression analysis ───────────────────────────────────────────────
private HPCExpr AnalyzeExpression(ExpressionSyntax expr) => expr switch
{
BinaryExpressionSyntax n => AnalyzeBinary(n),
PrefixUnaryExpressionSyntax n => AnalyzePrefixUnary(n),
PostfixUnaryExpressionSyntax n => AnalyzePostfixUnary(n),
InvocationExpressionSyntax n => AnalyzeInvocation(n),
MemberAccessExpressionSyntax n => AnalyzeMemberAccess(n),
LiteralExpressionSyntax n => AnalyzeLiteral(n),
IdentifierNameSyntax n => AnalyzeIdentifier(n),
ParenthesizedExpressionSyntax n => AnalyzeExpression(n.Expression),
CastExpressionSyntax n => AnalyzeCast(n),
_ => FallbackExpr(expr)
};
private HPCExpr AnalyzeBinary(BinaryExpressionSyntax node)
{
var left = AnalyzeExpression(node.Left);
var right = AnalyzeExpression(node.Right);
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(left);
var kind = node.Kind() switch
{
SyntaxKind.AddExpression => HPCBinaryKind.Add,
SyntaxKind.SubtractExpression => HPCBinaryKind.Subtract,
SyntaxKind.MultiplyExpression => HPCBinaryKind.Multiply,
SyntaxKind.DivideExpression => HPCBinaryKind.Divide,
SyntaxKind.ModuloExpression => HPCBinaryKind.Modulo,
SyntaxKind.BitwiseAndExpression => HPCBinaryKind.BitwiseAnd,
SyntaxKind.BitwiseOrExpression => HPCBinaryKind.BitwiseOr,
SyntaxKind.ExclusiveOrExpression => HPCBinaryKind.BitwiseXor,
SyntaxKind.LeftShiftExpression => HPCBinaryKind.ShiftLeft,
SyntaxKind.RightShiftExpression => HPCBinaryKind.ShiftRight,
SyntaxKind.EqualsExpression => HPCBinaryKind.Equal,
SyntaxKind.NotEqualsExpression => HPCBinaryKind.NotEqual,
SyntaxKind.LessThanExpression => HPCBinaryKind.LessThan,
SyntaxKind.LessThanOrEqualExpression => HPCBinaryKind.LessThanOrEqual,
SyntaxKind.GreaterThanExpression => HPCBinaryKind.GreaterThan,
SyntaxKind.GreaterThanOrEqualExpression => HPCBinaryKind.GreaterThanOrEqual,
_ => throw new NotSupportedException($"Binary operator not supported: {node.Kind()}")
};
return new HPCBinaryOp(kind, left, right, hpcType);
}
private HPCExpr AnalyzePrefixUnary(PrefixUnaryExpressionSyntax node)
{
var operand = AnalyzeExpression(node.Operand);
var hpcType = _typeResolver.Resolve(node) ?? InferFromExpr(operand);
var kind = node.Kind() switch
{
SyntaxKind.UnaryMinusExpression => HPCUnaryKind.Negate,
SyntaxKind.BitwiseNotExpression => HPCUnaryKind.BitwiseNot,
_ => throw new NotSupportedException($"Prefix unary not supported: {node.Kind()}")
};
return new HPCUnaryOp(kind, operand, hpcType);
}
private HPCExpr AnalyzePostfixUnary(PostfixUnaryExpressionSyntax node)
{
// i++ / i-- — treat as i + 1 / i - 1 in an IR context (no mutation semantics)
var operand = AnalyzeExpression(node.Operand);
var hpcType = InferFromExpr(operand);
var one = new HPCLiteral("1", hpcType);
return node.Kind() == SyntaxKind.PostIncrementExpression
? new HPCBinaryOp(HPCBinaryKind.Add, operand, one, hpcType)
: new HPCBinaryOp(HPCBinaryKind.Subtract, operand, one, hpcType);
}
private HPCExpr AnalyzeInvocation(InvocationExpressionSyntax node)
{
// ── Instance method on SPMD lane: lane.Add(b), lane.MultiplyAdd(a,b,c) ──
if (node.Expression is MemberAccessExpressionSyntax memberAccess)
{
var receiverType = _typeResolver.Resolve(memberAccess.Expression);
var methodName = memberAccess.Name.Identifier.Text;
if (receiverType != null)
{
// Unary math on the receiver: lane.Abs(), lane.Sqrt(), etc.
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
node.ArgumentList.Arguments.Count == 0)
{
var receiver = AnalyzeExpression(memberAccess.Expression);
return new HPCUnaryOp(unaryKind, receiver, receiverType);
}
// Multi-arg static-style methods called as instance methods via ISPMDLane
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
{
var args = new List<HPCExpr> { AnalyzeExpression(memberAccess.Expression) };
foreach (var arg in node.ArgumentList.Arguments)
args.Add(AnalyzeExpression(arg.Expression));
return new HPCIntrinsicCall(intrinsic, args.ToArray(), receiverType);
}
}
// Static call on the lane type: TSelf.Sin(v), TSelf.Load(ref p), etc.
if (memberAccess.Expression is IdentifierNameSyntax typeIdent &&
_typeResolver.IsSPMDTypeParam(typeIdent.Identifier.Text))
{
var elemType = _typeResolver.ResolveByName(typeIdent.Identifier.Text) ?? HPCType.Float;
if (TryMapUnaryMethod(methodName, out var unaryKind) &&
node.ArgumentList.Arguments.Count == 1)
{
var arg = AnalyzeExpression(node.ArgumentList.Arguments[0].Expression);
return new HPCUnaryOp(unaryKind, arg, elemType);
}
if (TryMapIntrinsicMethod(methodName, out var intrinsic))
{
var args = new List<HPCExpr>();
foreach (var arg in node.ArgumentList.Arguments)
args.Add(AnalyzeExpression(arg.Expression));
return new HPCIntrinsicCall(intrinsic, args.ToArray(), elemType);
}
}
}
// ── Fall through: unknown call, emit as pass-through ──────────────
return FallbackExpr(node);
}
private HPCExpr AnalyzeMemberAccess(MemberAccessExpressionSyntax node)
{
var receiverType = _typeResolver.Resolve(node.Expression);
var memberName = node.Name.Identifier.Text;
if (receiverType != null)
{
// LaneWidth → Count (property remap)
var mapped = memberName switch
{
"LaneWidth" => "Count",
_ => memberName
};
var receiver = AnalyzeExpression(node.Expression);
return new HPCPropertyAccess(receiver, mapped, receiverType);
}
return FallbackExpr(node);
}
private static HPCExpr AnalyzeLiteral(LiteralExpressionSyntax node)
{
var text = node.Token.ValueText;
var hpcType = node.Kind() switch
{
SyntaxKind.NumericLiteralExpression when node.Token.Value is float => HPCType.Float,
SyntaxKind.NumericLiteralExpression when node.Token.Value is double => HPCType.Double,
SyntaxKind.NumericLiteralExpression when node.Token.Value is int => HPCType.Int,
SyntaxKind.NumericLiteralExpression when node.Token.Value is long => HPCType.Long,
_ => HPCType.Float
};
return new HPCLiteral(text, hpcType);
}
private HPCExpr AnalyzeIdentifier(IdentifierNameSyntax node)
{
var name = node.Identifier.Text;
var hpcType = _typeResolver.Resolve(node)
?? _typeResolver.ResolveByName(name)
?? HPCType.Float;
return new HPCVarRef(name, hpcType);
}
private HPCExpr AnalyzeCast(CastExpressionSyntax node)
{
var operand = AnalyzeExpression(node.Expression);
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
return new HPCIntrinsicCall(HPCIntrinsic.Cast, [operand], hpcType);
}
private HPCExpr FallbackExpr(ExpressionSyntax node)
{
// Preserve anything we don't understand as a verbatim pass-through
var hpcType = _typeResolver.Resolve(node) ?? HPCType.Float;
return new HPCPassThroughCall(
node.ToString(),
Array.Empty<HPCExpr>(),
hpcType);
}
// ── Method mapping tables ─────────────────────────────────────────────
private static readonly Dictionary<string, HPCUnaryKind> s_unaryMethods = new()
{
["Abs"] = HPCUnaryKind.Abs,
["Sqrt"] = HPCUnaryKind.Sqrt,
["Floor"] = HPCUnaryKind.Floor,
["Ceil"] = HPCUnaryKind.Ceil,
["Round"] = HPCUnaryKind.Round,
["Trunc"] = HPCUnaryKind.Trunc,
["Frac"] = HPCUnaryKind.Frac,
["Sign"] = HPCUnaryKind.Sign,
["Saturate"] = HPCUnaryKind.Saturate,
["Rcp"] = HPCUnaryKind.Rcp,
["Rsqrt"] = HPCUnaryKind.Rsqrt,
["Sin"] = HPCUnaryKind.Sin,
["Cos"] = HPCUnaryKind.Cos,
["Tan"] = HPCUnaryKind.Tan,
["Asin"] = HPCUnaryKind.Asin,
["Acos"] = HPCUnaryKind.Acos,
["Atan"] = HPCUnaryKind.Atan,
["Exp"] = HPCUnaryKind.Exp,
["Exp2"] = HPCUnaryKind.Exp2,
["Log"] = HPCUnaryKind.Log,
["Log2"] = HPCUnaryKind.Log2,
};
private static readonly Dictionary<string, HPCIntrinsic> s_intrinsicMethods = new()
{
["MultiplyAdd"] = HPCIntrinsic.MultiplyAdd,
["MultiplySubtract"] = HPCIntrinsic.MultiplySubtract,
["Atan2"] = HPCIntrinsic.Atan2,
["Pow"] = HPCIntrinsic.Pow,
["SinCos"] = HPCIntrinsic.SinCos,
["Lerp"] = HPCIntrinsic.Lerp,
["Min"] = HPCIntrinsic.Min,
["Max"] = HPCIntrinsic.Max,
["Clamp"] = HPCIntrinsic.Clamp,
["Select"] = HPCIntrinsic.Select,
["CopySign"] = HPCIntrinsic.CopySign,
["ReduceAdd"] = HPCIntrinsic.ReduceAdd,
["ReduceMax"] = HPCIntrinsic.ReduceMax,
["ReduceMin"] = HPCIntrinsic.ReduceMin,
["Load"] = HPCIntrinsic.Load,
["MaskLoad"] = HPCIntrinsic.MaskLoad,
["Gather"] = HPCIntrinsic.Gather,
["MaskGather"] = HPCIntrinsic.MaskGather,
["Store"] = HPCIntrinsic.Store,
["MaskStore"] = HPCIntrinsic.MaskStore,
["Scatter"] = HPCIntrinsic.Scatter,
["MaskScatter"] = HPCIntrinsic.MaskScatter,
["CompressStore"] = HPCIntrinsic.CompressStore,
};
private static bool TryMapUnaryMethod(string name, out HPCUnaryKind kind)
=> s_unaryMethods.TryGetValue(name, out kind);
private static bool TryMapIntrinsicMethod(string name, out HPCIntrinsic intrinsic)
=> s_intrinsicMethods.TryGetValue(name, out intrinsic);
// ── Helpers ───────────────────────────────────────────────────────────
private void Emit(HPCStmt stmt) => _blocks.Peek().Add(stmt);
private static HPCType InferFromExpr(HPCExpr expr) => expr.Type;
private HPCType ResolveReturnType(MethodDeclarationSyntax method)
{
var typeText = method.ReturnType.ToString();
if (typeText == "void")
return new HPCType("void", IsFloatingPoint: false);
return _typeResolver.ResolveByName(typeText)
?? HPCType.FromElementName(typeText);
}
private HPCParameter[] BuildParameters(MethodDeclarationSyntax method)
{
var result = new List<HPCParameter>();
foreach (var param in method.ParameterList.Parameters)
{
// Skip generic SPMD type parameters (they become the vector type in the backend)
var paramTypeName = param.Type?.ToString() ?? "object";
if (_typeResolver.IsSPMDTypeParam(paramTypeName))
continue;
var hpcType = _typeResolver.ResolveByName(paramTypeName)
?? HPCType.FromElementName(paramTypeName);
bool isOut = false, isRef = false;
foreach (var mod in param.Modifiers)
{
if (mod.IsKind(SyntaxKind.OutKeyword))
isOut = true;
if (mod.IsKind(SyntaxKind.RefKeyword))
isRef = true;
}
result.Add(new HPCParameter(param.Identifier.Text, hpcType, isOut, isRef));
}
return result.ToArray();
}
}
}

View File

@@ -0,0 +1,133 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Collections.Generic;
namespace Misaki.HighPerformance.HPC.Generator.Analysis
{
/// <summary>
/// Resolves HPC-specific types from Roslyn's semantic model.
/// Centralises all "what is the scalar element type of this expression?"
/// logic that was previously duplicated across <c>HPCRewriter</c> and
/// <c>HPCOptimizerRewriter</c>.
/// </summary>
internal sealed class HPCTypeResolver
{
private readonly SemanticModel _semanticModel;
/// <summary>
/// Maps generic type-parameter names (e.g. <c>"TLane0"</c>) to their
/// resolved scalar element types (e.g. <c>"float"</c>).
/// Populated by <see cref="RegisterConstraints"/>.
/// </summary>
private readonly Dictionary<string, string> _typeParamToPrimitive = new();
public HPCTypeResolver(SemanticModel semanticModel)
{
_semanticModel = semanticModel;
}
// ── Type-parameter registration ───────────────────────────────────────
/// <summary>
/// Scans the generic constraints on <paramref name="method"/> and
/// registers every type parameter constrained to
/// <c>ISPMDLane&lt;TSelf, TNumber&gt;</c>.
/// Must be called before <see cref="Resolve"/> is used on the method body.
/// </summary>
public void RegisterConstraints(MethodDeclarationSyntax method)
{
_typeParamToPrimitive.Clear();
foreach (var clause in method.ConstraintClauses)
{
var typeParamName = clause.Name.Identifier.Text;
foreach (var constraint in clause.Constraints)
{
if (constraint is TypeConstraintSyntax typeConstraint &&
typeConstraint.Type is GenericNameSyntax generic &&
generic.Identifier.Text == "ISPMDLane" &&
generic.TypeArgumentList.Arguments.Count == 2)
{
// ISPMDLane<TSelf, TNumber> — TNumber is the scalar element type
var primitiveTypeName = generic.TypeArgumentList.Arguments[1].ToString();
_typeParamToPrimitive[typeParamName] = primitiveTypeName;
}
}
}
}
// ── Expression type resolution ────────────────────────────────────────
/// <summary>
/// Returns the <see cref="HPCType"/> for <paramref name="node"/>, or
/// <c>null</c> if the node is not an HPC-typed expression (i.e. not a
/// <c>WideLane&lt;T&gt;</c>, a constrained SPMD type-parameter, or a
/// plain scalar float/double).
/// </summary>
public HPCType? Resolve(SyntaxNode node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
var type = typeInfo.Type;
if (type is null) return null;
// WideLane<float>, WideLane<double>, …
if (type.Name == "WideLane" &&
type is INamedTypeSymbol wideLane &&
wideLane.IsGenericType)
{
return HPCType.FromElementName(
wideLane.TypeArguments[0].ToDisplayString());
}
// Generic type parameter constrained to ISPMDLane<TSelf, TNumber>
if (type is ITypeParameterSymbol typeParam)
{
foreach (var constraint in typeParam.ConstraintTypes)
{
if (constraint.Name == "ISPMDLane" &&
constraint is INamedTypeSymbol namedConstraint &&
namedConstraint.IsGenericType)
{
return HPCType.FromElementName(
namedConstraint.TypeArguments[1].ToDisplayString());
}
}
// Registered from method constraints
if (_typeParamToPrimitive.TryGetValue(typeParam.Name, out var prim))
return HPCType.FromElementName(prim);
}
// Bare scalar (used when a scalar literal/variable is involved in a
// mixed scalar-vector expression)
if (type.SpecialType == SpecialType.System_Single) return HPCType.Float;
if (type.SpecialType == SpecialType.System_Double) return HPCType.Double;
if (type.SpecialType == SpecialType.System_Int32) return HPCType.Int;
if (type.SpecialType == SpecialType.System_UInt32) return HPCType.UInt;
if (type.SpecialType == SpecialType.System_Int64) return HPCType.Long;
return null;
}
/// <summary>
/// Resolves the type of the expression using the registered type-parameter
/// map as a fallback (for simple identifier references where the semantic
/// model yields a type-parameter symbol rather than a concrete type).
/// </summary>
public HPCType? ResolveByName(string typeName)
{
if (_typeParamToPrimitive.TryGetValue(typeName, out var prim))
return HPCType.FromElementName(prim);
return null;
}
/// <summary>Returns true if <paramref name="typeName"/> is a known SPMD type-param.</summary>
public bool IsSPMDTypeParam(string typeName) =>
_typeParamToPrimitive.ContainsKey(typeName);
/// <summary>Snapshot of current type-param → primitive mappings (for reference).</summary>
public IReadOnlyDictionary<string, string> TypeParamMap => _typeParamToPrimitive;
}
}

View File

@@ -0,0 +1,365 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator.Backend
{
/// <summary>
/// Emits C# source code targeting AVX2 (256-bit vectors) with the bundled
/// AVX2 extensions: FMA3, F16C, and BMI1/2.
///
/// <para>
/// Vector type: <c>Vector256&lt;T&gt;</c><br/>
/// Intrinsic classes used: <c>Avx</c>, <c>Avx2</c>, <c>Fma</c>, <c>F16C</c>,
/// <c>Bmi1</c>, <c>Bmi2</c>, <c>Vector256</c>.
/// </para>
///
/// <para>
/// Math functions with no built-in AVX2 intrinsic (Sin, Cos, Asin, etc.) are
/// delegated to the <c>AVX2Utility</c> class that is generated separately by
/// <c>AVX2UtilityGenerator</c> via <c>IVectorAPIContext</c>.
/// </para>
/// </summary>
internal sealed class AVX2Backend : IHPCBackend
{
// ── IHPCBackend ───────────────────────────────────────────────────────
public string Name => "AVX2";
public string[] RequiredUsings => new[]
{
"using System.Runtime.CompilerServices;",
"using System.Runtime.Intrinsics;",
"using System.Runtime.Intrinsics.X86;",
};
public string EmitMethod(HPCMethodIR method)
{
var sb = new StringBuilder();
var indent = " ";
// ── Signature ──────────────────────────────────────────────────────
sb.AppendLine($"{indent}[MethodImpl(MethodImplOptions.AggressiveInlining)]");
sb.Append($"{indent}public static ");
sb.Append(EmitReturnType(method.ReturnType));
sb.Append(' ');
sb.Append($"{method.OriginalName}_{Name}");
sb.Append('(');
sb.Append(EmitParameterList(method));
sb.AppendLine(")");
// ── Body ───────────────────────────────────────────────────────────
sb.AppendLine($"{indent}{{");
foreach (var stmt in method.Body)
EmitStatement(sb, stmt, indent + " ");
sb.AppendLine($"{indent}}}");
return sb.ToString();
}
// ── Statement emission ────────────────────────────────────────────────
private void EmitStatement(StringBuilder sb, HPCStmt stmt, string indent)
{
switch (stmt)
{
case HPCVarDecl decl:
sb.AppendLine($"{indent}var {decl.Name} = {EmitExpr(decl.Initializer)};");
break;
case HPCAssignment assign:
sb.AppendLine($"{indent}{assign.Target} = {EmitExpr(assign.Value)};");
break;
case HPCExprStmt expr:
sb.AppendLine($"{indent}{EmitExpr(expr.Expression)};");
break;
case HPCReturn ret:
if (ret.Value is null)
sb.AppendLine($"{indent}return;");
else
sb.AppendLine($"{indent}return {EmitExpr(ret.Value)};");
break;
case HPCIf ifStmt:
EmitIf(sb, ifStmt, indent);
break;
case HPCForLoop forLoop:
EmitForLoop(sb, forLoop, indent);
break;
}
}
private void EmitIf(StringBuilder sb, HPCIf node, string indent)
{
sb.AppendLine($"{indent}if ({EmitExpr(node.Condition)})");
sb.AppendLine($"{indent}{{");
foreach (var s in node.ThenBody)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
if (node.ElseBody is { Length: > 0 })
{
sb.AppendLine($"{indent}else");
sb.AppendLine($"{indent}{{");
foreach (var s in node.ElseBody)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
}
}
private void EmitForLoop(StringBuilder sb, HPCForLoop node, string indent)
{
var init = $"var {node.Iterator.Name} = {EmitExpr(node.Iterator.Initializer)}";
var cond = EmitExpr(node.Condition);
var incr = EmitExpr(node.Increment);
sb.AppendLine($"{indent}for ({init}; {cond}; {incr})");
sb.AppendLine($"{indent}{{");
foreach (var s in node.Body)
EmitStatement(sb, s, indent + " ");
sb.AppendLine($"{indent}}}");
}
// ── Expression emission ───────────────────────────────────────────────
private string EmitExpr(HPCExpr expr) => expr switch
{
HPCVarRef v => v.Name,
HPCLiteral l => EmitLiteral(l),
HPCBinaryOp b => EmitBinary(b),
HPCUnaryOp u => EmitUnary(u),
HPCIntrinsicCall c => EmitIntrinsic(c),
HPCPropertyAccess p => EmitPropertyAccess(p),
HPCPassThroughCall pt => pt.MethodName, // verbatim
_ => throw new NotSupportedException($"Unknown IR expr: {expr.GetType().Name}")
};
private string EmitLiteral(HPCLiteral lit)
{
// Scalar literal → broadcast to all lanes
return $"Vector256.Create({lit.Value}{LiteralSuffix(lit.Type)})";
}
private string EmitBinary(HPCBinaryOp node)
{
var l = EmitExpr(node.Left);
var r = EmitExpr(node.Right);
return node.Kind switch
{
// Floating-point arithmetic uses Avx / Avx2
HPCBinaryKind.Add when node.Type.IsFloatingPoint => $"Avx.Add({l}, {r})",
HPCBinaryKind.Subtract when node.Type.IsFloatingPoint => $"Avx.Subtract({l}, {r})",
HPCBinaryKind.Multiply when node.Type.IsFloatingPoint => $"Avx.Multiply({l}, {r})",
HPCBinaryKind.Divide when node.Type.IsFloatingPoint => $"Avx.Divide({l}, {r})",
// Integer arithmetic uses Avx2
HPCBinaryKind.Add => $"Avx2.Add({l}, {r})",
HPCBinaryKind.Subtract => $"Avx2.Subtract({l}, {r})",
HPCBinaryKind.Multiply => $"Avx2.MultiplyLow({l}, {r})", // 32-bit int multiply
// Bitwise
HPCBinaryKind.BitwiseAnd => $"Avx2.And({l}, {r})",
HPCBinaryKind.BitwiseOr => $"Avx2.Or({l}, {r})",
HPCBinaryKind.BitwiseXor => $"Avx2.Xor({l}, {r})",
HPCBinaryKind.ShiftLeft => $"Avx2.ShiftLeftLogical({l}, {r})",
HPCBinaryKind.ShiftRight => $"Avx2.ShiftRightLogical({l}, {r})",
// Comparisons — emit as AVX compare returning a mask vector
HPCBinaryKind.Equal when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedEqualNonSignaling)",
HPCBinaryKind.NotEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedNotEqualNonSignaling)",
HPCBinaryKind.LessThan when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanNonSignaling)",
HPCBinaryKind.LessThanOrEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedLessThanOrEqualNonSignaling)",
HPCBinaryKind.GreaterThan when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanNonSignaling)",
HPCBinaryKind.GreaterThanOrEqual when node.Type.IsFloatingPoint
=> $"Avx.Compare({l}, {r}, FloatComparisonMode.OrderedGreaterThanOrEqualNonSignaling)",
// Integer comparisons
HPCBinaryKind.Equal => $"Avx2.CompareEqual({l}, {r})",
HPCBinaryKind.GreaterThan => $"Avx2.CompareGreaterThan({l}, {r})",
HPCBinaryKind.LessThan => $"Avx2.CompareGreaterThan({r}, {l})", // reversed
HPCBinaryKind.GreaterThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({l}, {r}), Avx2.CompareEqual({l}, {r}))",
HPCBinaryKind.LessThanOrEqual => $"Avx2.Or(Avx2.CompareGreaterThan({r}, {l}), Avx2.CompareEqual({l}, {r}))",
HPCBinaryKind.NotEqual => $"Avx2.Xor(Avx2.CompareEqual({l}, {r}), Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
HPCBinaryKind.Modulo => throw new NotSupportedException("Modulo has no AVX2 intrinsic; consider reformulating."),
_ => throw new NotSupportedException($"Binary kind {node.Kind} not supported in AVX2 backend")
};
}
private string EmitUnary(HPCUnaryOp node)
{
var op = EmitExpr(node.Operand);
return node.Kind switch
{
HPCUnaryKind.Negate when node.Type.IsFloatingPoint
=> $"Avx.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
HPCUnaryKind.Negate
=> $"Avx2.Subtract(Vector256<{CsTypeName(node.Type)}>.Zero, {op})",
HPCUnaryKind.BitwiseNot => $"Avx2.Xor({op}, Vector256<{CsTypeName(node.Type)}>.AllBitsSet)",
// Math — delegate to Vector256 helpers (available in .NET 7+)
HPCUnaryKind.Abs when node.Type.IsFloatingPoint => $"Vector256.Abs({op})",
HPCUnaryKind.Sqrt when node.Type.IsFloatingPoint => $"Avx.Sqrt({op})",
HPCUnaryKind.Floor when node.Type.IsFloatingPoint => $"Avx.Floor({op})",
HPCUnaryKind.Ceil when node.Type.IsFloatingPoint => $"Avx.Ceiling({op})",
HPCUnaryKind.Round when node.Type.IsFloatingPoint
=> $"Avx.RoundToNearestInteger({op})",
HPCUnaryKind.Trunc when node.Type.IsFloatingPoint
=> $"Avx.RoundToZero({op})",
// Reciprocal / Rsqrt (float only; approximate 14-bit variants)
HPCUnaryKind.Rcp when node.Type.ElementTypeName == "float"
=> $"Avx.Reciprocal({op})",
HPCUnaryKind.Rsqrt when node.Type.ElementTypeName == "float"
=> $"Avx.ReciprocalSqrt({op})",
// Transcendentals — routed to AVX2Utility (generated by UtilityTemplate)
HPCUnaryKind.Sin => $"AVX2Utility.Sin_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Cos => $"AVX2Utility.Cos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Asin => $"AVX2Utility.Asin({op})",
HPCUnaryKind.Atan => $"AVX2Utility.Atan({op})",
HPCUnaryKind.Log => $"AVX2Utility.Log_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
HPCUnaryKind.Exp => $"AVX2Utility.Exp_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({op})",
// Frac = x - Floor(x)
HPCUnaryKind.Frac when node.Type.IsFloatingPoint
=> $"Avx.Subtract({op}, Avx.Floor({op}))",
// Sign: extract and normalise sign bit
HPCUnaryKind.Sign when node.Type.ElementTypeName == "float"
=> $"Avx.And(Avx.CompareEqual({op}, Vector256<float>.Zero) == Vector256<float>.Zero ? Vector256.Create(1.0f) : Vector256<float>.Zero, Avx.Or(Avx.And({op}, Vector256.Create(-0.0f)), Vector256.Create(1.0f)))",
// Saturate: clamp to [0,1]
HPCUnaryKind.Saturate when node.Type.IsFloatingPoint
=> $"Avx.Max(Avx.Min({op}, Vector256.Create(1.0{LiteralSuffix(node.Type)})), Vector256<{CsTypeName(node.Type)}>.Zero)",
_ => throw new NotSupportedException($"Unary kind {node.Kind} not supported in AVX2 backend (type: {node.Type})")
};
}
private string EmitIntrinsic(HPCIntrinsicCall node)
{
var args = node.Args;
string A(int i) => EmitExpr(args[i]);
return node.Intrinsic switch
{
// FMA — uses the FMA extension bundled with AVX2
HPCIntrinsic.MultiplyAdd => $"Fma.MultiplyAdd({A(0)}, {A(1)}, {A(2)})",
HPCIntrinsic.MultiplySubtract => $"Fma.MultiplySubtract({A(0)}, {A(1)}, {A(2)})",
// Math
HPCIntrinsic.Atan2 => $"AVX2Utility.Atan2({A(0)}, {A(1)})",
HPCIntrinsic.Pow => $"AVX2Utility.Pow({A(0)}, {A(1)})",
HPCIntrinsic.SinCos => $"AVX2Utility.SinCos_{CsTypeNameCap(node.Type)}_{MathModeSuffix()}({A(0)}, out {A(1)}, out {A(2)})",
// Compound math
HPCIntrinsic.Lerp => $"Fma.MultiplyAdd(Avx.Subtract({A(1)}, {A(0)}), {A(2)}, {A(0)})",
HPCIntrinsic.Min when node.Type.IsFloatingPoint => $"Avx.Min({A(0)}, {A(1)})",
HPCIntrinsic.Max when node.Type.IsFloatingPoint => $"Avx.Max({A(0)}, {A(1)})",
HPCIntrinsic.Min => $"Avx2.Min({A(0)}, {A(1)})",
HPCIntrinsic.Max => $"Avx2.Max({A(0)}, {A(1)})",
HPCIntrinsic.Clamp when node.Type.IsFloatingPoint => $"Avx.Max(Avx.Min({A(0)}, {A(2)}), {A(1)})",
HPCIntrinsic.Clamp => $"Avx2.Max(Avx2.Min({A(0)}, {A(2)}), {A(1)})",
// Conditional select via bitwise blend
HPCIntrinsic.Select when node.Type.IsFloatingPoint
=> $"Avx.BlendVariable({A(2)}, {A(1)}, {A(0)})",
HPCIntrinsic.Select
=> $"Avx2.BlendVariable({A(2)}.As<{CsTypeName(node.Type)}, byte>(), {A(1)}.As<{CsTypeName(node.Type)}, byte>(), {A(0)}.As<{CsTypeName(node.Type)}, byte>()).As<byte, {CsTypeName(node.Type)}>()",
// CopySign: compose exponent+mantissa from A(0) and sign from A(1)
HPCIntrinsic.CopySign when node.Type.ElementTypeName == "float"
=> $"Avx.Or(Avx.And({A(0)}, Vector256.Create(0x7FFFFFFFu).AsSingle()), Avx.And({A(1)}, Vector256.Create(0x80000000u).AsSingle()))",
// Horizontal reductions
HPCIntrinsic.ReduceAdd => $"Vector256.Sum({A(0)})",
HPCIntrinsic.ReduceMax => $"Vector256.Max({A(0)})",
HPCIntrinsic.ReduceMin => $"Vector256.Min({A(0)})",
// Memory — pointer-based (emitted as ref-to-pointer pattern)
HPCIntrinsic.Load => $"Vector256.LoadUnsafe(ref {A(0)})",
HPCIntrinsic.MaskLoad => $"Avx2.MaskLoad(ref {A(0)}, {A(1)}.AsInt32())",
HPCIntrinsic.Store => $"Vector256.StoreUnsafe({A(0)}, ref {A(1)})",
HPCIntrinsic.MaskStore => $"Avx2.MaskStore(ref {A(1)}, {A(2)}.AsInt32(), {A(0)})",
HPCIntrinsic.Gather => $"Avx2.GatherVector256(ref {A(0)}, {A(1)}, {A(2)})",
HPCIntrinsic.MaskGather => $"Avx2.GatherMaskVector256({A(0)}, ref {A(1)}, {A(2)}, {A(3)}, {A(4)})",
HPCIntrinsic.CompressStore
=> $"/* CompressStore requires AVX-512VBMI2; use scalar fallback */ {A(0)}.CompressStore(ref {A(1)}, {A(2)})",
// Conversions
HPCIntrinsic.Cast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
HPCIntrinsic.BitCast => $"{A(0)}.As<{CsTypeName(args[0].Type)}, {CsTypeName(node.Type)}>()",
_ => throw new NotSupportedException($"Intrinsic {node.Intrinsic} not implemented in AVX2 backend")
};
}
private static string EmitPropertyAccess(HPCPropertyAccess node) =>
$"{node.Target}.{node.PropertyName}";
// ── Signature helpers ─────────────────────────────────────────────────
private static string EmitReturnType(HPCType type) =>
type.ElementTypeName == "void"
? "void"
: $"Vector256<{type.ElementTypeName}>";
private static string EmitParameterList(HPCMethodIR method)
{
var parts = new System.Collections.Generic.List<string>();
foreach (var p in method.Parameters)
{
var prefix = (p.IsOut ? "out " : "") + (p.IsRef ? "ref " : "");
var typeName = p.Type.ElementTypeName == "void"
? "void"
: $"Vector256<{p.Type.ElementTypeName}>";
parts.Add($"{prefix}{typeName} {p.Name}");
}
return string.Join(", ", parts);
}
// ── Naming helpers ────────────────────────────────────────────────────
private static string CsTypeName(HPCType t) => t.ElementTypeName;
private static string CsTypeNameCap(HPCType t) => t.ElementTypeName switch
{
"float" or "Single" => "Single",
"double" or "Double" => "Double",
_ => t.ElementTypeName
};
private static string LiteralSuffix(HPCType t) => t.ElementTypeName switch
{
"float" or "Single" => "f",
"double" or "Double" => "d",
_ => ""
};
// ── Mode-aware suffix ─────────────────────────────────────────────────
// The backend does not hold state for the current method's MathMode by default;
// instead EmitMethod passes context through a field set per-call.
private MathMode _currentMode = MathMode.Standard;
public string EmitMethod(HPCMethodIR method, MathMode mode)
{
_currentMode = mode;
return EmitMethod(method);
}
private string MathModeSuffix() => _currentMode == MathMode.Fast ? "Fast" : "Standard";
}
}

View File

@@ -0,0 +1,23 @@
using Misaki.HighPerformance.HPC.Generator.IR;
namespace Misaki.HighPerformance.HPC.Generator.Backend
{
/// <summary>
/// Emits C# source code for one target instruction-set architecture from a
/// fully analysed and optimised <see cref="HPCMethodIR"/>.
/// </summary>
internal interface IHPCBackend
{
/// <summary>Short identifier used in generated file/method names (e.g. "AVX2").</summary>
string Name { get; }
/// <summary>Using-directives the emitted code requires.</summary>
string[] RequiredUsings { get; }
/// <summary>
/// Emits the full body of a specialised method and returns the C# source
/// as a string (method signature + braces included).
/// </summary>
string EmitMethod(HPCMethodIR method);
}
}

View File

@@ -0,0 +1,108 @@
using Microsoft.CodeAnalysis;
using System;
namespace Misaki.HighPerformance.HPC.Generator
{
internal enum FloatPrecision
{
Standard = 0,
High = 1,
Low = 2,
}
internal enum MathMode
{
Standard = 0,
Fast = 1,
}
[Flags]
internal enum TargetInstructionSet
{
None = 0,
SSE2 = 1 << 0,
SSE4 = 1 << 1,
AVX = 1 << 2,
AVX2 = 1 << 3,
AVX512 = 1 << 4,
}
[Generator]
public class HPComputeAttributeGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(static ctx =>
{
var source = @$"
using System;
namespace Misaki.HighPerformance.HPC
{{
public enum FloatPrecision
{{
/// <summary>
/// Compute with an accuracy of 3.5 ULPs (Units in the Last Place). This is the default precision level for floating-point operations.
/// </summary>
Standard = {(int)FloatPrecision.Standard},
/// <summary>
/// Compute with an accuracy of 1 ULP. This level may use more aggressive optimizations that can lead to faster computations but with reduced precision.
/// </summary>
High = {(int)FloatPrecision.High},
/// <summary>
/// Compute with an accuracy that equals or lower than 3.5 ULPs. This level may use the most aggressive optimizations, potentially sacrificing precision for maximum performance.
/// </summary>
Low = {(int)FloatPrecision.Low},
}}
public enum MathMode
{{
/// <summary>
/// Use the default math mode, which balances performance and accuracy. This mode may allow certain optimizations that can lead to faster computations while maintaining reasonable precision.
/// </summary>
Standard = {(int)MathMode.Standard},
/// <summary>
/// Use a fast math mode, which prioritizes performance over accuracy. This mode assumes there are no special cases (like NaNs or infinities) and may allow for more aggressive optimizations.
/// </summary>
Fast = {(int)MathMode.Fast},
}}
[Flags]
public enum TargetInstructionSet
{{
None = {(int)TargetInstructionSet.None},
/// <summary>
/// Streaming SIMD Extensions 2.
/// </summary>
SSE2 = {(int)TargetInstructionSet.SSE2},
/// <summary>
/// Streaming SIMD Extensions 4.2.
/// </summary>
SSE4 = {(int)TargetInstructionSet.SSE4},
/// <summary>
/// Advanced Vector Extensions.
/// </summary>
AVX = {(int)TargetInstructionSet.AVX},
/// <summary>
/// Advanced Vector Extensions 2. Includes FMA, F16C and BMI1/2.
/// </summary>
AVX2 = {(int)TargetInstructionSet.AVX2},
/// <summary>
/// Advanced Vector Extensions 512.
/// </summary>
AVX512 = {(int)TargetInstructionSet.AVX512},
}}
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public sealed class HPComputeAttribute : Attribute
{{
public HPComputeAttribute(TargetInstructionSet instructionSet, FloatPrecision precision = FloatPrecision.Standard, MathMode mode = MathMode.Standard)
{{
}}
}}
}}";
ctx.AddSource("HPComputeAttribute.g.cs", source);
});
}
}
}

View File

@@ -0,0 +1,177 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Misaki.HighPerformance.HPC.Generator.Analysis;
using Misaki.HighPerformance.HPC.Generator.Backend;
using Misaki.HighPerformance.HPC.Generator.Optimization;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal class HPComputeMethodInfo
{
public MethodDeclarationSyntax MethodDeclaration { get; set; } = null!;
public IMethodSymbol MethodSymbol { get; set; } = null!;
public SemanticModel SemanticModel { get; set; } = null!;
public TargetInstructionSet InstructionSet { get; set; }
public FloatPrecision Precision { get; set; }
public MathMode Mode { get; set; }
}
[Generator]
public class HPComputeGenerator : IIncrementalGenerator
{
// ── Backends (one singleton per ISA, stateless between methods) ───────
private static readonly AVX2Backend s_avx2Backend = new();
// ── Optimization passes ───────────────────────────────────────────────
/// <summary>
/// Returns the ordered list of optimisation passes for the given method.
/// Passes are cheap objects; creating them per-method is intentional so
/// future stateful passes (e.g. CSE) can be added without concurrency issues.
/// </summary>
private static IEnumerable<IHPCOptimizationPass> GetPasses(HPComputeMethodInfo info)
{
// FMA fusion only makes sense on ISAs that have it
if (info.InstructionSet.HasFlag(TargetInstructionSet.AVX2))
yield return new FMAFusionPass();
}
// ── IIncrementalGenerator ─────────────────────────────────────────────
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var methodDeclarations = context.SyntaxProvider
.ForAttributeWithMetadataName(
"Misaki.HighPerformance.HPC.HPComputeAttribute",
static (n, _) => n is MethodDeclarationSyntax,
static (ctx, _) =>
{
var attribute = ctx.Attributes.FirstOrDefault(
a => a.AttributeClass?.ToDisplayString() ==
"Misaki.HighPerformance.HPC.HPComputeAttribute");
if (attribute is null || ctx.TargetSymbol is not IMethodSymbol methodSymbol)
return null;
return new HPComputeMethodInfo
{
MethodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode,
MethodSymbol = methodSymbol,
SemanticModel = ctx.SemanticModel,
InstructionSet = (TargetInstructionSet)attribute.ConstructorArguments[0].Value!,
Precision = (FloatPrecision)attribute.ConstructorArguments[1].Value!,
Mode = (MathMode)attribute.ConstructorArguments[2].Value!,
};
})
.Collect();
context.RegisterSourceOutput(methodDeclarations, GenerateHPCMethods);
}
// ── Core pipeline ─────────────────────────────────────────────────────
private static void GenerateHPCMethods(
SourceProductionContext context,
ImmutableArray<HPComputeMethodInfo?> array)
{
if (array.IsEmpty) return;
foreach (var info in array)
{
if (info is null) continue;
try
{
GenerateSingleMethod(context, info);
}
catch (Exception ex)
{
// Surface analyzer errors as Roslyn diagnostics so the user
// sees them in the IDE rather than a silent empty output.
context.ReportDiagnostic(Diagnostic.Create(
new DiagnosticDescriptor(
id: "HPC0001",
title: "HPC code generation failed",
messageFormat: "Failed to generate HPC variant for '{0}': {1}",
category: "HPCGenerator",
defaultSeverity: DiagnosticSeverity.Error,
isEnabledByDefault: true),
info.MethodDeclaration.GetLocation(),
info.MethodDeclaration.Identifier.Text,
ex.Message));
}
}
}
private static void GenerateSingleMethod(
SourceProductionContext context,
HPComputeMethodInfo info)
{
// ── Phase 1: Analyse ──────────────────────────────────────────────
var analyzer = new HPCAnalyzer(info.SemanticModel);
var ir = analyzer.Analyze(info.MethodDeclaration, info);
// ── Phase 2: Optimise ─────────────────────────────────────────────
var optimizedIR = ir;
foreach (var pass in GetPasses(info))
optimizedIR = pass.Transform(optimizedIR);
// ── Phase 3: Emit per-target ──────────────────────────────────────
foreach (var (backend, isa) in GetBackends(info.InstructionSet))
{
var methodSource = backend.EmitMethod(optimizedIR);
var fullSource = WrapSource(methodSource, optimizedIR, backend);
context.AddSource(
hintName: $"{ir.ContainingTypeName}_{ir.OriginalName}_{backend.Name}.g.cs",
source: fullSource);
}
}
// ── Backend selection ─────────────────────────────────────────────────
private static IEnumerable<(IHPCBackend backend, TargetInstructionSet isa)>
GetBackends(TargetInstructionSet instructionSet)
{
if (instructionSet.HasFlag(TargetInstructionSet.AVX2))
yield return (s_avx2Backend, TargetInstructionSet.AVX2);
// Future: SSE4, AVX512, NEON — add here without touching anything else
}
// ── Source wrapping ───────────────────────────────────────────────────
private static string WrapSource(
string methodBody,
IR.HPCMethodIR ir,
IHPCBackend backend)
{
var sb = new StringBuilder();
sb.AppendLine("// <auto-generated/>");
sb.AppendLine("#nullable enable");
foreach (var u in backend.RequiredUsings)
sb.AppendLine(u);
sb.AppendLine("using Misaki.HighPerformance.HPC;");
sb.AppendLine();
sb.AppendLine($"namespace {ir.ContainingNamespace}");
sb.AppendLine("{");
sb.AppendLine($" partial class {ir.ContainingTypeName}");
sb.AppendLine(" {");
sb.AppendLine(methodBody);
sb.AppendLine(" }");
sb.AppendLine("}");
return sb.ToString();
}
}
}

View File

@@ -0,0 +1,181 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Linq;
namespace Misaki.HighPerformance.HPC.Generator.IR
{
/// <summary>
/// Base class for passes that transform the IR tree.
/// Uses the <em>immutable rewrite</em> pattern: each Visit method returns
/// either the original node (if nothing changed) or a new <c>with</c>-expression
/// copy, leaving the input tree untouched.
/// </summary>
internal abstract class HPCNodeRewriter
{
// ── Entry points ─────────────────────────────────────────────────────
public virtual HPCExpr RewriteExpr(HPCExpr expr) => expr switch
{
HPCBinaryOp n => RewriteBinaryOp(n),
HPCUnaryOp n => RewriteUnaryOp(n),
HPCIntrinsicCall n => RewriteIntrinsicCall(n),
HPCPassThroughCall n => RewritePassThroughCall(n),
HPCPropertyAccess n => RewritePropertyAccess(n),
HPCVarRef n => RewriteVarRef(n),
HPCLiteral n => RewriteLiteral(n),
_ => expr
};
public virtual HPCStmt RewriteStmt(HPCStmt stmt) => stmt switch
{
HPCVarDecl n => RewriteVarDecl(n),
HPCAssignment n => RewriteAssignment(n),
HPCExprStmt n => RewriteExprStmt(n),
HPCReturn n => RewriteReturn(n),
HPCIf n => RewriteIf(n),
HPCForLoop n => RewriteForLoop(n),
_ => stmt
};
// ── Expression rewrites (override to modify specific patterns) ───────
protected virtual HPCExpr RewriteBinaryOp(HPCBinaryOp node)
{
var left = RewriteExpr(node.Left);
var right = RewriteExpr(node.Right);
return ReferenceEquals(left, node.Left) && ReferenceEquals(right, node.Right)
? node
: node with { Left = left, Right = right };
}
protected virtual HPCExpr RewriteUnaryOp(HPCUnaryOp node)
{
var operand = RewriteExpr(node.Operand);
return ReferenceEquals(operand, node.Operand)
? node
: node with { Operand = operand };
}
protected virtual HPCExpr RewriteIntrinsicCall(HPCIntrinsicCall node)
{
var args = RewriteArgs(node.Args);
return ReferenceEquals(args, node.Args)
? node
: node with { Args = args };
}
protected virtual HPCExpr RewritePassThroughCall(HPCPassThroughCall node)
{
var args = RewriteArgs(node.Args);
return ReferenceEquals(args, node.Args)
? node
: node with { Args = args };
}
protected virtual HPCExpr RewritePropertyAccess(HPCPropertyAccess node)
{
var target = RewriteExpr(node.Target);
return ReferenceEquals(target, node.Target)
? node
: node with { Target = target };
}
protected virtual HPCExpr RewriteVarRef(HPCVarRef node) => node;
protected virtual HPCExpr RewriteLiteral(HPCLiteral node) => node;
// ── Statement rewrites ───────────────────────────────────────────────
protected virtual HPCStmt RewriteVarDecl(HPCVarDecl node)
{
var init = RewriteExpr(node.Initializer);
return ReferenceEquals(init, node.Initializer)
? node
: node with { Initializer = init };
}
protected virtual HPCStmt RewriteAssignment(HPCAssignment node)
{
var value = RewriteExpr(node.Value);
return ReferenceEquals(value, node.Value)
? node
: node with { Value = value };
}
protected virtual HPCStmt RewriteExprStmt(HPCExprStmt node)
{
var expr = RewriteExpr(node.Expression);
return ReferenceEquals(expr, node.Expression)
? node
: node with { Expression = expr };
}
protected virtual HPCStmt RewriteReturn(HPCReturn node)
{
if (node.Value is null) return node;
var value = RewriteExpr(node.Value);
return ReferenceEquals(value, node.Value)
? node
: node with { Value = value };
}
protected virtual HPCStmt RewriteIf(HPCIf node)
{
var cond = RewriteExpr(node.Condition);
var thenBody = RewriteBody(node.ThenBody);
var elseBody = node.ElseBody is null ? null : RewriteBody(node.ElseBody);
return ReferenceEquals(cond, node.Condition) &&
ReferenceEquals(thenBody, node.ThenBody) &&
ReferenceEquals(elseBody, node.ElseBody)
? node
: node with { Condition = cond, ThenBody = thenBody, ElseBody = elseBody };
}
protected virtual HPCStmt RewriteForLoop(HPCForLoop node)
{
var iter = (HPCVarDecl)RewriteVarDecl(node.Iterator);
var cond = RewriteExpr(node.Condition);
var incr = RewriteExpr(node.Increment);
var body = RewriteBody(node.Body);
return ReferenceEquals(iter, node.Iterator) &&
ReferenceEquals(cond, node.Condition) &&
ReferenceEquals(incr, node.Increment) &&
ReferenceEquals(body, node.Body)
? node
: node with { Iterator = iter, Condition = cond, Increment = incr, Body = body };
}
// ── Helpers ──────────────────────────────────────────────────────────
protected HPCStmt[] RewriteBody(HPCStmt[] body)
{
HPCStmt[]? result = null;
for (int i = 0; i < body.Length; i++)
{
var original = body[i];
var rewritten = RewriteStmt(original);
if (!ReferenceEquals(original, rewritten))
{
result ??= body.ToArray();
result[i] = rewritten;
}
}
return result ?? body;
}
private HPCExpr[] RewriteArgs(HPCExpr[] args)
{
HPCExpr[]? result = null;
for (int i = 0; i < args.Length; i++)
{
var original = args[i];
var rewritten = RewriteExpr(original);
if (!ReferenceEquals(original, rewritten))
{
result ??= args.ToArray();
result[i] = rewritten;
}
}
return result ?? args;
}
}
}

View File

@@ -0,0 +1,205 @@
using System;
namespace Misaki.HighPerformance.HPC.Generator.IR
{
// ─────────────────────────────────────────────────────────────────────────
// Type system
// ─────────────────────────────────────────────────────────────────────────
/// <summary>
/// A resolved primitive type inside the HPC IR (always the element scalar type).
/// E.g. "float", "double", "int". Never a vector type name — the IR is
/// element-typecentric; the backend decides the concrete vector width.
/// </summary>
internal sealed record HPCType(string ElementTypeName, bool IsFloatingPoint)
{
// Convenience singletons for the most common cases
public static readonly HPCType Float = new("float", IsFloatingPoint: true);
public static readonly HPCType Double = new("double", IsFloatingPoint: true);
public static readonly HPCType Int = new("int", IsFloatingPoint: false);
public static readonly HPCType UInt = new("uint", IsFloatingPoint: false);
public static readonly HPCType Long = new("long", IsFloatingPoint: false);
public static HPCType FromElementName(string name) => name switch
{
"float" or "Single" => Float,
"double" or "Double" => Double,
"int" or "Int32" => Int,
"uint" or "UInt32" => UInt,
"long" or "Int64" => Long,
_ => new(name, IsFloatingPoint: false)
};
public override string ToString() => ElementTypeName;
}
// ─────────────────────────────────────────────────────────────────────────
// Expression nodes
// ─────────────────────────────────────────────────────────────────────────
/// <summary>Base type for all HPC IR expression nodes.</summary>
internal abstract record HPCExpr(HPCType Type);
/// <summary>A reference to a local variable or parameter by name.</summary>
internal sealed record HPCVarRef(string Name, HPCType Type) : HPCExpr(Type)
{
public override string ToString() => Name;
}
/// <summary>A scalar literal value broadcast to all lanes.</summary>
internal sealed record HPCLiteral(string Value, HPCType Type) : HPCExpr(Type)
{
public override string ToString() => Value;
}
// ── Binary ops ───────────────────────────────────────────────────────────
internal enum HPCBinaryKind
{
Add, Subtract, Multiply, Divide, Modulo,
BitwiseAnd, BitwiseOr, BitwiseXor, ShiftLeft, ShiftRight,
Equal, NotEqual,
LessThan, LessThanOrEqual,
GreaterThan, GreaterThanOrEqual,
}
internal sealed record HPCBinaryOp(
HPCBinaryKind Kind,
HPCExpr Left,
HPCExpr Right,
HPCType Type) : HPCExpr(Type);
// ── Unary ops ────────────────────────────────────────────────────────────
internal enum HPCUnaryKind
{
Negate, BitwiseNot,
// Math functions that map to a single ISPMDLane static method
Abs, Sqrt, Floor, Ceil, Round, Trunc, Frac, Sign, Saturate,
Rcp, Rsqrt,
Sin, Cos, Tan, Asin, Acos, Atan,
Exp, Exp2, Log, Log2,
}
internal sealed record HPCUnaryOp(
HPCUnaryKind Kind,
HPCExpr Operand,
HPCType Type) : HPCExpr(Type);
// ── Intrinsic calls (multi-argument) ────────────────────────────────────
internal enum HPCIntrinsic
{
// Arithmetic
MultiplyAdd, // a * b + c (FMA)
MultiplySubtract, // a * b - c
// Math
Atan2, Pow,
SinCos, // out sin, out cos simultaneously
Lerp, Min, Max, Clamp, Select, CopySign,
// Reduction
ReduceAdd, ReduceMax, ReduceMin,
// Memory
Load, Store,
MaskLoad, MaskStore,
Gather, MaskGather,
Scatter, MaskScatter,
CompressStore,
// Conversion
Cast, BitCast,
}
internal sealed record HPCIntrinsicCall(
HPCIntrinsic Intrinsic,
HPCExpr[] Args,
HPCType Type) : HPCExpr(Type);
/// <summary>
/// A method call that the HPC pipeline does not recognise as an intrinsic —
/// emitted verbatim so user code can still call arbitrary helpers.
/// </summary>
internal sealed record HPCPassThroughCall(
string MethodName,
HPCExpr[] Args,
HPCType Type) : HPCExpr(Type);
/// <summary>Member-property access, e.g. <c>lane.LaneWidth</c> → <c>Count</c>.</summary>
internal sealed record HPCPropertyAccess(
HPCExpr Target,
string PropertyName,
HPCType Type) : HPCExpr(Type);
// ─────────────────────────────────────────────────────────────────────────
// Statement nodes
// ─────────────────────────────────────────────────────────────────────────
/// <summary>Base type for all HPC IR statement nodes.</summary>
internal abstract record HPCStmt;
/// <summary>Local variable declaration with mandatory initialiser.</summary>
internal sealed record HPCVarDecl(
string Name,
HPCType Type,
HPCExpr Initializer) : HPCStmt;
/// <summary>Assignment to an existing variable (including out-parameters).</summary>
internal sealed record HPCAssignment(string Target, HPCExpr Value) : HPCStmt;
/// <summary>Expression evaluated purely for its side-effect (e.g. Store calls).</summary>
internal sealed record HPCExprStmt(HPCExpr Expression) : HPCStmt;
/// <summary>Return statement.</summary>
internal sealed record HPCReturn(HPCExpr? Value) : HPCStmt;
/// <summary>
/// Conditional branch. In a vectorised context the backend may lower this
/// to a predicated Select or a masked execution block.
/// </summary>
internal sealed record HPCIf(
HPCExpr Condition,
HPCStmt[] ThenBody,
HPCStmt[]? ElseBody) : HPCStmt;
/// <summary>Simple counted for-loop.</summary>
internal sealed record HPCForLoop(
HPCVarDecl Iterator,
HPCExpr Condition,
HPCExpr Increment,
HPCStmt[] Body) : HPCStmt;
// ─────────────────────────────────────────────────────────────────────────
// Method-level IR
// ─────────────────────────────────────────────────────────────────────────
internal sealed record HPCParameter(
string Name,
HPCType Type,
bool IsOut,
bool IsRef);
/// <summary>
/// The complete IR representation of a single <c>[HPCompute]</c>-annotated method,
/// fully detached from Roslyn syntax after the analysis phase.
/// </summary>
internal sealed record HPCMethodIR
{
public required string Name { get; init; }
public required HPCType ReturnType { get; init; }
public required HPCParameter[] Parameters { get; init; }
public required HPCStmt[] Body { get; init; }
// Compilation metadata from the attribute
public required TargetInstructionSet TargetISA { get; init; }
public required FloatPrecision Precision { get; init; }
public required MathMode Mode { get; init; }
// Emission metadata
public required string ContainingNamespace { get; init; }
public required string ContainingTypeName { get; init; }
public required string OriginalName { get; init; }
}
}

View File

@@ -0,0 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable>
<EnforceExtendedAnalyzerRules>True</EnforceExtendedAnalyzerRules>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<LangVersion>latest</LangVersion>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="5.3.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="5.3.0" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,65 @@
using Misaki.HighPerformance.HPC.Generator.IR;
using System.Linq;
namespace Misaki.HighPerformance.HPC.Generator.Optimization
{
/// <summary>
/// Detects <c>(a * b) + c</c> and <c>c + (a * b)</c> binary expression patterns
/// in the IR and replaces them with <see cref="HPCIntrinsic.MultiplyAdd"/> calls,
/// enabling the backend to emit a single FMA instruction instead of two.
///
/// <para>This pass only runs when the target ISA supports FMA (e.g. AVX2 which
/// includes FMA3 via the FMA extension, and AVX-512F). The pipeline checks
/// <see cref="HPCMethodIR.TargetISA"/> before adding the pass.</para>
/// </summary>
internal sealed class FMAFusionPass : HPCNodeRewriter, IHPCOptimizationPass
{
public string Name => "FMA Fusion";
public HPCMethodIR Transform(HPCMethodIR method)
{
var newBody = RewriteBody(method.Body);
return ReferenceEquals(newBody, method.Body)
? method
: method with { Body = newBody };
}
protected override HPCExpr RewriteBinaryOp(HPCBinaryOp node)
{
if (node.Kind == HPCBinaryKind.Add)
{
// (a * b) + c → FMA(a, b, c)
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplyAdd,
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
node.Type);
}
// c + (a * b) → FMA(a, b, c)
if (node.Right is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulRight)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplyAdd,
[RewriteExpr(mulRight.Left), RewriteExpr(mulRight.Right), RewriteExpr(node.Left)],
node.Type);
}
}
if (node.Kind == HPCBinaryKind.Subtract)
{
// (a * b) - c → FMA(a, b, -c) ... or MultiplySubtract if available
if (node.Left is HPCBinaryOp { Kind: HPCBinaryKind.Multiply } mulLeft)
{
return new HPCIntrinsicCall(
HPCIntrinsic.MultiplySubtract,
[RewriteExpr(mulLeft.Left), RewriteExpr(mulLeft.Right), RewriteExpr(node.Right)],
node.Type);
}
}
return base.RewriteBinaryOp(node);
}
}
}

View File

@@ -0,0 +1,20 @@
using Misaki.HighPerformance.HPC.Generator.IR;
namespace Misaki.HighPerformance.HPC.Generator.Optimization
{
/// <summary>
/// A single optimization pass that transforms an <see cref="HPCMethodIR"/>
/// into another <see cref="HPCMethodIR"/>. Passes must be pure functions —
/// they must not mutate the input and should return the original object
/// unchanged when nothing was modified (enabling cheap "no-change" detection
/// in the pipeline).
/// </summary>
internal interface IHPCOptimizationPass
{
/// <summary>Human-readable name for diagnostics.</summary>
string Name { get; }
/// <summary>Transforms the IR. Returns the same object if unchanged.</summary>
HPCMethodIR Transform(HPCMethodIR method);
}
}

View File

@@ -0,0 +1,31 @@
// Polyfills needed for modern C# features on netstandard2.0 target.
// These types are built into .NET 5+ runtimes but must be declared manually
// when targeting netstandard2.0, which source generators must do.
namespace System.Runtime.CompilerServices
{
// Enables `init` accessors and `record` types (C# 9)
internal static class IsExternalInit { }
// Enables `required` members (C# 11)
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct |
AttributeTargets.Field | AttributeTargets.Property,
AllowMultiple = false, Inherited = false)]
internal sealed class RequiredMemberAttribute : Attribute { }
[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
internal sealed class CompilerFeatureRequiredAttribute : Attribute
{
public CompilerFeatureRequiredAttribute(string featureName)
=> FeatureName = featureName;
public string FeatureName { get; }
public bool IsOptional { get; init; }
}
}
namespace System.Diagnostics.CodeAnalysis
{
// Enables `required` constructor attribution (C# 11)
[AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)]
internal sealed class SetsRequiredMembersAttribute : Attribute { }
}

View File

@@ -0,0 +1,233 @@
using Misaki.HighPerformance.HPC.Generator.APIContext;
using System.Text;
namespace Misaki.HighPerformance.HPC.Generator
{
internal static class UtilityTemplate
{
public static Method Sin_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call($"{api.GetVectorType()}.Sin", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method Sin_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var y_sin = api.Multiply(x_sin, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_sin = (api.Round(k_sin * half) * two).Assign();
var sign_sin = (api.One<T>() - two * api.Abs(k_sin - k_even_sin)).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = (z_sin * z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var body = api.Return(poly_sin * sign_sin);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Sin_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method Cos_Standard<T>(IVectorAPIContext api)
{
var body = api.Return(api.Call($"{api.GetVectorType()}.Cos", "value"));
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value" },
body: body);
}
public static Method Cos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_cos = api.Add(input, halfPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(api.One<T>(), api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
var body = api.Return(poly_cos * sign_cos);
return new Method(
modifier: "public static",
returnType: api.GetVectorType<T>(),
name: $"Cos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}" },
body: body);
}
public static Method SinCos_Standard<T>(IVectorAPIContext api)
{
var sin_cos = api.Return(api.Call($"{api.GetVectorType()}.SinCos", "value"));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Standard",
parameters: new[] { $"{api.GetVectorType<T>()} value", $"out {api.GetVectorType<T>()} sin", $"out {api.GetVectorType<T>()} cos" },
body: sin_cos);
}
public static Method SinCos_Fast<T>(IVectorAPIContext api)
{
var isFloat = typeof(T) == typeof(float);
var typePrefix = isFloat ? "f" : "d";
var input = new Expression(api, "value");
var sinOut = new Expression(api, "sin");
var cosOut = new Expression(api, "cos");
var halfPi = api.Create($"1.570796327{typePrefix}").Assign();
var invPi = api.Create($"0.318309886{typePrefix}").Assign();
var x_sin = input;
var x_cos = api.Add(x_sin, halfPi).Assign();
var y_sin = api.Multiply(x_sin, invPi).Assign();
var y_cos = api.Multiply(x_cos, invPi).Assign();
var k_sin = api.Round(y_sin).Assign();
var k_cos = api.Round(y_cos).Assign();
var z_sin = api.Subtract(y_sin, k_sin).Assign();
var z_cos = api.Subtract(y_cos, k_cos).Assign();
var half = api.Create($"0.5{typePrefix}").Assign();
var two = api.Create($"2.0{typePrefix}").Assign();
var one = api.One<T>();
var k_even_sin = api.Multiply(api.Round(api.Multiply(k_sin, half)), two).Assign();
var sign_sin = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_sin, k_even_sin)))).Assign();
var k_even_cos = api.Multiply(api.Round(api.Multiply(k_cos, half)), two).Assign();
var sign_cos = api.Subtract(one, api.Multiply(two, api.Abs(api.Subtract(k_cos, k_even_cos)))).Assign();
var c1 = api.Create($"3.14159265{typePrefix}").Assign();
var c3 = api.Create($"-5.16771278{typePrefix}").Assign();
var c5 = api.Create($"2.55016404{typePrefix}").Assign();
var c7 = api.Create($"-0.59926453{typePrefix}").Assign();
var c9 = api.Create($"0.08214589{typePrefix}").Assign();
var z2_sin = api.Multiply(z_sin, z_sin).Assign();
var poly_sin = api.MultiplyAdd(z2_sin, c9, c7).Assign();
var poly_sin_name = api.LastAssignedVariable;
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c5).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c3).Assign(poly_sin_name, false);
poly_sin = api.MultiplyAdd(z2_sin, poly_sin, c1).Assign(poly_sin_name, false);
poly_sin = api.Multiply(z_sin, poly_sin).Assign(poly_sin_name, false);
var z2_cos = api.Multiply(z_cos, z_cos).Assign();
var poly_cos = api.MultiplyAdd(z2_cos, c9, c7).Assign();
var poly_cos_name = api.LastAssignedVariable;
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c5).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c3).Assign(poly_cos_name, false);
poly_cos = api.MultiplyAdd(z2_cos, poly_cos, c1).Assign(poly_cos_name, false);
poly_cos = api.Multiply(z_cos, poly_cos).Assign(poly_cos_name, false);
sinOut = api.Multiply(poly_sin, sign_sin).Assign(sinOut.Code, false);
cosOut = api.Multiply(poly_cos, sign_cos).Assign(cosOut.Code, false);
var body = api.Return(api.Create(""));
return new Method(
modifier: "public static",
returnType: "void",
name: $"SinCos_{typeof(T).Name}_Fast",
parameters: new[] { $"{api.GetVectorType<T>()} {input.Code}", $"out {api.GetVectorType<T>()} {sinOut.Code}", $"out {api.GetVectorType<T>()} {cosOut.Code}" },
body: body);
}
public static string GenerateSinCosUtilityMethods(IVectorAPIContext api, string identation)
{
var methods = new Method[]
{
Sin_Standard<float>(api),
Sin_Fast<float>(api),
Cos_Standard<float>(api),
Cos_Fast<float>(api),
SinCos_Standard<float>(api),
SinCos_Fast<float>(api),
Sin_Standard<double>(api),
Sin_Fast<double>(api),
Cos_Standard<double>(api),
Cos_Fast<double>(api),
SinCos_Standard<double>(api),
SinCos_Fast<double>(api)
};
var sb = new StringBuilder();
var inlineAttr = identation + "[MethodImpl(MethodImplOptions.AggressiveInlining)]";
foreach (var method in methods)
{
sb.AppendLine(inlineAttr);
sb.AppendLine(method.GetFullCode(identation));
}
return sb.ToString();
}
}
}

View File

@@ -1,7 +1,7 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
/// <summary> /// <summary>
/// Common marker interface for SPMD lane types. /// Common marker interface for SPMD lane types.
@@ -100,7 +100,7 @@ public unsafe interface ISPMDLane<TSelf, TNumber> : ISPMDLane, IEquatable<TSelf>
/// <param name="step">The step value for the sequence.</param> /// <param name="step">The step value for the sequence.</param>
/// <returns>The lane value containing the arithmetic sequence.</returns> /// <returns>The lane value containing the arithmetic sequence.</returns>
/// <remarks> /// <remarks>
/// Implementations may rely on vector creation helpers and assume that the resulting sequence length matches <see cref="ISPMDLane.LaneWidth"/>. /// Implementations may rely on vector creation helpers and assume that the resulting sequence length matches <see cref="LaneWidth"/>.
/// </remarks> /// </remarks>
static abstract TSelf Sequence(TNumber start, TNumber step); static abstract TSelf Sequence(TNumber start, TNumber step);
/// <summary> /// <summary>
@@ -217,70 +217,17 @@ public unsafe interface ISPMDLane<TSelf, TNumber> : ISPMDLane, IEquatable<TSelf>
/// Implementations may use hardware-specific shuffle tables to reorder the selected lanes before storing, falling back to a scalar loop otherwise. /// Implementations may use hardware-specific shuffle tables to reorder the selected lanes before storing, falling back to a scalar loop otherwise.
/// </remarks> /// </remarks>
int CompressStore(TNumber* pDestination, TSelf mask); int CompressStore(TNumber* pDestination, TSelf mask);
/// <summary>
/// Masks the lane value with the specified mask and stores the result to the given reference, where masked lanes are stored and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="pDestination">A pointer to the variable where the masked data will be stored.</param>
/// <param name="mask">A mask value that determines which elements are included in the masking operation.</param>
void MaskStore(TNumber* pDestination, TSelf mask); void MaskStore(TNumber* pDestination, TSelf mask);
/// <summary>
/// Masks the lane value with the specified mask and stores the result to the given reference, where masked lanes are stored and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="destination">A reference to the variable where the masked data will be stored.</param>
/// <param name="mask">A mask value that determines which elements are included in the masking operation.</param>
void MaskStore(ref TNumber destination, TSelf mask); void MaskStore(ref TNumber destination, TSelf mask);
/// <summary>
/// Scatters the lane value to the specified base address and indices, where each lane is stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address.
/// </summary>
/// <param name="pDst">A pointer to the base address where the data will be scattered.</param>
/// <param name="indices">A vector of indices that determine the destinations of each lane.</param>
void Scatter(TNumber* pDst, TSelf indices); void Scatter(TNumber* pDst, TSelf indices);
/// <summary>
/// Scatters the lane value to the specified base address and indices, where each lane is stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address.
/// </summary>
/// <param name="destination">A reference to the variable where the scattered data will be stored.</param>
/// <param name="indices">A vector of indices that determine the destinations of each lane.</param>
void Scatter(ref TNumber destination, TSelf indices); void Scatter(ref TNumber destination, TSelf indices);
/// <summary>
/// Scatters the lane value to the specified base address and indices, where each lane is stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address.
/// </summary>
/// <param name="pDst">A pointer to the base address where the data will be scattered.</param>
/// <param name="pIndices">A pointer to the array of indices that determine the destinations of each lane.</param>
void Scatter(TNumber* pDst, int* pIndices); void Scatter(TNumber* pDst, int* pIndices);
/// <summary>
/// Scatters the lane value to the specified base address and indices, where each lane is stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address.
/// </summary>
/// <param name="destination">A reference to the variable where the scattered data will be stored.</param>
/// <param name="pIndices">A pointer to the array of indices that determine the destinations of each lane.</param>
void Scatter(ref TNumber destination, int* pIndices); void Scatter(ref TNumber destination, int* pIndices);
/// <summary>
/// Masks the lane value with the specified mask and scatters the result to the given base address and indices, where masked lanes are stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address, and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="pDst">A pointer to the base address where the data will be scattered.</param>
/// <param name="indices">A vector of indices that determine the destinations of each lane.</param>
/// <param name="mask">A vector of boolean values that determine which lanes to scatter.</param>
void MaskScatter(TNumber* pDst, TSelf indices, TSelf mask); void MaskScatter(TNumber* pDst, TSelf indices, TSelf mask);
/// <summary>
/// Masks the lane value with the specified mask and scatters the result to the given base address and indices, where masked lanes are stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address, and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="destination">A reference to the variable where the scattered data will be stored.</param>
/// <param name="indices">A vector of indices that determine the destinations of each lane.</param>
/// <param name="mask">A vector of boolean values that determine which lanes to scatter.</param>
void MaskScatter(ref TNumber destination, TSelf indices, TSelf mask); void MaskScatter(ref TNumber destination, TSelf indices, TSelf mask);
/// <summary>
/// Masks the lane value with the specified mask and scatters the result to the given base address and indices, where masked lanes are stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address, and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="pDst">A pointer to the base address where the data will be scattered.</param>
/// <param name="pIndices">A pointer to the array of indices that determine the destinations of each lane.</param>
/// <param name="mask">A vector of boolean values that determine which lanes to scatter.</param>
void MaskScatter(TNumber* pDst, int* pIndices, TSelf mask); void MaskScatter(TNumber* pDst, int* pIndices, TSelf mask);
/// <summary>
/// Masks the lane value with the specified mask and scatters the result to the given base address and indices, where masked lanes are stored to the address computed by adding the corresponding index (multiplied by the scale) to the base address, and unmasked lanes are left unchanged in the destination.
/// </summary>
/// <param name="destination">A reference to the variable where the scattered data will be stored.</param>
/// <param name="pIndices">A pointer to the array of indices that determine the destinations of each lane.</param>
/// <param name="mask">A vector of boolean values that determine which lanes to scatter.</param>
void MaskScatter(ref TNumber destination, int* pIndices, TSelf mask); void MaskScatter(ref TNumber destination, int* pIndices, TSelf mask);
/// <summary> /// <summary>
@@ -469,14 +416,14 @@ public unsafe interface ISPMDLane<TSelf, TNumber> : ISPMDLane, IEquatable<TSelf>
/// <summary> /// <summary>
/// Computes a * b + c element-wise. /// Computes a * b + c element-wise.
/// </summary> /// </summary>
/// <param name="a">The first multiplier.</param> /// <param name="left">The first multiplier.</param>
/// <param name="b">The second multiplier.</param> /// <param name="right">The second multiplier.</param>
/// <param name="c">The addend.</param> /// <param name="addend">The addend.</param>
/// <returns>The result of the fused multiply-add operation.</returns> /// <returns>The result of the fused multiply-add operation.</returns>
/// <remarks> /// <remarks>
/// Float and double implementations should use fused multiply-add instructions when available for both accuracy and performance. /// Float and double implementations should use fused multiply-add instructions when available for both accuracy and performance.
/// </remarks> /// </remarks>
static abstract TSelf MultiplyAdd(TSelf a, TSelf b, TSelf c); static abstract TSelf MultiplyAdd(TSelf left, TSelf right, TSelf addend);
/// <summary> /// <summary>
/// Returns the minimum of the two lane values element-wise. /// Returns the minimum of the two lane values element-wise.
/// </summary> /// </summary>
@@ -644,6 +591,9 @@ public unsafe interface ISPMDLane<TSelf, TNumber> : ISPMDLane, IEquatable<TSelf>
/// </summary> /// </summary>
/// <param name="x">The lane value.</param> /// <param name="x">The lane value.</param>
/// <returns>The truncated lane value.</returns> /// <returns>The truncated lane value.</returns>
/// <remarks>
/// Floating-point truncation typically maps to <see cref="Vector.Truncate(Vector{TNumber})"/>.
/// </remarks>
static abstract TSelf Trunc(TSelf value); static abstract TSelf Trunc(TSelf value);
/// <summary> /// <summary>
/// Returns the sign of each lane element. /// Returns the sign of each lane element.

View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<TargetFramework>net10.0</TargetFramework> <TargetFramework>net10.0</TargetFramework>
@@ -7,7 +7,7 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<Authors>Misaki</Authors> <Authors>Misaki</Authors>
<AssemblyVersion>1.3.8</AssemblyVersion> <AssemblyVersion>1.3.7</AssemblyVersion>
<Version>$(AssemblyVersion)</Version> <Version>$(AssemblyVersion)</Version>
<PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl> <PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl>
<RepositoryUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</RepositoryUrl> <RepositoryUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</RepositoryUrl>
@@ -18,7 +18,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<None Include="../../LICENSE" Pack="true" PackagePath=""/> <None Include="../LICENSE" Pack="true" PackagePath=""/>
<None Include="README.md" Pack="true" PackagePath=""/> <None Include="README.md" Pack="true" PackagePath=""/>
</ItemGroup> </ItemGroup>
@@ -35,7 +35,7 @@
<ItemGroup> <ItemGroup>
<Content Include="**\*.cs" Exclude="obj\**;bin\**"> <Content Include="**\*.cs" Exclude="obj\**;bin\**">
<Pack>true</Pack> <Pack>true</Pack>
<PackagePath>contentFiles\cs\any\Misaki.HighPerformance.Mathematics.SPMD\</PackagePath> <PackagePath>contentFiles\cs\any\Misaki.HighPerformance.HPC\</PackagePath>
<PackageCopyToOutput>false</PackageCopyToOutput> <PackageCopyToOutput>false</PackageCopyToOutput>
<BuildAction>Compile</BuildAction> <BuildAction>Compile</BuildAction>
</Content> </Content>

View File

@@ -2,7 +2,7 @@ using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
internal static unsafe class SPMDUtility internal static unsafe class SPMDUtility
{ {

View File

@@ -3,7 +3,7 @@ using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct ScalarLane<TNumber> : ISPMDLane<ScalarLane<TNumber>, TNumber> public readonly unsafe struct ScalarLane<TNumber> : ISPMDLane<ScalarLane<TNumber>, TNumber>

View File

@@ -1,6 +1,6 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
internal static unsafe class ShuffleTableGenerator internal static unsafe class ShuffleTableGenerator
{ {

View File

@@ -1,7 +1,7 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
/// <summary> /// <summary>
/// A job interface for Single Program Multiple Data (SPMD) execution, allowing for efficient parallel processing of data across multiple lanes. /// A job interface for Single Program Multiple Data (SPMD) execution, allowing for efficient parallel processing of data across multiple lanes.
@@ -498,6 +498,41 @@ internal struct SPMDScalerJobWrapper<T, TNumber0, TNumber1, TNumber2, TNumber3,
public static class IJobParallelForSPMDExtensions public static class IJobParallelForSPMDExtensions
{ {
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -539,6 +574,43 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -582,6 +654,45 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -627,6 +738,47 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -674,6 +826,49 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -723,6 +918,51 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -774,6 +1014,53 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber6">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
where TNumber6 : unmanaged, INumber<TNumber6>, IBinaryNumber<TNumber6>, IMinMaxValue<TNumber6>, IBitwiseOperators<TNumber6, TNumber6, TNumber6>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>, WideLane<TNumber6>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>, ScalarLane<TNumber6>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -827,6 +1114,55 @@ public static class IJobParallelForSPMDExtensions
} }
} }
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber6">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber7">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6, TNumber7>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6, TNumber7>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
where TNumber6 : unmanaged, INumber<TNumber6>, IBinaryNumber<TNumber6>, IMinMaxValue<TNumber6>, IBitwiseOperators<TNumber6, TNumber6, TNumber6>
where TNumber7 : unmanaged, INumber<TNumber7>, IBinaryNumber<TNumber7>, IMinMaxValue<TNumber7>, IBitwiseOperators<TNumber7, TNumber7, TNumber7>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>, WideLane<TNumber6>, WideLane<TNumber7>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>, ScalarLane<TNumber6>, ScalarLane<TNumber7>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -884,343 +1220,3 @@ public static class IJobParallelForSPMDExtensions
} }
public static class JobSPMDUtility
{
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber6">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
where TNumber6 : unmanaged, INumber<TNumber6>, IBinaryNumber<TNumber6>, IMinMaxValue<TNumber6>, IBitwiseOperators<TNumber6, TNumber6, TNumber6>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>, WideLane<TNumber6>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>, ScalarLane<TNumber6>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
/// <typeparam name="TNumber0">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber1">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber2">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber3">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber4">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber5">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber6">The first numeric type used in the SPMD job.</typeparam>
/// <typeparam name="TNumber7">The first numeric type used in the SPMD job.</typeparam>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6, TNumber7>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<TNumber0, TNumber1, TNumber2, TNumber3, TNumber4, TNumber5, TNumber6, TNumber7>
where TNumber0 : unmanaged, INumber<TNumber0>, IBinaryNumber<TNumber0>, IMinMaxValue<TNumber0>, IBitwiseOperators<TNumber0, TNumber0, TNumber0>
where TNumber1 : unmanaged, INumber<TNumber1>, IBinaryNumber<TNumber1>, IMinMaxValue<TNumber1>, IBitwiseOperators<TNumber1, TNumber1, TNumber1>
where TNumber2 : unmanaged, INumber<TNumber2>, IBinaryNumber<TNumber2>, IMinMaxValue<TNumber2>, IBitwiseOperators<TNumber2, TNumber2, TNumber2>
where TNumber3 : unmanaged, INumber<TNumber3>, IBinaryNumber<TNumber3>, IMinMaxValue<TNumber3>, IBitwiseOperators<TNumber3, TNumber3, TNumber3>
where TNumber4 : unmanaged, INumber<TNumber4>, IBinaryNumber<TNumber4>, IMinMaxValue<TNumber4>, IBitwiseOperators<TNumber4, TNumber4, TNumber4>
where TNumber5 : unmanaged, INumber<TNumber5>, IBinaryNumber<TNumber5>, IMinMaxValue<TNumber5>, IBitwiseOperators<TNumber5, TNumber5, TNumber5>
where TNumber6 : unmanaged, INumber<TNumber6>, IBinaryNumber<TNumber6>, IMinMaxValue<TNumber6>, IBitwiseOperators<TNumber6, TNumber6, TNumber6>
where TNumber7 : unmanaged, INumber<TNumber7>, IBinaryNumber<TNumber7>, IMinMaxValue<TNumber7>, IBitwiseOperators<TNumber7, TNumber7, TNumber7>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<WideLane<TNumber0>, WideLane<TNumber1>, WideLane<TNumber2>, WideLane<TNumber3>, WideLane<TNumber4>, WideLane<TNumber5>, WideLane<TNumber6>, WideLane<TNumber7>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<ScalarLane<TNumber0>, ScalarLane<TNumber1>, ScalarLane<TNumber2>, ScalarLane<TNumber3>, ScalarLane<TNumber4>, ScalarLane<TNumber5>, ScalarLane<TNumber6>, ScalarLane<TNumber7>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
}

View File

@@ -7,14 +7,14 @@
using Misaki.HighPerformance.Jobs; using Misaki.HighPerformance.Jobs;
using System.Numerics; using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
const string TLane = "TLane"; const string TLane = "TLane";
const string TNumber = "TNumber"; const string TNumber = "TNumber";
const string GenericParameters = $"{TLane}, {TNumber}"; const string GenericParameters = $"{TLane}, {TNumber}";
var TLaneRestrictions = $@"where {TLane} : unmanaged, ISPMDLane<{TLane}, {TNumber}>"; var TLaneRestrictions = $@"where {TLane} : ISPMDLane<{TLane}, {TNumber}>";
var TNumberRestrictions = $@"where {TNumber} : unmanaged, INumber<{TNumber}>, IBinaryNumber<{TNumber}>, IMinMaxValue<{TNumber}>, IBitwiseOperators<{TNumber}, {TNumber}, {TNumber}>"; var TNumberRestrictions = $@"where {TNumber} : unmanaged, INumber<{TNumber}>, IBinaryNumber<{TNumber}>, IMinMaxValue<{TNumber}>, IBitwiseOperators<{TNumber}, {TNumber}, {TNumber}>";
for (var i = 0; i < 8; i++) { #> for (var i = 0; i < 8; i++) { #>
@@ -67,6 +67,41 @@ internal struct SPMDScalerJobWrapper<T, <#= ForEachDimension(i + 1, j => $"TNumb
public static class IJobParallelForSPMDExtensions public static class IJobParallelForSPMDExtensions
{ {
<# for (var i = 0; i < 8; i++) { #> <# for (var i = 0; i < 8; i++) { #>
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
<#= ForEachDimension(i + 1, j => @$" /// <typeparam name=""TNumber{j}"">The first numeric type used in the SPMD job.</typeparam>", Environment.NewLine) #>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, <#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>(this T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1, " ") #>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<<#= ForEachDimension(i + 1, j => $"WideLane<TNumber{j}>") #>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane<TNumber{j}>") #>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
/// <summary> /// <summary>
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context. /// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
/// </summary> /// </summary>
@@ -111,47 +146,6 @@ public static class IJobParallelForSPMDExtensions
<# } #> <# } #>
} }
public static class JobSPMDUtility
{
<# for (var i = 0; i < 8; i++) { #>
/// <summary>
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
/// </summary>
/// <remarks>
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
/// </remarks>
<#= ForEachDimension(i + 1, j => @$" /// <typeparam name=""TNumber{j}"">The first numeric type used in the SPMD job.</typeparam>", Environment.NewLine) #>
/// <param name="job">The SPMD job to run.</param>
/// <param name="totalIteration">The total number of iterations to execute across all lanes.</param>
/// <param name="ctx">The job execution context providing information about the current execution environment.</param>
public static void Run<T, <#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>(ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1, " ") #>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane<TNumber0>.LaneWidth - 1) / WideLane<TNumber0>.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane<TNumber0>.LaneWidth;
var indices = WideLane<TNumber0>.Sequence(TNumber0.CreateTruncating(baseIndex), TNumber0.One);
var mask = indices < TNumber0.CreateTruncating(totalIteration);
job.Execute<<#= ForEachDimension(i + 1, j => $"WideLane<TNumber{j}>") #>>(indices, mask, in ctx);
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane<TNumber{j}>") #>>(TNumber0.CreateTruncating(loopIndex), ScalarLane<TNumber0>.AllBitsSet, in ctx);
}
}
}
<# } #>
}
<#+ <#+
public string ForEachDimension(int dimension, Func<int, string> action, string spliter = ", ") public string ForEachDimension(int dimension, Func<int, string> action, string spliter = ", ")
{ {

View File

@@ -6,12 +6,14 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public static unsafe partial class MathV public static unsafe partial class MathV
{ {
#region Vector2 # region Vector2
// Creation Functions // Creation Functions
@@ -467,9 +469,9 @@ public static unsafe partial class MathV
}; };
} }
#endregion # endregion
#region Vector3 # region Vector3
// Creation Functions // Creation Functions
@@ -950,9 +952,9 @@ public static unsafe partial class MathV
}; };
} }
#endregion # endregion
#region Vector4 # region Vector4
// Creation Functions // Creation Functions
@@ -1458,10 +1460,10 @@ public static unsafe partial class MathV
}; };
} }
#endregion # endregion
#region Vector3 Specific # region Vector3 Specific
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector3<TLane, TNumber> Cross<TLane, TNumber>(in Vector3<TLane, TNumber> a, in Vector3<TLane, TNumber> b) public static Vector3<TLane, TNumber> Cross<TLane, TNumber>(in Vector3<TLane, TNumber> a, in Vector3<TLane, TNumber> b)
@@ -1476,6 +1478,6 @@ public static unsafe partial class MathV
}; };
} }
#endregion # endregion
} }

View File

@@ -15,7 +15,7 @@ using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
const string TLane = "TLane"; const string TLane = "TLane";
const string TNumber = "TNumber"; const string TNumber = "TNumber";

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector2<TLane, TNumber> : IEquatable<Vector2<TLane, TNumber>> public unsafe struct Vector2<TLane, TNumber> : IEquatable<Vector2<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector3<TLane, TNumber> : IEquatable<Vector3<TLane, TNumber>> public unsafe struct Vector3<TLane, TNumber> : IEquatable<Vector3<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -3,7 +3,7 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public unsafe struct Vector4<TLane, TNumber> : IEquatable<Vector4<TLane, TNumber>> public unsafe struct Vector4<TLane, TNumber> : IEquatable<Vector4<TLane, TNumber>>
where TLane : ISPMDLane<TLane, TNumber> where TLane : ISPMDLane<TLane, TNumber>

View File

@@ -9,5 +9,5 @@ using System.Diagnostics;
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<#= code #> <#= code #>

View File

@@ -1,7 +1,7 @@
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber> public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber>
where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber> where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber>

View File

@@ -7,7 +7,7 @@
using System.Numerics; using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
<# <#
var conversions = new CastRoute[] var conversions = new CastRoute[]
{ {

View File

@@ -5,7 +5,7 @@ using System.Runtime.InteropServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
namespace Misaki.HighPerformance.Mathematics.SPMD; namespace Misaki.HighPerformance.HPC;
public static unsafe class WideLane public static unsafe class WideLane
{ {
@@ -40,8 +40,6 @@ public static unsafe class WideLane
public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber> public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNumber>, TNumber>
where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber> where TNumber : unmanaged, INumber<TNumber>, IBinaryNumber<TNumber>, IMinMaxValue<TNumber>, IBitwiseOperators<TNumber, TNumber, TNumber>
{ {
private static readonly Vector<TNumber> s_indices;
public readonly Vector<TNumber> value; public readonly Vector<TNumber> value;
public static int LaneWidth public static int LaneWidth
@@ -53,13 +51,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
public static WideLane<TNumber> Zero public static WideLane<TNumber> Zero
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane<TNumber>(Vector<TNumber>.Zero); get => Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector<TNumber>.Zero);
} }
public static WideLane<TNumber> One public static WideLane<TNumber> One
{ {
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new WideLane<TNumber>(Vector<TNumber>.One); get => Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector<TNumber>.One);
} }
public static WideLane<TNumber> MinValue public static WideLane<TNumber> MinValue
@@ -86,17 +84,6 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
get => value[index]; get => value[index];
} }
static WideLane()
{
var pValues = stackalloc TNumber[LaneWidth];
for (var i = 0; i < LaneWidth; i++)
{
pValues[i] = TNumber.CreateTruncating(i);
}
s_indices = Vector.Load(pValues);
}
public WideLane(Vector<TNumber> value) public WideLane(Vector<TNumber> value)
{ {
this.value = value; this.value = value;
@@ -145,19 +132,19 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(TNumber value) public static WideLane<TNumber> Create(TNumber value)
{ {
return new WideLane<TNumber>(Vector.Create(value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(params ReadOnlySpan<TNumber> values) public static WideLane<TNumber> Create(params ReadOnlySpan<TNumber> values)
{ {
return new WideLane<TNumber>(Vector.Create(values)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(values));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Create(Vector<TNumber> value) public static WideLane<TNumber> Create(Vector<TNumber> value)
{ {
return new WideLane<TNumber>(value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -185,20 +172,20 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
} }
else else
{ {
return new WideLane<TNumber>(Vector.Create(start) + (Vector.Create(step) * s_indices)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Create(start) + (Vector.Create(step) * Vector<TNumber>.Indices));
} }
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Load(ref TNumber value) public static WideLane<TNumber> Load(ref TNumber value)
{ {
return new WideLane<TNumber>(Vector.LoadUnsafe(ref value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LoadUnsafe(ref value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Load(TNumber* pValue) public static WideLane<TNumber> Load(TNumber* pValue)
{ {
return new WideLane<TNumber>(Vector.Load(pValue)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Load(pValue));
} }
@@ -302,7 +289,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (idx * scale)); pResult[i] = *(TNumber*)((byte*)pData + (idx * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -352,7 +339,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale)); pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -419,7 +406,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (idx * scale)); pResult[i] = *(TNumber*)((byte*)pData + (idx * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -473,7 +460,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale)); pResult[i] = *(TNumber*)((byte*)pData + (pIndices[i] * scale));
} }
return new WideLane<TNumber>(result); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(result);
} }
@@ -777,61 +764,61 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator +(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator +(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value + b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value + b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator -(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator -(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value - b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value - b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator *(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator *(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value * b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value * b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator /(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator /(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value / b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value / b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator %(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator %(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value - VectorFloor(a.value / b.value) * b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value - VectorFloor(a.value / b.value) * b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator -(WideLane<TNumber> a) public static WideLane<TNumber> operator -(WideLane<TNumber> a)
{ {
return new WideLane<TNumber>(-a.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(-a.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator &(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator &(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value & b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value & b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator |(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator |(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value | b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value | b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator ^(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> operator ^(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(a.value ^ b.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value ^ b.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> operator ~(WideLane<TNumber> a) public static WideLane<TNumber> operator ~(WideLane<TNumber> a)
{ {
return new WideLane<TNumber>(~a.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(~a.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -881,7 +868,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Abs(WideLane<TNumber> value) public static WideLane<TNumber> Abs(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(Vector.Abs(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Abs(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -891,13 +878,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var floored = Vector.Floor(v); var floored = Vector.Floor(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(floored)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(floored);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var floored = Vector.Floor(v); var floored = Vector.Floor(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(floored)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(floored));
} }
return value; return value;
@@ -906,60 +893,60 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Frac(WideLane<TNumber> value) public static WideLane<TNumber> Frac(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(value.value - VectorFloor(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(value.value - VectorFloor(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Sqrt(WideLane<TNumber> value) public static WideLane<TNumber> Sqrt(WideLane<TNumber> value)
{ {
return new WideLane<TNumber>(Vector.SquareRoot(value.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.SquareRoot(value.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Lerp(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> t) public static WideLane<TNumber> Lerp(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> t)
{ {
return new WideLane<TNumber>(a.value + (b.value - a.value) * t.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(a.value + (b.value - a.value) * t.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> MultiplyAdd(WideLane<TNumber> a, WideLane<TNumber> b, WideLane<TNumber> c) public static WideLane<TNumber> MultiplyAdd(WideLane<TNumber> left, WideLane<TNumber> right, WideLane<TNumber> addend)
{ {
if (typeof(TNumber) == typeof(float)) if (typeof(TNumber) == typeof(float))
{ {
var va = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(a); var va = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(left);
var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(b); var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(right);
var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(c); var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(addend);
var result = Vector.FusedMultiplyAdd(va, vb, vc); var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var va = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(a); var va = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(left);
var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(b); var vb = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(right);
var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(c); var vc = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(addend);
var result = Vector.FusedMultiplyAdd(va, vb, vc); var result = Vector.FusedMultiplyAdd(va, vb, vc);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return new WideLane<TNumber>((a.value * b.value) + c.value); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>((left.value * right.value) + addend.value);
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Min(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Min(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Min(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Min(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Max(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Max(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Max(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Max(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Clamp(WideLane<TNumber> value, WideLane<TNumber> min, WideLane<TNumber> max) public static WideLane<TNumber> Clamp(WideLane<TNumber> value, WideLane<TNumber> min, WideLane<TNumber> max)
{ {
return new WideLane<TNumber>(Vector.Clamp(value.value, min.value, max.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Clamp(value.value, min.value, max.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1004,13 +991,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Sin(v); var result = Vector.Sin(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Sin(v); var result = Vector.Sin(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1054,13 +1041,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Cos(v); var result = Vector.Cos(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Cos(v); var result = Vector.Cos(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1137,15 +1124,15 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var (sinResult, cosResult) = Vector.SinCos(v); var (sinResult, cosResult) = Vector.SinCos(v);
sin = new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(sinResult)); sin = Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(sinResult));
cos = new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(cosResult)); cos = Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(cosResult));
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var (sinResult, cosResult) = Vector.SinCos(v); var (sinResult, cosResult) = Vector.SinCos(v);
sin = new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(sinResult)); sin = Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(sinResult);
cos = new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(cosResult)); cos = Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(cosResult);
} }
else else
{ {
@@ -1290,13 +1277,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Exp(v); var result = Vector.Exp(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Exp(v); var result = Vector.Exp(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1315,13 +1302,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Log(v); var result = Vector.Log(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Log(v); var result = Vector.Log(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1334,13 +1321,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Log2(v); var result = Vector.Log2(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Log2(v); var result = Vector.Log2(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1353,13 +1340,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Ceiling(v); var result = Vector.Ceiling(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Ceiling(v); var result = Vector.Ceiling(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1372,13 +1359,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Round(v); var result = Vector.Round(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Round(v); var result = Vector.Round(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1391,13 +1378,13 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<float>>(value);
var result = Vector.Truncate(v); var result = Vector.Truncate(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<float>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<float>, WideLane<TNumber>>(result);
} }
else if (typeof(TNumber) == typeof(double)) else if (typeof(TNumber) == typeof(double))
{ {
var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value); var v = Unsafe.BitCast<WideLane<TNumber>, Vector<double>>(value);
var result = Vector.Truncate(v); var result = Vector.Truncate(v);
return new WideLane<TNumber>(Unsafe.BitCast<Vector<double>, Vector<TNumber>>(result)); return Unsafe.BitCast<Vector<double>, WideLane<TNumber>>(result);
} }
return value; return value;
@@ -1418,7 +1405,7 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> CopySign(WideLane<TNumber> magnitude, WideLane<TNumber> sign) public static WideLane<TNumber> CopySign(WideLane<TNumber> magnitude, WideLane<TNumber> sign)
{ {
return new WideLane<TNumber>(Vector.CopySign(magnitude.value, sign.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.CopySign(magnitude.value, sign.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1538,40 +1525,46 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Select(WideLane<TNumber> conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse) public static WideLane<TNumber> Select(WideLane<TNumber> conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse)
{ {
return new WideLane<TNumber>(Vector.ConditionalSelect( return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.ConditionalSelect(
conditionMask.value, conditionMask.value,
ifTrue.value, ifTrue.value,
ifFalse.value)); ifFalse.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Select(byte conditionMask, WideLane<TNumber> ifTrue, WideLane<TNumber> ifFalse)
{
throw new NotImplementedException();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> GreaterThan(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> GreaterThan(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.GreaterThan(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.GreaterThan(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> GreaterThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> GreaterThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.GreaterThanOrEqual(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.GreaterThanOrEqual(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> LessThan(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> LessThan(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.LessThan(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LessThan(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> LessThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> LessThanOrEqual(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.LessThanOrEqual(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.LessThanOrEqual(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public static WideLane<TNumber> Equal(WideLane<TNumber> a, WideLane<TNumber> b) public static WideLane<TNumber> Equal(WideLane<TNumber> a, WideLane<TNumber> b)
{ {
return new WideLane<TNumber>(Vector.Equals(a.value, b.value)); return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(Vector.Equals(a.value, b.value));
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -1617,4 +1610,10 @@ public readonly unsafe partial struct WideLane<TNumber> : ISPMDLane<WideLane<TNu
{ {
return value.ToString(); return value.ToString();
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static implicit operator WideLane<TNumber>(Vector<TNumber> v)
{
return Unsafe.BitCast<Vector<TNumber>, WideLane<TNumber>>(v);
}
} }

View File

@@ -66,7 +66,7 @@ internal static class JobExecutor
} }
} }
public static unsafe void ExecuteCustom<T>(int dataID, int dataGeneration, ref JobRanges jobRanges, ref readonly JobExecutionContext ctx) public unsafe static void ExecuteCustom<T>(int dataID, int dataGeneration, ref JobRanges jobRanges, ref readonly JobExecutionContext ctx)
{ {
ref var job = ref JobDataPool<T>.GetReference(dataID, dataGeneration, out var exists); ref var job = ref JobDataPool<T>.GetReference(dataID, dataGeneration, out var exists);
Debug.Assert(exists, "Job data not found in the pool."); Debug.Assert(exists, "Job data not found in the pool.");
@@ -80,7 +80,7 @@ internal static class JobExecutor
} }
} }
public static unsafe void FreeCustom<T>(ref readonly JobInfo jobInfo) public unsafe static void FreeCustom<T>(ref readonly JobInfo jobInfo)
{ {
ref var job = ref JobDataPool<T>.GetReference(jobInfo.dataID, jobInfo.dataGeneration, out var exists); ref var job = ref JobDataPool<T>.GetReference(jobInfo.dataID, jobInfo.dataGeneration, out var exists);
Debug.Assert(exists, "Job data not found in the pool."); Debug.Assert(exists, "Job data not found in the pool.");

View File

@@ -1,5 +1,3 @@
using System.Runtime.CompilerServices;
namespace Misaki.HighPerformance.Jobs; namespace Misaki.HighPerformance.Jobs;
public readonly struct JobHandle : IEquatable<JobHandle> public readonly struct JobHandle : IEquatable<JobHandle>
@@ -7,17 +5,8 @@ public readonly struct JobHandle : IEquatable<JobHandle>
private readonly int _id; private readonly int _id;
private readonly int _generation; private readonly int _generation;
public int ID public int ID => _id - 1;
{ public int Generation => _generation - 1;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => _id;
}
public int generation
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => _generation;
}
public static JobHandle Invalid => default; public static JobHandle Invalid => default;
@@ -25,8 +14,8 @@ public readonly struct JobHandle : IEquatable<JobHandle>
internal JobHandle(int id, int generation) internal JobHandle(int id, int generation)
{ {
_id = id; _id = id + 1;
_generation = generation; _generation = generation + 1;
} }
public bool Equals(JobHandle other) public bool Equals(JobHandle other)

View File

@@ -1,5 +1,4 @@
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Jobs; namespace Misaki.HighPerformance.Jobs;
@@ -30,22 +29,10 @@ public enum JobState
Completed = 3 Completed = 3
} }
/// <summary>
/// The priority level of a job.
/// </summary>
public enum JobPriority public enum JobPriority
{ {
/// <summary> High = 0,
/// Normal priority. Which will have 37.5% chance to be picked when there are multiple jobs ready to run. Normal = 1,
/// </summary>
Normal = 0,
/// <summary>
/// High priority. Which will have 50.0% chance to be picked when there are multiple jobs ready to run. This is useful for jobs that are on the critical path of the execution and we want to prioritize their completion.
/// </summary>
High = 1,
/// <summary>
/// Low priority. Which will have 12.5% chance to be picked when there are multiple jobs ready to run.
/// </summary>
Low = 2 Low = 2
} }
@@ -74,7 +61,6 @@ public unsafe ref struct CustomJobDesc<T>
public JobPriority priority; public JobPriority priority;
} }
[StructLayout(LayoutKind.Sequential, Pack = 8)]
internal unsafe struct JobInfo internal unsafe struct JobInfo
{ {
public ref struct DependentIterator public ref struct DependentIterator
@@ -204,96 +190,4 @@ internal static class JobUtility
{ {
return (Interlocked.Add(ref jobState, -RC_ONE) & ~STATE_MASK) >> RC_SHIFT; return (Interlocked.Add(ref jobState, -RC_ONE) & ~STATE_MASK) >> RC_SHIFT;
} }
public static unsafe bool TryHelpExecuteJob(JobScheduler jobScheduler, JobHandle handle, int callerThreadIndex)
{
ref var jobInfo = ref jobScheduler.GetJobInfoReference(handle, out var exist);
if (!exist)
{
return false;
}
var rcSpin = new SpinWait();
var rcAcquired = false;
int rc;
while (true)
{
jobScheduler.GetJobInfoReference(handle, out var currentExist);
if (!currentExist)
{
return false;
}
var stateVal = Volatile.Read(ref jobInfo.state);
var state = GetState(stateVal);
// We can't execute it if it's not ready or already done
if (state == JobState.Created || state == JobState.Completed || state == JobState.Invalid)
{
return false;
}
// If it's single job and already running, we can't help it unless we restructure it.
// But if it's a Parallel job, multiple threads CAN safely join the `Running` state.
if (state == JobState.Running && jobInfo.jobRanges.batchSize == jobInfo.jobRanges.totalIteration)
{
// Single execution job is already running on another thread. We just return false.
return false;
}
var newState = stateVal + RC_ONE;
if (state == JobState.Scheduled)
{
newState = (newState & ~STATE_MASK) | JOBSTATE_RUNNING;
}
if (Interlocked.CompareExchange(ref jobInfo.state, newState, stateVal) == stateVal)
{
jobScheduler.GetJobInfoReference(handle, out currentExist);
if (!currentExist)
{
rc = ReleaseRC(ref jobInfo.state);
if (rc == 0)
{
jobScheduler.MarkJobComplete(handle);
}
return false;
}
rcAcquired = true;
break;
}
rcSpin.SpinOnce(-1);
}
if (!rcAcquired)
{
return false;
}
// Execute the work inline
if (jobInfo.pExecutionFunc != null)
{
var ctx = new JobExecutionContext
{
ThreadIndex = callerThreadIndex,
JobScheduler = jobScheduler,
State = jobScheduler.State,
SelfHandle = handle,
};
jobInfo.pExecutionFunc(jobInfo.dataID, jobInfo.dataGeneration, ref jobInfo.jobRanges, in ctx);
}
rc = ReleaseRC(ref jobInfo.state);
if (rc == 0)
{
jobScheduler.MarkJobComplete(handle);
}
return true;
}
} }

View File

@@ -61,7 +61,7 @@ internal sealed class WaitItem : IThreadPoolWorkItem
public void Execute() public void Execute()
{ {
_scheduler.Wait(_jobHandle, false); _scheduler.Wait(_jobHandle);
_completionSource.SetResult(); _completionSource.SetResult();
} }
} }
@@ -121,6 +121,9 @@ internal sealed class WaitAnyItem : IThreadPoolWorkItem
/// </summary> /// </summary>
public sealed unsafe partial class JobScheduler : IDisposable public sealed unsafe partial class JobScheduler : IDisposable
{ {
// Don't sleep indefinitely because that causes our 1ms job to become 15ms.
private const int _SLEEP_THRESHOLD = -1;
private readonly ConcurrentSlotMap<JobInfo> _jobInfoPool; private readonly ConcurrentSlotMap<JobInfo> _jobInfoPool;
private readonly ConcurrentQueue<JobHandle>[] _jobQueues; private readonly ConcurrentQueue<JobHandle>[] _jobQueues;
private readonly WorkerThread[] _workerThreads; private readonly WorkerThread[] _workerThreads;
@@ -132,9 +135,6 @@ public sealed unsafe partial class JobScheduler : IDisposable
private readonly SemaphoreSlim _workSignal; private readonly SemaphoreSlim _workSignal;
private readonly CancellationTokenSource _cts; private readonly CancellationTokenSource _cts;
private readonly int _workerThreadCount;
private readonly int _helperThreadCount;
private readonly object? _state; private readonly object? _state;
private bool _disposed = false; private bool _disposed = false;
@@ -145,17 +145,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
/// <summary> /// <summary>
/// Gets the number of worker threads managed by the job scheduler. /// Gets the number of worker threads managed by the job scheduler.
/// </summary> /// </summary>
public int WorkerCount => _workerThreadCount; public int WorkerCount => _workerThreads.Length;
/// <summary>
/// Gets the number of external helper threads, which is the number of threads reserved for external use.
/// </summary>
public int ExternalHelperThreadCount => _helperThreadCount;
/// <summary>
/// Gets the total number of threads that is possible to execute the jobs, including worker threads and external helper threads. You can use this property to allocate thread local storage.
/// </summary>
public int ThreadLocalCount => _workerThreads.Length;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="JobScheduler"/> class with the specified description. /// Initializes a new instance of the <see cref="JobScheduler"/> class with the specified description.
@@ -180,18 +170,57 @@ public sealed unsafe partial class JobScheduler : IDisposable
_workSignal = new SemaphoreSlim(0); _workSignal = new SemaphoreSlim(0);
_cts = new CancellationTokenSource(); _cts = new CancellationTokenSource();
_workerThreadCount = workerCount;
_helperThreadCount = 1;
_state = desc.State; _state = desc.State;
_workerThreads = new WorkerThread[workerCount + _helperThreadCount]; _workerThreads = new WorkerThread[workerCount];
for (var i = _helperThreadCount; i < _workerThreads.Length; i++) for (var i = 0; i < workerCount; i++)
{ {
_workerThreads[i] = new WorkerThread(i, this, desc.ThreadPriority); _workerThreads[i] = new WorkerThread(i, this, desc.ThreadPriority);
} }
for (var i = _helperThreadCount; i < _workerThreads.Length; i++) foreach (var worker in _workerThreads)
{
worker.Start();
}
}
/// <summary>
/// Initializes a new instance of the <see cref="JobScheduler"/> class with the specified number of worker threads.
/// </summary>
/// <param name="threadCount">The number of worker threads to create. If less than 1, at least one thread will be created.</param>
/// <param name="priority">The priority of the worker threads.</param>
/// <param name="state">The state object for the job scheduler.</param>
[Obsolete("Use JobScheduler(JobSchedulerDesc) instead.")]
public JobScheduler(int threadCount, ThreadPriority priority = ThreadPriority.Normal, object? state = null)
{
var workerCount = Math.Max(1, threadCount);
_jobInfoPool = new ConcurrentSlotMap<JobInfo>(128);
_jobQueues = new ConcurrentQueue<JobHandle>[3];
for (var i = 0; i < 3; i++)
{
_jobQueues[i] = new ConcurrentQueue<JobHandle>();
}
_jobEdges = new JobEdge[4096];
_watermark = 0;
_freeListHead = -1L;
_workSignal = new SemaphoreSlim(0);
_cts = new CancellationTokenSource();
_state = state;
_workerThreads = new WorkerThread[workerCount];
for (var i = 0; i < workerCount; i++)
{
_workerThreads[i] = new WorkerThread(i, this, priority);
}
for (var i = 0; i < workerCount; i++)
{ {
_workerThreads[i].Start(); _workerThreads[i].Start();
} }
@@ -204,7 +233,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
private void EnqueueJobIfReady(JobHandle handle, bool preferLocal) private void EnqueueJobIfReady(JobHandle handle, bool preferLocal)
{ {
ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.generation, out var exist); ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.Generation, out var exist);
if (exist && Volatile.Read(ref jobInfo.dependencyCount) == 0) if (exist && Volatile.Read(ref jobInfo.dependencyCount) == 0)
{ {
@@ -306,7 +335,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
{ {
var dependency = dependencies[i]; var dependency = dependencies[i];
ref var depJobInfo = ref _jobInfoPool.GetElementReferenceAt(dependency.ID, dependency.generation, out var exist); ref var depJobInfo = ref _jobInfoPool.GetElementReferenceAt(dependency.ID, dependency.Generation, out var exist);
if (!exist) if (!exist)
{ {
Interlocked.Decrement(ref infoInPool.dependencyCount); Interlocked.Decrement(ref infoInPool.dependencyCount);
@@ -406,20 +435,58 @@ public sealed unsafe partial class JobScheduler : IDisposable
return ref Unsafe.NullRef<JobInfo>(); return ref Unsafe.NullRef<JobInfo>();
} }
return ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.generation, out exist); return ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.Generation, out exist);
} }
internal void MarkJobComplete(JobHandle handle) internal void MarkJobComplete(JobHandle handle)
{ {
Debug.Assert(handle.IsValid); Debug.Assert(handle.IsValid);
ref var info = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.generation, out var exist); ref var info = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.Generation, out var exist);
if (!exist) if (!exist)
{ {
return; return;
} }
// NOTE: Because we call this on the thread that get rc = 0, not the last one to complete. So we can directly set state to Completed without caring about RC. This also means we don't need to preserve upper bits. #if false
// Lock-free Completion:
// 1. Transition State to Completed (preserving or setting upper bits?).
// Actually, we want to block new Readers. Setting state to Completed blocks new Readers.
// 2. Wait for existing Readers (RC == 0).
var spin = new SpinWait();
while (true)
{
var stateVal = Volatile.Read(ref info.state);
var state = JobUtility.GetState(stateVal);
if (state == JobState.Completed)
{
return;
}
// Preserve upper bits (RC) and set state to Completed. This blocks new Readers.
var newState = (stateVal & ~JobUtility.STATE_MASK) | (int)JobState.Completed;
if (Interlocked.CompareExchange(ref info.state, newState, stateVal) == stateVal)
{
// Successfully set State to Completed. New readers will see Completed and back off.
// Now we must wait for existing readers to finish (RC to become 0).
while (true)
{
var current = Volatile.Read(ref info.state);
if (((uint)current >> 16) == 0)
{
break; // RC is 0. Safe to proceed.
}
spin.SpinOnce(-1);
}
break;
}
spin.SpinOnce(-1);
}
#else
// NOTE: We are the last one to complete. Because we call this on the thread that get rc = 0, not the last one to complete. So we can directly set state to Completed without caring about RC. This also means we don't need to preserve upper bits.
var spin = new SpinWait(); var spin = new SpinWait();
while (Interlocked.CompareExchange(ref info.state, JobUtility.JOBSTATE_COMPLETED, JobUtility.JOBSTATE_RUNNING) != JobUtility.JOBSTATE_RUNNING) while (Interlocked.CompareExchange(ref info.state, JobUtility.JOBSTATE_COMPLETED, JobUtility.JOBSTATE_RUNNING) != JobUtility.JOBSTATE_RUNNING)
{ {
@@ -430,13 +497,14 @@ public sealed unsafe partial class JobScheduler : IDisposable
spin.SpinOnce(-1); spin.SpinOnce(-1);
} }
#endif
var it = info.GetDependentIterator(_jobEdges); var it = info.GetDependentIterator(_jobEdges);
while (it.MoveNext()) while (it.MoveNext())
{ {
var depHandle = it.Current; var depHandle = it.Current;
ref var depJobInfo = ref _jobInfoPool.GetElementReferenceAt(depHandle.ID, depHandle.generation, out var depExist); ref var depJobInfo = ref _jobInfoPool.GetElementReferenceAt(depHandle.ID, depHandle.Generation, out var depExist);
if (depExist && Interlocked.Decrement(ref depJobInfo.dependencyCount) == 0) if (depExist && Interlocked.Decrement(ref depJobInfo.dependencyCount) == 0)
{ {
EnqueueJobIfReady(depHandle, true); EnqueueJobIfReady(depHandle, true);
@@ -450,7 +518,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
info.pFreeFunc(in info); info.pFreeFunc(in info);
} }
_jobInfoPool.Remove(handle.ID, handle.generation); _jobInfoPool.Remove(handle.ID, handle.Generation);
} }
/// <summary> /// <summary>
@@ -776,7 +844,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
return JobState.Invalid; return JobState.Invalid;
} }
ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.generation, out var exist); ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.Generation, out var exist);
if (!exist) if (!exist)
{ {
return JobState.Completed; // We assume completed if not found. Invalid state is reserved for error. return JobState.Completed; // We assume completed if not found. Invalid state is reserved for error.
@@ -790,8 +858,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
/// Blocks the calling thread until the specified job is completed. /// Blocks the calling thread until the specified job is completed.
/// </summary> /// </summary>
/// <param name="handle">The handle of the job to wait for.</param> /// <param name="handle">The handle of the job to wait for.</param>
/// <param name="inlineExecution">A value indicating whether to help execute the job while waiting. Defaults to true. Only ONE external thread is allowed to help execute jobs if you rely on thread local storage.</param> public void Wait(JobHandle handle)
public void Wait(JobHandle handle, bool inlineExecution = true)
{ {
if (!handle.IsValid) if (!handle.IsValid)
{ {
@@ -801,38 +868,23 @@ public sealed unsafe partial class JobScheduler : IDisposable
// TODO: Maybe we can steal a up stream or current job to execute while waiting? // TODO: Maybe we can steal a up stream or current job to execute while waiting?
// For example, if we wait on job A which depends on job B, and both are not scheduled yet, we can steal and execute job B to speed up the completion of A. // For example, if we wait on job A which depends on job B, and both are not scheduled yet, we can steal and execute job B to speed up the completion of A.
var callerThreadIndex = WorkerThread.ThreadIndex;
var spin = new SpinWait(); var spin = new SpinWait();
while (true) while (true)
{ {
ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.generation, out var exist); ref var jobInfo = ref _jobInfoPool.GetElementReferenceAt(handle.ID, handle.Generation, out var exist);
if (!exist) if (!exist)
{ {
return; return;
} }
// Mask out RC // Mask out RC
var state = JobUtility.ReadState(ref jobInfo); if (JobUtility.ReadState(ref jobInfo) == JobState.Completed)
if (state == JobState.Completed)
{ {
return; return;
} }
var madeProgress = false; // var sleepThreshold = jobInfo.jobRanges.totalIteration * jobInfo.jobRanges.batchSize * 100;
if (inlineExecution) spin.SpinOnce(_SLEEP_THRESHOLD);
{
// Only try to help execute THIS specific job.
if (state == JobState.Scheduled || (state == JobState.Running && jobInfo.jobRanges.totalIteration > jobInfo.jobRanges.batchSize))
{
madeProgress = JobUtility.TryHelpExecuteJob(this, handle, callerThreadIndex);
}
}
if (!madeProgress)
{
spin.SpinOnce(-1); // Never sleep and yield to achieve lowest latency for single job completion.
}
} }
} }
@@ -851,7 +903,6 @@ public sealed unsafe partial class JobScheduler : IDisposable
} }
var spin = new SpinWait(); var spin = new SpinWait();
var sleepThreshold = handles.Length * 20;
var completedCount = 0; var completedCount = 0;
while (true) while (true)
@@ -859,7 +910,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
for (var i = completedCount; i < handles.Length; i++) for (var i = completedCount; i < handles.Length; i++)
{ {
var handle = handles[i]; var handle = handles[i];
if (!_jobInfoPool.Contains(handle.ID, handle.generation)) if (!_jobInfoPool.Contains(handle.ID, handle.Generation))
{ {
// Move completed handle to the front (completedCount index) to avoid checking it again. // Move completed handle to the front (completedCount index) to avoid checking it again.
var temp = handles[completedCount]; var temp = handles[completedCount];
@@ -875,7 +926,7 @@ public sealed unsafe partial class JobScheduler : IDisposable
return; return;
} }
spin.SpinOnce(sleepThreshold); spin.SpinOnce(_SLEEP_THRESHOLD);
} }
} }
@@ -887,19 +938,18 @@ public sealed unsafe partial class JobScheduler : IDisposable
public JobHandle WaitAny(params ReadOnlySpan<JobHandle> handles) public JobHandle WaitAny(params ReadOnlySpan<JobHandle> handles)
{ {
var spin = new SpinWait(); var spin = new SpinWait();
var sleepThreshold = handles.Length * 10;
while (true) while (true)
{ {
foreach (var handle in handles) foreach (var handle in handles)
{ {
if (!_jobInfoPool.Contains(handle.ID, handle.generation)) if (!_jobInfoPool.Contains(handle.ID, handle.Generation))
{ {
return handle; return handle;
} }
} }
spin.SpinOnce(sleepThreshold); spin.SpinOnce(_SLEEP_THRESHOLD);
} }
} }
@@ -972,9 +1022,9 @@ public sealed unsafe partial class JobScheduler : IDisposable
_cts.Cancel(); _cts.Cancel();
for (var i = _helperThreadCount; i < _workerThreads.Length; i++) foreach (var worker in _workerThreads)
{ {
_workerThreads[i].Dispose(); worker.Dispose();
} }
_workSignal.Dispose(); _workSignal.Dispose();

View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<TargetFramework>net10.0</TargetFramework> <TargetFramework>net10.0</TargetFramework>
@@ -6,24 +6,17 @@
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks> <AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild> <GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<AssemblyVersion>3.1.8</AssemblyVersion> <AssemblyVersion>3.1.6</AssemblyVersion>
<Version>$(AssemblyVersion)</Version> <Version>$(AssemblyVersion)</Version>
<Authors>Misaki</Authors> <Authors>Misaki</Authors>
<PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl> <PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl>
<RepositoryUrl>https://github.com/misakieku/Misaki.HighPerformance.git</RepositoryUrl> <RepositoryUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</RepositoryUrl>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageReadmeFile>README.md</PackageReadmeFile> <PackageReadmeFile>README.md</PackageReadmeFile>
<PackageTags>job-system;multithreading;concurrency;task-scheduler;work-stealing;zero-allocation;0gc;dag;dependency-graph;game-engine;ecs;high-performance;spmc</PackageTags>
<Description>
A high-performance, zero-allocation (0 GC), and zero-closure job system designed for custom game engines and data-oriented design (DOD).
Features a lock-free Work-Stealing scheduler (SPMC) with DAG-based multi-dependency resolution.
Includes a blazingly fast O(1) branchless priority queue (High/Normal/Low) using Cascade LUTs.
Uniquely supports both unmanaged and managed jobs seamlessly via internal pooling, offering maximum flexibility without compromising C# GC performance.
</Description>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<None Include="../../LICENSE" Pack="true" PackagePath=""/> <None Include="../LICENSE" Pack="true" PackagePath=""/>
<None Include="README.md" Pack="true" PackagePath=""/> <None Include="README.md" Pack="true" PackagePath=""/>
</ItemGroup> </ItemGroup>

View File

@@ -29,6 +29,7 @@ This package provides job contracts, scheduling, worker threads, and dependency
- `JobExecutionContext` - `JobExecutionContext`
- `JobState` - `JobState`
- `WorkerThread` - `WorkerThread`
- `TempJobAllocator`
## Example ## Example
@@ -49,15 +50,6 @@ public struct AddJob : IJob
} }
} }
JobSchedulerDesc desc = new JobSchedulerDesc
{
ThreadCount = Environment.ProcessorCount,
ThreadPriority = ThreadPriority.Normal,
DependencyChainCapacity = 64,
};
JobScheduler jobScheduler = new JobScheduler(in desc);
int a = 5; int a = 5;
int b = 10; int b = 10;
int result = 0; int result = 0;
@@ -71,8 +63,6 @@ AddJob job = new AddJob
JobHandle handle = jobScheduler.Schedule(job); JobHandle handle = jobScheduler.Schedule(job);
jobScheduler.Wait(handle); jobScheduler.Wait(handle);
Console.WriteLine($"Result: {result}"); // Output: Result: 15
``` ```
### IJobParallelFor example ### IJobParallelFor example

View File

@@ -1,45 +1,24 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Numerics; using System.Numerics;
using System.Runtime.InteropServices;
namespace Misaki.HighPerformance.Jobs; namespace Misaki.HighPerformance.Jobs;
[StructLayout(LayoutKind.Sequential)]
public class SPMCQueue<T> public class SPMCQueue<T>
{ {
private struct __padding
{
private unsafe fixed byte _padding[64];
}
private readonly T[] _queue; private readonly T[] _queue;
private readonly int _mask; private readonly int _mask;
private int _head; private int _head;
private __padding _padding;
private int _tail; private int _tail;
public bool IsEmpty => Volatile.Read(ref _tail) - Volatile.Read(ref _head) <= 0; public bool IsEmpty => Volatile.Read(ref _tail) - Volatile.Read(ref _head) <= 0;
/// <summary>
/// Initializes a new instance of the SPMCQueue class with the specified capacity.
/// </summary>
/// <remarks>
/// This queue will not resize when it reaches capacity.
/// </remarks>
/// <param name="capacity">The capacity of the queue.</param>
public SPMCQueue(int capacity) public SPMCQueue(int capacity)
{ {
var powerOfTwoCapacity = (int)BitOperations.RoundUpToPowerOf2((uint)capacity); _queue = new T[(int)BitOperations.RoundUpToPowerOf2((uint)capacity)];
_queue = new T[powerOfTwoCapacity]; _mask = capacity - 1;
_mask = powerOfTwoCapacity - 1;
} }
/// <summary>
/// Tries to push an item onto the queue.
/// </summary>
/// <param name="item">The item to push.</param>
/// <returns>True if the item was pushed successfully; otherwise, false.</returns>
public bool TryPush(T item) public bool TryPush(T item)
{ {
var tail = _tail; var tail = _tail;
@@ -56,11 +35,6 @@ public class SPMCQueue<T>
return true; return true;
} }
/// <summary>
/// Trys to pop an item from the queue.
/// </summary>
/// <param name="item">The item to pop.</param>
/// <returns>True if an item was popped successfully; otherwise, false.</returns>
public bool TryPop([MaybeNullWhen(false)] out T? item) public bool TryPop([MaybeNullWhen(false)] out T? item)
{ {
var tail = _tail - 1; var tail = _tail - 1;
@@ -96,11 +70,6 @@ public class SPMCQueue<T>
return false; return false;
} }
/// <summary>
/// Trys to steal an item from the queue.
/// </summary>
/// <param name="item">The item to steal.</param>
/// <returns>True if an item was stolen successfully; otherwise, false.</returns>
public bool TrySteal([MaybeNullWhen(false)] out T? item) public bool TrySteal([MaybeNullWhen(false)] out T? item)
{ {
var head = Volatile.Read(ref _head); var head = Volatile.Read(ref _head);

Some files were not shown because too many files have changed in this diff Show More