Crates.io | candle-birnn |
lib.rs | candle-birnn |
version | 0.2.3 |
source | src |
created_at | 2024-08-11 11:25:22.664468 |
updated_at | 2024-09-25 06:53:57.080859 |
description | implement Pytorch LSTM and BiDirectional LSTM with Candle |
homepage | |
repository | https://github.com/kigichang/candle-birnn |
max_upload_size | |
id | 1333223 |
size | 61,895 |
Implementing PyTorch LSTM inference using Candle, including the implementation of bidirectional LSTM inference.
lstm_test.pt: Results generated using a PyTorch demo program. The code is as follows:
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 1)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)
state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "lstm_test.pt")
bi_lstm_test.pt: Results generated using a PyTorch demo program. The code is as follows:
import torch
import torch.nn as nn
rnn = nn.LSTM(10, 20, 1, bidirectional=True)
input = torch.randn(5, 3, 10)
output, (hn, cn) = rnn(input)
state_dict = rnn.state_dict()
state_dict['input'] = input
state_dict['output'] = output
state_dict['hn'] = hn
state_dict['cn'] = cn
torch.save(state_dict, "bi_lstm_test.pt")