Collections.sort原始碼分析
序
在平時的開發中總是會用到排序,每次直接寫一個Comparator,然後呼叫Collections。sort(),但是忽略了具體實現的細節。
今天在使用的時候,報了下面的一個錯誤,原因是比較兩個元素時沒有明確返回1,-1,0,導致違背了比較的傳遞性。
開始
Collections中sort呼叫List中的sort方法:
public
static
<
T
>
void
sort
(
List
<
T
>
list
,
Comparator
<?
super
T
>
c
)
{
list
。
sort
(
c
);
}
// ——> List
default
void
sort
(
Comparator
<?
super
E
>
c
)
{
Object
[]
a
=
this
。
toArray
();
Arrays
。
sort
(
a
,
(
Comparator
)
c
);
ListIterator
<
E
>
i
=
this
。
listIterator
();
for
(
Object
e
:
a
)
{
i
。
next
();
i
。
set
((
E
)
e
);
}
}
Arrays中具體處理,傳統的歸併,或者timsort:
public static
if (c == null) {
sort(a);
} else { // 因為TimSort是jdk1。7引入的,使用了效能更好的Timsort,但是也是可以使用遺留的merge sort
if (LegacyMergeSort。userRequested)
legacyMergeSort(a, c);
else
TimSort。sort(a, 0, a。length, c, null, 0, 0);
}
}
TimSort實現
排序這個陣列的給定區間,後面三個引數先忽略。待排序的個數如果小於32(MIN_MERGE),比較簡單。
static
<
T
>
void
sort
(
T
[]
a
,
int
lo
,
int
hi
,
Comparator
<?
super
T
>
c
,
T
[]
work
,
int
workBase
,
int
workLen
)
{
assert
c
!=
null
&&
a
!=
null
&&
lo
>=
0
&&
lo
<=
hi
&&
hi
<=
a
。
length
;
int
nRemaining
=
hi
-
lo
;
if
(
nRemaining
<
2
)
return
;
// Arrays of size 0 and 1 are always sorted
// If array is small, do a “mini-TimSort” with no merges 情況(1)
if
(
nRemaining
<
MIN_MERGE
)
{
// 32
int
initRunLen
=
countRunAndMakeAscending
(
a
,
lo
,
hi
,
c
);
binarySort
(
a
,
lo
,
hi
,
lo
+
initRunLen
,
c
);
return
;
}
/**
* March over the array once, left to right, finding natural runs,
* extending short natural runs to minRun elements, and merging runs
* to maintain stack invariant。 情況(2)
*/
TimSort
<
T
>
ts
=
new
TimSort
<>(
a
,
c
,
work
,
workBase
,
workLen
);
int
minRun
=
minRunLength
(
nRemaining
);
do
{
// Identify next run
int
runLen
=
countRunAndMakeAscending
(
a
,
lo
,
hi
,
c
);
// If run is short, extend to min(minRun, nRemaining)
if
(
runLen
<
minRun
)
{
// 儘可能的做一次
int
force
=
nRemaining
<=
minRun
?
nRemaining
:
minRun
;
binarySort
(
a
,
lo
,
lo
+
force
,
lo
+
runLen
,
c
);
// 對[lo,lo+force]拍好序了,當然下次的 run length 長度是force
runLen
=
force
;
}
// Push run onto pending-run stack, and maybe merge
// 把這次run的基點位置和長度存入棧中,必要時合併
ts
。
pushRun
(
lo
,
runLen
);
ts
。
mergeCollapse
();
// TimSort持有陣列a,根據區間來合併,從而達到排序
// Advance to find next run 準備下一輪的部分排序
lo
+=
runLen
;
nRemaining
-=
runLen
;
}
while
(
nRemaining
!=
0
);
// Merge all remaining runs to complete sort
assert
lo
==
hi
;
ts
。
mergeForceCollapse
();
assert
ts
。
stackSize
==
1
;
}
先看看run的定義,翻譯成趨向?一個run是從陣列給定位置開始的最長遞增活遞減序列的長度,為了得到穩定的歸併排序,這裡的降序中使用的“>”,不包含“=”,保證stability。程式碼中的原註釋是:
*
Returns
the
length
of
the
run
beginning
at
the
specified
position
in
*
the
specified
array
and
reverses
the
run
if
it
is
descending
(
ensuring
*
that
the
run
will
always
be
ascending
when
the
method
returns
)。
*
*
A
run
is
the
longest
ascending
sequence
with
:
*
*
a
[
lo
]
<=
a
[
lo
+
1
]
<=
a
[
lo
+
2
]
<=
。。。
*
*
or
the
longest
descending
sequence
with
:
*
*
a
[
lo
]
>
a
[
lo
+
1
]
>
a
[
lo
+
2
]
>
。。。
*
*
For
its
intended
use
in
a
stable
mergesort
,
the
strictness
of
the
*
definition
of
“descending”
is
needed
so
that
the
call
can
safely
*
reverse
a
descending
sequence
without
violating
stability
。
具體計算最長run長度:
private
static
<
T
>
int
countRunAndMakeAscending
(
T
[]
a
,
int
lo
,
int
hi
,
Comparator
<?
super
T
>
c
)
{
assert
lo
<
hi
;
int
runHi
=
lo
+
1
;
if
(
runHi
==
hi
)
return
1
;
// Find end of run, and reverse range if descending
if
(
c
。
compare
(
a
[
runHi
++],
a
[
lo
])
<
0
)
{
// Descending
while
(
runHi
<
hi
&&
c
。
compare
(
a
[
runHi
],
a
[
runHi
-
1
])
<
0
)
runHi
++;
// 如果是遞減序列,那麼就得到最長的,然後逆序
reverseRange
(
a
,
lo
,
runHi
);
}
else
{
// Ascending
while
(
runHi
<
hi
&&
c
。
compare
(
a
[
runHi
],
a
[
runHi
-
1
])
>=
0
)
runHi
++;
}
return
runHi
-
lo
;
// 這個run的最大長度
}
舉個例子吧,如下圖:
排序小陣列
獲得初始的run長度後,呼叫 binarySort(a, lo, hi, lo + initRunLen, c),binarySort 當然不會浪費時間再去排序在求run長度時已經排好序的頭部(lo->start),然後進行二分插入排序。
binarySort要做的就是把後續的元素依次插入到屬於他們的位置,基點就是已經排好序的子陣列(如果沒有的子陣列就是首元素),把當前操作的元素稱為pivot,透過二分查詢,找到自己應該插入的位置(達到的狀態是left==right),找到位置後,就要為pivot的插入騰出空間,所以需要元素的移動,程式碼中如果移動少於兩個元素就直接操作,否則呼叫System。arraycopy(),最後插入我們的pivot到正確的位置。
這樣我想到了之前在學習排序的時候的幾個演算法,其中有個說法是,對於小陣列的排序使用插入排序,大陣列的時候使用快排,歸併排序之類的。
/**
* Sorts the specified portion of the specified array using a binary
* insertion sort。 This is the best method for sorting small numbers
* of elements。 It requires O(n log n) compares, but O(n^2) data
* movement (worst case)。
*
*/
private
static
<
T
>
void
binarySort
(
T
[]
a
,
int
lo
,
int
hi
,
int
start
,
Comparator
<?
super
T
>
c
)
{
assert
lo
<=
start
&&
start
<=
hi
;
if
(
start
==
lo
)
start
++;
for
(
;
start
<
hi
;
start
++)
{
T
pivot
=
a
[
start
];
// Set left (and right) to the index where a[start] (pivot) belongs
int
left
=
lo
;
int
right
=
start
;
assert
left
<=
right
;
/*
* Invariants: 排序過程的不變數
* pivot >= all in [lo, left)。
* pivot < all in [right, start)。
*/
while
(
left
<
right
)
{
int
mid
=
(
left
+
right
)
>>>
1
;
// 二分查詢找到屬於pivot的位置
if
(
c
。
compare
(
pivot
,
a
[
mid
])
<
0
)
right
=
mid
;
else
left
=
mid
+
1
;
}
assert
left
==
right
;
/*
* The invariants still hold: pivot >= all in [lo, left) and
* pivot < all in [left, start), so pivot belongs at left。 Note
* that if there are elements equal to pivot, left points to the
* first slot after them —— that‘s why this sort is stable。
* Slide elements over to make room for pivot。
*/
int
n
=
start
-
left
;
// The number of elements to move
// Switch is just an optimization for arraycopy in default case
switch
(
n
)
{
// 移動元素
case
2
:
a
[
left
+
2
]
=
a
[
left
+
1
];
case
1
:
a
[
left
+
1
]
=
a
[
left
];
break
;
default
:
System
。
arraycopy
(
a
,
left
,
a
,
left
+
1
,
n
);
}
// 屬於自己的位置
a
[
left
]
=
pivot
;
}
}
排序大陣列
接下來看如果待排序的個數>=32時的過程,首先弄明白minRunLength得到的是什麼。註釋很清楚,雖然理論基礎不理解。
*
Roughly
speaking
,
the
computation
is
:
*
*
If
n
<
MIN_MERGE
,
return
n
(
it
’
s
too
small
to
bother
with
fancy
stuff
)。
*
Else
if
n
is
an
exact
power
of
2
,
return
MIN_MERGE
/
2
。
*
Else
return
an
int
k
,
MIN_MERGE
/
2
<=
k
<=
MIN_MERGE
,
such
that
n
/
k
*
is
close
to
,
but
strictly
less
than
,
an
exact
power
of
2
。
如果還是很抽象的話,從32到100得到的min run length如下,可以直觀的體會下:
16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,30,30,31,31,32,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,22,22,22,22,23,23,23,23,24,24,24,24,25,25,25
得到 minRun 之後,取 minRun 和 nRemaining 的最小值作為這次要排序的序列,初始的有序陣列和前面情況(1)的獲取方式一樣,然後做一次二分插入排序,現在有序序列的長度是force,這一部分排好序之後,把本次run的起始位置和長度存入一個stack中(兩個陣列),後續就是根據這些區間完成排序的。每次push之後就是要進行合併檢查,也就是說相鄰的區間能合併的就合併,具體的:
/**
* Examines the stack of runs waiting to be merged and merges adjacent runs
* until the stack invariants are reestablished:
*
* 1。 runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
* 2。 runLen[i - 2] > runLen[i - 1]
*
* This method is called each time a new run is pushed onto the stack,
* so the invariants are guaranteed to hold for i < stackSize upon
* entry to the method。
*/
private
void
mergeCollapse
()
{
while
(
stackSize
>
1
)
{
int
n
=
stackSize
-
2
;
if
(
n
>
0
&&
runLen
[
n
-
1
]
<=
runLen
[
n
]
+
runLen
[
n
+
1
])
{
if
(
runLen
[
n
-
1
]
<
runLen
[
n
+
1
])
n
——;
mergeAt
(
n
);
}
else
if
(
runLen
[
n
]
<=
runLen
[
n
+
1
])
{
mergeAt
(
n
);
}
else
{
break
;
// Invariant is established
}
}
}
我的理解下,雖然每次run之後都能進行合併,但是為了減少合併帶來的開銷,找到了某種規則,可以在某些條件下避免合併。接下來看看具體合併時的動作。
合併有序區間
有一種情況是:如果前一個區間的長度小於當前區間長度,就進行merge,每個區間是一個排好序的陣列,現在要合併第i和i+1個區間。
首先把 run length更新到 ruLen[i] 中,刪掉 i+1 的run資訊;接下來定位區間2的最小元素在有序區間1的插入位置,更新區間1的 run base 和 run length,稱更新後的為區間1‘; 然後查詢區間1’的最大元素在區間2的正確定位;此時此刻這個陣列已經得到了有效的劃分,如下圖,只需要合併[base1,len1]和[base2,len2]就可以了,其他段已經在正確位置。
private
void
mergeAt
(
int
i
)
{
assert
stackSize
>=
2
;
assert
i
>=
0
;
assert
i
==
stackSize
-
2
||
i
==
stackSize
-
3
;
int
base1
=
runBase
[
i
];
int
len1
=
runLen
[
i
];
int
base2
=
runBase
[
i
+
1
];
int
len2
=
runLen
[
i
+
1
];
assert
len1
>
0
&&
len2
>
0
;
assert
base1
+
len1
==
base2
;
/*
* (1) 合併了 i,i+1, 把i+2的資訊移動到之前i+1的位置,就是刪除i+1
* Record the length of the combined runs; if i is the 3rd-last
* run now, also slide over the last run (which isn‘t involved
* in this merge)。 The current run (i+1) goes away in any case。
*/
runLen
[
i
]
=
len1
+
len2
;
if
(
i
==
stackSize
-
3
)
{
runBase
[
i
+
1
]
=
runBase
[
i
+
2
];
runLen
[
i
+
1
]
=
runLen
[
i
+
2
];
}
stackSize
——;
/*
*(2)找到區間2的最小元素若插入到區間的話,正確索引位置
* Find where the first element of run2 goes in run1。 Prior elements
* in run1 can be ignored (because they’re already in place)。
*/
int
k
=
gallopRight
(
a
[
base2
],
a
,
base1
,
len1
,
0
,
c
);
assert
k
>=
0
;
base1
+=
k
;
len1
-=
k
;
// 說明區間2的最小元素在區間1的末尾,所以完成兩個區間的合併排序
if
(
len1
==
0
)
return
;
/*
* (3)查詢區間1‘的最大元素在區間2的正確定位
* Find where the last element of run1 goes in run2。 Subsequent elements
* in run2 can be ignored (because they’re already in place)。
*/
len2
=
gallopLeft
(
a
[
base1
+
len1
-
1
],
a
,
base2
,
len2
,
len2
-
1
,
c
);
assert
len2
>=
0
;
// 說明區間1‘的最大元素小於區間2的最小元素,所以完成排序
if
(
len2
==
0
)
return
;
// Merge remaining runs, using tmp array with min(len1, len2) elements
if
(
len1
<=
len2
)
mergeLo
(
base1
,
len1
,
base2
,
len2
);
else
mergeHi
(
base1
,
len1
,
base2
,
len2
);
}
為了效能,在len1<=len2的時候使用mergeLo,len1>=len2的時候使用mergeHi,透過前面的定位,到這裡的時候,有a[base1]>a[base2],a[base1+len1] 參考 http:// bugs。java。com/bugdataba se/view_bug。do?bug_id=6804124