10782 Matlab 中 lsqisotonic 函式的高效實現
上個月發現,我一直在使用的盜版 Matlab 2010a 就要過期了。於是我從學校下載了正版的 2017b。升級後,我發現統計工具包裡的 lsqisotonic 函式跟原來一樣,仍然使用了一種低效的實現,而我在 2012 年就發現了這個問題,並寫了一個更高效的版本。這篇專欄就來分享一下我的實現。
本文的第一部分介紹 lsqisotonic 這個函式的背景 —— multi-dimensional scaling,第二部分討論這個函式本身的實現。對背景不感興趣的讀者,可以直接跳到第二部分,因為這個函式本身的功能,可以歸納成一道普通的演算法題。
一、背景:Multi-dimensional Scaling
。Multi-dimensional scaling 是一種資料視覺化的方法。它不太容易翻譯成中文,主要是因為 scaling 這個詞的用法比較奇怪。維基百科給出的中文翻譯是「多維標度」,其實挺不知所云的,甚至都不能體現出這是一個動名詞。而日文翻譯叫「多次元尺度構成法」,我覺得可以把二者融合一下,譯作「多維尺度構成法」。在本文中,我把這種方法簡稱為 MDS。Matlab 中統計工具包裡的 mdscale 函式,就是用來做 MDS 的。
MDS 並不是目前最流行的資料視覺化方法,最流行的應該是機器學習大佬 Hinton 開創的 t-SNE。在這個回答中,我用 MDS 和 t-SNE 兩種方法對我 2012 年的人人好友關係進行了視覺化,並比較了它們的效果:王贇 Maigo:有沒有那種方式可以將高維資料進行視覺化?比如保持資料結構不變將高維資料對映到低維空間?下圖就是用 MDS 視覺化的結果:
MDS 的輸入,是
個物件兩兩之間的
差異度
(dissimilarities),共
個數值。如果已知的是相似度(similarities),則可以透過一個單調遞減的函式轉換成差異度。記第
個物件之間的差異度為
。MDS 要做的事情,是在一個給定維數(通常為二維或三維)的空間中找一組點
來代表這些物件,使得第
兩個點之間的距離
儘可能接近給定的差異度
。具體來說,是要最小化如下的目標函式,這個函式稱為
壓力
(stress):
其中
是點對
的權重。一般來說,所有權重都取為 1;如果輸入資料不全,某一組差異度
沒有測量到,那麼可以透過設定
來把這個點對排除掉。當然,如果認為某些點對的差異度比另一些點對更重要,也可以給每一個點對賦予不同的權重。
求使得壓力最小化的點集
的方法有很多。比如梯度下降法就可以使用;Matlab 中 mdscale 函式使用的是一種共軛梯度法。除此之外,Modern Multidimensional Scaling 一書的第 8 章還介紹了一種稱為 SMACOF 的迭代演算法,它與機器學習中常見的 EM 演算法有相似之處,二者都是 MM 演算法的特例。不過求點集
的演算法不是本文討論的重點,所以我就不繼續展開了。
在實際問題中,差異度不一定是定比資料,而有可能只是定序的(參見:王贇 Maigo:華裔數學家陶哲軒IQ230,是智商100聰明程度的幾倍?),即差異度的數值並無意義,而只有它們之間的大小關係有意義。這種情形的 MDS,稱為 non-metric MDS。上文中的壓力函式依賴於差異度的具體數值,在 non-metric MDS 中再使用這種壓力函式,就顯得不合理了。於是就有了下面這種新的壓力函式:
其中
是差異度經過變換
的結果。變換
僅需要滿足單調性,稱為
單調回歸變換
(isotonic regression);它的作用就是說明差異度的數值並不重要,重要的只有大小關係。要最小化這個壓力函式,一方面需要求出一組點的座標
,另一方面還要求出一個變換
,形成了一個「雞生蛋,蛋生雞」的問題。這種問題一般也是透過迭代演算法來解決的,即不停重複下面的步驟:
固定
,求使得壓力最小化的
。事實上這一步並不需要使得壓力「最小化」,只要能讓它減小就行了。這一步可以使用梯度下降法、共軛梯度法、SMACOF 等任意一種方法,且只需迭代一次。
固定
,求使得壓力最小化的單調回歸變換
。這一步同樣只要讓壓力減小就行了,不過讓壓力最小化也不困難。這一步,就是由本文的主角 —— lsqisotonic 函式來實現的。
lsqisotonic 這個函式的名字中,lsq 是 least squares(最小二乘)的意思,指的是壓力函式的形式;isotonic 則說明函式用來求解最優的單調回歸變換。為了下文討論方便,我把函式的功能再提煉一下。我們把所有的差異度
從小到大排序,得到
,其中
是點對的數目。與
對應的那個點對在空間中的距離
記作
,其權重記作
。lsqisotonic 函式要求的是變換後的差異度
,把它們記作
。它們需要滿足單調性:
,並且最小化目標函式
。注意,提煉後的函式的輸入其實只有
和
;
的具體數值是沒有用的,它們唯一的作用就是指定了
和
的順序。
二、lsqisotonic 函式的實現
如果你跳過了第一部分,那麼從此往下看,也不會有任何問題 ^_^
lsqisotonic 函式的作用是求解最優的「單調回歸變換」。它的輸入是一個序列
和一組權重
。序列
不一定是單調遞增的,現在我們要求一個單調遞增的序列
,讓它跟輸入序列
儘可能接近。「接近」的具體標準是最小化壓力函式
。為討論簡潔,認為所有權重均為正。
先來看一個具體的例子。設輸入序列為
,所有元素權重均為 1。輸入序列不單調遞增,我們需要用最小的力把它「掰」成單調遞增的,即把所有的下降段「掰平」。
首先看「4, 3」這個下降段。現在要把它掰平,那麼把兩個數都掰成多少能使得壓力最小呢?不難發現,答案應該是 4 和 3 的平均數,即 3。5。如果這兩個數有不同的權重,那麼使得壓力最小的,就應該是它們的加權平均數(證明留給讀者)。
按照這種思路,可以把輸入序列中三個下降段「4, 3」「5, 3, 1」「7, 5」分別掰成 3。5、3、6,得到
。這就結束了嗎?並沒有 —— 因為 3。5 和 3 這兩個段落又違反單調性了。此時,就要把 3。5 和 3 這兩個段落整體掰平。3。5 段落的總權重為 2,3 段落的總權重為 3,所以掰平的結果應該是加權平均數 3。2。此時得到的序列為
,滿足單調性,所以這就是要求的
。
上面逐漸把下降段「掰平」的過程用圖象表示如下,藍點為輸入序列,紅點及紅線為所求的單調序列。
Matlab 自帶的 lsqisotonic 函式,就是這樣求解最優單調回歸變換的。它不斷地在序列中尋找下降段,並把下降段掰成整體的加權平均數,直到序列單調遞增為止。其程式碼的核心部分如下:
yhat
=
y
;
% 用輸入序列初始化輸出序列
block
=
1
:
length
(
y
);
% block(i) 表示第 i 個元素屬於第幾個段落
% 初始時每個元素獨立成段
while
true
diffs
=
diff
(
yhat
);
% 求所有相鄰元素之差
if
all
(
diffs
>
=
0
),
break
;
end
% 若已滿足單調性,退出
idx
=
cumsum
([
1
;
(
diffs
>
0
)]);
% 找出序列中所有的下降段,並依次編號
% 例如,若輸入為 1,4,3,5,3,1,7,5
% 則編號結果為 1,2,2,3,3,3,4,4
sumyhat
=
accumarray
(
idx
,
w
。*
yhat
);
% 計算每段元素的加權和
w
=
accumarray
(
idx
,
w
);
% 計算每段元素的總權重
yhat
=
sumyhat
。/
w
;
% 求出每段元素的加權平均數
block
=
idx
(
block
);
% 更新每個元素所屬的段落編號
end
yhat
=
yhat
(
block
);
% 構建輸出序列
這段程式碼使用了一些 Matlab 特有的操作(比如 cumsum、accumarray),可能比較難理解。理解的關鍵在於,在迭代過程中,yhat 並不是記錄了完整的序列,而是對序列中每一個水平段落,只記錄一個值。上文所舉例子的執行過程如下表所示,它會幫助你理解。
上面的實現方式有什麼問題呢?當然是複雜度啦!不難看出,每次迭代的時間複雜度為
,而迭代次數的上限也是
,所以總複雜度為
。下面這個例子可以達到複雜度的上限:輸入序列
,權重
。這個例子的精髓在於,序列有且僅有前兩個元素組成下降段,並且因為第一個元素 10000 的權重很大,把前兩個元素取加權平均合併後,序列第一段的值依然會很大。這個巨大的值會在每次迭代中吃且僅吃掉後面的一個元素,導致迭代次數達到
。
下面這段話是寫給看過第一部分的讀者的:
注意在 MDS 的背景下,
是點對的個數,它與資料點數目
的關係是
。也就是說,Matlab 自帶的 lsqisotonic 函式的時間複雜度,是嚇人的
!我用來視覺化的人人好友有 1000 多名,難怪 lsqisotonic 會卡死了。
事實上,「反覆合併下降段」這個過程,完全可以用
的時間複雜度來實現。我們從左到右掃描序列的每一個元素,並用一個棧來維護已經掃描的部分「掰平」後的各個水平段落。當掃描到一個新的元素的時候,先把它作為一個單獨的段落壓入棧頂,然後反覆檢視棧頂的兩個段落,如果它們違反了單調性,就把它們合併。這種實現的程式碼如下:
yhat
=
y
;
N
=
length
(
y
);
% 用輸入序列初始化輸出序列
bstart
=
zeros
(
1
,
N
);
bend
=
zeros
(
1
,
N
);
% 棧:bstart(i), bend(i) 記錄第 i 段的起止位置
% 此外 yhat 和 w 也兼用作棧,
% yhat(i) 與 w(i) 表示第 i 段的值和總權重
b
=
0
;
% 棧頂指標
for
i
=
1
:
N
% 依次掃描每個元素
b
=
b
+
1
;
% 由此往下三行:新元素作為單獨的段落入棧
yhat
(
b
)
=
yhat
(
i
);
w
(
b
)
=
w
(
i
);
bstart
(
b
)
=
i
;
bend
(
b
)
=
i
;
while
b
>
1
&&
yhat
(
b
)
<
yhat
(
b
-
1
)
% 棧頂兩個段落違反單調性
yhat
(
b
-
1
)
=
(
yhat
(
b
-
1
)
*
w
(
b
-
1
)
+
yhat
(
b
)
*
w
(
b
))
/
(
w
(
b
-
1
)
+
w
(
b
));
w
(
b
-
1
)
=
w
(
b
-
1
)
+
w
(
b
);
bend
(
b
-
1
)
=
bend
(
b
);
b
=
b
-
1
;
% 由此往上四行:棧頂兩個段落取加權平均合併
end
end
block
=
zeros
(
1
,
N
);
for
i
=
1
:
b
block
(
bstart
(
i
)
:
bend
(
i
))
=
i
;
% 由棧中資訊反推出輸出序列的每個元素位於第幾段
end
yhat
=
yhat
(
block
);
% 構建輸出序列
這段程式碼的主體迴圈沒有用到 Matlab 的黑科技,比較好懂,所以樣例資料的執行過程我就不寫了。
我的實現的時間複雜度為
。雖然有時一個元素入棧會引發連鎖式的段落合併,但考慮演算法的整個執行過程,一共會有
個元素入棧,最多有
次段落合併,所以複雜度為
。那麼 Matlab 自帶實現慢在哪兒了呢?仍然考慮極端輸入
,
。在迭代過程中,序列的尾部始終是單調遞增的,但 Matlab 的實現在每次迭代中都徒勞無功地在序列的尾部檢查是否有下降段。這就是它慢的原因。
三、附記
我實現的 lsqisotonic 函式,可以從 Mathworks File Exchange 上下載。這個函式位於 Matlab 安裝目錄下的 toolbox\stats\stats\private 子目錄,可以用我的版本替代原有版本。
對 MDS 感興趣的讀者,推薦閱讀 Modern Multidimensional Scaling 一書。其中第 8、9 章介紹的就是本文討論的 non-metric MDS。第 12 章介紹了 metric MDS 的另一種情形 classical MDS,它最小化的目標函式並不是 stress,而是另一種稱為 strain 的目標函式;其優點是求解過程不是迭代的,而是可以一步到位。Classical MDS 在 Matlab 中由 cmdscale 函式實現。