用 ConcurrentDictionary 替换 Dictionary 是否安全,应该进行哪些修改?

Is it safe to replace Dictionary with ConcurrentDictionary and what modifications should be made?

我想知道将 Dictionary 替换为 ConcurrentDictionary 是否安全,我应该对 ex 做哪些修改。 TryAdd、TryGetValue、移除锁等?

protected class SubscriptionManager
{
    private readonly DeribitV2Client _client;
    private readonly Dictionary<string, SubscriptionEntry> _subscriptionMap;

    public SubscriptionManager(DeribitV2Client client)
    {
        _client = client;
        _subscriptionMap = new Dictionary<string, SubscriptionEntry>();
    }

    public async Task<SubscriptionToken> Subscribe(ISubscriptionChannel channel, Action<Notification> callback)
    {
        if (callback == null)
        {
            return SubscriptionToken.Invalid;
        }

        var channelName = channel.ToChannelName();
        TaskCompletionSource<SubscriptionToken> taskSource = null;
        SubscriptionEntry entry;

        lock (_subscriptionMap)
        {
            if (!_subscriptionMap.TryGetValue(channelName, out entry))
            {
                entry = new SubscriptionEntry();
                if (!_subscriptionMap.TryAdd(channelName, entry))
                {
                    _client.Logger?.Error("Subscribe: Could not add internal item for channel {Channel}", channelName);
                    return SubscriptionToken.Invalid;
                }

                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Entry already exists but is completely unsubscribed
            if (entry.State == SubscriptionState.Unsubscribed)
            {
                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Already subscribed - Put the callback in there and let's go
            if (entry.State == SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channelName);
                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            // We are in the middle of unsubscribing from the channel
            if (entry.State == SubscriptionState.Unsubscribing)
            {
                _client.Logger?.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }
        }

        // Only one state left: Subscribing

        // We are already subscribing
        if (taskSource == null && entry.State == SubscriptionState.Subscribing)
        {
            _client.Logger?.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channelName);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }

            _client.Logger?.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channelName);
            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (taskSource == null)
        {
            _client.Logger?.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channelName);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var subscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/subscribe" : "public/subscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = subscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                _client.Logger?.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channelName, response);
                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                _client.Logger?.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channelName);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;
            Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public async Task<bool> Unsubscribe(SubscriptionToken token)
    {
        string channelName;
        SubscriptionEntry entry;
        SubscriptionCallback callbackEntry;
        TaskCompletionSource<bool> taskSource;

        lock (_subscriptionMap)
        {
            (channelName, entry, callbackEntry) = GetEntryByToken(token);

            if (string.IsNullOrEmpty(channelName) || entry == null || callbackEntry == null)
            {
                _client.Logger?.Warning("Unsubscribe: Could not find token {token}", token.Token);
                return false;
            }

            switch (entry.State)
            {
                case SubscriptionState.Subscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channelName);
                    return false;
                case SubscriptionState.Unsubscribed:
                case SubscriptionState.Unsubscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channelName);
                    entry.Callbacks.Remove(callbackEntry);
                    return true;
                case SubscriptionState.Subscribed:
                    if (entry.Callbacks.Count > 1)
                    {
                        _client.Logger?.Debug("Unsubscribe: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channelName);
                        entry.Callbacks.Remove(callbackEntry);
                        return true;
                    }

                    _client.Logger?.Debug("Unsubscribe: No callbacks left. Unsubscribe and remove callback (Channel: {Channel})", channelName);
                    break;
                default:
                    return false;
            }

            // At this point it's only possible that the entry-State is Subscribed
            // and the callback list is empty after removing this callback.
            // Hence we unsubscribe at the server now
            entry.State = SubscriptionState.Unsubscribing;
            taskSource = new TaskCompletionSource<bool>();
            entry.UnsubscribeTask = taskSource.Task;
        }

        try
        {
            var unsubscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/unsubscribe" : "public/unsubscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = unsubscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);
                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(true);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptionMap.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    public void Reset()
    {
        _subscriptionMap.Clear();
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }

    private (string channelName, SubscriptionEntry entry, SubscriptionCallback callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        lock (_subscriptionMap)
        {
            foreach (var kvp in _subscriptionMap)
            {
                foreach (var callbackEntry in kvp.Value.Callbacks)
                {
                    if (callbackEntry.Token == token)
                    {
                        return (kvp.Key, kvp.Value, callbackEntry);
                    }
                }
            }
        }

        return (null, null, null);
    }
}

GitHub

我的尝试

public class SubscriptionToken
{
    public static readonly SubscriptionToken Invalid = new(Guid.Empty);

    public SubscriptionToken(Guid token)
    {
        Token = token;
    }

    public Guid Token { get; }
}

public class SubscriptionCallback
{
    public SubscriptionCallback(SubscriptionToken token, Action<Notification> action)
    {
        Token = token;
        Action = action;
    }

