Create GANs model using TensorFlow TFGAN library
GAN originally introduced in 2014 by Ian Goodfellow.GAN is basically a model where we have two separate models fighting against each other. We want to be able to reproduce images that are similar to images but it not going to be exactly the same images so we’re building it using generative networks.
If you don’t have enough data then manufacture data by using the data that you already have and this is what GAN do.
Normally in the classifier or something else in deep learning, we’re just trying to predict the classification we don’t really care actual distribution whereas with GAN we much more interested in being able to reproduce a distribution.
The generative model is to come up with new versions of images and the discriminator check images and say this is a real image or this is fake images.
The whole reason why this works is that as the Generator start to get better at making the fake image and the discriminator has to get better at detecting it. The whole thing with GAN is we want to balance that we’re basically trying to optimize.
It uses a very simple concept to put some latent noise which we call Z into a generator. We also take some real data and stick that in and then randomly present these to a discriminator who has to decide are they real or fake. Then use that to score loss and update weights to get the model better at being able to take that noise and be able to reproduce images.
We basically take the generative loss and the discriminative loss we add those together that’s our total loss of the network. We flip the loss of the generator because the generator is actually trying to push the loss up and then we get a total loss and we then use that.
TFGAN is a lightweight library for GANs in TensorFlow. It has a set of pre-made losses and GAN components with a lot of things. With TFGAN you can basically just take all these off-the-shelf losses and stuff that is built for you and then you can put it into a model it’s a much simpler way to be able to make GAN. You can also make a GAN an GANEstimator.
We trying to do the same sort of thing with MNIST. First, We do our imports and load our data.
import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.layers import UpSampling2D, Conv2D, BatchNormalization,Reshape, Activation, Dense, Flatten, MaxPooling2D from tensorflow.keras.models import Sequential import matplotlib.pyplot as plt tfgan = tf.contrib.gan
Here’s my input pipeline for pulling the data into them in a format. That is sort of estimator friendly for the model.
def train_input_fn(batch_size, num_epochs, noise_dim): def resize_image(features): image = tf.image.convert_image_dtype(features["image"], dtype=tf.float32) image = (image - 0.5) * 2 image = tf.image.resize(image, size=(28, 28)) noise = tf.random_normal([noise_dim], name="train_noise") return noise,image def _input_fn(): dataset = tfds.load("mnist",split=tfds.Split.TRAIN) dataset = dataset.map(resize_image) dataset = dataset.batch(batch_size, drop_remainder=True).repeat(num_epochs) return dataset return _input_fn
Generator and Discriminator
We’ve our generator here and you can see the unconditional generator. We’re basically passing in some noise that latency matrix or vector that we had before.
def generator(): model = Sequential() model.add(Dense(input_dim=64, units=512)) model.add(Activation('relu')) model.add(Dense(64*7*7)) model.add(BatchNormalization()) model.add(Activation('relu')) model.add(Reshape((7, 7, 64), input_shape=(64*7*7,))) # 7x7 image model.add(UpSampling2D(size=(2, 2))) # 14x14 image model.add(Conv2D(64, (5, 5), padding='same')) model.add(Activation('relu')) model.add(UpSampling2D(size=(2, 2))) # 28x28 image model.add(Conv2D(1, (5, 5), padding='same')) model.add(Activation('relu')) return model def generator_fn(inputs, mode): is_training = mode == tf.estimator.ModeKeys.TRAIN model = generator() return model(inputs, is_training)
Discriminator whole job is to detect what’s not real image and you see that here we’ve our model is much simpler.
def discriminator(): model = Sequential() model.add(Conv2D(32,(5, 5),padding='same',input_shape=(28, 28, 1))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Conv2D(64, (5, 5))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dense(1)) return model def discriminator_fn(inputs, conditioning, mode): is_training = mode == tf.estimator.ModeKeys.TRAIN model = discriminator() return model(inputs, is_training)
GANs Loss Function
One of the cool things with TFGAN is it has all the loss functions made for you so you don’t have to go through and encode them it also optimized.
One of the biggest things that’s changed in GAN over time and one of the things that the sort of improved GAN is different sort of loss functions different ways of dealing with these sorts things and TFGAN has a lot of these built in.
This is a vanilla GAN but this is basically doing it as GANEstimator. Estimator has a few key functions like the model function, the input functions and some sort of evaluation function.
def gan(): # hyper param model_dir = "../logs-2/" batch_size = 64 num_epochs = 10 noise_dim = 64 # Run Configuration run_config = tf.estimator.RunConfig( model_dir=model_dir, save_summary_steps=100, save_checkpoints_steps=1000) gan_estimator = tfgan.estimator.GANEstimator( config=run_config, generator_fn=generator_fn, discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, generator_optimizer=tf.train.AdamOptimizer(0.0002, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(0.0002, 0.5), add_summaries=tfgan.estimator.SummaryType.IMAGES) input_fn = train_input_fn(batch_size, num_epochs, noise_dim) model = gan_estimator.train(input_fn, max_steps=None) return model
For the model, you can just take it and built this estimator and just say for the generator function use this, for the discriminative function use this for the generator loss function use this.
gan_model = gan()
Then literally train it just like we do any other estimator. You can run the evaluation and then print some out.
predict_batch=32 input_fn = predict_input_fn(predict_batch,64) predict = gan_model.predict(input_fn) result=[next(predict) for _ in range(predict_batch)]
By here we’re actually able to produce real MNIST digits that are not actually digits that were given to the model.