Transfer Learning with GAN (CycleGAN) from scratch

This article is a tutorial of using Transfer Learning in CycleGAN from scratch by yourself

Seachaos
tree.rocks

--

figure: Transfer Learning GAN ( TFGAN )

We know Transfer Learning is powerful, it can reduce training time, have better generalization models, reduce training memory and so on.

So in this article, I will demonstrate how to build a GAN (CycleGAN) with transfer learning from scratch.

With Transfer Learning with CycleGAN:

  • You only need one generator.
  • No input labels are needed: it will learn and output opposite image by itself.
  • Reduced VRAM usage.
  • Only need few hours for training.
  • Much more stable results in GAN.

You will need a GPU with at least 16GB of VRAM for better results.
However, you can adjust the batch size / filters of CNN to reduce memory usage and make it compatible with GPUs with 12GB or 8GB VRAM.

You can find code in my GitHub:
https://github.com/Seachaos/Tree.Rocks/blob/main/TransferLearningGAN/TransferLearningCycleGAN.ipynb

The ideas

CycleGAN idea

it use two sets of images, we call it A and B. For example, set A is a lot of gray cats image and set B is a lot of tabby cats.

The Deep Learning Model will learn the style and features of each image set and translate between. so it can convert images from A to B and convert back from B to A.

For more detail you can read on CycleGAN tutorials on the TensorFlow website.

Figure: Concept of CycleGAN

Transfer Learning Idea

Transfer Learning using exists deep learning model ( pre-trained models ) to do another tasks, It’s like extracting what we need features from an existing model then to use another model.

With this idea, we can use image classifier models ( such as VGG16 ) to build GAN models. benefit from the image classifier model so the GAN model can easily understand input and output images.

Figure: The concept of Transfer Learning to GAN model, we use pre-trained models input image layers to our GAN.

Combine together: Transfer Learning with CycleGAN

Now we leverage by Transfer Learning, we use it as input layers for GAN, and the model can easy recognized input image and output corresponding image without label.

Figure: Concept of Transfer Learning with CycleGAN

Let’s take close look about our model.

Figure: more detail about transfer learning with GAN.

It use the input image to classifier model, and to get “x_cmd” it includes the source image information, like making the model see the entire image. then combine this information with other convolutional neural network ( CNN ) like U-Net to get output.

Implement Transfer Learning with GAN

We use TensorFlow, image input size is 128x128 and horse to zebra images as example.
You need install TensorFlow Datasets for the training data. ( if you already installed you can ignore this )

pip install tensorflow-datasets

First thing first, let’s import what we needed:

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

import random

Prepare Datasets

We use horse2zebra as our training datasets.
This section includes a lot of code for data preparation and augmentation.

Let’s get image sets A and B ( house and zebra )

dataset, dataset_info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)

train_a, train_b = dataset['trainA'], dataset['trainB']
test_a, test_b = dataset['testA'], dataset['testB']

Now, Let’s setup some variable we will be use.
If your GPU has less than 16GB VRAM, try reduce the batch size to fit your GPU ( you may also need do more training epoch ).

batch_size = 32  # set to 16 or less, if you don't have enough VRAM.

img_size = 128
big_img_size = 192

LR = 0.00012
  • big_img_size: is for image augment use.
  • LR: learning rate

Then we now extract datasets for training and test

def _process_img(image, label):
image = tf.image.resize(image, (big_img_size, big_img_size))
image = (image / 127.5) - 1.0
return image, label

def prepare_data(data, b=batch_size):
return data \
.cache() \
.map(_process_img, num_parallel_calls=tf.data.AUTOTUNE) \
.shuffle(b) \
.batch(b)

ds_train_a, ds_train_b = prepare_data(train_a), prepare_data(train_b)
ds_test_a, ds_test_b = prepare_data(test_a), prepare_data(test_b)


x_train_sets = [
tf.concat([a[0] for a in ds_train_a], axis=0),
tf.concat([b[0] for b in ds_train_b], axis=0),
]