    public Action<Notification> Action { get; }
    public SubscriptionToken Token { get; }
}

public class SubscriptionEntry
{
    public List<SubscriptionCallback> Callbacks { get; } = new();
    public Task<SubscriptionToken>? SubscribeTask { get; set; }
    public Task<bool>? UnsubscribeTask { get; set; }
    public SubscriptionState State { get; set; } = SubscriptionState.Unsubscribed;
}

public class SubscriptionManager
{
    private readonly DeribitClient _client;
    private readonly ConcurrentDictionary<string, SubscriptionEntry> _subscriptions = new();

    public SubscriptionManager(DeribitClient client)
    {
        _client = client ?? throw new ArgumentNullException(nameof(client));
    }

    public async Task<SubscriptionToken> SubscribeAsync(string channel, Action<Notification>? callback)
    {
        if (callback == null)
        {
            throw new ArgumentNullException(nameof(callback));
        }

        TaskCompletionSource<SubscriptionToken>? tcs = null;

        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            if (entry.State == SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            if (entry.State == SubscriptionState.Unsubscribing)
            {
                Log.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            if (entry.State == SubscriptionState.Unsubscribed)
            {
                Log.Debug("Subscribe: Entry already exists but is completely unsubscribed (Channel: {Channel})", channel);

                tcs = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = tcs.Task;
            }
        }
        else
        {
            tcs = new TaskCompletionSource<SubscriptionToken>();
            entry = new SubscriptionEntry
            {
                State = SubscriptionState.Subscribing,
                SubscribeTask = tcs.Task
            };

            if (!_subscriptions.TryAdd(channel, entry))
            {
                Log.Error("Subscribe: Could not add internal item for channel {Channel}", channel);
                return SubscriptionToken.Invalid;
            }
        }

        if (tcs == null && entry.State == SubscriptionState.Subscribing)
        {
            Log.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channel);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask.ConfigureAwait(false) != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            Log.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channel);

            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (tcs == null)
        {
            Log.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channel);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/subscribe" : "public/subscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var subscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (subscribeResponse == null)
            {
                Log.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channel, subscribeResponse);

                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                Log.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);

                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;

            Debug.Assert(tcs != null);

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    public async Task<bool> UnsubscribeAsync(SubscriptionToken token)
    {
        TaskCompletionSource<bool> tcs;

        var (channel, entry, callbackEntry) = GetEntryByToken(token);

        if (string.IsNullOrEmpty(channel) || entry == null || callbackEntry == null)
        {
            Log.Warning("UnsubscribeAsync: Could not find token {token}", token.Token);
            return false;
        }

        switch (entry.State)
        {
            case SubscriptionState.Subscribing:
                Log.Debug("UnsubscribeAsync: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channel);
                return false;
            case SubscriptionState.Unsubscribed:
            case SubscriptionState.Unsubscribing:
                Log.Debug("UnsubscribeAsync: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed when entry.Callbacks.Count > 1:
                Log.Debug("UnsubscribeAsync: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed:
                Log.Debug("UnsubscribeAsync: No callbacks left. UnsubscribeAsync and remove callback (Channel: {Channel})", channel);
                tcs = new TaskCompletionSource<bool>();
                entry.State = SubscriptionState.Unsubscribing;
                entry.UnsubscribeTask = tcs.Task;
                break;
            default:
                return false;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/unsubscribe" : "public/unsubscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var unsubscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (unsubscribeResponse == null)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);

                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(true);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    private (string? channelName, SubscriptionEntry? entry, SubscriptionCallback? callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        foreach (var (key, value) in _subscriptions)
        {
            foreach (var callbackEntry in value.Callbacks.Where(callbackEntry => callbackEntry.Token == token))
            {
                return (key, value, callbackEntry);
            }
        }

        return (null, null, null);
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }
}

A ConcurrentDictionary<K,V> 是 thread-safe,因为它保护其内部状态免受损坏。它不保护它包含的键和值,以防它们是可变对象。

在您的例子中,存储在字典 (SubscriptionEntry) 中的值是可变对象。他们有 public setter,并且公开了类型 List<SubscriptionCallback> 的 public 属性。 List<T>classis not thread-safe。所以,不,你不能用问题中显示的方式 ConcurrentDictionary 替换 DictionaryMy attempt 部分)。以下是一些选项:

  1. 确保 SubscriptionCallback 类型是不可变的。如果要更改它,请创建一个新的 SubscriptionCallback 实例并丢弃之前的实例。
  2. 保持 SubscriptionCallback 可变,但使其成为 thread-safe。
  3. 只保留 Dictionary,忘记切换到 ConcurrentDictionarylock 的开销是微不足道的,前提是您在持有锁时没有做任何不重要的事情。如果您只执行基本操作 (Add/TryGetValue/Remove),您不太可能注意到任何可测量的争用,除非您每秒执行 100,000 次或更多操作。