JDK错误用法—TimSort

时间:2022-07-23
本文章向大家介绍JDK错误用法—TimSort,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

TimSort介绍

Tim Peters在2002年设计了该算法并在Python中使用(TimSort 是Python中list.sort的默认实现),后被引入java。TimSort算法是一种归并排序和插入排序的混合排序算法,设计初衷是为了在真实世界中的各种数据中可以有较好的性能。基本工作过程是:

  1. 扫描数组,确定其中的单调上升段和严格单调下降段,将严格下降段反转;
  2. 定义最小基本片段长度,短于此的单调片段通过插入排序集中为长于此的段;
  3. 反复归并一些相邻片段,过程中避免归并长度相差很大的片段,直至整个排序完成,所用分段选择策略可以保证O(nlogn)时间复杂性。

背景

年前写了一篇《JDK错误用法——throwable.getCause》,后来家里出了些事情,就一直耽搁了。那篇文章说明了在Throwable类中cause字段所代表的具体意义,而这篇文章所阐述的就是曾经掉在那个坑里、未被抛出来异常。

错误代码示例

private void compositeRank(List<ADXEntity> list) {
   if (CollectionUtils.isEmpty(list)) {
       return;
    }
    list.sort((o1, o2) -> o1.getQ() > o2.getQ() ? -1 : 1);
}

乍一看上去,貌似没有问题,逻辑上也是对的,但是却抛出了异常。比较器用错了吗?我们仔细看下:

这个方法属于函数式接口Comparator中一个纯函数。其中我省略的长篇注释大概意思是:比较器模拟了数学中的一种运算方式signum,符号表示为sng(expression):

expression = compare(x,y);

实现该比较器需要遵循4条规则:

  1. 对于所有x,y必须保证sgn(compare(x, y)) == -sgn(compare(y, x)),暗指如果compare(x, y)抛出异常,compare(y, x)也会抛出异常;
  2. 可传递性:如果((compare(x, y)>0) && (compare(y, z)>0)) 则compare(x, z)>0;
  3. 如果有compare(x, y)==0 则对于所有的z有sgn(compare(x, z))==sgn(compare(y, z));
  4. (compare(x, y)==0) == (x.equals(y)),一般来说是这样的,但并不严格要求,一般来说,任何违反这个条件的比较器都需要加以说明。

这样看,我的代码中的三目运算,确实违背了第一条。当o1.getQ()==o2.getQ()时,比较器返回1,反过来,还是返回1。三目运算忽略了0,造成了使用不规范。

但为啥会报错呢?继续看代码。
@Override
@SuppressWarnings("unchecked")
public void sort(Comparator<? super E> c) {
    final int expectedModCount = modCount;
    Arrays.sort((E[]) elementData, 0, size, c);
    if (modCount != expectedModCount) {
     throw new ConcurrentModificationException();
    }
    modCount++;
    }
}

这段代码,是list.sort()的具体实现,除去一些Fast-fail的处理外,主要是第四行,其实看到这我还有一些疑问?为什么这个方法不设置一些Fast-fail的校验呢?以下为进一步实现:

public static <T> void sort(T[] a, int fromIndex, int toIndex,Comparator<? super T> c) {
    if (c == null) {
        sort(a, fromIndex, toIndex);
    } else {
        rangeCheck(a.length, fromIndex, toIndex);
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a, fromIndex, toIndex, c);
        else
            TimSort.sort(a, fromIndex, toIndex, c, null, 0, 0);
    }
}

整体实现就是两种方法,如果设置了变量:java.util.Arrays.useLegacyMergeSort为true,则使用旧的排序方法legacyMergeSort,否则使用TimSort排序。我们主要看TimSort。

