"""
Python code for conditional generative adversarial network
ver. 20230310
coded by Xiaoyang Zheng
Copyright (C) Xiaoyang Zheng and Ikumu Watanabe
Email: ZHENG.Xiaoyang@nims.go.jp; WATANABE.Ikumu@nims.go.jp
"""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import glob, datetime
import matplotlib.pyplot as plt
import numpy as np


physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)


## draw figures for comparing the input (read) and output (predicted) values
def new_bi_plot():
    lims_property1 = [0, 1]
    lims_property2 = [0, 1]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[8, 4], constrained_layout=True)

    ax1.plot(lims_property1, lims_property1, 'gray')
    ax1.set_xlim(lims_property1)
    ax1.set_ylim(lims_property1)
    ax1.set_aspect(1)
    ax1.set_xlabel("True values")
    ax1.set_ylabel("Predictions")
    ax1.set_title("Property1")

    ax2.plot(lims_property2, lims_property2, 'gray')
    ax2.set_xlim(lims_property2)
    ax2.set_ylim(lims_property2)
    ax2.set_aspect(1)
    ax2.set_xlabel("True values")
    ax2.set_ylabel("Predictions")
    ax2.set_title("Property2")

    return ax1, ax2


def bi_plot(y, pred, ax1, ax2):
    x1 = y[:, 0]
    y1 = pred[:, 0]
    x2 = y[:, 1]
    y2 = pred[:, 1]
    ax1.scatter(x1, y1, c='C0', alpha=0.2)
    ax2.scatter(x2, y2, c='C1', alpha=0.2)


def pad_dim(x, n=1):
    x = tf.concat((x[:,:,-n:,:], x, x[:,:,:n,:]), axis=-2)
    x = tf.concat((x[:,-n:,:,:], x, x[:,:n,:,:]), axis=1)
    return x


class Solver(keras.Model):

    def __init__(self):
        super(Solver, self).__init__()

        # unit 1, [256,256,1] => [128,128,32]
        self.conv1a = layers.Conv2D(16, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv1b = layers.Conv2D(16, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max1 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 2, => [64,64,64]
        self.conv2a = layers.Conv2D(32, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv2b = layers.Conv2D(32, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max2 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 3, => [32,32,128]
        self.conv3a = layers.Conv2D(64, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv3b = layers.Conv2D(64, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max3 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 4, => [16,16,256]
        self.conv4a = layers.Conv2D(128, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv4b = layers.Conv2D(128, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max4 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 5, => [8,8,384]
        self.conv5a = layers.Conv2D(256, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv5b = layers.Conv2D(256, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max5 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 6, => [4,4,512]
        self.conv6a = layers.Conv2D(384, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv6b = layers.Conv2D(384, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max6 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 7, => [2,2,512]
        self.conv7a = layers.Conv2D(512, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv7b = layers.Conv2D(512, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max7 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 8, => [1,1,512]
        self.conv8a = layers.Conv2D(512, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.conv8b = layers.Conv2D(512, kernel_size=[3, 3], padding='valid', activation=tf.nn.relu)
        self.max8 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='valid')

        # unit 7, => [2]
        self.fc1 = layers.Dense(256, activation=tf.nn.relu)
        self.fc2 = layers.Dense(128, activation=tf.nn.relu)
        self.fc3 = layers.Dense(2, activation=None)

    def call(self, x):
        # inputs_noise: (b, 64), inputs_condition: (b, 3)
        x = pad_dim(x)
        x = self.conv1a(x)
        x = pad_dim(x)
        x = self.conv1b(x)
        x = self.max1(x)

        x = pad_dim(x)
        x = self.conv2a(x)
        x = pad_dim(x)
        x = self.conv2b(x)
        x = self.max2(x)

        x = pad_dim(x)
        x = self.conv3a(x)
        x = pad_dim(x)
        x = self.conv3b(x)
        x = self.max3(x)

        x = pad_dim(x)
        x = self.conv4a(x)
        x = pad_dim(x)
        x = self.conv4b(x)
        x = self.max4(x)

        x = pad_dim(x)
        x = self.conv5a(x)
        x = pad_dim(x)
        x = self.conv5b(x)
        x = self.max5(x)

        x = pad_dim(x)
        x = self.conv6a(x)
        x = pad_dim(x)
        x = self.conv6b(x)
        x = self.max6(x)

        x = pad_dim(x)
        x = self.conv7a(x)
        x = pad_dim(x)
        x = self.conv7b(x)
        x = self.max7(x)

        # x = pad_dim(x)
        # x = self.conv8a(x)
        # x = pad_dim(x)
        # x = self.conv8b(x)
        # x = self.max8(x)

        x = tf.keras.layers.Flatten()(x)

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x


## convert inout shape [b,256,256] into [b,256,256,1]
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.expand_dims(x, -1)
    y = tf.cast(y, dtype=tf.float32)
    return x, y


def main():

    # tf.random.set_seed(2345)

    dataset_matrixes = np.load("yourdataset_geometries.npy") ## dataset_matrixes shape [b,256,256]
    dataset_labels = np.load("yourdataset_labels.npy") ## dataset_matrixes shape [b,2]
    print(dataset_labels.shape)

    x_test = dataset_matrixes[int(0.8 * len(dataset_matrixes)):-1]
    y_test = dataset_labels[int(0.8 * len(dataset_labels)):-1]
    x = dataset_matrixes[0:int(0.8 * len(dataset_matrixes))]
    y = dataset_labels[0:int(0.8 * len(dataset_labels))]

    train_db = tf.data.Dataset.from_tensor_slices((x, y))
    train_db = train_db.shuffle(10000).map(preprocess).batch(32)

    test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    test_db = test_db.map(preprocess).batch(32)

    solver = Solver()
    # solver.load_weights('ckpt/ckpt8/solver_20.ckpt')
    solver.build(input_shape=[None, 256, 256, 1])
    # x = tf.random.normal([4, 256, 256, 1])
    # out = solver(x)
    # print(out.shape)

    optimizer = optimizers.Adam(learning_rate=2e-4, beta_1=0.8)


    for epoch in range(200):

        # plot train
        ax1, ax2 = new_bi_plot()

        for step, (x, y) in enumerate(train_db):

            with tf.GradientTape() as tape:
                logits = solver(x)
                loss = tf.losses.mean_squared_error(y, logits)
                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss, solver.trainable_variables)
            optimizer.apply_gradients(zip(grads, solver.trainable_variables))



            if step < 300:
                bi_plot(y, logits, ax1, ax2)

        plt.savefig('results/train/%d_train.png' % epoch)
        plt.close()
        print(epoch, "loss: ", loss.numpy())

        # plot test
        ax1, ax2 = new_bi_plot()
        total_sum = 0
        total_error = 0
        for x, y in test_db:

            pred = solver(x)
            loss_test = tf.losses.mean_squared_error(y, pred)
            loss_test = tf.reduce_mean(loss_test)

            total_sum += 1
            total_error += loss_test

            bi_plot(y, pred, ax1, ax2)

        mse = total_error / total_sum
        print(epoch, "mse: ", mse.numpy())
        plt.savefig('results/test/%d_test.png' % epoch)
        plt.close()

        with summary_writer.as_default():
            tf.summary.scalar('loss:', float(loss.numpy()), step=epoch)
            tf.summary.scalar('mse', float(mse.numpy()), step=epoch)

        solver.save_weights('ckpt/solver_%d.ckpt' % epoch)  ## save weights


if __name__ == "__main__":
    main()
