C# Dictionary源码学习

最近打算阅读和学习一下c#的源码,主要涉及到常用的几个集合类,了解其中包含的一些算法原理和一些其它的细节。第一篇是关于平常经常会用到的Dictionary<,>。源码的版本是.NET 4.8。

Dictionary

Dictionary<,>是最常用到的集合类之一,其定义实现的接口如下:

1
2
3
4
5
6
7
8
9
10
[DebuggerTypeProxy(typeof(Mscorlib_DictionaryDebugView<,>))]
[DebuggerDisplay("Count = {Count}")]
[Serializable]
[System.Runtime.InteropServices.ComVisible(false)]
public class Dictionary<TKey,TValue>: IDictionary<TKey,TValue>, IDictionary, IReadOnlyDictionary<TKey, TValue>, ISerializable, IDeserializationCallback
{

// ...

}

其中除了最后两个和序列化相关的之外,其它的接口都是关于字典的,包括泛型与非泛型版本,字典需要实现增删改查的方法,并且实现IEnumerable要求可遍历。在开始Dictionary<,>之前,先看一眼其中的几个内嵌类(结构):

内嵌类和结构

Entry

1
2
3
4
5
6
private struct Entry {
public int hashCode; // Lower 31 bits of hash code, -1 if unused
public int next; // Index of next entry, -1 if last
public TKey key; // Key of entry
public TValue value; // Value of entry
}

Entry是字典内部数据存储结构,也是字典的最核心的功能所在,其中存储一对keyvaluehashcode的后31位用于存储由key计算出来的原始哈希值,如果hashCode是-1,表示是空的Entry(该Entry)未被使用,next指向数组内下一个Entry的下标,后边会详细说到,在字典中使用一个Entry的数组来模拟一个链表,在实现链表的功能的同时减少了碎片化的内存分配。

Enumerator

Enumerator是用于遍历字典的迭代器,它实现了泛型和非泛型的接口方法。

1
2
3
4
5
6
[Serializable]
public struct Enumerator: IEnumerator<KeyValuePair<TKey,TValue>>,
IDictionaryEnumerator
{
// ...
}

KeyCollection

只读的字典中所有的key的集合,任何写操作都会抛出异常,它包含了一个内嵌类KeyCollection.Enumerator,用于实现遍历的相关接口。

ValueCollection

KeyCollection类似,对应字典里所有的value。

构造函数

包含多个构造函数,可以提供两个参数(重载版本使用默认值),字典的容量capacity会用于初始化桶buckets和数组entries,这两者会影响到key的原始哈希值的取模计算,进而影响到出现哈希值碰撞的机会,以及当插入更多的数据之后需要扩容的次数。comparer用于计算key的原始哈希值以及判断两个key是否相等。

1
2
3
4
5
6
7
8
public Dictionary(int capacity, IEqualityComparer<TKey> comparer) {
if (capacity < 0) ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity);
if (capacity > 0) Initialize(capacity);
this.comparer = comparer ?? EqualityComparer<TKey>.Default;

// ...

}

增删改查

字典中索引数据使用的是哈希算法,哈希函数使用取模计算,首先是计算key的原始哈希值,得到一个int之后,对bucket的长度取模。当出现哈希碰撞时,使用链表法来解决。

在一个字典中有以下两个字段:

1
2
private int[] buckets;
private Entry[] entries;

其中entries用于存储数据,buckets用于存储数据在entries数组中的下标(相当于链表的首地址)。在源码中可以反复多次看到哈希函数如下:

1
2
int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF;
int targetBucket = hashCode % buckets.Length;

当得到哈希值之后,到entries中对应下标位置去查找,以该Entry为链表的首节点,直到找到该key或者没有更多的Entry,在源码中也可以反复多次看到下边的逻辑:

1
2
for (int i = buckets[hashCode % buckets.Length]; i >= 0; i = entries[i].next) {
if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) return i;

以上是字典哈希算法的最基本原理,后边会展开讨论增删改查的过程。注意在字典中,每当有数据发生变化时(增、删、改),会使字典的version加1。这种机制保证在对字典遍历的过程中其中的数据是无法修改的。

Add

