不知道為什麼知乎的程式碼塊不能對markdown的程式碼塊渲染,和客服提了需求也不知道啥時候能加上

吐槽吐完了,進入正題。在看開原始碼的時候,經常會遇到pytorch的各種各樣的高階(奇怪)操作,每次去百度太浪費時間,還不一定能看懂,而且有的部落格給的例子不太合適。所以在理解後整理成筆記也方便以後複習使用。各位如果發現哪裡不對,也請嚴厲指出,感激不盡。

index_select

index_select(tensor, dim, index)

對於tensor物件的第dim維度進行索引,在該維度上取index

例如,張量a是shape為

4\times 3 \times 28 \times 28

的四張三通道影象

# 選擇第一張和第三張圖

print(a。index_select(0, torch。tensor([0, 2]))。shape)

# 選擇R通道和B通道

print(a。index_select(1, torch。tensor([0, 2]))。shape)

mask_select

透過mask選取,返回一個一維張量

x = torch。randn(3, 4)

x

1。2045 2。4084 0。4001 1。1372

0。5596 1。5677 0。6219 -0。7954

1。3635 -1。2313 -0。5414 -1。8478

[torch。FloatTensor of size 3x4]

mask = x。ge(0。5)

mask

1 1 0 1

1 1 1 0

1 0 0 0

[torch。ByteTensor of size 3x4]

torch。masked_select(x, mask)

1。2045

2。4084

1。1372

0。5596

1。5677

0。6219

1。3635

[torch。FloatTensor of size 7]

non_zero

索引出張量中不為零的位置

>>> torch。nonzero(torch。Tensor([[0。6, 0。0, 0。0, 0。0],

。。。 [0。0, 0。4, 0。0, 0。0],

。。。 [0。0, 0。0, 1。2, 0。0],

。。。 [0。0, 0。0, 0。0,-0。4]]))

0 0

1 1

2 2

3 3

[torch。LongTensor of size 4x2]

gather

t = torch。Tensor([[1,2,3],[4,5,6]])

index_a = torch。LongTensor([[0,0,2],[0,1,0]])

index_b = torch。LongTensor([[0,1,1],[1,0,0]])

>>t

tensor([[1。, 2。, 3。],

[4。, 5。, 6。]])

>>> index_a

tensor([[0, 0, 2],

[0, 1, 0]])

>>> index_b

tensor([[0, 1, 1],

[1, 0, 0]])

>>> torch。gather(t,dim=1,index=index_a)

tensor([[1。, 1。, 3。],

[4。, 5。, 4。]])

>>> torch。gather(t,dim=0,index=index_b))

tensor([[1。, 5。, 6。],

[4。, 2。, 3。]])

dim=1時,索引為列,等價於

concat([t[0, [0,0,2]], t[1, [0,1,0]]], axis=0)

dim=0時,索引為行,等價於

concat([t[[0,1], 0], t[[1,0], 1], t[[1,0], 2], axis=1)

scatter

>>> x = torch。rand(2, 5)

>>> x

0。4319 0。6500 0。4080 0。8760 0。2355

0。2609 0。4711 0。8486 0。8573 0。1029

[torch。FloatTensor of size 2x5]

>>> torch。zeros(3, 5)。scatter_(

0, torch。LongTensor([[0, 1, 2, 0, 0],

[2, 0, 0, 1, 2]]), x)

0。4319 0。4711 0。8486 0。8760 0。2355

0。0000 0。6500 0。0000 0。8573 0。0000

0。2609 0。0000 0。4080 0。0000 0。1029

[torch。FloatTensor of size 3x5]

dim=0, x第一行,分別插入到結果的第[0,1,2,0,0]行中,

如x中的0。4080應插入到結果的第2列第2行(從0開始)。

repeat

>> x = torch。tensor([1, 2, 3])

>> x。repeat(3, 2)

tensor([[1, 2, 3, 1, 2, 3],

[1, 2, 3, 1, 2, 3],

[1, 2, 3, 1, 2, 3]])

unbind

刪除指定維度,返回元組形式

>>> x = torch。randn(3, 3)

>>> x

tensor([[ 0。4775, 0。0161, -0。9403],

[ 1。6109, 2。1144, 1。1833],

[-0。2656, 0。7772, 0。5989]])

>>> torch。unbind(x, dim=1)

(tensor([ 0。4775, 1。6109, -0。2656]),

tensor([0。0161, 2。1144, 0。7772]),

tensor([-0。9403, 1。1833, 0。5989]))

narrow

tensor。narrow(dim, start, num)

對dim維進行索引,從第start個開始取,取num個

data=torch。tensor([[1,2],[3,4],[5,6]])

#tensor([[1, 2],

[3, 4],

[5, 6]])

In [46]: data。narrow(1,1,1)

Out[46]:

tensor([[2],

[4],

[6]])

文章持續更新