x_test_sets = [
tf.concat([a[0] for a in ds_test_a], axis=0),
tf.concat([b[0] for b in ds_test_b], axis=0),
]

print('x_train_all: ', sum([s.shape[0] for s in x_train_sets]), x_train_sets[0].numpy().min(), x_train_sets[0].numpy().max())
print('x_test_all: ', sum([s.shape[0] for s in x_test_sets]), x_test_sets[0].numpy().min(), x_test_sets[0].numpy().max())

you should see output like this:

x_train_all:  2401 -1.0 1.0
x_test_all: 260 -1.0 1.0

Then, We need do image augment for training, so there have two function will be use: “get_x_train” and “get_x_test

These functions will give us image sets A and image sets B for training.

def _rand_pick(data, augment=True):
idx = np.random.choice(range(len(data)), size=batch_size, replace=False)
x = tf.gather(data, idx, axis=0)
if augment:
cx = random.uniform(1.0, 1.5)
cy = random.uniform(1.0, 1.5)
x = tf.image.random_crop(x, size=(batch_size, int(img_size * cx), int(img_size * cy), 3))
x = tf.image.random_flip_left_right(x)
x = tf.image.resize(x, (img_size, img_size))
return x

def get_x_train():
xa = _rand_pick(x_train_sets[0])
xb = _rand_pick(x_train_sets[1])
return xa, xb

def get_x_test():
xa = _rand_pick(x_test_sets[0], augment=False)
xb = _rand_pick(x_test_sets[1], augment=False)
return xa, xb

Let’s verify “get_x_train” is work or not.

# Verify "get_x_train" output
def cvtImg(x):
return (x + 1.0) / 2.0

def show(x, S=12):
x = cvtImg(x)
plt.figure(figsize=(15, 3))
for i in range(min(len(x), S)):
plt.subplot(1, S, i + 1)
plt.imshow(x[i])
plt.axis('off')
plt.show()

for _ in range(1):
xa, xb = get_x_train()
xa = xa.numpy()
print(xa.min(), xa.max(), xa.shape)
show(xa)
show(xb.numpy())

You should see image output like:

test “get_x_train”

1. Build Model — Transfer Learning

Now, we start build GAN model, but first, we need get image input layers first. Let’s extract it from VGG16 model.

We will use the layers “block2_conv2”, “block3_conv3”, “block4_conv3”… from the VGG16 model.
This is the output size of 64x64, 32x32, 16x16 …

You can use “base_model.summary()” to see more detail about VGG16 model.

base_model = tf.keras.applications.VGG16(input_shape=(img_size, img_size, 3), include_top=False)

x = x_input = base_model.input

outputs = [
'block2_conv2',
'block3_conv3',
'block4_conv3',
'block5_conv1',
'block5_pool',
]

x_output = [base_model.get_layer(n).output for n in outputs]
base_model = tf.keras.models.Model(x_input, x_output)

base_model.trainable = False

# base_model.summary() # if you want see more detail about VGG16

2. Build Model — Generator

We use GELU as our activation function, For convenience, we’ll define the“act” function for normalization and activation function.

act_name = 'gelu'

def act(x):
x = layers.LayerNormalization()(x)
x = layers.Activation(act_name)(x)
return x

This is layer function for generator model.
It takes “x_cmd” from input, it observed input image and find out what’s should be output.

def conv_with_cmd(x_img_input, x_cmd, f=64, sp=4):
x = layers.Dense(128)(x_cmd)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)

x = layers.Dense(f)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('sigmoid')(x)

x_g = layers.Reshape((1, 1, f))(x)

# ---

x = layers.Conv2D(f, kernel_size=3, padding='same')(x_img_input)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x = x * x_g


return x

