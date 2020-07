import torch

from torch.utils.data import DataLoader

import torchvision

from torchvision import transforms

from torchvision.datasets import CIFAR10

import numpy as np

import matplotlib.pyplot as plt



def imshow(img):

img = torchvision.utils.make_grid(img)

img = img / 2 + 0.5

npimg = img.detach().numpy()

plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.show()



def train(net, criterion, optimizer, epochs, trainloader, input_size):

losses = []

output_and_label = []



for epoch in range(1, epochs+1):

print(f'epoch: {epoch}, ', end='')

running_loss = 0.0

for counter, (img, _) in enumerate(trainloader, 1):

optimizer.zero_grad()

img = img.reshape(-1, input_size)

output = net(img)

loss = criterion(output, img)

loss.backward()

optimizer.step()

running_loss += loss.item()

avg_loss = running_loss / counter

losses.append(avg_loss)

print('loss:', avg_loss)

output_and_label.append((output, img))

print('finished')

return output_and_label, losses