"""
Python code for conditional generative adversarial network
ver. 20230310
coded by Xiaoyang Zheng
Copyright (C) Xiaoyang Zheng and Ikumu Watanabe
Email: ZHENG.Xiaoyang@nims.go.jp; WATANABE.Ikumu@nims.go.jp
"""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from PIL import Image
import glob, datetime
import matplotlib.pyplot as plt
from solver import Solver
from tensorflow import keras
from tensorflow.keras import layers


## for circular padding
def pad_dim(x, n=1):
    x = tf.concat((x[:,:,-n:,:], x, x[:,:,:n,:]), axis=-2)
    x = tf.concat((x[:,-n:,:,:], x, x[:,:n,:,:]), axis=1)
    return x


def delete_dim(x):
    x = x[:, 3:-3, 3:-3, :]
    return x


## architecture of generator
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        self.fc = layers.Dense(4*4*512)
        self.bn0 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv1 = layers.Conv2DTranspose(384, kernel_size=[4,4], strides=[2,2], padding='valid')
        self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv2 = layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='valid')
        self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv3 = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='valid')
        self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv4 = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='valid')
        self.bn4 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv5 = layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='valid')
        self.bn5 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

        self.conv6 = layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='valid')


    def call(self, inputs_noise, input_condition, training=None):
        # inputs_noise: (b, 128), inputs_condition: (b, 2)
        inputs_noise = tf.cast(inputs_noise, dtype=tf.float32)
        input_condition = tf.cast(input_condition, dtype=tf.float32)
        net = tf.concat((inputs_noise, input_condition), axis=-1)
        net = self.fc(net)  # (b, 4*4*512)
        net = self.bn0(net, training=training)

        net = tf.reshape(net, [-1, 4, 4, 512])  # (b, 4, 4, 512)

        net = pad_dim(net, n=1)
        net = tf.nn.leaky_relu(self.bn1(self.conv1(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn2(self.conv2(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn3(self.conv3(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn4(self.conv4(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = tf.nn.leaky_relu(self.bn5(self.conv5(net), training=training))
        net = delete_dim(net)

        net = pad_dim(net)
        net = self.conv6(net)
        net = delete_dim(net)

        net = tf.tanh(net)  # (b, 256, 256, 1)

        return net


## draw figures for comparing the input and output values
def new_bi_plot():
    lims_property1 = [0, 1]
    lims_property2 = [0, 1]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[8, 4], constrained_layout=True)

    ax1.plot(lims_property1, lims_property1, 'gray')
    ax1.set_xlim(lims_property1)
    ax1.set_ylim(lims_property1)
    ax1.set_aspect(1)
    ax1.set_xlabel("Input property 1")
    ax1.set_ylabel("Output property 1")

    ax2.plot(lims_property2, lims_property2, 'gray')
    ax2.set_xlim(lims_property2)
    ax2.set_ylim(lims_property2)
    ax2.set_aspect(1)
    ax2.set_xlabel("Input property 2")
    ax2.set_ylabel("Output property 2")
    return ax1, ax2


def bi_plot(pre, ref, ax1, ax2):
    x2 = pre[:, 1]
    y2 = ref[:, 1]
    x1 = pre[:, 0]
    y1 = ref[:, 0]
    ax1.scatter(x1, y1, c='C0', alpha=0.2)
    ax2.scatter(x2, y2, c='C1', alpha=0.2)


## generate the images of generated 2D geometries
def generate_and_save_images(fake_image, epoch):
    fig = plt.figure(figsize=(4, 4), dpi=300)
    for i in range(fake_image.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_image[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig(r'fake_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()


def main():
    tf.random.set_seed(222)
    np.random.seed(222)
    assert tf.__version__.startswith('2.')
    # hyper parameters
    z_dim = 128

    generator = Generator()
    generator.load_weights(r'ckpt/generator_199.ckpt')  ## replace the source file
    solver = Solver()
    solver.load_weights(r"/ckpt/solver_199.ckpt")  ## replace the source file

    # calculate MSE
    mse_1024 = 0
    # plot test
    ax1, ax2 = new_bi_plot()
    for i in range(64):
        noise_test = tf.random.normal([batch_size, z_dim])
        ## you can replace the condition using your target values
        condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size),
                              np.random.uniform(lims_property2[0], lims_property2[1], batch_size)],
                             axis=1)  # get a condition from available data space

        fake_image = generator(noise_test, condition, training=False)
        reference = solver(fake_image)
        # cal mse
        mse = tf.reduce_mean(tf.losses.mean_squared_error(condition, reference))
        mse_1024 += mse
        bi_plot(condition, reference, ax1, ax2)
    mse_1024 = mse_1024/64
    print(mse_1024)
    plt.savefig('results/%d_test.png' % 0)  ## save the comparison between input and output values
    plt.close()
    generate_and_save_images(fake_image, 0)  ## save generated images


if __name__ == '__main__':
    main()
