#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension="gen.cs" #>
using Misaki.HighPerformance.Jobs;
using System.Numerics;
namespace Misaki.HighPerformance.Mathematics.SPMD;
<#
const string TLane = "TLane";
const string TNumber = "TNumber";
const string GenericParameters = $"{TLane}, {TNumber}";
var TLaneRestrictions = $@"where {TLane} : ISPMDLane<{TLane}, {TNumber}>";
var TNumberRestrictions = $@"where {TNumber} : unmanaged, INumber<{TNumber}>, IBinaryNumber<{TNumber}>, IMinMaxValue<{TNumber}>, IBitwiseOperators<{TNumber}, {TNumber}, {TNumber}>";
for (var i = 0; i < 8; i++) { #>
///
/// A job interface for Single Program Multiple Data (SPMD) execution, allowing for efficient parallel processing of data across multiple lanes.
///
///
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
///
<#= ForEachDimension(i + 1, j => @$"/// The first numeric type used in the SPMD job.", Environment.NewLine) #>
public interface IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1) #>
{
void Execute<<#= ForEachDimension(i + 1, j => $"TLane{j}") #>>(int baseIndex, ref readonly JobExecutionContext ctx)
<#= GetTLaneRestrictions(i + 1, " ") #>;
}
internal struct SPMDJobWrapper $"TNumber{j}") #>> : IJobParallelFor
where T : unmanaged, IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1) #>
{
public T innerJob;
public int totalIteration;
public void Execute(int loopIndex, ref readonly JobExecutionContext ctx)
{
var baseIndex = loopIndex * WideLane.LaneWidth;
var remaining = totalIteration - baseIndex;
if (remaining >= WideLane.LaneWidth)
{
innerJob.Execute<<#= ForEachDimension(i + 1, j => $"WideLane") #>>(baseIndex, in ctx);
}
else
{
for (var j = 0; j < remaining; j++)
{
innerJob.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane") #>>(baseIndex + j, in ctx);
}
}
}
}
internal struct SPMDScalerJobWrapper $"TNumber{j}") #>> : IJobParallelFor
where T : unmanaged, IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1) #>
{
public T innerJob;
public int totalIteration;
public void Execute(int loopIndex, ref readonly JobExecutionContext ctx)
{
innerJob.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane") #>>(loopIndex, in ctx);
}
}
<# } #>
public static class IJobParallelForSPMDExtensions
{
<# for (var i = 0; i < 8; i++) { #>
///
/// Run the SPMD job with the specified total count and job execution context directly on the calling thread.
///
///
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
///
<#= ForEachDimension(i + 1, j => @$" /// The first numeric type used in the SPMD job.", Environment.NewLine) #>
/// The SPMD job to run.
/// The total number of iterations to execute across all lanes.
/// The job execution context providing information about the current execution environment.
public static void Run $"TNumber{j}") #>>(this ref T job, int totalIteration, ref readonly JobExecutionContext ctx)
where T : struct, IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1) #>
{
if (WideLane.IsSupported)
{
var iterations = (totalIteration + WideLane.LaneWidth - 1) / WideLane.LaneWidth;
for (var loopIndex = 0; loopIndex < iterations; loopIndex++)
{
var baseIndex = loopIndex * WideLane.LaneWidth;
var remaining = totalIteration - baseIndex;
if (remaining >= WideLane.LaneWidth)
{
job.Execute<<#= ForEachDimension(i + 1, j => $"WideLane") #>>(baseIndex, in ctx);
}
else
{
for (var i = 0; i < remaining; i++)
{
job.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane") #>>(baseIndex + i, in ctx);
}
}
}
}
else
{
for (var loopIndex = 0; loopIndex < totalIteration; loopIndex++)
{
job.Execute<<#= ForEachDimension(i + 1, j => $"ScalarLane") #>>(loopIndex, in ctx);
}
}
}
///
/// Schedule the SPMD job for parallel execution across multiple threads, with the specified total count, batch size, and job execution context.
///
<#= ForEachDimension(i + 1, j => @$" /// The first numeric type used in the SPMD job.", Environment.NewLine) #>
///
/// Always use TNumber0 as the primary type for determining lane width and job scheduling, even if it's not used in the job execution.
///
/// The job scheduler to use for scheduling the job.
/// The SPMD job to schedule.
/// The total number of iterations to execute across all lanes.
/// The number of iterations to execute in each batch for parallel execution.
/// Whether to prefer scheduling the job on the local thread for better cache locality.
/// The priority of the job.
/// Any job handles that this job depends on, which must complete before this job can start.
public static JobHandle ScheduleParallelSPDM $"TNumber{j}") #>>(this JobScheduler jobScheduler, ref T job, int totalIteration, int batchSize, bool preferLocal, JobPriority priority, params ReadOnlySpan dependencies)
where T : unmanaged, IJobSPMD<<#= ForEachDimension(i + 1, j => $"TNumber{j}") #>>
<#= GetTNumberRestrictions(i + 1) #>
{
if (WideLane.IsSupported)
{
var warper = new SPMDJobWrapper $"TNumber{j}") #>>
{
innerJob = job,
totalIteration = totalIteration,
};
var iterations = (totalIteration + WideLane.LaneWidth - 1) / WideLane.LaneWidth;
return jobScheduler.ScheduleParallelFor(ref warper, iterations, batchSize, preferLocal, priority, dependencies);
}
else
{
var warper = new SPMDScalerJobWrapper $"TNumber{j}") #>>
{
innerJob = job,
totalIteration = totalIteration,
};
return jobScheduler.ScheduleParallelFor(ref warper, totalIteration, batchSize, preferLocal, priority, dependencies);
}
}
<# } #>
}
<#+
public string ForEachDimension(int dimension, Func action, string spliter = ", ")
{
return string.Join(spliter, Enumerable.Range(0, dimension).Select(i => action(i)));
}
public string GetTNumberRestrictions(int dimension, string space = " ")
{
var sb = new StringBuilder();
for (var i = 0; i < dimension; i++)
{
sb.Append(space + $@"where TNumber{i} : unmanaged, INumber, IBinaryNumber, IMinMaxValue, IBitwiseOperators");
if (i < dimension - 1)
{
sb.AppendLine();
}
}
return sb.ToString();
}
public string GetTLaneRestrictions(int dimension, string space = " ")
{
var sb = new StringBuilder();
for (var i = 0; i < dimension; i++)
{
sb.Append(space + $@"where TLane{i} : unmanaged, ISPMDLane");
if (i < dimension - 1)
{
sb.AppendLine();
}
}
return sb.ToString();
}
#>