1、 MNIST dataset

The MNIST data set is the handwritten numeral in the figure above.

2、 Gan principle (generation countermeasure network)

Gan network consists of two parts: one is generator (g), the other is discriminator (d)

At the beginning, G is composed of noises that obey certain distributions (such as Gaussian distribution). The generated images are sent to D to judge whether they are correct or not until the images generated by G are even judged to be true by D. In each round, d not only looks at the fake image generated by G, but also the real image in the data set. The weights in D network are updated based on the loss function values obtained by the former and the latter. So g and D are constantly updating the weights. The following figure is an example:

In V1, G is just a bunch of noise. If you have seen D in real images, you can definitely judge that G is false. Of course, G can also know the result that D judges that it is false, so g will update the weight. In V2, G can generate more realistic images for D to judge. Of course, in V2, D will first see the real image, and then judge the image generated by G. By analogy, continuous circulation is Gan’s idea.

3、 Training code

``````import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torchvision import datasets

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_ shape = ( opt.channels ,  opt.img_ size,  opt.img_ Size) # determine the format of image input as (1, 28, 28), because the MNIST data set is a gray image, so the channel is 1
cuda = True if torch.cuda.is_available() else False

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity

# Loss function

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
generator.cuda()
discriminator.cuda()

os.makedirs("../../data/mnist", exist_ok=True)
datasets.MNIST(
"../../data/mnist",
train=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
# Training
# ----------
if __name__ == '__main__':
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# print(imgs.shape)
valid = Variable(Tensor( imgs.size (0), 1).fill_ (1.0), requires_ Grad = false) # all 1
fake = Variable(Tensor( imgs.size (0), 1).fill_ (0.0), requires_ Grad = false) # all 0
# Configure input
real_imgs = Variable(imgs.type(Tensor))

# -----------------
# Train Generator
# -----------------

optimizer_ G.zero_ Grad () # clear the gradient of a batch on G network

# Sample noise as generator input
z = Variable(Tensor( np.random.normal (0, 1, ( imgs.shape [0],  opt.latent_ (DIM)))) # generated noise, the mean value is 0, the variance is 1, and the dimension is (64100)
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator

g_ loss.backward () # g_ Loss is used to update the weight of G network_ Results related to loss judgment
optimizer_G.step()

# ---------------------
# Train Discriminator
# ---------------------

optimizer_ D.zero_ Grad () # clear the gradient of a batch on d-network
# Measure discriminator's ability to classify real from generated samples
d_loss = (real_loss + fake_loss) / 2

d_ loss.backward () # d_ Loss is used to update the weights of d-network
optimizer_D.step()

print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_ image(gen_ imgs.data [:25], "images/%d.png" % batches_ Do, nrow = 5, normalize = true) # save 25 in a batch size
if (epoch+1) %2 ==0:
print('save..')
torch.save(generator,'g%d.pth' % epoch)
torch.save(discriminator,'d%d.pth' % epoch)``````

Results of operation:

At the beginning, all G produced was noise:

Then it gradually presents the rudiment of numbers

Last generated result:

4、 Test code:

Import the model of the last saved generator:

``````from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g =  torch.load ('g199. PTH ') ~ import generator model