1

So I am currently trying to implement an LSTM on Pytorch, but for some reason the loss is not decreasing. Here is my network:

```
class MyNN(nn.Module):
def __init__(self, input_size=3, seq_len=107, pred_len=68, hidden_size=50, num_layers=1, dropout=0.2):
super().__init__()
self.pred_len = pred_len
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=True,
batch_first=True
)
self.linear = nn.Linear(hidden_size*2, 5)
def forward(self, X):
lstm_output, (hidden_state, cell_state) = self.rnn(X)
labels = self.linear(lstm_output[:, :self.pred_len, :])
return lstm_output, labels
```

And my training loop

```
LEARNING_RATE = 1e-2
net = MyNN(num_layers=1, dropout=0)
compute_loss = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
all_loss = []
for data in tqdm(list(train_loader)):
X, y = data
optimizer.zero_grad()
lstm_output, output = net(X.float())
# Computing the loss
loss = compute_loss(y, output)
all_loss.append(loss)
loss.backward()
optimizer.step()
# Plot
plt.plot(all_loss, marker=".")
plt.xlabel("Epoch")
plt.xlabel("Loss")
plt.show()
```

I have been trying to look for what the hell I am doing wrong but I have no idea. Also, before I used a keras LSTM and it worked well on the dataset.

Any help? Thanks!

Hmm had already tried lowering the LR. What I didn't do was trying to clip the gradients, because I did not think this could be exploding gradients. I will try it and see if it solves it. Thanks! – David Marques – 2020-09-22T13:18:00.483

If lower LR didn't help gradient clipping shouldn't work either because they both do nothing more than scale/clip gradients. Take a look at my edit though. – YuseqYaseq – 2020-09-22T13:23:12.167

Omg that makes total sense... I knew I was making some idiot mistake. I spent an afternoon yesterday thinking I was calling the loss function in a wrong way or something like that. I will check it later when I have the chance and then accept the answer. Thank you! – David Marques – 2020-09-22T13:28:46.970