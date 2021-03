class Net(nn.Module):

def __init__(self, vocab_size, embedding_dim, hidden_size,

batch_size=25, num_layers=1):

super().__init__()

self.hidden_size = hidden_size

self.batch_size = batch_size

self.num_layers = num_layers

self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

self.rnn = nn.RNN(embedding_dim, hidden_size,

batch_first=True, num_layers=self.num_layers)

self.fc = nn.Linear(hidden_size, vocab_size)

self = self.to(self.device)



def init_hidden(self, batch_size=None):

if not batch_size:

batch_size = self.batch_size

self.hidden_state = torch.zeros(self.num_layers, batch_size,

self.hidden_size).to(self.device)



def forward(self, x):

x = self.embedding(x)

x, self.hidden_state = self.rnn(x, self.hidden_state)

x = self.fc(x)

return x