上個月發現,我一直在使用的盜版 Matlab 2010a 就要過期了。於是我從學校下載了正版的 2017b。升級後,我發現統計工具包裡的 lsqisotonic 函式跟原來一樣,仍然使用了一種低效的實現,而我在 2012 年就發現了這個問題,並寫了一個更高效的版本。這篇專欄就來分享一下我的實現。

本文的第一部分介紹 lsqisotonic 這個函式的背景 —— multi-dimensional scaling,第二部分討論這個函式本身的實現。對背景不感興趣的讀者,可以直接跳到第二部分,因為這個函式本身的功能,可以歸納成一道普通的演算法題。

一、背景:Multi-dimensional Scaling

。Multi-dimensional scaling 是一種資料視覺化的方法。它不太容易翻譯成中文,主要是因為 scaling 這個詞的用法比較奇怪。維基百科給出的中文翻譯是「多維標度」,其實挺不知所云的,甚至都不能體現出這是一個動名詞。而日文翻譯叫「多次元尺度構成法」,我覺得可以把二者融合一下,譯作「多維尺度構成法」。在本文中,我把這種方法簡稱為 MDS。Matlab 中統計工具包裡的 mdscale 函式,就是用來做 MDS 的。

MDS 並不是目前最流行的資料視覺化方法,最流行的應該是機器學習大佬 Hinton 開創的 t-SNE。在這個回答中,我用 MDS 和 t-SNE 兩種方法對我 2012 年的人人好友關係進行了視覺化,並比較了它們的效果:王贇 Maigo:有沒有那種方式可以將高維資料進行視覺化?比如保持資料結構不變將高維資料對映到低維空間?下圖就是用 MDS 視覺化的結果:

10782 Matlab 中 lsqisotonic 函式的高效實現

10782 Matlab 中 lsqisotonic 函式的高效實現

MDS 的輸入,是

n

個物件兩兩之間的

差異度

(dissimilarities),共

n(n-1)/2

個數值。如果已知的是相似度(similarities),則可以透過一個單調遞減的函式轉換成差異度。記第

i,j

個物件之間的差異度為

\delta_{ij}

。MDS 要做的事情,是在一個給定維數(通常為二維或三維)的空間中找一組點

X = \{X_1, \ldots, X_n\}

來代表這些物件,使得第

i,j

兩個點之間的距離

d_{ij}(X) = ||X_i - X_j||

儘可能接近給定的差異度

\delta_{ij}

。具體來說,是要最小化如下的目標函式,這個函式稱為

壓力

(stress):

\sigma(X) = \sum_{i<j} w_{ij}(d_{ij}(X) - \delta_{ij})^2

其中

w_{ij}

是點對

(i,j)

的權重。一般來說,所有權重都取為 1;如果輸入資料不全,某一組差異度

\delta_{ij}

沒有測量到,那麼可以透過設定

w_{ij} = 0

來把這個點對排除掉。當然,如果認為某些點對的差異度比另一些點對更重要,也可以給每一個點對賦予不同的權重。

求使得壓力最小化的點集

 X

的方法有很多。比如梯度下降法就可以使用;Matlab 中 mdscale 函式使用的是一種共軛梯度法。除此之外,Modern Multidimensional Scaling 一書的第 8 章還介紹了一種稱為 SMACOF 的迭代演算法,它與機器學習中常見的 EM 演算法有相似之處,二者都是 MM 演算法的特例。不過求點集

 X

的演算法不是本文討論的重點,所以我就不繼續展開了。

在實際問題中,差異度不一定是定比資料,而有可能只是定序的(參見:王贇 Maigo:華裔數學家陶哲軒IQ230,是智商100聰明程度的幾倍?),即差異度的數值並無意義,而只有它們之間的大小關係有意義。這種情形的 MDS,稱為 non-metric MDS。上文中的壓力函式依賴於差異度的具體數值,在 non-metric MDS 中再使用這種壓力函式,就顯得不合理了。於是就有了下面這種新的壓力函式:

\sigma(X, \hat{d}) = \sum_{i<j} w_{ij}(d_{ij}(X) - \hat{d}_{ij})^2

其中

\hat{d}_{ij}

是差異度經過變換

\hat{d}_{ij} = f(\delta_{ij})

的結果。變換

f

僅需要滿足單調性,稱為

單調回歸變換

(isotonic regression);它的作用就是說明差異度的數值並不重要,重要的只有大小關係。要最小化這個壓力函式,一方面需要求出一組點的座標

X

,另一方面還要求出一個變換

f

,形成了一個「雞生蛋,蛋生雞」的問題。這種問題一般也是透過迭代演算法來解決的,即不停重複下面的步驟:

固定

\hat{d}_{ij}

,求使得壓力最小化的

X

。事實上這一步並不需要使得壓力「最小化」,只要能讓它減小就行了。這一步可以使用梯度下降法、共軛梯度法、SMACOF 等任意一種方法,且只需迭代一次。

固定

X

,求使得壓力最小化的單調回歸變換