Now we build generator model.

  • The x_input is input source ( image ), go to base_model (VGG16), then have output: [x64, x32, x16, x8, x4]
  • x_cmd” is from VGG16 last output model ( which is x4, 4x4 pixel ), use GlobalMaxPool2D and Dense to extract information.
  • The entire model is like U-Net, use “UpSampling2D” and “Concatenate” with base_model output.
    From x4 up to x8, x8 up to x16… and so on until output size, each with x_cmd information.
  • if you don’t have enough VRAM, can try reduce filter of CNN, but the output result may not well.
def create_gen_model():
# img input
x_input = layers.Input(shape=(img_size, img_size, 3))

# load base model
x_base_out = base_model(x_input)
[x64, x32, x16, x8, x4] = x_base_out


# x_cmd
x = x4
x = layers.Conv2D(256, kernel_size=3, padding='same')(x)
x = act(x)

x = layers.GlobalMaxPool2D()(x)


x = layers.Dense(128)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)
x_cmd = x


# GAN up
x = conv_with_cmd(x4, x_cmd, f=512)

# if you don't have enought VRAM, try reduce filters
for i, (x_cat, f) in enumerate([
(x8, 512),
(x16, 384),
(x32, 256),
(x64, 256),
(x_input, 256),
]):

# final output
x = layers.Conv2D(3, kernel_size=3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('tanh')(x)

return tf.keras.models.Model(x_input, x)

gen = create_gen_model()
# gen.summary() # if you want see more detail about model

3. Build Model — Discriminator

Let’s makeDiscriminator model.

  • use same from base_model.
  • use x8 and x4 from base_model.
  • softmax output for classifier output.

We will explain output later.

def create_dis_model():
x = x_input = layers.Input(shape=(img_size, img_size, 3))

[x64, x32, x16, x8, x4] = base_model(x_input)

x = x8
x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
x = act(x)
x = layers.MaxPool2D()(x)

x = layers.Concatenate()([x, x4])
x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
x = act(x)

x = layers.GlobalMaxPool2D()(x)

x = layers.Dense(384)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)

x = layers.Dense(128)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(act_name)(x)

x = layers.Dense(4)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('softmax')(x)

return tf.keras.models.Model(x_input, x)

dis = create_dis_model()
# dis.summary() # if you want see more detail about model

Here we define the output variable for training.
We have 4 output results for Discriminator, which is:

  • false A, all 0 ( fake image A )
  • false B, all 1 ( fake image B )
  • true A, all 2 ( real image A )
  • true B, all 3 ( real image B )

They all have same shape as batch_size for training.

y_false_a = np.zeros(batch_size)
y_false_b = np.full_like(y_false_a, 1)
y_true_a = np.full_like(y_false_a, 2)
y_true_b = np.full_like(y_false_a, 3)

And here are optimizer for dis and gen model,
we use AdamW as optimizer.

opt_gen = tf.keras.optimizers.AdamW(learning_rate=LR)
opt_dis = tf.keras.optimizers.AdamW(learning_rate=LR)

4. Training Model — Discriminator

After we have the generator model (gen) and the discriminator model (dis).
We do `train_dis` and `train_gen` function first then we do train function to run all.

Let’s see the discriminator training code first:

@tf.function
def _train_gen_cycle(x_real, y_t, y_f):
with tf.GradientTape(persistent=True) as tape:
x_fake = gen(x_real) # forward

# discriminator
y_p = dis(x_fake)
loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)

# revert
x_revert = gen(x_fake)
loss_revert = tf.losses.mse(x_real, x_revert)

loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)


g = tape.gradient(loss, gen.trainable_variables)
g = zip(g, gen.trainable_variables)
opt_gen.apply_gradients(g)

return float(loss)

def train_gen():
gen.trainable = True
dis.trainable = False
base_model.trainable = False

xa, xb = get_x_train()

loss_a = \
_train_gen_cycle(xa, y_true_b, y_true_a)

loss_b = \
_train_gen_cycle(xb, y_true_a, y_true_b)

return float(loss_a), float(loss_b)