增加新的键值对时,内部调用的是Insert方法,并用一个参数add表示当前是在插入新的值还是尝试修改值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
private void Insert(TKey key, TValue value, bool add) {

if( key == null ) {
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (buckets == null) Initialize(0);
int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF;
int targetBucket = hashCode % buckets.Length;

// ...

for (int i = buckets[targetBucket]; i >= 0; i = entries[i].next) {
if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) {
if (add) {
ThrowHelper.ThrowArgumentException(ExceptionResource.Argument_AddingDuplicate);
}
entries[i].value = value;
version++;
return;
}

// ...

}
int index;
if (freeCount > 0) {
index = freeList;
freeList = entries[index].next;
freeCount--;
}
else {
if (count == entries.Length)
{
Resize();
targetBucket = hashCode % buckets.Length;
}
index = count;
count++;
}

entries[index].hashCode = hashCode;
entries[index].next = buckets[targetBucket];
entries[index].key = key;
entries[index].value = value;
buckets[targetBucket] = index;
version++;

// ...

}
  1. 首先是查找是否已经存在了对应的key,如果存在,对于add的情况,是要抛出异常的,键已存在;
  2. 如果没有找到,就会到后续的步骤,需要插入新的数据,这时候要确定这组新数据存储的位置即index
  3. 如果有freeCount,直接将index指定到freeList的头部,然后将freeList指向原freeListnext的位置;
  4. 如果没有freeCount了,判断是否还有空的Entry,如果count == entries.Length即所有的Entry都被使用了,就需要扩容Resize(后边待会再看扩容的逻辑),如果有空的Entry,就使用下标为countEntry,并将count加1;
  5. 找到了index之后,将keyvaluehashCode保存在对应的Entry中,将next指向原来的链表头,然后把链表头的值改为index

Resize的过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private void Resize() {
Resize(HashHelpers.ExpandPrime(count), false);
}

private void Resize(int newSize, bool forceNewHashCodes) {
Contract.Assert(newSize >= entries.Length);
int[] newBuckets = new int[newSize];
for (int i = 0; i < newBuckets.Length; i++) newBuckets[i] = -1;
Entry[] newEntries = new Entry[newSize];
Array.Copy(entries, 0, newEntries, 0, count);
if(forceNewHashCodes) {
for (int i = 0; i < count; i++) {
if(newEntries[i].hashCode != -1) {
newEntries[i].hashCode = (comparer.GetHashCode(newEntries[i].key) & 0x7FFFFFFF);
}
}
}
for (int i = 0; i < count; i++) {
if (newEntries[i].hashCode >= 0) {
int bucket = newEntries[i].hashCode % newSize;
newEntries[i].next = newBuckets[bucket];
newBuckets[bucket] = i;
}
}
buckets = newBuckets;
entries = newEntries;
}

首先取一个更大的质数作为buckets的长度,构造对应尺寸的数组bucketsentries,然后将原来的Entry复制到新的数组中,然后根据新的newSize重新构建链表的连接关系。

Remove

删除数据时,会清除所在Entry中的keyvalue,同时把这个Entry放到freeList的头部。

在一个字典中保存的有:

1
2
private int freeList;
private int freeCount;

使用一个链表来记录被清空的Entry,以便在复用时可以快速找到。删除数据的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public bool Remove(TKey key) {
if(key == null) {
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (buckets != null) {
int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF;
int bucket = hashCode % buckets.Length;
int last = -1;
for (int i = buckets[bucket]; i >= 0; last = i, i = entries[i].next) {
if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) {
if (last < 0) {
buckets[bucket] = entries[i].next;
}
else {
entries[last].next = entries[i].next;
}
entries[i].hashCode = -1;
entries[i].next = freeList;
entries[i].key = default(TKey);
entries[i].value = default(TValue);
freeList = i;
freeCount++;
version++;
return true;
}
}
}
return false;
}

其中last表示要删除的数据所在的Entry的前置节点,如果没有前置节点,就把buckets中保存的首节点的位置改为被删除节点的next

set

1
2
3
4
5
6
7
8
9
10
11
public TValue this[TKey key] {
get {
int i = FindEntry(key);
if (i >= 0) return entries[i].value;
ThrowHelper.ThrowKeyNotFoundException();
return default(TValue);
}
set {
Insert(key, value, false);
}
}

修改数据的时候,同样调用的是Insert,如果发现已存在key,直接替换value,其它的逻辑同Add

get

首先是常用的ContainsKey方法:

