1、 MNIST dataset

The MNIST data set is the handwritten numeral in the figure above.

2、 Gan principle (generation countermeasure network)

Gan network consists of two parts: one is generator (g), the other is discriminator (d)

At the beginning, G is composed of noises that obey certain distributions (such as Gaussian distribution). The generated images are sent to D to judge whether they are correct or not until the images generated by G are even judged to be true by D. In each round, d not only looks at the fake image generated by G, but also the real image in the data set. The weights in D network are updated based on the loss function values obtained by the former and the latter. So g and D are constantly updating the weights. The following figure is an example:

In V1, G is just a bunch of noise. If you have seen D in real images, you can definitely judge that G is false. Of course, G can also know the result that D judges that it is false, so g will update the weight. In V2, G can generate more realistic images for D to judge. Of course, in V2, D will first see the real image, and then judge the image generated by G. By analogy, continuous circulation is Gan’s idea.

3、 Training code

import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
img_ shape = ( opt.channels ,  opt.img_ size,  opt.img_ Size) # determine the format of image input as (1, 28, 28), because the MNIST data set is a gray image, so the channel is 1
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
# Loss function
adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor( imgs.size (0), 1).fill_ (1.0), requires_ Grad = false) # all 1
   fake = Variable(Tensor( imgs.size (0), 1).fill_ (0.0), requires_ Grad = false) # all 0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
   # -----------------
   # Train Generator
   # -----------------
   optimizer_ G.zero_ Grad () # clear the gradient of a batch on G network
   # Sample noise as generator input
   z = Variable(Tensor( np.random.normal (0, 1, ( imgs.shape [0],  opt.latent_ (DIM)))) # generated noise, the mean value is 0, the variance is 1, and the dimension is (64100)
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
   g_ loss.backward () # g_ Loss is used to update the weight of G network_ Results related to loss judgment
   # ---------------------
   # Train Discriminator
   # ---------------------
   optimizer_ D.zero_ Grad () # clear the gradient of a batch on d-network
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
   d_ loss.backward () # d_ Loss is used to update the weights of d-network
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_ image(gen_ imgs.data [:25], "images/%d.png" % batches_ Do, nrow = 5, normalize = true) # save 25 in a batch size
   if (epoch+1) %2 ==0:
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

Results of operation:

At the beginning, all G produced was noise:

Then it gradually presents the rudiment of numbers

Last generated result:

4、 Test code:

Import the model of the last saved generator:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g =  torch.load ('g199. PTH ') ~ import generator model
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
z = Variable(Tensor( np.random.normal (0, 1, (64, 100)))) # input noise
gen_ IMGs = g (z) # production picture
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

Generation result:

The above pytorch Gan forged handwritten MNIST data set is the whole content shared by Xiaobian. I hope it can give you a reference and support developer.