fix(collection) Fixed the race condition problem in ConcurrentSlotMap.

This commit is contained in:
2026-04-04 20:01:09 +09:00
parent b57b7adc77
commit d8be2c7b2a
2 changed files with 153 additions and 78 deletions

View File

@@ -14,6 +14,10 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
public int isValid; public int isValid;
} }
private const int CHUNK_SHIFT = 8;
private const int CHUNK_SIZE = 1 << CHUNK_SHIFT;
private const int CHUNK_MASK = CHUNK_SIZE - 1;
public struct Enumerator : IEnumerator<T> public struct Enumerator : IEnumerator<T>
{ {
private readonly ConcurrentSlotMap<T> _slotMap; private readonly ConcurrentSlotMap<T> _slotMap;
@@ -25,15 +29,30 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
_currentIndex = -1; _currentIndex = -1;
} }
public readonly T Current => _slotMap._data[_currentIndex].value!; public readonly T Current
{
get
{
var chunks = _slotMap._chunks;
int chunkIdx = _currentIndex >> CHUNK_SHIFT;
int localIdx = _currentIndex & CHUNK_MASK;
return chunks[chunkIdx][localIdx].value!;
}
}
readonly object? IEnumerator.Current => Current; readonly object? IEnumerator.Current => Current;
public bool MoveNext() public bool MoveNext()
{ {
var capacity = Volatile.Read(ref _slotMap._capacity); var maxIndex = Volatile.Read(ref _slotMap._nextSlotIndex);
while (++_currentIndex < capacity) var chunks = _slotMap._chunks;
while (++_currentIndex < maxIndex)
{ {
if (Volatile.Read(ref _slotMap._data[_currentIndex].isValid) == 1) int chunkIdx = _currentIndex >> CHUNK_SHIFT;
int localIdx = _currentIndex & CHUNK_MASK;
if (chunkIdx < chunks.Length && Volatile.Read(ref chunks[chunkIdx][localIdx].isValid) == 1)
{ {
return true; return true;
} }
@@ -52,7 +71,7 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
} }
} }
private volatile SlotEntry[] _data; private volatile SlotEntry[][] _chunks;
private readonly ConcurrentQueue<int> _freeSlots; private readonly ConcurrentQueue<int> _freeSlots;
private int _count; private int _count;
@@ -69,20 +88,30 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public ConcurrentSlotMap(int initialCapacity = 16) public ConcurrentSlotMap(int initialCapacity = 256)
{ {
_capacity = initialCapacity;
_count = 0; _count = 0;
_nextSlotIndex = 0; _nextSlotIndex = 0;
_isResizing = 0; _isResizing = 0;
_data = new SlotEntry[initialCapacity]; int initialChunks = (initialCapacity + CHUNK_MASK) / CHUNK_SIZE;
if (initialChunks == 0) initialChunks = 1;
_capacity = initialChunks * CHUNK_SIZE;
_chunks = new SlotEntry[initialChunks][];
for (int i = 0; i < initialChunks; i++)
{
_chunks[i] = new SlotEntry[CHUNK_SIZE];
}
_freeSlots = new(); _freeSlots = new();
} }
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
private void TryResize(int requiredCapacity) private void EnsureChunkExists(int requiredChunkIndex)
{ {
if (requiredChunkIndex < _chunks.Length) return;
// Use CAS to ensure only one thread does the resize // Use CAS to ensure only one thread does the resize
if (Interlocked.CompareExchange(ref _isResizing, 1, 0) != 0) if (Interlocked.CompareExchange(ref _isResizing, 1, 0) != 0)
{ {
@@ -97,33 +126,30 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
try try
{ {
var currentCapacity = Volatile.Read(ref _capacity); var oldChunks = _chunks;
if (currentCapacity >= requiredCapacity) if (requiredChunkIndex < oldChunks.Length)
{ {
return; // Another thread already resized return; // Another thread already resized
} }
var newCapacity = currentCapacity; int newChunkCount = oldChunks.Length;
while (newCapacity < requiredCapacity) while (newChunkCount <= requiredChunkIndex)
{ {
newCapacity *= 2; newChunkCount *= 2;
} }
var newData = new SlotEntry[newCapacity]; var newChunks = new SlotEntry[newChunkCount][];
var oldData = _data; Array.Copy(oldChunks, newChunks, oldChunks.Length);
// Copy existing data // Initialize new chunks
Array.Copy(oldData, newData, currentCapacity); for (var i = oldChunks.Length; i < newChunkCount; i++)
// Initialize new slots
for (var i = currentCapacity; i < newCapacity; i++)
{ {
newData[i] = new SlotEntry(); newChunks[i] = new SlotEntry[CHUNK_SIZE];
} }
// Atomically update the array reference and capacity // Atomically update the array reference and capacity
_data = newData; _chunks = newChunks;
Volatile.Write(ref _capacity, newCapacity); Volatile.Write(ref _capacity, newChunkCount * CHUNK_SIZE);
} }
finally finally
{ {
@@ -134,49 +160,65 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
public int Add(T item, out int generation) public int Add(T item, out int generation)
{ {
// Try to get a free slot first while (true)
if (_freeSlots.TryDequeue(out var slotIndex))
{ {
ref var slot = ref _data[slotIndex]; // Try to get a free slot first
if (_freeSlots.TryDequeue(out var slotIndex))
// Atomically mark as valid and get the current generation
var currentGeneration = Volatile.Read(ref slot.generation);
slot.value = item;
// Use CAS to mark as valid atomically
if (Interlocked.CompareExchange(ref slot.isValid, 1, 0) == 0)
{ {
generation = currentGeneration; var chunks = _chunks;
Interlocked.Increment(ref _count); int chunkIdx = slotIndex >> CHUNK_SHIFT;
return slotIndex; int localIdx = slotIndex & CHUNK_MASK;
if (chunkIdx < chunks.Length)
{
ref var slot = ref chunks[chunkIdx][localIdx];
// Atomically mark as valid and get the current generation
var currentGeneration = Volatile.Read(ref slot.generation);
slot.value = item;
// Use CAS to mark as valid atomically
if (Interlocked.CompareExchange(ref slot.isValid, 1, 0) == 0)
{
generation = currentGeneration;
Interlocked.Increment(ref _count);
return slotIndex;
}
else
{
// Slot was somehow already valid, don't put it back in free pool
// Just loop and try again
continue;
}
}
else
{
continue;
}
} }
else
// Need a new slot
int newSlotIndex = Interlocked.Increment(ref _nextSlotIndex) - 1;
int newChunkIdx = newSlotIndex >> CHUNK_SHIFT;
int newLocalIdx = newSlotIndex & CHUNK_MASK;
var currentChunks = _chunks;
if (newChunkIdx >= currentChunks.Length)
{ {
// Slot was somehow already valid, put it back and try again EnsureChunkExists(newChunkIdx);
_freeSlots.Enqueue(slotIndex); currentChunks = _chunks; // Re-read after resize
return Add(item, out generation);
} }
// Initialize the new slot
ref var newSlot = ref currentChunks[newChunkIdx][newLocalIdx];
newSlot.value = item;
newSlot.generation = 0;
Volatile.Write(ref newSlot.isValid, 1);
generation = 0;
Interlocked.Increment(ref _count);
return newSlotIndex;
} }
// Need a new slot
slotIndex = Interlocked.Increment(ref _nextSlotIndex) - 1;
// Check if we need to resize
var currentCapacity = Volatile.Read(ref _capacity);
if (slotIndex >= currentCapacity)
{
TryResize(slotIndex + 1);
}
// Initialize the new slot
ref var newSlot = ref _data[slotIndex];
newSlot.value = item;
newSlot.generation = 0;
Volatile.Write(ref newSlot.isValid, 1);
generation = 0;
Interlocked.Increment(ref _count);
return slotIndex;
} }
public bool Remove(int slotIndex, int generation) public bool Remove(int slotIndex, int generation)
@@ -186,15 +228,23 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
public bool Remove(int slotIndex, int generation, [MaybeNullWhen(false)] out T value) public bool Remove(int slotIndex, int generation, [MaybeNullWhen(false)] out T value)
{ {
var capacity = Volatile.Read(ref _capacity); if (slotIndex < 0)
if (slotIndex < 0 || slotIndex >= capacity)
{ {
value = default; value = default;
return false; return false;
} }
ref var slot = ref _data[slotIndex]; var chunks = _chunks;
int chunkIdx = slotIndex >> CHUNK_SHIFT;
int localIdx = slotIndex & CHUNK_MASK;
if (chunkIdx >= chunks.Length)
{
value = default;
return false;
}
ref var slot = ref chunks[chunkIdx][localIdx];
// Check if slot is valid and generation matches // Check if slot is valid and generation matches
if (Volatile.Read(ref slot.isValid) == 0 || Volatile.Read(ref slot.generation) != generation) if (Volatile.Read(ref slot.isValid) == 0 || Volatile.Read(ref slot.generation) != generation)
@@ -221,12 +271,21 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
public bool Contains(int slotIndex, int generation) public bool Contains(int slotIndex, int generation)
{ {
if (slotIndex < 0 || slotIndex >= Volatile.Read(ref _capacity)) if (slotIndex < 0)
{ {
return false; return false;
} }
ref var slot = ref _data[slotIndex]; var chunks = _chunks;
int chunkIdx = slotIndex >> CHUNK_SHIFT;
int localIdx = slotIndex & CHUNK_MASK;
if (chunkIdx >= chunks.Length)
{
return false;
}
ref var slot = ref chunks[chunkIdx][localIdx];
var currentGeneration = Volatile.Read(ref slot.generation); var currentGeneration = Volatile.Read(ref slot.generation);
var isValid = Volatile.Read(ref slot.isValid) == 1; var isValid = Volatile.Read(ref slot.isValid) == 1;
@@ -250,7 +309,11 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
return false; return false;
} }
value = _data[slotIndex].value!; var chunks = _chunks;
int chunkIdx = slotIndex >> CHUNK_SHIFT;
int localIdx = slotIndex & CHUNK_MASK;
value = chunks[chunkIdx][localIdx].value!;
return true; return true;
} }
@@ -272,8 +335,12 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
return ref Unsafe.NullRef<T>(); return ref Unsafe.NullRef<T>();
} }
var chunks = _chunks;
int chunkIdx = slotIndex >> CHUNK_SHIFT;
int localIdx = slotIndex & CHUNK_MASK;
exist = true; exist = true;
return ref _data[slotIndex].value!; return ref chunks[chunkIdx][localIdx].value!;
} }
public bool UpdateElement(int slotIndex, int generation, T newValue) public bool UpdateElement(int slotIndex, int generation, T newValue)
@@ -283,7 +350,11 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
return false; return false;
} }
_data[slotIndex].value = newValue; var chunks = _chunks;
int chunkIdx = slotIndex >> CHUNK_SHIFT;
int localIdx = slotIndex & CHUNK_MASK;
chunks[chunkIdx][localIdx].value = newValue;
return true; return true;
} }
@@ -294,13 +365,17 @@ public class ConcurrentSlotMap<T> : IEnumerable<T>
Volatile.Write(ref _nextSlotIndex, 0); Volatile.Write(ref _nextSlotIndex, 0);
// Clear all slots // Clear all slots
var capacity = Volatile.Read(ref _capacity); var chunks = _chunks;
for (var i = 0; i < capacity; i++) for (var c = 0; c < chunks.Length; c++)
{ {
ref var slot = ref _data[i]; var chunk = chunks[c];
Volatile.Write(ref slot.isValid, 0); for (var i = 0; i < CHUNK_SIZE; i++)
slot.generation = 0; {
slot.value = default!; ref var slot = ref chunk[i];
Volatile.Write(ref slot.isValid, 0);
slot.generation = 0;
slot.value = default!;
}
} }
_freeSlots.Clear(); _freeSlots.Clear();

View File

@@ -7,7 +7,7 @@
<AllowUnsafeBlocks>True</AllowUnsafeBlocks> <AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild> <GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<Authors>Misaki</Authors> <Authors>Misaki</Authors>
<AssemblyVersion>1.0.6</AssemblyVersion> <AssemblyVersion>1.0.7</AssemblyVersion>
<Version>$(AssemblyVersion)</Version> <Version>$(AssemblyVersion)</Version>
<PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl> <PackageProjectUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</PackageProjectUrl>
<RepositoryUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</RepositoryUrl> <RepositoryUrl>https://git.personalnas.com/Misaki/Misaki.HighPerformance.git</RepositoryUrl>