1
2
3
public bool ContainsKey(TKey key) {
return FindEntry(key) >= 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
private int FindEntry(TKey key) {
if( key == null) {
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (buckets != null) {
int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF;
for (int i = buckets[hashCode % buckets.Length]; i >= 0; i = entries[i].next) {
if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) return i;
}
}
return -1;
}

由此可见,在尝试取一个key对应的value时,先使用ContainsKey再索引,会发生两次查询过程,更好的做法是使用TryGetValue

1
2
3
4
5
6
7
8
9
public bool TryGetValue(TKey key, out TValue value) {
int i = FindEntry(key);
if (i >= 0) {
value = entries[i].value;
return true;
}
value = default(TValue);
return false;
}

除了泛型版本的增删改查之外,Dictionary<,>还显式实现非泛型版本的IDictionary的接口。

遍历

Dictionary<,>实现泛型和非泛型版本的IEnumerable,这里显式实现非泛型版的:

1
2
3
4
5
6
7
public Enumerator GetEnumerator() {
return new Enumerator(this, Enumerator.KeyValuePair);
}

IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() {
return new Enumerator(this, Enumerator.KeyValuePair);
}

返回的就是内嵌类Enumerator的新实例,Enumerator的迭代对象类型可以为DictEntryKeyValuePairEnumerator实现IEnumerator的各接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public bool MoveNext() {
if (version != dictionary.version) {
ThrowHelper.ThrowInvalidOperationException(ExceptionResource.InvalidOperation_EnumFailedVersion);
}

// Use unsigned comparison since we set index to dictionary.count+1 when the enumeration ends.
// dictionary.count+1 could be negative if dictionary.count is Int32.MaxValue
while ((uint)index < (uint)dictionary.count) {
if (dictionary.entries[index].hashCode >= 0) {
current = new KeyValuePair<TKey, TValue>(dictionary.entries[index].key, dictionary.entries[index].value);
index++;
return true;
}
index++;
}

index = dictionary.count + 1;
current = new KeyValuePair<TKey, TValue>();
return false;
}

实际上是遍历字典的entries数组,判断并返回其中的有效值的过程。当version不一致时,表示数据发生了变化,遍历失败会抛出异常。

1
2
3
public KeyValuePair<TKey,TValue> Current {
get { return current; }
}
1
2
3
4
5
6
7
8
void IEnumerator.Reset() {
if (version != dictionary.version) {
ThrowHelper.ThrowInvalidOperationException(ExceptionResource.InvalidOperation_EnumFailedVersion);
}

index = 0;
current = new KeyValuePair<TKey, TValue>();
}

序列化

C#的字典类支持序列化,它实现了ISerializableIDeserializationCallback接口。ISerializable允许自定义序列化过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
[System.Security.SecurityCritical]  // auto-generated_required
public virtual void GetObjectData(SerializationInfo info, StreamingContext context) {
if (info==null) {
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.info);
}
info.AddValue(VersionName, version);

// ...

info.AddValue(HashSizeName, buckets == null ? 0 : buckets.Length); //This is the length of the bucket array.
if( buckets != null) {
KeyValuePair<TKey, TValue>[] array = new KeyValuePair<TKey, TValue>[Count];
CopyTo(array, 0);
info.AddValue(KeyValuePairsName, array, typeof(KeyValuePair<TKey, TValue>[]));
}
}

将自身的各个字段和数据塞入SerializationInfo

字典的反序列化,会涉及到依赖类型的反序列化。对象反序列化的顺序是无法保证的,为避免在反序列化一个对象时它所依赖的对象还未完成序列化,可以使用IDeserializationCallback接口实现OnDeserialization方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
public virtual void OnDeserialization(Object sender) {
SerializationInfo siInfo;
HashHelpers.SerializationInfoTable.TryGetValue(this, out siInfo);

if (siInfo==null) {
// It might be necessary to call OnDeserialization from a container if the container object also implements
// OnDeserialization. However, remoting will call OnDeserialization again.
// We can return immediately if this function is called twice.
// Note we set remove the serialization info from the table at the end of this method.
return;
}

int realVersion = siInfo.GetInt32(VersionName);
int hashsize = siInfo.GetInt32(HashSizeName);
comparer = (IEqualityComparer<TKey>)siInfo.GetValue(ComparerName, typeof(IEqualityComparer<TKey>));

if( hashsize != 0) {
buckets = new int[hashsize];
for (int i = 0; i < buckets.Length; i++) buckets[i] = -1;
entries = new Entry[hashsize];
freeList = -1;

KeyValuePair<TKey, TValue>[] array = (KeyValuePair<TKey, TValue>[])
siInfo.GetValue(KeyValuePairsName, typeof(KeyValuePair<TKey, TValue>[]));

if (array==null) {
ThrowHelper.ThrowSerializationException(ExceptionResource.Serialization_MissingKeys);
}

for (int i=0; i<array.Length; i++) {
if ( array[i].Key == null) {
ThrowHelper.ThrowSerializationException(ExceptionResource.Serialization_NullKey);
}
Insert(array[i].Key, array[i].Value, true);
}
}
else {
buckets = null;
}

version = realVersion;
HashHelpers.SerializationInfoTable.Remove(this);
}

与序列化的过程相反,从SerializationInfo中取出各个字段数据,赋值给自己。

REFERENCE

https://referencesource.microsoft.com/#mscorlib/system/collections/generic/dictionary.cs