LSTM-Pytorch实现

绫波波 发布于 2023-07-27 2666 次阅读


概述

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。

LSTM门结构

nn.LSTM

__init__

  • input_size:输入特征x的维数,即每一行输入元素的个数。
  • hidden_size:隐藏层状态h的维数,即隐藏层节点的个数。隐藏层的维数值是自定义的,和输入的维度没有关系。
  • num_layers:LSTM堆叠的层数,默认值是1层,如果设置为2,将被处理为一个stacked LSTM,即第二个LSTM接收第一个LSTM的计算结果,如,第一层输入[X_0,X_1,X_2,...,X_t],计算出[h_0,h_1,h_2,...,h_t],第二层将 [h_0,h_1,h_2,...,h_t],作为[X_0,X_1,X_2,...,X_t]输入再次计算,输出最后的[h_0,h_1,h_2,...,h_t]。默认值为1。

LSTM.forward()

不同于RNN,LSTM在前向传播时返回两个中间变量h_t,c_t

out,(ht,ct)=lstm(x,[ht_1,ct_1])

#案例:建立一个LSTM,输入维数为100,隐藏层为20,LSTM层数为4
lstm=nn.LSTM(input_size=100,hidden_size=20,num_layers=4)
print(lstm)
x=torch.randn(10,3,100) # x:[seq,b,vec],可以理解为3个句子,每个句子有十个单词,每个单词encoding成为长度为100的vector
out,(h,c)=lstm(x) #[h_0,c_0]
print(out.shape,h.shape,c.shape)

#torch.Size([10,3,20]) out:[seq,b,h]
#torch.Size([4,3,20]) h/c:[num_layer,b,h]
#torch.Size([4,3,20]) h/c:[num_layer,b,h]

nn.LSTMCell

__init__

  • input_size:输入维度x
  • hidden_size:隐藏层h
  • num_layers:LSTM堆叠的层数

LSTMCell.forward()

ht,ct=lstmcell(xt,[ht_1,ct_1])

#Single layer
print('one layer lstm')
cell=nn.LSTMCell(input_size=100,hidden_size=20)
h=torch.zeros(3,20)
c=torch.zeros(3,20) #1层,3句话,hidden=20
for xt in x:
    h,c=cell(xt,[h,c]) #xt:[b,vec]
print(h.shape,c.shape)

torch.Size([3,20])
torch.Size([3,20])#ht/ct:[b,h]

#Two Layers
print('two layer lstm')
cell1=nn.LSTMCell(input_size=100,hidden_size=30)
cell2=nn.LSTMCell(input_size=30,hidden_size=20)
h1=torch.zeros(3,30)
c1=torch.zeros(3,30)
h2=torch.zeros(3,20)
c2=torch.zeros(3,20)
for xt in x:
    h1,c1=cell1(xt,[ht,c1])
    h2,c2=cell2(h1,[h2,c2])

print(h2.shape,c2.shape)
torch.Size(3,20)
torch.Size(3,20)

案例分析

背景介绍

本文使用一个交通流量预测案例来说明Pytorch中LSTM模型的建模方法。该数据来源于英国高速(国道)各个监测点的每日交通流量数据,其中数据标准格式如下:

英国高速公路交通检测流量数据基本格式

标准数据的采样频率为15分钟,字段包括监测点id、日期、时间、时间间隔、车辆长度(用于区别不同的车辆类型)、平均速度(英里/小时)和监测流量。案例中的预测模型仅使用车流量和平均速度作为模型输入的特征数据。

Talk is cheap, show me the code.
最后更新于 2023-07-28