f

。這一步同樣只要讓壓力減小就行了,不過讓壓力最小化也不困難。這一步,就是由本文的主角 —— lsqisotonic 函式來實現的。

lsqisotonic 這個函式的名字中,lsq 是 least squares(最小二乘)的意思,指的是壓力函式的形式;isotonic 則說明函式用來求解最優的單調回歸變換。為了下文討論方便,我把函式的功能再提煉一下。我們把所有的差異度

\delta_{ij}

從小到大排序,得到

x_1, \ldots, x_N

,其中

N = n(n-1)/2

是點對的數目。與

x_k = \delta_{ij}

對應的那個點對在空間中的距離

d_{ij}(X)

記作

y_k

,其權重記作

w_k

。lsqisotonic 函式要求的是變換後的差異度

f(x_1), \ldots, f(x_N)

,把它們記作

\hat{y}_1, \ldots, \hat{y}_N

。它們需要滿足單調性:

\hat{y}_1 \le \hat{y}_2 \le \ldots \le \hat{y}_N

,並且最小化目標函式

\sigma = \sum_{k=1}^N w_k(y_k - \hat{y}_k)^2

。注意,提煉後的函式的輸入其實只有

y_1, \ldots, y_N

w_1, \ldots, w_N

x_1, \ldots, x_N

的具體數值是沒有用的,它們唯一的作用就是指定了

y_1, \ldots, y_N

w_1, \ldots, w_N

的順序。

二、lsqisotonic 函式的實現

如果你跳過了第一部分,那麼從此往下看,也不會有任何問題 ^_^

lsqisotonic 函式的作用是求解最優的「單調回歸變換」。它的輸入是一個序列

y = \{y_1, \ldots, y_N\}

和一組權重

w = \{w_1, \ldots, w_N\}

。序列

y

不一定是單調遞增的,現在我們要求一個單調遞增的序列

\hat{y}_1 \le \hat{y}_2 \le \ldots \le \hat{y}_N

,讓它跟輸入序列

y

儘可能接近。「接近」的具體標準是最小化壓力函式

\sigma = \sum_{k=1}^N w_k(y_k - \hat{y}_k)^2

。為討論簡潔,認為所有權重均為正。

先來看一個具體的例子。設輸入序列為

y = \{1, 4, 3, 5, 3, 1, 7, 5\}

,所有元素權重均為 1。輸入序列不單調遞增,我們需要用最小的力把它「掰」成單調遞增的,即把所有的下降段「掰平」。

首先看「4, 3」這個下降段。現在要把它掰平,那麼把兩個數都掰成多少能使得壓力最小呢?不難發現,答案應該是 4 和 3 的平均數,即 3。5。如果這兩個數有不同的權重,那麼使得壓力最小的,就應該是它們的加權平均數(證明留給讀者)。

按照這種思路,可以把輸入序列中三個下降段「4, 3」「5, 3, 1」「7, 5」分別掰成 3。5、3、6,得到

1, 3.5, 3.5, 3, 3, 3, 6, 6

。這就結束了嗎?並沒有 —— 因為 3。5 和 3 這兩個段落又違反單調性了。此時,就要把 3。5 和 3 這兩個段落整體掰平。3。5 段落的總權重為 2,3 段落的總權重為 3,所以掰平的結果應該是加權平均數 3。2。此時得到的序列為

1, 3.2, 3.2, 3.2, 3.2, 3.2, 6, 6

,滿足單調性,所以這就是要求的

\hat{y}_1, \ldots, \hat{y}_8

上面逐漸把下降段「掰平」的過程用圖象表示如下,藍點為輸入序列,紅點及紅線為所求的單調序列。

10782 Matlab 中 lsqisotonic 函式的高效實現

10782 Matlab 中 lsqisotonic 函式的高效實現

Matlab 自帶的 lsqisotonic 函式,就是這樣求解最優單調回歸變換的。它不斷地在序列中尋找下降段,並把下降段掰成整體的加權平均數,直到序列單調遞增為止。其程式碼的核心部分如下:

yhat

=

y

% 用輸入序列初始化輸出序列

block

=

1

length

y

);

% block(i) 表示第 i 個元素屬於第幾個段落

% 初始時每個元素獨立成段

while

true

diffs

=

diff

yhat

);

% 求所有相鄰元素之差

if

all

diffs

>

=

0

),

break

end

% 若已滿足單調性,退出

idx

=

cumsum

([

1

diffs

>

0

)]);

% 找出序列中所有的下降段,並依次編號

% 例如,若輸入為 1,4,3,5,3,1,7,5

% 則編號結果為 1,2,2,3,3,3,4,4

sumyhat

=

accumarray

idx

w

。*

yhat

);

% 計算每段元素的加權和

w

=

accumarray

idx

w

);

% 計算每段元素的總權重

yhat

=

sumyhat

。/

w

% 求出每段元素的加權平均數

block

=

idx

block

);

% 更新每個元素所屬的段落編號

end

yhat

