用 ConcurrentDictionary 替换 Dictionary 是否安全,应该进行哪些修改?
Is it safe to replace Dictionary with ConcurrentDictionary and what modifications should be made?
我想知道将 Dictionary
替换为 ConcurrentDictionary
是否安全,我应该对 ex 做哪些修改。 TryAdd、TryGetValue、移除锁等?
protected class SubscriptionManager
{
private readonly DeribitV2Client _client;
private readonly Dictionary<string, SubscriptionEntry> _subscriptionMap;
public SubscriptionManager(DeribitV2Client client)
{
_client = client;
_subscriptionMap = new Dictionary<string, SubscriptionEntry>();
}
public async Task<SubscriptionToken> Subscribe(ISubscriptionChannel channel, Action<Notification> callback)
{
if (callback == null)
{
return SubscriptionToken.Invalid;
}
var channelName = channel.ToChannelName();
TaskCompletionSource<SubscriptionToken> taskSource = null;
SubscriptionEntry entry;
lock (_subscriptionMap)
{
if (!_subscriptionMap.TryGetValue(channelName, out entry))
{
entry = new SubscriptionEntry();
if (!_subscriptionMap.TryAdd(channelName, entry))
{
_client.Logger?.Error("Subscribe: Could not add internal item for channel {Channel}", channelName);
return SubscriptionToken.Invalid;
}
taskSource = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = taskSource.Task;
}
// Entry already exists but is completely unsubscribed
if (entry.State == SubscriptionState.Unsubscribed)
{
taskSource = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = taskSource.Task;
}
// Already subscribed - Put the callback in there and let's go
if (entry.State == SubscriptionState.Subscribed)
{
_client.Logger?.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
// We are in the middle of unsubscribing from the channel
if (entry.State == SubscriptionState.Unsubscribing)
{
_client.Logger?.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channelName);
return SubscriptionToken.Invalid;
}
}
// Only one state left: Subscribing
// We are already subscribing
if (taskSource == null && entry.State == SubscriptionState.Subscribing)
{
_client.Logger?.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channelName);
var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask != SubscriptionToken.Invalid;
if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
{
_client.Logger?.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channelName);
return SubscriptionToken.Invalid;
}
_client.Logger?.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (taskSource == null)
{
_client.Logger?.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channelName);
return SubscriptionToken.Invalid;
}
try
{
var subscribeResponse = await _client.Send(
IsPrivateChannel(channelName) ? "private/subscribe" : "public/subscribe",
new { channels = new[] { channelName } },
new ListJsonConverter<string>()).ConfigureAwait(false);
var response = subscribeResponse.ResultData;
if (response.Count != 1 || response[0] != channelName)
{
_client.Logger?.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channelName, response);
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetResult(SubscriptionToken.Invalid);
}
else
{
_client.Logger?.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
entry.State = SubscriptionState.Subscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetResult(callbackEntry.Token);
}
}
catch (Exception e)
{
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetException(e);
}
return await taskSource.Task;
}
public async Task<bool> Unsubscribe(SubscriptionToken token)
{
string channelName;
SubscriptionEntry entry;
SubscriptionCallback callbackEntry;
TaskCompletionSource<bool> taskSource;
lock (_subscriptionMap)
{
(channelName, entry, callbackEntry) = GetEntryByToken(token);
if (string.IsNullOrEmpty(channelName) || entry == null || callbackEntry == null)
{
_client.Logger?.Warning("Unsubscribe: Could not find token {token}", token.Token);
return false;
}
switch (entry.State)
{
case SubscriptionState.Subscribing:
_client.Logger?.Debug("Unsubscribe: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channelName);
return false;
case SubscriptionState.Unsubscribed:
case SubscriptionState.Unsubscribing:
_client.Logger?.Debug("Unsubscribe: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channelName);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed:
if (entry.Callbacks.Count > 1)
{
_client.Logger?.Debug("Unsubscribe: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channelName);
entry.Callbacks.Remove(callbackEntry);
return true;
}
_client.Logger?.Debug("Unsubscribe: No callbacks left. Unsubscribe and remove callback (Channel: {Channel})", channelName);
break;
default:
return false;
}
// At this point it's only possible that the entry-State is Subscribed
// and the callback list is empty after removing this callback.
// Hence we unsubscribe at the server now
entry.State = SubscriptionState.Unsubscribing;
taskSource = new TaskCompletionSource<bool>();
entry.UnsubscribeTask = taskSource.Task;
}
try
{
var unsubscribeResponse = await _client.Send(
IsPrivateChannel(channelName) ? "private/unsubscribe" : "public/unsubscribe",
new { channels = new[] { channelName } },
new ListJsonConverter<string>()).ConfigureAwait(false);
var response = unsubscribeResponse.ResultData;
if (response.Count != 1 || response[0] != channelName)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
taskSource.SetResult(false);
}
else
{
entry.Callbacks.Remove(callbackEntry);
entry.State = SubscriptionState.Unsubscribed;
entry.UnsubscribeTask = null;
taskSource.SetResult(true);
}
}
catch (Exception e)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
taskSource.SetException(e);
}
return await taskSource.Task;
}
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
{
if (_subscriptionMap.TryGetValue(channel, out var entry))
{
foreach (var callbackEntry in entry.Callbacks)
{
yield return callbackEntry.Action;
}
}
}
public void Reset()
{
_subscriptionMap.Clear();
}
private static bool IsPrivateChannel(string channel)
{
return channel.StartsWith("user.");
}
private (string channelName, SubscriptionEntry entry, SubscriptionCallback callbackEntry) GetEntryByToken(SubscriptionToken token)
{
lock (_subscriptionMap)
{
foreach (var kvp in _subscriptionMap)
{
foreach (var callbackEntry in kvp.Value.Callbacks)
{
if (callbackEntry.Token == token)
{
return (kvp.Key, kvp.Value, callbackEntry);
}
}
}
}
return (null, null, null);
}
}
GitHub
我的尝试
public class SubscriptionToken
{
public static readonly SubscriptionToken Invalid = new(Guid.Empty);
public SubscriptionToken(Guid token)
{
Token = token;
}
public Guid Token { get; }
}
public class SubscriptionCallback
{
public SubscriptionCallback(SubscriptionToken token, Action<Notification> action)
{
Token = token;
Action = action;
}
public Action<Notification> Action { get; }
public SubscriptionToken Token { get; }
}
public class SubscriptionEntry
{
public List<SubscriptionCallback> Callbacks { get; } = new();
public Task<SubscriptionToken>? SubscribeTask { get; set; }
public Task<bool>? UnsubscribeTask { get; set; }
public SubscriptionState State { get; set; } = SubscriptionState.Unsubscribed;
}
public class SubscriptionManager
{
private readonly DeribitClient _client;
private readonly ConcurrentDictionary<string, SubscriptionEntry> _subscriptions = new();
public SubscriptionManager(DeribitClient client)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
}
public async Task<SubscriptionToken> SubscribeAsync(string channel, Action<Notification>? callback)
{
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
TaskCompletionSource<SubscriptionToken>? tcs = null;
if (_subscriptions.TryGetValue(channel, out var entry))
{
if (entry.State == SubscriptionState.Subscribed)
{
Log.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (entry.State == SubscriptionState.Unsubscribing)
{
Log.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channel);
return SubscriptionToken.Invalid;
}
if (entry.State == SubscriptionState.Unsubscribed)
{
Log.Debug("Subscribe: Entry already exists but is completely unsubscribed (Channel: {Channel})", channel);
tcs = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = tcs.Task;
}
}
else
{
tcs = new TaskCompletionSource<SubscriptionToken>();
entry = new SubscriptionEntry
{
State = SubscriptionState.Subscribing,
SubscribeTask = tcs.Task
};
if (!_subscriptions.TryAdd(channel, entry))
{
Log.Error("Subscribe: Could not add internal item for channel {Channel}", channel);
return SubscriptionToken.Invalid;
}
}
if (tcs == null && entry.State == SubscriptionState.Subscribing)
{
Log.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channel);
var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask.ConfigureAwait(false) != SubscriptionToken.Invalid;
if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
{
Log.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channel);
return SubscriptionToken.Invalid;
}
Log.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (tcs == null)
{
Log.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channel);
return SubscriptionToken.Invalid;
}
try
{
var method = IsPrivateChannel(channel) ? "private/subscribe" : "public/subscribe";
var @params = new Dictionary<string, string[]>
{
{ "channels", new[] { channel } }
};
var subscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);
if (subscribeResponse == null)
{
Log.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channel, subscribeResponse);
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetResult(SubscriptionToken.Invalid);
}
else
{
Log.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
entry.State = SubscriptionState.Subscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetResult(callbackEntry.Token);
}
}
catch (Exception ex)
{
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetException(ex);
}
return await tcs.Task.ConfigureAwait(false);
}
public async Task<bool> UnsubscribeAsync(SubscriptionToken token)
{
TaskCompletionSource<bool> tcs;
var (channel, entry, callbackEntry) = GetEntryByToken(token);
if (string.IsNullOrEmpty(channel) || entry == null || callbackEntry == null)
{
Log.Warning("UnsubscribeAsync: Could not find token {token}", token.Token);
return false;
}
switch (entry.State)
{
case SubscriptionState.Subscribing:
Log.Debug("UnsubscribeAsync: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channel);
return false;
case SubscriptionState.Unsubscribed:
case SubscriptionState.Unsubscribing:
Log.Debug("UnsubscribeAsync: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channel);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed when entry.Callbacks.Count > 1:
Log.Debug("UnsubscribeAsync: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channel);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed:
Log.Debug("UnsubscribeAsync: No callbacks left. UnsubscribeAsync and remove callback (Channel: {Channel})", channel);
tcs = new TaskCompletionSource<bool>();
entry.State = SubscriptionState.Unsubscribing;
entry.UnsubscribeTask = tcs.Task;
break;
default:
return false;
}
try
{
var method = IsPrivateChannel(channel) ? "private/unsubscribe" : "public/unsubscribe";
var @params = new Dictionary<string, string[]>
{
{ "channels", new[] { channel } }
};
var unsubscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);
if (unsubscribeResponse == null)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
tcs.SetResult(false);
}
else
{
entry.Callbacks.Remove(callbackEntry);
entry.State = SubscriptionState.Unsubscribed;
entry.UnsubscribeTask = null;
tcs.SetResult(true);
}
}
catch (Exception ex)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
tcs.SetException(ex);
}
return await tcs.Task.ConfigureAwait(false);
}
private (string? channelName, SubscriptionEntry? entry, SubscriptionCallback? callbackEntry) GetEntryByToken(SubscriptionToken token)
{
foreach (var (key, value) in _subscriptions)
{
foreach (var callbackEntry in value.Callbacks.Where(callbackEntry => callbackEntry.Token == token))
{
return (key, value, callbackEntry);
}
}
return (null, null, null);
}
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
{
if (_subscriptions.TryGetValue(channel, out var entry))
{
foreach (var callbackEntry in entry.Callbacks)
{
yield return callbackEntry.Action;
}
}
}
private static bool IsPrivateChannel(string channel)
{
return channel.StartsWith("user.");
}
}
A ConcurrentDictionary<K,V>
是 thread-safe,因为它保护其内部状态免受损坏。它不保护它包含的键和值,以防它们是可变对象。
在您的例子中,存储在字典 (SubscriptionEntry
) 中的值是可变对象。他们有 public setter,并且公开了类型 List<SubscriptionCallback>
的 public 属性。 List<T>
classis not thread-safe。所以,不,你不能用问题中显示的方式 ConcurrentDictionary
替换 Dictionary
(My attempt 部分)。以下是一些选项:
- 确保
SubscriptionCallback
类型是不可变的。如果要更改它,请创建一个新的 SubscriptionCallback
实例并丢弃之前的实例。
- 保持
SubscriptionCallback
可变,但使其成为 thread-safe。
- 只保留
Dictionary
,忘记切换到 ConcurrentDictionary
。 lock
的开销是微不足道的,前提是您在持有锁时没有做任何不重要的事情。如果您只执行基本操作 (Add
/TryGetValue
/Remove
),您不太可能注意到任何可测量的争用,除非您每秒执行 100,000 次或更多操作。
我想知道将 Dictionary
替换为 ConcurrentDictionary
是否安全,我应该对 ex 做哪些修改。 TryAdd、TryGetValue、移除锁等?
protected class SubscriptionManager
{
private readonly DeribitV2Client _client;
private readonly Dictionary<string, SubscriptionEntry> _subscriptionMap;
public SubscriptionManager(DeribitV2Client client)
{
_client = client;
_subscriptionMap = new Dictionary<string, SubscriptionEntry>();
}
public async Task<SubscriptionToken> Subscribe(ISubscriptionChannel channel, Action<Notification> callback)
{
if (callback == null)
{
return SubscriptionToken.Invalid;
}
var channelName = channel.ToChannelName();
TaskCompletionSource<SubscriptionToken> taskSource = null;
SubscriptionEntry entry;
lock (_subscriptionMap)
{
if (!_subscriptionMap.TryGetValue(channelName, out entry))
{
entry = new SubscriptionEntry();
if (!_subscriptionMap.TryAdd(channelName, entry))
{
_client.Logger?.Error("Subscribe: Could not add internal item for channel {Channel}", channelName);
return SubscriptionToken.Invalid;
}
taskSource = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = taskSource.Task;
}
// Entry already exists but is completely unsubscribed
if (entry.State == SubscriptionState.Unsubscribed)
{
taskSource = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = taskSource.Task;
}
// Already subscribed - Put the callback in there and let's go
if (entry.State == SubscriptionState.Subscribed)
{
_client.Logger?.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
// We are in the middle of unsubscribing from the channel
if (entry.State == SubscriptionState.Unsubscribing)
{
_client.Logger?.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channelName);
return SubscriptionToken.Invalid;
}
}
// Only one state left: Subscribing
// We are already subscribing
if (taskSource == null && entry.State == SubscriptionState.Subscribing)
{
_client.Logger?.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channelName);
var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask != SubscriptionToken.Invalid;
if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
{
_client.Logger?.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channelName);
return SubscriptionToken.Invalid;
}
_client.Logger?.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (taskSource == null)
{
_client.Logger?.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channelName);
return SubscriptionToken.Invalid;
}
try
{
var subscribeResponse = await _client.Send(
IsPrivateChannel(channelName) ? "private/subscribe" : "public/subscribe",
new { channels = new[] { channelName } },
new ListJsonConverter<string>()).ConfigureAwait(false);
var response = subscribeResponse.ResultData;
if (response.Count != 1 || response[0] != channelName)
{
_client.Logger?.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channelName, response);
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetResult(SubscriptionToken.Invalid);
}
else
{
_client.Logger?.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channelName);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
entry.State = SubscriptionState.Subscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetResult(callbackEntry.Token);
}
}
catch (Exception e)
{
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
taskSource.SetException(e);
}
return await taskSource.Task;
}
public async Task<bool> Unsubscribe(SubscriptionToken token)
{
string channelName;
SubscriptionEntry entry;
SubscriptionCallback callbackEntry;
TaskCompletionSource<bool> taskSource;
lock (_subscriptionMap)
{
(channelName, entry, callbackEntry) = GetEntryByToken(token);
if (string.IsNullOrEmpty(channelName) || entry == null || callbackEntry == null)
{
_client.Logger?.Warning("Unsubscribe: Could not find token {token}", token.Token);
return false;
}
switch (entry.State)
{
case SubscriptionState.Subscribing:
_client.Logger?.Debug("Unsubscribe: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channelName);
return false;
case SubscriptionState.Unsubscribed:
case SubscriptionState.Unsubscribing:
_client.Logger?.Debug("Unsubscribe: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channelName);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed:
if (entry.Callbacks.Count > 1)
{
_client.Logger?.Debug("Unsubscribe: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channelName);
entry.Callbacks.Remove(callbackEntry);
return true;
}
_client.Logger?.Debug("Unsubscribe: No callbacks left. Unsubscribe and remove callback (Channel: {Channel})", channelName);
break;
default:
return false;
}
// At this point it's only possible that the entry-State is Subscribed
// and the callback list is empty after removing this callback.
// Hence we unsubscribe at the server now
entry.State = SubscriptionState.Unsubscribing;
taskSource = new TaskCompletionSource<bool>();
entry.UnsubscribeTask = taskSource.Task;
}
try
{
var unsubscribeResponse = await _client.Send(
IsPrivateChannel(channelName) ? "private/unsubscribe" : "public/unsubscribe",
new { channels = new[] { channelName } },
new ListJsonConverter<string>()).ConfigureAwait(false);
var response = unsubscribeResponse.ResultData;
if (response.Count != 1 || response[0] != channelName)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
taskSource.SetResult(false);
}
else
{
entry.Callbacks.Remove(callbackEntry);
entry.State = SubscriptionState.Unsubscribed;
entry.UnsubscribeTask = null;
taskSource.SetResult(true);
}
}
catch (Exception e)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
taskSource.SetException(e);
}
return await taskSource.Task;
}
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
{
if (_subscriptionMap.TryGetValue(channel, out var entry))
{
foreach (var callbackEntry in entry.Callbacks)
{
yield return callbackEntry.Action;
}
}
}
public void Reset()
{
_subscriptionMap.Clear();
}
private static bool IsPrivateChannel(string channel)
{
return channel.StartsWith("user.");
}
private (string channelName, SubscriptionEntry entry, SubscriptionCallback callbackEntry) GetEntryByToken(SubscriptionToken token)
{
lock (_subscriptionMap)
{
foreach (var kvp in _subscriptionMap)
{
foreach (var callbackEntry in kvp.Value.Callbacks)
{
if (callbackEntry.Token == token)
{
return (kvp.Key, kvp.Value, callbackEntry);
}
}
}
}
return (null, null, null);
}
}
GitHub
我的尝试
public class SubscriptionToken
{
public static readonly SubscriptionToken Invalid = new(Guid.Empty);
public SubscriptionToken(Guid token)
{
Token = token;
}
public Guid Token { get; }
}
public class SubscriptionCallback
{
public SubscriptionCallback(SubscriptionToken token, Action<Notification> action)
{
Token = token;
Action = action;
}
public Action<Notification> Action { get; }
public SubscriptionToken Token { get; }
}
public class SubscriptionEntry
{
public List<SubscriptionCallback> Callbacks { get; } = new();
public Task<SubscriptionToken>? SubscribeTask { get; set; }
public Task<bool>? UnsubscribeTask { get; set; }
public SubscriptionState State { get; set; } = SubscriptionState.Unsubscribed;
}
public class SubscriptionManager
{
private readonly DeribitClient _client;
private readonly ConcurrentDictionary<string, SubscriptionEntry> _subscriptions = new();
public SubscriptionManager(DeribitClient client)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
}
public async Task<SubscriptionToken> SubscribeAsync(string channel, Action<Notification>? callback)
{
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
TaskCompletionSource<SubscriptionToken>? tcs = null;
if (_subscriptions.TryGetValue(channel, out var entry))
{
if (entry.State == SubscriptionState.Subscribed)
{
Log.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (entry.State == SubscriptionState.Unsubscribing)
{
Log.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channel);
return SubscriptionToken.Invalid;
}
if (entry.State == SubscriptionState.Unsubscribed)
{
Log.Debug("Subscribe: Entry already exists but is completely unsubscribed (Channel: {Channel})", channel);
tcs = new TaskCompletionSource<SubscriptionToken>();
entry.State = SubscriptionState.Subscribing;
entry.SubscribeTask = tcs.Task;
}
}
else
{
tcs = new TaskCompletionSource<SubscriptionToken>();
entry = new SubscriptionEntry
{
State = SubscriptionState.Subscribing,
SubscribeTask = tcs.Task
};
if (!_subscriptions.TryAdd(channel, entry))
{
Log.Error("Subscribe: Could not add internal item for channel {Channel}", channel);
return SubscriptionToken.Invalid;
}
}
if (tcs == null && entry.State == SubscriptionState.Subscribing)
{
Log.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channel);
var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask.ConfigureAwait(false) != SubscriptionToken.Invalid;
if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
{
Log.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channel);
return SubscriptionToken.Invalid;
}
Log.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
return callbackEntry.Token;
}
if (tcs == null)
{
Log.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channel);
return SubscriptionToken.Invalid;
}
try
{
var method = IsPrivateChannel(channel) ? "private/subscribe" : "public/subscribe";
var @params = new Dictionary<string, string[]>
{
{ "channels", new[] { channel } }
};
var subscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);
if (subscribeResponse == null)
{
Log.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channel, subscribeResponse);
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetResult(SubscriptionToken.Invalid);
}
else
{
Log.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channel);
var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
entry.Callbacks.Add(callbackEntry);
entry.State = SubscriptionState.Subscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetResult(callbackEntry.Token);
}
}
catch (Exception ex)
{
entry.State = SubscriptionState.Unsubscribed;
entry.SubscribeTask = null;
Debug.Assert(tcs != null);
tcs.SetException(ex);
}
return await tcs.Task.ConfigureAwait(false);
}
public async Task<bool> UnsubscribeAsync(SubscriptionToken token)
{
TaskCompletionSource<bool> tcs;
var (channel, entry, callbackEntry) = GetEntryByToken(token);
if (string.IsNullOrEmpty(channel) || entry == null || callbackEntry == null)
{
Log.Warning("UnsubscribeAsync: Could not find token {token}", token.Token);
return false;
}
switch (entry.State)
{
case SubscriptionState.Subscribing:
Log.Debug("UnsubscribeAsync: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channel);
return false;
case SubscriptionState.Unsubscribed:
case SubscriptionState.Unsubscribing:
Log.Debug("UnsubscribeAsync: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channel);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed when entry.Callbacks.Count > 1:
Log.Debug("UnsubscribeAsync: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channel);
entry.Callbacks.Remove(callbackEntry);
return true;
case SubscriptionState.Subscribed:
Log.Debug("UnsubscribeAsync: No callbacks left. UnsubscribeAsync and remove callback (Channel: {Channel})", channel);
tcs = new TaskCompletionSource<bool>();
entry.State = SubscriptionState.Unsubscribing;
entry.UnsubscribeTask = tcs.Task;
break;
default:
return false;
}
try
{
var method = IsPrivateChannel(channel) ? "private/unsubscribe" : "public/unsubscribe";
var @params = new Dictionary<string, string[]>
{
{ "channels", new[] { channel } }
};
var unsubscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);
if (unsubscribeResponse == null)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
tcs.SetResult(false);
}
else
{
entry.Callbacks.Remove(callbackEntry);
entry.State = SubscriptionState.Unsubscribed;
entry.UnsubscribeTask = null;
tcs.SetResult(true);
}
}
catch (Exception ex)
{
entry.State = SubscriptionState.Subscribed;
entry.UnsubscribeTask = null;
tcs.SetException(ex);
}
return await tcs.Task.ConfigureAwait(false);
}
private (string? channelName, SubscriptionEntry? entry, SubscriptionCallback? callbackEntry) GetEntryByToken(SubscriptionToken token)
{
foreach (var (key, value) in _subscriptions)
{
foreach (var callbackEntry in value.Callbacks.Where(callbackEntry => callbackEntry.Token == token))
{
return (key, value, callbackEntry);
}
}
return (null, null, null);
}
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
{
if (_subscriptions.TryGetValue(channel, out var entry))
{
foreach (var callbackEntry in entry.Callbacks)
{
yield return callbackEntry.Action;
}
}
}
private static bool IsPrivateChannel(string channel)
{
return channel.StartsWith("user.");
}
}
A ConcurrentDictionary<K,V>
是 thread-safe,因为它保护其内部状态免受损坏。它不保护它包含的键和值,以防它们是可变对象。
在您的例子中,存储在字典 (SubscriptionEntry
) 中的值是可变对象。他们有 public setter,并且公开了类型 List<SubscriptionCallback>
的 public 属性。 List<T>
classis not thread-safe。所以,不,你不能用问题中显示的方式 ConcurrentDictionary
替换 Dictionary
(My attempt 部分)。以下是一些选项:
- 确保
SubscriptionCallback
类型是不可变的。如果要更改它,请创建一个新的SubscriptionCallback
实例并丢弃之前的实例。 - 保持
SubscriptionCallback
可变,但使其成为 thread-safe。 - 只保留
Dictionary
,忘记切换到ConcurrentDictionary
。lock
的开销是微不足道的,前提是您在持有锁时没有做任何不重要的事情。如果您只执行基本操作 (Add
/TryGetValue
/Remove
),您不太可能注意到任何可测量的争用,除非您每秒执行 100,000 次或更多操作。