""" Python code for conditional generative adversarial network ver. 20230310 coded by Xiaoyang Zheng Copyright (C) Xiaoyang Zheng and Ikumu Watanabe Email: ZHENG.Xiaoyang@nims.go.jp; WATANABE.Ikumu@nims.go.jp """ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np import tensorflow as tf from PIL import Image import glob, datetime import matplotlib.pyplot as plt from solver import Solver from tensorflow import keras from tensorflow.keras import layers ## for circular padding def pad_dim(x, n=1): x = tf.concat((x[:,:,-n:,:], x, x[:,:,:n,:]), axis=-2) x = tf.concat((x[:,-n:,:,:], x, x[:,:n,:,:]), axis=1) return x def delete_dim(x): x = x[:, 3:-3, 3:-3, :] return x ## architecture of generator class Generator(keras.Model): def __init__(self): super(Generator, self).__init__() self.fc = layers.Dense(4*4*512) self.bn0 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv1 = layers.Conv2DTranspose(384, kernel_size=[4,4], strides=[2,2], padding='valid') self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv2 = layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='valid') self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv3 = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='valid') self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv4 = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='valid') self.bn4 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv5 = layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='valid') self.bn5 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5) self.conv6 = layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='valid') def call(self, inputs_noise, input_condition, training=None): # inputs_noise: (b, 128), inputs_condition: (b, 2) inputs_noise = tf.cast(inputs_noise, dtype=tf.float32) input_condition = tf.cast(input_condition, dtype=tf.float32) net = tf.concat((inputs_noise, input_condition), axis=-1) net = self.fc(net) # (b, 4*4*512) net = self.bn0(net, training=training) net = tf.reshape(net, [-1, 4, 4, 512]) # (b, 4, 4, 512) net = pad_dim(net, n=1) net = tf.nn.leaky_relu(self.bn1(self.conv1(net), training=training)) net = delete_dim(net) net = pad_dim(net) net = tf.nn.leaky_relu(self.bn2(self.conv2(net), training=training)) net = delete_dim(net) net = pad_dim(net) net = tf.nn.leaky_relu(self.bn3(self.conv3(net), training=training)) net = delete_dim(net) net = pad_dim(net) net = tf.nn.leaky_relu(self.bn4(self.conv4(net), training=training)) net = delete_dim(net) net = pad_dim(net) net = tf.nn.leaky_relu(self.bn5(self.conv5(net), training=training)) net = delete_dim(net) net = pad_dim(net) net = self.conv6(net) net = delete_dim(net) net = tf.tanh(net) # (b, 256, 256, 1) return net ## draw figures for comparing the input and output values def new_bi_plot(): lims_property1 = [0, 1] lims_property2 = [0, 1] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[8, 4], constrained_layout=True) ax1.plot(lims_property1, lims_property1, 'gray') ax1.set_xlim(lims_property1) ax1.set_ylim(lims_property1) ax1.set_aspect(1) ax1.set_xlabel("Input property 1") ax1.set_ylabel("Output property 1") ax2.plot(lims_property2, lims_property2, 'gray') ax2.set_xlim(lims_property2) ax2.set_ylim(lims_property2) ax2.set_aspect(1) ax2.set_xlabel("Input property 2") ax2.set_ylabel("Output property 2") return ax1, ax2 def bi_plot(pre, ref, ax1, ax2): x2 = pre[:, 1] y2 = ref[:, 1] x1 = pre[:, 0] y1 = ref[:, 0] ax1.scatter(x1, y1, c='C0', alpha=0.2) ax2.scatter(x2, y2, c='C1', alpha=0.2) ## generate the images of generated 2D geometries def generate_and_save_images(fake_image, epoch): fig = plt.figure(figsize=(4, 4), dpi=300) for i in range(fake_image.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow(fake_image[i, :, :, 0] * 127.5 + 127.5, cmap='gray') plt.axis('off') plt.savefig(r'fake_images/image_at_epoch_{:04d}.png'.format(epoch)) plt.close() def main(): tf.random.set_seed(222) np.random.seed(222) assert tf.__version__.startswith('2.') # hyper parameters z_dim = 128 generator = Generator() generator.load_weights(r'ckpt/generator_199.ckpt') ## replace the source file solver = Solver() solver.load_weights(r"/ckpt/solver_199.ckpt") ## replace the source file # calculate MSE mse_1024 = 0 # plot test ax1, ax2 = new_bi_plot() for i in range(64): noise_test = tf.random.normal([batch_size, z_dim]) ## you can replace the condition using your target values condition = np.stack([np.random.uniform(lims_property1[0], lims_property1[1], batch_size), np.random.uniform(lims_property2[0], lims_property2[1], batch_size)], axis=1) # get a condition from available data space fake_image = generator(noise_test, condition, training=False) reference = solver(fake_image) # cal mse mse = tf.reduce_mean(tf.losses.mean_squared_error(condition, reference)) mse_1024 += mse bi_plot(condition, reference, ax1, ax2) mse_1024 = mse_1024/64 print(mse_1024) plt.savefig('results/%d_test.png' % 0) ## save the comparison between input and output values plt.close() generate_and_save_images(fake_image, 0) ## save generated images if __name__ == '__main__': main()