=

yhat

block

);

% 構建輸出序列

這段程式碼使用了一些 Matlab 特有的操作(比如 cumsum、accumarray),可能比較難理解。理解的關鍵在於,在迭代過程中,yhat 並不是記錄了完整的序列,而是對序列中每一個水平段落,只記錄一個值。上文所舉例子的執行過程如下表所示,它會幫助你理解。

10782 Matlab 中 lsqisotonic 函式的高效實現

10782 Matlab 中 lsqisotonic 函式的高效實現

上面的實現方式有什麼問題呢?當然是複雜度啦!不難看出,每次迭代的時間複雜度為

O(N)

,而迭代次數的上限也是

O(N)

,所以總複雜度為

O(N^2)

。下面這個例子可以達到複雜度的上限:輸入序列

y = \{10000, 1, 2, 3, 4, 5\}

,權重

w = \{10000, 1, 1, 1, 1, 1\}

。這個例子的精髓在於,序列有且僅有前兩個元素組成下降段,並且因為第一個元素 10000 的權重很大,把前兩個元素取加權平均合併後,序列第一段的值依然會很大。這個巨大的值會在每次迭代中吃且僅吃掉後面的一個元素,導致迭代次數達到

N

下面這段話是寫給看過第一部分的讀者的:

注意在 MDS 的背景下,

N

是點對的個數,它與資料點數目

n

的關係是

N = n(n-1)/2

。也就是說,Matlab 自帶的 lsqisotonic 函式的時間複雜度,是嚇人的

O(n^4)

!我用來視覺化的人人好友有 1000 多名,難怪 lsqisotonic 會卡死了。

事實上,「反覆合併下降段」這個過程,完全可以用

O(N)

的時間複雜度來實現。我們從左到右掃描序列的每一個元素,並用一個棧來維護已經掃描的部分「掰平」後的各個水平段落。當掃描到一個新的元素的時候,先把它作為一個單獨的段落壓入棧頂,然後反覆檢視棧頂的兩個段落,如果它們違反了單調性,就把它們合併。這種實現的程式碼如下:

yhat

=

y

N

=

length

y

);

% 用輸入序列初始化輸出序列

bstart

=

zeros

1

N

);

bend

=

zeros

1

N

);

% 棧:bstart(i), bend(i) 記錄第 i 段的起止位置

% 此外 yhat 和 w 也兼用作棧,

% yhat(i) 與 w(i) 表示第 i 段的值和總權重

b

=

0

% 棧頂指標

for

i

=

1

N

% 依次掃描每個元素

b

=

b

+

1

% 由此往下三行:新元素作為單獨的段落入棧

yhat

b

=

yhat

i

);

w

b

=

w

i

);

bstart

b

=

i

bend

b

=

i

while

b

>

1

&&

yhat

b

<

yhat

b

-

1

% 棧頂兩個段落違反單調性

yhat

b

-

1

=

yhat

b

-

1

*

w

b

-

1

+

yhat

b

*

w

b

))

/

w

b

-

1

+

w

b

));

w

b

-

1

=

w

b

-

1

+

w

b

);

bend

b

-

1

=

bend

b

);

b

=

b

-

1

% 由此往上四行:棧頂兩個段落取加權平均合併

end

end

block

=

zeros

1

N

);

for

i

=

1

b

block

bstart

i

bend

i

))

=

i

% 由棧中資訊反推出輸出序列的每個元素位於第幾段

end

yhat

=

yhat

block

);

% 構建輸出序列

這段程式碼的主體迴圈沒有用到 Matlab 的黑科技,比較好懂,所以樣例資料的執行過程我就不寫了。

我的實現的時間複雜度為

O(N)

。雖然有時一個元素入棧會引發連鎖式的段落合併,但考慮演算法的整個執行過程,一共會有

N

個元素入棧,最多有

N-1

次段落合併,所以複雜度為

O(N)

。那麼 Matlab 自帶實現慢在哪兒了呢?仍然考慮極端輸入

y = \{10000, 1, 2, 3, 4, 5\}

w = \{10000, 1, 1, 1, 1, 1\}

。在迭代過程中,序列的尾部始終是單調遞增的,但 Matlab 的實現在每次迭代中都徒勞無功地在序列的尾部檢查是否有下降段。這就是它慢的原因。

三、附記

我實現的 lsqisotonic 函式,可以從 Mathworks File Exchange 上下載。這個函式位於 Matlab 安裝目錄下的 toolbox\stats\stats\private 子目錄,可以用我的版本替代原有版本。

對 MDS 感興趣的讀者,推薦閱讀 Modern Multidimensional Scaling 一書。其中第 8、9 章介紹的就是本文討論的 non-metric MDS。第 12 章介紹了 metric MDS 的另一種情形 classical MDS,它最小化的目標函式並不是 stress,而是另一種稱為 strain 的目標函式;其優點是求解過程不是迭代的,而是可以一步到位。Classical MDS 在 Matlab 中由 cmdscale 函式實現。