diff --git a/ICD.Common.Utils/ICD.Common.Utils_SimplSharp.csproj b/ICD.Common.Utils/ICD.Common.Utils_SimplSharp.csproj index ef2a1ce..562e95b 100644 --- a/ICD.Common.Utils/ICD.Common.Utils_SimplSharp.csproj +++ b/ICD.Common.Utils/ICD.Common.Utils_SimplSharp.csproj @@ -123,6 +123,7 @@ + diff --git a/ICD.Common.Utils/Threading/KeyedLock.cs b/ICD.Common.Utils/Threading/KeyedLock.cs new file mode 100644 index 0000000..1f4b030 --- /dev/null +++ b/ICD.Common.Utils/Threading/KeyedLock.cs @@ -0,0 +1,139 @@ +#if !SIMPLSHARP +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using ICD.Common.Properties; + +namespace ICD.Common.Utils.Threading +{ + public sealed class KeyedLock + { + private readonly Dictionary m_PerKey; + private readonly Queue m_Pool; + private readonly int m_PoolCapacity; + + public KeyedLock(IEqualityComparer keyComparer = null, int poolCapacity = 10) + { + m_PerKey = new Dictionary(keyComparer); + m_Pool = new Queue(poolCapacity); + m_PoolCapacity = poolCapacity; + } + + #region Methods + + [PublicAPI] + public async Task WaitAsync(TKey key, int millisecondsTimeout, + CancellationToken cancellationToken = default) + { + var semaphore = GetSemaphore(key); + bool entered = false; + try + { + entered = await semaphore.WaitAsync(millisecondsTimeout, + cancellationToken).ConfigureAwait(false); + } + finally + { + if (!entered) ReleaseSemaphore(key, entered: false); + } + + return entered; + } + + [PublicAPI] + public Task WaitAsync(TKey key, CancellationToken cancellationToken = default) + => WaitAsync(key, Timeout.Infinite, cancellationToken); + + [PublicAPI] + public bool Wait(TKey key, int millisecondsTimeout, + CancellationToken cancellationToken = default) + { + var semaphore = GetSemaphore(key); + bool entered = false; + try + { + entered = semaphore.Wait(millisecondsTimeout, cancellationToken); + } + finally + { + if (!entered) + ReleaseSemaphore(key, entered: false); + } + + return entered; + } + + [PublicAPI] + public void Wait(TKey key, CancellationToken cancellationToken = default) + => Wait(key, Timeout.Infinite, cancellationToken); + + [PublicAPI] + public void Release(TKey key) => ReleaseSemaphore(key, entered: true); + + #endregion + + #region Private Methods + + private SemaphoreSlim GetSemaphore(TKey key) + { + SemaphoreSlim semaphore; + lock (m_PerKey) + { + if (m_PerKey.TryGetValue(key, out var entry)) + { + int counter; + (semaphore, counter) = entry; + m_PerKey[key] = (semaphore, ++counter); + } + else + { + lock (m_Pool) semaphore = m_Pool.Count > 0 ? m_Pool.Dequeue() : null; + if (semaphore == null) semaphore = new SemaphoreSlim(1, 1); + m_PerKey[key] = (semaphore, 1); + } + } + + return semaphore; + } + + private void ReleaseSemaphore(TKey key, bool entered) + { + SemaphoreSlim semaphore; + int counter; + lock (m_PerKey) + { + if (m_PerKey.TryGetValue(key, out var entry)) + { + (semaphore, counter) = entry; + counter--; + if (counter == 0) + m_PerKey.Remove(key); + else + m_PerKey[key] = (semaphore, counter); + } + else + { + throw new InvalidOperationException("Key not found."); + } + } + + if (entered) + semaphore.Release(); + + if (counter == 0) + { + Debug.Assert(semaphore.CurrentCount == 1); + lock (m_Pool) + { + if (m_Pool.Count < m_PoolCapacity) + m_Pool.Enqueue(semaphore); + } + } + } + + #endregion + } +} +#endif