pytorch 笔记:PAD_PACKED_SEQUENCE 和PACK_PADDED_SEQUENCE-CSDN博客
- 当使用
pack_padded_sequence
得到一个PackedSequence
对象并将其送入RNN(如LSTM或GRU)时,RNN内部会进行特定的操作来处理这种特殊的输入形式。 -
使用
PackedSequence
的主要好处是提高效率和计算速度。因为通过跳过填充部分,RNN不需要在这些部分进行无用的计算。这特别对于处理长度差异很大的批量序列时很有帮助。文章来源:https://www.toymoban.com/news/detail-739580.html
1 PackedSequence对象
-
PackedSequence
是一个命名元组,其中主要的两个属性是data
和batch_sizes
。-
data
是一个1D张量,包含所有非零长度序列的元素,按照其在批次中的顺序排列。 -
batch_sizes
是一个1D张量,表示每个时间步的批次大小
-
-
PackedSequence(data=tensor([6, 5, 1, 8, 7, 9]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
2 处理PackedSequence
- 当RNN遇到
PackedSequence
作为输入时,它会按照batch_sizes
中指定的方式对data
进行迭代 - 举例来说,上面例子中
batch_sizes
是[3,2,1]
,那么RNN首先处理前3个元素,然后是接下来的2个元素,最后是最后一个元素。 - 这允许RNN仅处理有效的序列部分,而跳过填充
3 输出
- 当RNN完成对
PackedSequence
的处理后,它的输出同样是一个PackedSequence
对象 - 可以使用
pad_packed_sequence
将其转换回常规的填充张量格式,以进行后续操作或损失计算 - 隐藏状态和单元状态(对于LSTM)也会被返回,这些状态与未打包的序列的处理方式相同
4 举例
- 假设我们有以下3个句子,我们想要用RNN进行处理:
I love AI
Hello
PyTorch is great
- 为了送入RNN,我们首先需要将这些句子转换为整数形式,并进行填充以保证它们在同一个批次中有相同的长度。
{
'PAD': 0,
'I': 1,
'love': 2,
'AI': 3,
'Hello': 4,
'PyTorch': 5,
'is': 6,
'great': 7
}
- 句子转换为整数后(id):
-
I love AI
->[1, 2, 3]
-
Hello
->[4]
-
PyTorch is great
->[5, 6, 7]
- 为了将它们放入同一个批次,我们进行填充:
[1, 2, 3]
[4, 0, 0]
[5, 6, 7]
- 假设每个单词的id 对应的embedding就是自己:
[[1], [2], [3]]
[[4], [0], [0]]
[[5], [6], [7]]
- 使用pack_padded_sequence进行处理
import torch
from torch.nn.utils.rnn import pack_padded_sequence
# 输入序列
input_seq = torch.tensor([[1,2,3], [4, 0, 0], [5,6,7]])
input_seq=input_seq.reshape(data.shape[0],input_seq.shape[1],1)
#每个单词id的embedding就是他自己
input_seq=input_seq.float()
#变成float是为了喂入RNN所需
# 序列的实际长度
lengths = [3, 1, 3]
# 使用pack_padded_sequence
packed = pack_padded_sequence(input_seq, lengths, batch_first=True,enforce_sorted=False)
packed
'''
PackedSequence(data=tensor([[1.],
[5.],
[4.],
[2.],
[6.],
[3.],
[7.]]), batch_sizes=tensor([3, 2, 2]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))
'''
- 现在,当我们将此
PackedSequence
送入RNN时,RNN首先处理前3个元素,因为batch_sizes
的第一个元素是3。然后,它处理接下来的2个元素,最后处理剩下的2个元素。-
具体来说,RNN会如下处理:文章来源地址https://www.toymoban.com/news/detail-739580.html
-
时间步1:根据
batch_sizes[0] = 3
,RNN同时处理三个句子的第一个元素。具体地说,它处理句子1的"I",句子2的"PyTorch",和句子3的"Hello"。 -
时间步2:根据
batch_sizes[1] = 2
,RNN处理接下来两个句子的第二个元素,即句子1的"love"和句子2的"is"。 -
时间步3:根据
batch_sizes[2] = 2
,RNN处理接下来两个句子的第三个元素,即句子1的"AI"和句子2的"great"。
-
时间步1:根据
-
- 喂入RNN
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self,input_size,hidden_size,num_layer=1):
super(SimpleRNN,self).__init__()
self.rnn=nn.RNN(input_size,
hidden_size,
num_layer,
batch_first=True)
def forward(self,x,hidden=None):
packed_output,h_n=self.rnn(x,hidden)
return packed_output,h_n
#单层的RNN
Srnn=SimpleRNN(1,3)
Srnn(packed_data)
'''
(PackedSequence(data=tensor([[-0.1207, -0.0247, 0.4188],
[-0.3173, -0.0499, 0.6838],
[-0.4900, -0.0751, 0.8415],
[-0.7051, -0.1611, 0.9610],
[-0.7497, -0.2117, 0.9829],
[-0.3361, -0.1660, 0.9329],
[ 0.4608, -0.0492, 0.1138]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 2]), sorted_indices=None, unsorted_indices=None),
tensor([[[-0.3361, -0.1660, 0.9329],
[ 0.4608, -0.0492, 0.1138],
[-0.4900, -0.0751, 0.8415]]], grad_fn=<StackBackward0>))
'''
- 得到的RNN输出是pack的,hidden state没有变化
-
Srnn=SimpleRNN(1,3) Srnn(packed_data) ''' (PackedSequence(data=tensor([[-0.1207, -0.0247, 0.4188], [-0.3173, -0.0499, 0.6838], [-0.4900, -0.0751, 0.8415], [-0.7051, -0.1611, 0.9610], [-0.7497, -0.2117, 0.9829], [-0.3361, -0.1660, 0.9329], [ 0.4608, -0.0492, 0.1138]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 2]), sorted_indices=None, unsorted_indices=None), tensor([[[-0.3361, -0.1660, 0.9329], [ 0.4608, -0.0492, 0.1138], [-0.4900, -0.0751, 0.8415]]], grad_fn=<StackBackward0>)) ''' pad_packed_sequence(Srnn(packed_data)[0],batch_first=True) ''' (tensor([[[-0.1207, -0.0247, 0.4188], [-0.7051, -0.1611, 0.9610], [-0.3361, -0.1660, 0.9329]], [[-0.3173, -0.0499, 0.6838], [-0.7497, -0.2117, 0.9829], [ 0.4608, -0.0492, 0.1138]], [[-0.4900, -0.0751, 0.8415], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000]]], grad_fn=<TransposeBackward0>), tensor([3, 3, 1])) '''
-
到了这里,关于pytorch笔记:PackedSequence对象送入RNN的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!