Delving into Wasserstein GANs: A Comprehensive Guide with Training Example
Rajesh Epili • June 4, 2024
Dive into the world of Wasserstein Generative Adversarial Networks (WGANs) – a game-changer in deep learning. Explore its stability, realistic data generation, and creative applications with a comprehensive TensorFlow code example. Uncover the power of WGANs in reshaping the landscape of generative models.
Generative Adversarial Networks (GANs) and Wasserstein GANs (WGANs)
In the realm of deep learning, generative adversarial networks (GANs) have emerged as a powerful tool for data generation and representation learning. Among the various GAN architectures, the Wasserstein GAN (WGAN) stands out for its unique training approach and ability to produce high-quality, realistic data samples.
What is WGAN?
Wasserstein GANs, introduced by Arjovsky et al. in 2017, aim to address the instability and mode collapse issues often encountered in traditional GANs. They achieve this by replacing the Jensen-Shannon (JS) divergence with the Wasserstein distance as their metric for measuring the difference between the real and generated data distributions.
The Wasserstein distance, also known as the Earth Mover's Distance (EMD), is a metric that quantifies the minimum cost of transporting mass from one distribution to another. This conceptual approach provides several advantages over the JS divergence, including:
Stability
The WGAN's loss function is less sensitive to hyperparameter choices, leading to more stable training and preventing the vanishing gradient problem.
Mode Collapse Prevention
WGANs are less prone to mode collapse, where the generator produces only a limited subset of the possible data variations.
Meaningful Learning Curves
The WGAN's loss function directly correlates with the quality of generated samples, allowing for better evaluation of the training process.
Uses of WGAN
Wasserstein GANs have found applications in various domains, including:
Image Generation
WGANs can be used to generate realistic images of various objects and scenes, including faces, animals, and landscapes.
Data Augmentation
By generating new data points that closely resemble the real data distribution, WGANs can enhance the quality of training data in various machine learning tasks.
Data Representation Learning
WGANs can be employed to learn the underlying manifold structure of complex datasets, enabling more effective dimensionality reduction and feature extraction.
Creative Applications
WGANs have been used to generate art, music, and even interactive experiences, demonstrating their potential for creative expression.
Disadvantages of WGAN
Despite their advantages, WGANs also have some limitations:
Gradient Penalty
One of the key aspects of WGAN training is the use of a gradient penalty term, which can be computationally expensive.
Training Sensitivity
WGANs are still sensitive to hyperparameter choices, particularly the Lipschitz constant used in the gradient penalty.
Rare Events Generation
WGANs may struggle to generate rare or unusual data points, as the Wasserstein distance may not be well-suited for such cases.
Training WGAN Model
To illustrate the training process of WGANs, consider this code using TensorFlow and the MNIST dataset. For more detailed explanation, please check our GitHub repository for trained WGAN models and source code.
Download and Import the Required Libraries
pip install numpy matplotlib tensorflow
import numpy as np from tensorflow.keras.layers import Input, Dense, LeakyReLU, BatchNormalization, Reshape, UpSampling2D, Conv2D, Flatten from tensorflow.keras.models import Sequential, Model from tensorflow.keras.optimizers import RMSprop import tensorflow as tf from tensorflow.keras.datasets import mnist
Load the MNIST Dataset
# Load the MNIST dataset (x_train, y_train), (_, _) = mnist.load_data() x_train = x_train.reshape(60000, 28, 28, 1) x_train = x_train.astype('float32') / 255
Define and Create the Generator and Discriminator Models
# Define the generator model def create_generator(): model = Sequential() model.add(Dense(256, use_bias=False, input_shape=(100,))) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512, use_bias=False)) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1024, use_bias=False)) model.add(BatchNormalization()) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(784, activation='tanh', use_bias=False)) model.add(Reshape((28, 28, 1))) return model # Define the discriminator model def create_discriminator(): model = Sequential() model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(28, 28, 1))) model.add(LeakyReLU(alpha=0.2)) model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same')) model.add(LeakyReLU(alpha=0.2)) model.add(Flatten()) model.add(Dense(1)) return model # Create the generator and discriminator models generator = create_generator() discriminator = create_discriminator()
Define and Create the WGAN Model
# Compile the discriminator for Wasserstein GAN discriminator.compile(optimizer=RMSprop(lr=0.00005), loss='mse') # Create the Wasserstein GAN model wgan_model = create_wgan(generator, discriminator) wgan_model.compile(optimizer=RMSprop(lr=0.00005), loss='mse')
Define the Training Loop and Save Copies of Generator & Discriminator
# Train the WGAN epochs = 200 batch_size = 128 clip_value = 0.01 # Clip weights to enforce Lipschitz continuity for epoch in range(epochs): for i in range(0, len(x_train) - batch_size + 1, batch_size): real_images = x_train[i:i + batch_size] noise = np.random.normal(0, 1, (batch_size, 100)) generated_images = generator.predict(noise) real_labels = np.ones((batch_size, 1)) fake_labels = -np.ones((batch_size, 1)) # Train the discriminator d_loss_real = discriminator.train_on_batch(real_images, real_labels) d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Clip discriminator weights for layer in discriminator.layers: weights = layer.get_weights() weights = [np.clip(w, -clip_value, clip_value) for w in weights] layer.set_weights(weights) # Train the generator noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = wgan_model.train_on_batch(noise, real_labels) # Print losses at the end of each epoch print(f'Epoch {epoch + 1}/{epochs}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}') # Save models every 25th epoch if (epoch + 1) % 25 == 0: generator.save(f'wgan_generator_epoch_{epoch + 1}.h5') discriminator.save(f'wgan_discriminator_epoch_{epoch + 1}.h5') # Save the final generator model generator.save('wgan_generator_final.h5') # Save the final discriminator model (optional for evaluation purposes) discriminator.save('wgan_discriminator_final.h5')
Generating Images and Rating Them Using Discriminator
import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.models import load_model # Load the generator and discriminator models generator = load_model('wgan_generator_epoch_200.h5') discriminator = load_model('wgan_discriminator_epoch_200.h5') # Generate a batch of random noise batch_size = 144 # Adjust as needed noise = np.random.normal(0, 1, (batch_size, 100)) # Generate images using the generator generated_images = generator.predict(noise) # Rescale the generated images to the range [0, 1] generated_images = 0.5 * generated_images + 0.5 # Display the generated images rows, cols = 12, 12 # Adjust as needed fig, axs = plt.subplots(rows, cols) fig.suptitle('Generated Images') idx = 0 for i in range(rows): for j in range(cols): axs[i, j].imshow(generated_images[idx].reshape(28, 28), cmap='gray') axs[i, j].axis('off') idx += 1 plt.show() # Evaluate generated images using the discriminator discriminator_predictions = discriminator.predict(generated_images) # Print discriminator predictions for each generated image for i in range(batch_size): print(f"Image {i + 1} - Discriminator Prediction: {discriminator_predictions[i][0]}")
Conclusion
In summary, Wasserstein Generative Adversarial Networks (WGANs) offer a powerful solution for stable deep learning in data generation and representation learning. By utilizing the Wasserstein distance, WGANs overcome issues like instability and mode collapse seen in traditional GANs.
Key advantages include stability in training, prevention of mode collapse, and meaningful learning curves. WGANs find applications in diverse areas such as image generation, data augmentation, and creative pursuits like art generation. Despite strengths, there are considerations like the computational cost of the gradient penalty.
The provided TensorFlow code exemplifies WGAN training using the MNIST dataset, offering a practical implementation for those diving
into this domain. The final section demonstrates image generation and evaluation using the trained models, showcasing the tangible outcomes of WGANs in practice.