static <T> void sort(T[] a,int lo,int hi, Comparator<? super T> c,T[] work, int workBase, int workLen) {
    assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;
    
    int nRemaining  = hi - lo;
    if (nRemaining < 2)
        return;  // 长度小于2,必然是有序的

// 如果数组长度小于MIN_MERGE(默认为32),则使用没有merge的二分排序法。
if (nRemaining < MIN_MERGE) {
    int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
    binarySort(a, lo, hi, lo + initRunLen, c);
    return;
}
TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
//获取最短的run
int minRun = minRunLength(nRemaining);
do {
    // 找出连续升序的最大个数
    int runLen = countRunAndMakeAscending(a, lo, hi, c);

    // 如果太短,就通过二分插入法扩展
    if (runLen < minRun) {
        int force = nRemaining <= minRun ? nRemaining : minRun;
        binarySort(a, lo, lo + force, lo + runLen, c);
        runLen = force;
    }
    ts.pushRun(lo, runLen);// 将run入栈
    ts.mergeCollapse();// 合并

    // 继续寻找下一个run
    lo += runLen;
    nRemaining -= runLen;
} while (nRemaining != 0);

// 合并剩下的run
assert lo == hi;
ts.mergeForceCollapse();
assert ts.stackSize == 1;
}

当栈内存在3个或者3个以上的run的时候,并且满足以下

  • runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
  • runLen[i - 2] > runLen[i - 1]

两个条件其一的时候,合并2个run,合并的主要逻辑是:

  1. 合并必须是相邻的2个run;
  2. 合并的2个run中,第一个run的长度小于第二个的长度;
private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
            if (runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        } else if (runLen[n] <= runLen[n + 1]) {
            mergeAt(n);
        } else {
            break; // Invariant is established
        }
    }
}

再看下mergeAt逻辑

private void mergeAt(int i) {、
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;

        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * 记录归并这些run的长度;如果i是最后一个run的三
         * 个run中第一个,就滑过最后一个run(最后一个
         *run不会被归并),而i+1在任何情况都会消失。
         * 其实这是与mergeCollapse的逻辑相关联的,传入          *的参数i,已经明确了它必须是被归并的,而与它相关          *联的就只有i+1
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) {//如果传入的是3个数中的第一个
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        stackSize--;

        /*
         * 查找run2的第一个元素在run1中的位置。可以忽略
         *run1中的先前元素(因为它们已经就绪)。
         */
        int k = gallopRight(a[base2], a, base1, len1, 0, c);
        assert k >= 0;
        base1 += k;
        len1 -= k;
        if (len1 == 0)
            return;

        /*
         * 查找run1的最后一个元素在run2中的位置。
         * run2后续的元素可以被忽略了,因为他们
         * 本来就是有序的。
         */
        len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
        assert len2 >= 0;
        if (len2 == 0)
            return;

        // 归并排序,使用min(len1,len2)作为临时数组
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }

mergeLo和mergeHi类似,只看下mergeLo 以稳定的方式合并两个相邻的运行。第一个run的第一个元素必须大于第二个run的第一个元素(a[base1] > a[base2]),第一次运行的最后一个元素(a[base1 + len1-1])必须大于第二次运行的所有元素。这也是由于前边gallopRight和gallopLeft所致。