train_gen()

Explain:

  • we need disable all model except discriminator training, so we use: dis.trainable = True
    gen.trainable = False
    base_model.trainable = False
  • get xa and xb for training images. ( image sets A and image sets B )
  • we use `generator` to get `xa_fake` by `gen.predict(xb, verbose=False)`
  • feed `_train_dis(xa, y_true_a)` to teach discriminator source image A is real image (y_true_a)
  • feed `_train_dis(xa_fake, y_false_a)` to teach discriminator that generator image A is fake image (y_false_a)
  • xb and xb_fake same as above step.

5. Training Model — Generator

Now the code for training Generator, see code first:

@tf.function
def _train_gen_cycle(x_real, y_t, y_f):
with tf.GradientTape(persistent=True) as tape:
x_fake = gen(x_real) # forward

# discriminator
y_p = dis(x_fake)
loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)

# revert
x_revert = gen(x_fake)
loss_revert = tf.losses.mse(x_real, x_revert)

loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)


g = tape.gradient(loss, gen.trainable_variables)
g = zip(g, gen.trainable_variables)
opt_gen.apply_gradients(g)

return float(loss)

def train_gen():
gen.trainable = True
dis.trainable = False
base_model.trainable = False

xa, xb = get_x_train()

loss_a = \
_train_gen_cycle(xa, y_true_b, y_true_a)

loss_b = \
_train_gen_cycle(xb, y_true_a, y_true_b)

return float(loss_a), float(loss_b)

train_gen()

Explain:

  • Same as discriminator training, disable all model except discriminator training, so we use:
    gen.trainable = True
    dis.trainable = False
    base_model.trainable = False
  • get xa and xb for training images. ( image sets A and image sets B )
  • the function _train_gen_cycle is take (source image, y true, y false )

in `_train_gen_cycle` function:

  • We use the gen model to generate fake image then feed to dis (discriminator) to get ouptut (y_p) , back propagation to the gen model as it is true image. ( use loss function crossentropy for y_t, y_p )
  • Use gen to revert image it self, example: A->B then do B->A, and it should be the same as input. ( use loss function MSE )

6. Preview — Before total training

Let’s preview output before we start:

def _preview(x_real, title=None):
x_fake = gen.predict(x_real, verbose=0)
x_real = cvtImg(x_real.numpy())
x_fake = cvtImg(x_fake)


plt.figure(figsize=(25, 5))
if title:
plt.suptitle(title)
s = min(batch_size, 9)
for i in range(s):
plt.subplot(2, s, i + 1)
plt.axis('off')
plt.imshow(x_real[i])
plt.subplot(2, s, i + 1 + s)
plt.axis('off')
plt.imshow(x_fake[i])
plt.show()

def preview(useTest=True):
if useTest:
xa, xb = get_x_test()
else:
xa, xb = get_x_train()
_preview(xa[:9], 'A -> B')
_preview(xb[:9], 'B -> A')

preview()

You should see output like this: ( results may different )

figure: before total training.

7. Total Training

Here we run total training:

def train():
bar = trange(200)
for _ in bar:
lda, ldb = train_dis()
lga, lgb = train_gen()
msg = f'gen: {lga:.5f}, {lgb:.5f} | dis: {lda:.5f}, {ldb:.5f}'
bar.set_description(msg)

def go():
for i in trange(50):
train()
if i % 5 == 0:
preview()

opt_dis.learning_rate = opt_dis.learning_rate * 0.98
opt_gen.learning_rate = opt_gen.learning_rate * 0.98
lg = opt_gen.learning_rate.numpy()
ld = opt_dis.learning_rate.numpy()
print(f'run: {i}')
print(f'LR gen: {lg:.7f}')
print(f'LR dis: {ld:.7f}')


go()
preview()

We reduce learning rate during training.
It may take a few hours depends on your GPU power.

During training, you may see some output for preview:

You may get some results like:

That’s all :D

--

--