概要
本文在C# Linq源码分析之Take(四)的基础上继续从源码角度分析Take的优化方法,主要分析Where.Select.Take的使用案例。
Where.Select.Take的案例分析
该场景模拟我们显示中将EF中与数据库关联的对象进行过滤,然后转换成Web前端需要的对象,并分页的情况。
studentList.Where(x => x.MathResult >= 90).Select(x => new {x.Name,x.MathResult}).Take(3).ToList().ForEach(x=>Console.WriteLine(x.Name + x.MathResult));
找到数学90分以上的学生,获取学生的姓名和数学成绩,每次只取前三个学生。并将学生信息打印。Student类的代码请见附录。
源码流程分析
第一步进入Where方法,返回WhereListIterator对象;
第二步进入Select方法,将Where和Select两个操作合并,返回WhereSelectListIterator对象;
第三步进入Take方法,调用takeIterator方法;由于人WhereSelectListIterator并没有实现IPartition接口和IList接口,所以无法再进行操作合并,只能返回EnumerablePartition对象。
private static IEnumerable<TSource> takeIterator<TSource>(IEnumerable<TSource> source, int count){Debug.Assert(count > 0);returnsource is IPartition<TSource> partition ? partition.Take(count) :source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :new EnumerablePartition<TSource>(source, 0, count - 1);}
第四步进入ToList方法
public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source){if (source == null){ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);}return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);}
此时的source是EnumerablePartition对象,它实现了IPartition接口,而IPartition接口继承了IIListProvider接口,所以可以调用自己的ToList方法;
public List<TSource> ToList()
{var list = new List<TSource>();using (IEnumerator<TSource> en = _source.GetEnumerator()){if (SkipBeforeFirst(en) && en.MoveNext()){int remaining = Limit - 1; // Max number of items left, not counting the current element.int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.do{remaining--;list.Add(en.Current);}while (remaining >= comparand && en.MoveNext());}}return list;
}
- 定义迭代器en,此时的_source是WhereSelectListIterator对象;
- 该ToList方法同样支持Skip,所以要判断迭代器的起始位置是不是从第一个开始;
- 每次迭代,首先从WhereSelectListIterator迭代器中返回一个符合过滤条件,并完成Selector操作的元素,存入list,直到list中包含三个元素,返回执行结果。
虽然WhereSelectListIterator没有实现IPartition接口,不能实现一次迭代,完成全部操作,但是现有的流程性能并不差,因为WhereSelectListIterator迭代器本身已经合并了过滤和投影操作,而且并不需要遍历所有元素,只要找到3个符合条件的元素即可。
我认为如果代码需要用到Take方法,尽量把它放到Linq的最后。这样做的好处是前面的Linq操作并不需要遍历全部的序列元素,只要得到Take方法中需要的元素个数即可。
本文中涉及的源码请见附录,关于WhereSelectListIterator的合并优化操作,更多详细内容,请参考C# LINQ源码分析之Select
附录
Student类
public class Student {public string Id { get; set; }public string Name { get; set; }public string Classroom { get; set; }public int MathResult { get; set; }
}
IIListProvider接口
internal interface IIListProvider<TElement> : IEnumerable<TElement>
{TElement[] ToArray();List<TElement> ToList();int GetCount(bool onlyIfCheap);
}
IPartition接口
internal interface IPartition<TElement> : IIListProvider<TElement>
{IPartition<TElement> Skip(int count);IPartition<TElement> Take(int count);TElement? TryGetElementAt(int index, out bool found);TElement? TryGetFirst(out bool found);TElement? TryGetLast(out bool found);
}
EnumerablePartition类
private sealed class EnumerablePartition<TSource> : Iterator<TSource>, IPartition<TSource>{private readonly IEnumerable<TSource> _source;private readonly int _minIndexInclusive;private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.// If this is -1, it's impossible to set a limit on the count.private IEnumerator<TSource>? _enumerator;internal EnumerablePartition(IEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive){Debug.Assert(source != null);Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");Debug.Assert(minIndexInclusive >= 0);Debug.Assert(maxIndexInclusive >= -1);// Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.// We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.// But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);_source = source;_minIndexInclusive = minIndexInclusive;_maxIndexInclusive = maxIndexInclusive;}// If this is true (e.g. at least one Take call was made), then we have an upper bound// on how many elements we can have.private bool HasLimit => _maxIndexInclusive != -1;private int Limit => unchecked((_maxIndexInclusive + 1) - _minIndexInclusive); // This is that upper bound.public override Iterator<TSource> Clone() =>new EnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);public override void Dispose(){if (_enumerator != null){_enumerator.Dispose();_enumerator = null;}base.Dispose();}public int GetCount(bool onlyIfCheap){if (onlyIfCheap){return -1;}if (!HasLimit){// If HasLimit is false, we contain everything past _minIndexInclusive.// Therefore, we have to iterate the whole enumerable.//return Math.Max(_source.Count()- _minIndexInclusive, 0);return 0;}using (IEnumerator<TSource> en = _source.GetEnumerator()){// We only want to iterate up to _maxIndexInclusive + 1.// Past that, we know the enumerable will be able to fit this partition,// so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.// Note that it is possible for _maxIndexInclusive to be int.MaxValue here,// so + 1 may result in signed integer overflow. We need to handle this.// At the same time, however, we are guaranteed that our max count can fit// in an int because if that is true, then _minIndexInclusive must > 0.uint count = SkipAndCount((uint)_maxIndexInclusive + 1, en);Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");return Math.Max((int)count - _minIndexInclusive, 0);}}public override bool MoveNext(){// Cases where GetEnumerator has not been called or Dispose has already// been called need to be handled explicitly, due to the default: clause.int taken = _state - 3;if (taken < -2){Dispose();return false;}switch (_state){case 1:_enumerator = _source.GetEnumerator();_state = 2;goto case 2;case 2:Debug.Assert(_enumerator != null);if (!SkipBeforeFirst(_enumerator)){// Reached the end before we finished skipping.break;}_state = 3;goto default;default:Debug.Assert(_enumerator != null);if ((!HasLimit || taken < Limit) && _enumerator.MoveNext()){if (HasLimit){// If we are taking an unknown number of elements, it's important not to increment _state.// _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though// we haven't finished enumerating._state++;}_current = _enumerator.Current;return true;}break;}Dispose();return false;}public override IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector) =>new SelectIPartitionIterator<TSource, TResult>(this, selector);public IPartition<TSource> Skip(int count){int minIndex = unchecked(_minIndexInclusive + count);if (!HasLimit){if (minIndex < 0){// If we don't know our max count and minIndex can no longer fit in a positive int,// then we will need to wrap ourselves in another iterator.// This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).return new EnumerablePartition<TSource>(this, count, -1);}}else if ((uint)minIndex > (uint)_maxIndexInclusive){// If minIndex overflows and we have an upper bound, we will go down this branch.// We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.// This branch should not be taken if we don't have a bound.return EmptyPartition<TSource>.Instance;}Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");return new EnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);}public IPartition<TSource> Take(int count){int maxIndex = unchecked(_minIndexInclusive + count - 1);if (!HasLimit){if (maxIndex < 0){// If we don't know our max count and maxIndex can no longer fit in a positive int,// then we will need to wrap ourselves in another iterator.// Note that although maxIndex may be too large, the difference between it and// _minIndexInclusive (which is count - 1) must fit in an int.// Example: e.Skip(50).Take(int.MaxValue).return new EnumerablePartition<TSource>(this, 0, count - 1);}}else if (unchecked((uint)maxIndex >= (uint)_maxIndexInclusive)){// If we don't know our max count, we can't go down this branch.// It's always possible for us to contain more than count items, as the rest// of the enumerable past _minIndexInclusive can be arbitrarily long.return this;}Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");return new EnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);}public TSource? TryGetElementAt(int index, out bool found){// If the index is negative or >= our max count, return early.if (index >= 0 && (!HasLimit || index < Limit)){using (IEnumerator<TSource> en = _source.GetEnumerator()){Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext()){found = true;return en.Current;}}}found = false;return default;}public TSource? TryGetFirst(out bool found){using (IEnumerator<TSource> en = _source.GetEnumerator()){if (SkipBeforeFirst(en) && en.MoveNext()){found = true;return en.Current;}}found = false;return default;}public TSource? TryGetLast(out bool found){using (IEnumerator<TSource> en = _source.GetEnumerator()){if (SkipBeforeFirst(en) && en.MoveNext()){int remaining = Limit - 1; // Max number of items left, not counting the current element.int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.TSource result;do{remaining--;result = en.Current;}while (remaining >= comparand && en.MoveNext());found = true;return result;}}found = false;return default;}public List<TSource> ToList(){var list = new List<TSource>();using (IEnumerator<TSource> en = _source.GetEnumerator()){if (SkipBeforeFirst(en) && en.MoveNext()){int remaining = Limit - 1; // Max number of items left, not counting the current element.int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.do{remaining--;list.Add(en.Current);}while (remaining >= comparand && en.MoveNext());}}return list;}private bool SkipBeforeFirst(IEnumerator<TSource> en) => SkipBefore(_minIndexInclusive, en);private static bool SkipBefore(int index, IEnumerator<TSource> en) => SkipAndCount(index, en) == index;private static int SkipAndCount(int index, IEnumerator<TSource> en){Debug.Assert(index >= 0);return (int)SkipAndCount((uint)index, en);}private static uint SkipAndCount(uint index, IEnumerator<TSource> en){Debug.Assert(en != null);for (uint i = 0; i < index; i++){if (!en.MoveNext()){return i;}}return index;}public TSource[] ToArray(){throw new NotImplementedException();}}