private void mergeLo(int base1, int len1, int base2, int len2) {
        assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

        // Copy first run into temp array
        T[] a = this.a; // For performance
        T[] tmp = ensureCapacity(len1);
        int cursor1 = tmpBase; // Indexes into tmp array
        int cursor2 = base2;   // Indexes int a
        int dest = base1;      // Indexes int a
        System.arraycopy(a, base1, tmp, cursor1, len1);

       //将第二个run中的第一个num移动到,整个第一个run的第一个num的位置上,因为从前面的gallopLeft结果也可知道,a[cusor2]相对于a[dest]是递增的;
        a[dest++] = a[cursor2++];
        if (--len2 == 0) {
            System.arraycopy(tmp, cursor1, a, dest, len1);
            return;
        }
        if (len1 == 1) {
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; // Last elt of run 1 to end of merge
            return;
        }

        Comparator<? super T> c = this.c;  // Use local variable for performance
        int minGallop = this.minGallop;
        
    outer:
        while (true) {
            int count1 = 0; // Number of times in a row that first run won
            int count2 = 0; // Number of times in a row that second run won

            /*
             * Do the straightforward thing until (if ever) one run starts
             * winning consistently.
             */
            do {
                assert len1 > 1 && len2 > 0;
                if (c.compare(a[cursor2], tmp[cursor1]) < 0) {
                    a[dest++] = a[cursor2++];
                    count2++;
                    count1 = 0;
                    if (--len2 == 0)
                        break outer;
                } else {
                    a[dest++] = tmp[cursor1++];
                    count1++;
                    count2 = 0;
                    if (--len1 == 1)
                        break outer;
                }
            } while ((count1 | count2) < minGallop);

            /*
             * One run is winning so consistently that galloping may be a
             * huge win. So try that, and continue galloping until (if ever)
             * neither run appears to be winning consistently anymore.
             */
            do {
                assert len1 > 1 && len2 > 0;
                count1 = gallopRight(a[cursor2], tmp, cursor1, len1, 0, c);
                if (count1 != 0) {
                    System.arraycopy(tmp, cursor1, a, dest, count1);
                    dest += count1;
                    cursor1 += count1;
                    len1 -= count1;
                    if (len1 <= 1) // len1 == 1 || len1 == 0
                        break outer;
                }
                a[dest++] = a[cursor2++];
                if (--len2 == 0)
                    break outer;

                count2 = gallopLeft(tmp[cursor1], a, cursor2, len2, 0, c);
                if (count2 != 0) {
                    System.arraycopy(a, cursor2, a, dest, count2);
                    dest += count2;
                    cursor2 += count2;
                    len2 -= count2;
                    if (len2 == 0)
                        break outer;
                }
                a[dest++] = tmp[cursor1++];
                if (--len1 == 1)
                    break outer;
                minGallop--;
            } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
            if (minGallop < 0)
                minGallop = 0;
            minGallop += 2;  // Penalize for leaving gallop mode
        }  // End of "outer" loop
        this.minGallop = minGallop < 1 ? 1 : minGallop;  // Write back to field

        if (len1 == 1) {
            assert len2 > 0;
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; //  Last elt of run 1 to end of merge
        } else if (len1 == 0) {
            throw new IllegalArgumentException(
                "Comparison method violates its general contract!");
        } else {
            assert len2 == 0;
            assert len1 > 1;
            System.arraycopy(tmp, cursor1, a, dest, len1);
        }
    }

详细介绍下这个合并的步骤:

  1. 首先申请一个临时数组tmp,尽量小,能够放下第一个run(简称run1)就可以了;
  2. tmp数组和run2进行merge操作,且在整个merge的过程中,还放了一个变量minGallop,用于检测tmp中数相对于run2中数有序的数目,如果数目>7,则重新执行gallopRight和gallopLeft方法进行处理;
  3. 直到tmp和run2合并完毕;

问题就出在这里了

  • 首先对于使用比较器排序要转换一个思路:这里不存在大于或者小于,只存在升序或者降序,当比较器返回-1视为降序,会被做换位置处理,反之0或者1则并不会;
  • 三目运算符:o1, o2) -> o1.getQ() > o2.getQ() ? -1 : 1 将1和0放在了一起,首先说不规范,就是这样做会破坏排序的稳定性,然后就是这样使代码在第一次gallop操作的时候,未被发现,在后边超过了限制7次后,重新执行gallop,导致len1被减少到0,这个时候触发了异常问题。因为从正常思路上来讲,经过gallop操作后:

compare(run1[first],run2[first])<0

compare(run1[last],run2[last])<0;

如此的话,最先合并结束的肯定是run2,所以当len1等于0的时候,抛出了异常,我想这也许也是Fail-Fast的一种方式吧!

给出一个测试数据

1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,11,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1

在jdk 1.8的情况下,该用例会触发异常。