"""
Python code for conditional generative adversarial network
ver. 20230310
coded by Xiaoyang Zheng
Copyright (C) Xiaoyang Zheng and Ikumu Watanabe
Email: ZHENG.Xiaoyang@nims.go.jp; WATANABE.Ikumu@nims.go.jp
"""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from PIL import Image
import glob, datetime
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from solver import Solver
import math


## defines the path name of the directory to which system execution logs are to be output
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)


## for circular padding
def pad_dim(x, n=1):
    x = tf.concat((x[:,:,-n:,:], x, x[:,:,:n,:]), axis=-2)
    x = tf.concat((x[:,-n:,:,:], x, x[:,:n,:,:]), axis=1)
    return x


def delete_dim(x):
    x = x[:, 3:-3, 3:-3, :]
    return x


## architecture of generator
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        self.fc = layers.Dense(4*4*512)
        self.bn0 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv1 = layers.Conv2DTranspose(384, kernel_size=[4,4], strides=[2,2], padding='valid')
        self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv2 = layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='valid')
        self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv3 = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='valid')
        self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv4 = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='valid')
        self.bn4 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv5 = layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='valid')
        self.bn5 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv6 = layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='valid')


    def call(self, inputs_noise, input_condition, training=None):
        # inputs_noise: (b, 128), inputs_condition: (b, 2)
        inputs_noise = tf.cast(inputs_noise, dtype=tf.float32)
        input_condition = tf.cast(input_condition, dtype=tf.float32)
        net = tf.concat((inputs_noise, input_condition), axis=-1)
        net = self.fc(net)  # (b, 4*4*512)
        net = self.bn0(net, training=training)

        net = tf.reshape(net, [-1, 4, 4, 512])  # (b, 4, 4, 512)

        net = pad_dim(net, n=1)
        net = tf.nn.leaky_relu(self.bn1(self.conv1(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn2(self.conv2(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn3(self.conv3(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn4(self.conv4(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn5(self.conv5(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = self.conv6(net)
        net = delete_dim(net)

        net = tf.tanh(net)  # (b, 256, 256, 1)

        return net


## architecture of discriminator
class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = layers.Conv2D(16, kernel_size=4, strides=2, padding='valid')  # => (b, 128, 128, 16)
        self.conv2 = layers.Conv2D(32, kernel_size=4, strides=2, padding='valid')  # => (b, 64, 64, 32)
        self.conv3 = layers.Conv2D(64, kernel_size=4, strides=2, padding='valid')  # => (b, 32, 32, 64)
        self.conv4 = layers.Conv2D(128, kernel_size=4, strides=2, padding='valid')  # => (b, 16, 16, 128)
        self.conv5 = layers.Conv2D(256, kernel_size=4, strides=2, padding='valid')  # => (b, 8, 8, 256)
        self.conv6 = layers.Conv2D(512, kernel_size=4, strides=2, padding='valid')  # => (b, 2, 2, 512)

        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)


    def call(self, inputs_img, training=None):
        inputs_img = tf.cast(inputs_img, dtype=tf.float32)
        x = pad_dim(inputs_img)

        # inputs_img: (b, 256, 256, 1) => (b, 4, 4, 384)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv1(x)))
        x = pad_dim(x)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv2(x)))
        x = pad_dim(x)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv3(x)))
        x = pad_dim(x)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv4(x)))

        x = pad_dim(x)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv5(x)))
        x = pad_dim(x)
        x = layers.Dropout(0.3)(tf.nn.leaky_relu(self.conv6(x)))

        net = self.flatten(x)  # (b, 4*4*512)
        net = self.fc(net)

        return net


## draw figures for comparing the input and output values
def new_bi_plot():
    lims_young = [1, 15]
    lims_poi = [-0.4, 0.5]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[8, 4], constrained_layout=True)

    ax1.plot(lims_young, lims_young, 'gray')
    ax1.set_xlim(lims_young)
    ax1.set_ylim(lims_young)
    ax1.set_aspect(1)
    ax1.set_xlabel("Prediction")
    ax1.set_ylabel("Reference")
    ax1.set_title("Young's modulus [Pa]")

    ax2.plot(lims_poi, lims_poi, 'gray')
    ax2.set_xlim(lims_poi)
    ax2.set_ylim(lims_poi)
    ax2.set_aspect(1)
    ax2.set_xlabel("Prediction")
    ax2.set_ylabel("Reference")
    ax2.set_title("Poisson's ratio")

    return ax1, ax2


def bi_plot(pre, ref, ax1, ax2):
    x1 = pre[:, 0]*12+2
    y1 = ref[:, 0]*12+2
    x2 = pre[:, 1]*0.7-0.3
    y2 = ref[:, 1]*0.7-0.3
    ax1.scatter(x1, y1, c='C0', alpha=0.2)
    ax2.scatter(x2, y2, c='C1', alpha=0.2)


## generate the images of generated 2D geometries
def generate_and_save_images(fake_image, epoch):
    fig = plt.figure(figsize=(4, 4), dpi=300)

    for i in range(fake_image.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_image[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig(r'fake_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()


def celoss_ones(logits):
    # Label Smoothing, replace the label with a random number between 0.7 and 1.2
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits
                          (logits=logits, labels=tf.ones_like(logits)-0.3 + np.random.uniform(size=logits.shape) * 0.5))
    # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
    #                                                 y_true=tf.ones_like(logits))
    return tf.reduce_mean(loss)


def celoss_zeros(logits):
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits
                          (logits=logits, labels=tf.zeros_like(logits)+np.random.uniform(size=logits.shape) * 0.3))
    # loss = tf.keras.losses.categorical_crossentropy(y_pred=logits,
    #                                                 y_true=tf.zeros_like(logits))
    return tf.reduce_mean(loss)


def gradient_penalty(discriminator, real_seeds, fake_seeds):

    alpha = tf.random.uniform(shape=real_seeds.get_shape(), minval=0., maxval=1.)
    differences = fake_seeds - real_seeds  # This is different from MAGAN
    interpolates = real_seeds + (alpha * differences)
    with tf.GradientTape() as tape:
        tape.watch([interpolates])
        d_interplote_logits = discriminator(interpolates, training=True)
    grads = tape.gradient(d_interplote_logits, interpolates)

    # grads:[b, 64, 2] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1)  # [b]
    gp = tf.reduce_mean((gp - 1) ** 2)
    return gp


def d_loss_fn(generator, discriminator, batch_z, batch_x, condition, is_training):
    # 1. treat real image as real
    # 2. treat generated image as fake
    fake_image = generator(batch_z, condition, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)

    d_loss_real = celoss_zeros(d_real_logits)
    d_loss_fake = celoss_ones(d_fake_logits)
    gp = gradient_penalty(discriminator, batch_x, fake_image)

    loss = d_loss_real + d_loss_fake + 1 * gp

    return loss, gp


def g_loss_fn(generator, discriminator, solver, batch_z, condition, is_training):
    fake_image = generator(batch_z, condition, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = celoss_zeros(d_fake_logits)

    reference = solver(fake_image)
    mse = tf.reduce_mean(tf.losses.mean_squared_error(condition, reference))

    return loss+0.1*mse, mse


## convert inout shape [b,256,256] into [b,256,256,1]
def preprocess(x):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.expand_dims(x, -1)
    return x


def main():
    tf.random.set_seed(222)
    np.random.seed(222)

    assert tf.__version__.startswith('2.')

    # hyper parameters
    z_dim = 128
    epochs = 200
    batch_size = 32
    learning_rate = 0.0001
    is_training = True

    lims_property1 = [0, 1]
    lims_property2 = [0, 1]  ## this is to say if the properties of your dataset is within a square of [0,1]. If not change the limit.

    dataset_matrixes = np.load("yourdataset_geometries.npy")  ## replace the source file
    print(dataset_matrixes.shape)  # (b, 256, 256)
    dataset = tf.data.Dataset.from_tensor_slices(dataset_matrixes)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch_size)


    generator = Generator()
    discriminator = Discriminator()
    solver = Solver()
    solver.load_weights(r"/ckpt/solver_199.ckpt") ## replace the source file
    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epochs):

        for step, real_images in enumerate(dataset):
            batch = len(real_images)

            for _ in range(1):
                noise = tf.random.normal([batch, z_dim])  # (b, 128)
                condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size),
                   np.random.uniform(lims_property2[0], lims_property2[1], batch_size)], axis=1)  # get a condition from available data space

                with tf.GradientTape() as tape:
                    d_loss, gp = d_loss_fn(generator, discriminator, batch_z=noise, batch_x=real_images,
                                           condition = condition, is_training=is_training)
                grads = tape.gradient(d_loss, discriminator.trainable_variables)
                d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

            for _ in range(1):
                condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size),
                   np.random.uniform(lims_property2[0], lims_property2[1], batch_size)], axis=1)
                noise = tf.random.normal([batch, z_dim])
                with tf.GradientTape() as tape:
                    g_loss, mse = g_loss_fn(generator, discriminator, solver, batch_z=noise, condition=condition,
                                            is_training=is_training)
                grads = tape.gradient(g_loss, generator.trainable_variables)
                g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

            for _ in range(int(math.log10(epoch*0.9+1)//1.*9+3)):
                condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size),
                                      np.random.uniform(lims_property2[0], lims_property2[1], batch_size)], axis=1)
                # E and nu => E and G => [-1~1]
                noise = tf.random.normal([batch, z_dim])
                with tf.GradientTape() as tape:
                    fake_image = generator(noise, condition, is_training)
                    reference = solver(fake_image)
                    mse = tf.reduce_mean(tf.losses.mean_squared_error(condition, reference))
                grads = tape.gradient(mse, generator.trainable_variables)
                g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        # calculate MSE
        mse_1024 = 0
        # plot test
        ax1, ax2 = new_bi_plot()
        for i in range(32):
            noise_test = tf.random.normal([32, z_dim])
            condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size),
                                  np.random.uniform(lims_property2[0], lims_property2[1], batch_size)], axis=1)

            fake_image = generator(noise_test, condition, training=False)
            reference = solver(fake_image)
            # cal mse
            mse = tf.reduce_mean(tf.losses.mean_squared_error(condition, reference))
            mse_1024 += mse
            bi_plot(condition, reference, ax1, ax2)
        mse_1024 = mse_1024/32
        plt.savefig('results/%d_test.png' % epoch)
        plt.close()
        generate_and_save_images(fake_image[0:16], epoch)

        print(epoch, 'd-loss: ', float(d_loss), 'gp:', float(gp), 'g-loss', float(g_loss), 'mse', float(mse_1024))
        # vusulize it on tensorboard: http://localhost:6006/

        with summary_writer.as_default():

            tf.summary.scalar('d-loss: ', float(d_loss), step=epoch)
            tf.summary.scalar('gp: ', float(gp), step=epoch)
            tf.summary.scalar('g-loss: ', float(g_loss), step=epoch)
            tf.summary.scalar('mse: ', float(mse_1024), step=epoch)

        # save weights
        generator.save_weights('ckpt/generator_%d.ckpt' % (epoch))
        discriminator.save_weights('ckpt/discriminator_%d.ckpt' % epoch)


if __name__ == '__main__':
    main()
