From cb09e252af6d3da9787a788ae4ea6aba050c4ae4 Mon Sep 17 00:00:00 2001 From: Nattapon J <nj7@illinois.edu> Date: Thu, 12 Oct 2023 15:00:18 -0500 Subject: [PATCH] V-Unet and GAN model --- GAN_model.py | 234 +++++++++++++++++++++++++++++++++++++++++++++++++ VAE-unet.py | 201 ++++++++++++++++++++++++++++++++++++++++++ data_util.py | 1 + eval_gan.py | 162 ++++++++++++++++++++++++++++++++++ two-h-model.py | 18 ++-- 5 files changed, 610 insertions(+), 6 deletions(-) create mode 100644 GAN_model.py create mode 100644 VAE-unet.py create mode 100644 eval_gan.py diff --git a/GAN_model.py b/GAN_model.py new file mode 100644 index 0000000..b4eff11 --- /dev/null +++ b/GAN_model.py @@ -0,0 +1,234 @@ +import tensorflow as tf +import random +import os +import numpy as np +from tensorflow.keras import layers +from tqdm import tqdm +from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, TensorBoard + + +#You can change the id for each run so that all models and stats are saved separately. +name_id = "GAN_test" +prediction_path = './predicts_'+name_id+'/' +log_path = './logs_'+name_id+'/' +model_path = './models_'+name_id+'/' +save_model_path = './models_'+name_id+'/' + +# Create the folder if it does not exist +os.makedirs(model_path, exist_ok=True) +os.makedirs(prediction_path, exist_ok=True) + +# Hyperparameters +z_dim = 100 +input_shape = (256, 256, 6) + +def build_generator(): + input_image = layers.Input(shape=input_shape, name='input_image') + z = layers.Input(shape=(z_dim,), name='z') + + # Transform the noise + z_transformed = layers.Dense(256*256*6, activation='relu')(z) + z_transformed = layers.Reshape((256, 256, 6))(z_transformed) + + # Concatenate the transformed noise and input image + combined = layers.Concatenate(axis=-1)([input_image, z_transformed]) + + x = layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')(combined) + x = layers.BatchNormalization()(x) + + x = layers.Conv2D(32, (3, 3), strides=(1, 1), padding='same', activation='relu')(x) + x = layers.BatchNormalization()(x) + + # Generate a segmentation map + output = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='output')(x) + + model = tf.keras.Model(inputs=[input_image, z], outputs=output) + return model + +# Discriminator +def build_discriminator(): + input_image = layers.Input(shape=input_shape, name='input_image') + segmentation_map = layers.Input(shape=(256, 256, 1), name='segmentation_map') + + combined = layers.Concatenate()([input_image, segmentation_map]) + + x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', activation='relu')(combined) + x = layers.BatchNormalization()(x) + + x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', activation='relu')(x) + x = layers.BatchNormalization()(x) + + x = layers.Flatten()(x) + output = layers.Dense(1, activation='sigmoid')(x) + + model = tf.keras.Model(inputs=[input_image, segmentation_map], outputs=output) + return model + +generator = build_generator() +discriminator = build_discriminator() + +# Losses and optimizers +bce_loss = tf.keras.losses.BinaryCrossentropy() +optimizer_g = tf.keras.optimizers.Adam(1e-4) +optimizer_d = tf.keras.optimizers.Adam(1e-4) + +# Define training steps +@tf.function +def train_step(images, masks, batch_size): + # Generate noise for generator + noise = tf.random.normal([tf.shape(images)[0], z_dim]) + + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: + # Generate fake segmentation maps using generator + generated_masks = generator([images, noise], training=True) + + # Discriminator output for real and fake images + real_output = discriminator([images, masks], training=True) + fake_output = discriminator([images, generated_masks], training=True) + + # Generator loss: Adversarial loss + L1 loss for generated mask + gen_loss = bce_loss(tf.ones_like(fake_output), fake_output) + l1_loss = tf.reduce_mean(tf.abs(masks - generated_masks)) + total_gen_loss = gen_loss + (100.0 * l1_loss) + + # Discriminator loss + real_loss = bce_loss(tf.ones_like(real_output), real_output) + fake_loss = bce_loss(tf.zeros_like(fake_output), fake_output) + disc_loss = (real_loss + fake_loss) * 0.5 + + # Calculate gradients and apply updates + gradients_of_generator = gen_tape.gradient(total_gen_loss, generator.trainable_variables) + gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) + + optimizer_g.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) + optimizer_d.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) + + return total_gen_loss, disc_loss + +# Helper function to compute F1 score +def compute_f1_score(true_masks, pred_masks): + pred_masks_bin = tf.cast(pred_masks > 0.5, tf.float32) + + TP = tf.reduce_sum(true_masks * pred_masks_bin) + FP = tf.reduce_sum((1 - true_masks) * pred_masks_bin) + FN = tf.reduce_sum(true_masks * (1 - pred_masks_bin)) + F1 = (2 * TP) / (2 * TP + FP + FN + 1e-7) + return F1.numpy() + +def dice_coef(y_true, y_pred, smooth=1.0): + intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3]) + union = tf.reduce_sum(y_true + y_pred, axis=[1,2,3]) + return tf.reduce_mean((2. * intersection + smooth) / (union + smooth), axis=0) + +def dice_coef_loss(y_true, y_pred): + return 1 - dice_coef(y_true, y_pred) + + +generator.compile(optimizer=optimizer_g, loss=dice_coef_loss) +discriminator.compile(optimizer=optimizer_d, loss=bce_loss) + +def train(dataset, val_dataset, epochs, batch_size): + history = {'train_loss': [], 'val_loss': [], 'f1_score': []} + + # Initialize the best_val_loss with a high value + best_val_loss = float('inf') + + for epoch in tqdm(range(epochs), desc="Training"): + train_losses = [] # store training losses for this epoch + + # Training + for image_batch, mask_batch in dataset: + gen_loss, disc_loss = train_step(image_batch, mask_batch, batch_size) # Assuming you already have the `train_step` function + train_losses.append(gen_loss) + + # Validation + total_f1_score = 0 + val_losses = [] # store validation losses for this epoch + + for val_image_batch, val_mask_batch in val_dataset: + # Generate segmentation masks using the generator + noise = tf.random.normal([val_image_batch.shape[0], z_dim]) + val_pred_masks = generator([val_image_batch, noise], training=False) + + # Compute validation loss and F1 score + val_loss = dice_coef_loss(val_mask_batch, val_pred_masks) # Implement your own loss_function or use a suitable one from TF + total_f1_score += compute_f1_score(val_mask_batch, val_pred_masks) + + val_losses.append(val_loss.numpy()) + + # Averaging metrics + avg_train_loss = np.mean(train_losses) + avg_val_loss = np.mean(val_losses) + avg_f1_score = total_f1_score / len(val_dataset) + + history['train_loss'].append(avg_train_loss) + history['val_loss'].append(avg_val_loss) + history['f1_score'].append(avg_f1_score) + + # Check if this epoch's val_loss is better than the best so far + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + generator.save(os.path.join(model_path, name_id + '_generator.h5')) + + return history + + +def load_img(filename, map_dir, legend_dir, seg_dir): + mapName = tf.strings.join([map_dir, filename[0]], separator='/') + legendName = tf.strings.join([legend_dir, filename[1]], separator='/') + + map_img = tf.io.read_file(mapName) + map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 + + legend_img = tf.io.read_file(legendName) + legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 + + map_img = tf.concat(axis=2, values=[map_img, legend_img]) + map_img = map_img*2.0 - 1.0 + map_img = tf.image.resize(map_img, [256, 256]) + + segName = tf.strings.join([seg_dir, filename[0]], separator='/') + seg_img = tf.io.read_file(segName) + seg_img = tf.io.decode_png(seg_img) + seg_img = tf.image.resize(seg_img, [256, 256]) + + return map_img, seg_img + +def load_train_img(filename): + return load_img(filename, + '/projects/bbym/shared/all_patched_data/training/poly/map_patches', + '/projects/bbym/shared/all_patched_data/training/poly/legend', + '/projects/bbym/shared/all_patched_data/training/poly/seg_patches') + +def load_validation_img(filename): + return load_img(filename, + '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', + '/projects/bbym/shared/all_patched_data/validation/poly/legend', + '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches') + +train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches') +random.shuffle(train_map_file) + +# Pre-filter map files based on existence of corresponding legend files +train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') + for x in train_map_file + if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', + '_'.join(x.split('_')[0:-2])+'.png'))] + +train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names) +train_dataset = train_dataset.map(load_train_img) +train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128) + +validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') + +# Pre-filter map files based on existence of corresponding legend files +validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') + for x in validate_map_file + if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', + '_'.join(x.split('_')[0:-2])+'.png'))] + +validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names) +validate_dataset = validate_dataset.map(load_validation_img) +validate_dataset = validate_dataset.batch(50) + +train(train_dataset, validate_dataset, epochs=100, batch_size=128) diff --git a/VAE-unet.py b/VAE-unet.py new file mode 100644 index 0000000..8149115 --- /dev/null +++ b/VAE-unet.py @@ -0,0 +1,201 @@ +import tensorflow as tf +import random +import os +import shutil +from tensorflow.keras import layers, Model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard + +def sampling(args): + """Reparameterization trick.""" + z_mean, z_log_var = args + batch, height, width, channels = tf.shape(z_mean)[0], tf.shape(z_mean)[1], tf.shape(z_mean)[2], tf.shape(z_mean)[3] + epsilon = tf.keras.backend.random_normal(shape=(batch, height, width, channels)) + return z_mean + tf.exp(0.5 * z_log_var) * epsilon + +def attention_block(x, g, inter_channel): + """Attention block. `x` is the local feature and `g` is the wider context.""" + theta_x = layers.Conv2D(inter_channel, (1, 1), strides=(1, 1))(x) + + phi_g = layers.Conv2D(inter_channel, (1, 1), strides=(1, 1))(g) + phi_g = layers.UpSampling2D(size=(2, 2))(phi_g) + + f = layers.Add()([theta_x, phi_g]) + f = layers.Activation('relu')(f) + + psi_f = layers.Conv2D(1, (1, 1), strides=(1, 1))(f) + psi_f = layers.Activation('sigmoid')(psi_f) + + return layers.Multiply()([x, psi_f]) + +def encoder_block(inputs, filters, attention=False, pool=True): + x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs) + x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x) + if pool: + if attention: + g = layers.MaxPooling2D(pool_size=(2, 2))(x) + x = attention_block(x, g, filters//2) + return x, layers.MaxPooling2D(pool_size=(2, 2))(x) + else: + return x, layers.MaxPooling2D(pool_size=(2, 2))(x) + else: + return x + +def decoder_block(inputs, skip_features, filters): + x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(inputs) + x = layers.Concatenate()([x, skip_features]) + x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x) + x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x) + return x + +def variational_unet(input_shape1, input_shape2): + # Encoder 1 + inputs1 = layers.Input(shape=input_shape1, name="legend_patch") + x1, p1 = encoder_block(inputs1, 64, attention=False) + x1, p1 = encoder_block(p1, 128) + x1, p1 = encoder_block(p1, 256) + x1, p1 = encoder_block(p1, 512) + x1 = encoder_block(p1, 1024, pool=False) + + # Latent Space for Encoder 1 + z_mean = layers.Conv2D(1024, (1, 1))(x1) + z_log_var = layers.Conv2D(1024, (1, 1))(x1) + z = layers.Lambda(sampling)([z_mean, z_log_var]) + + # Encoder 2 with attention + inputs2 = layers.Input(shape=input_shape2, name="map_patch") + x2_1, p2 = encoder_block(inputs2, 64, attention=True) + x2_2, p2 = encoder_block(p2, 128, attention=True) + x2_3, p2 = encoder_block(p2, 256, attention=True) + x2_4, p2 = encoder_block(p2, 512, attention=True) + x2 = encoder_block(p2, 1024, attention=True, pool=False) + + # Concatenate at the bottleneck + bottleneck = layers.Concatenate()([z, x2]) + + print(bottleneck.shape, z.shape, x2.shape) + + # Decoder + x = decoder_block(bottleneck, x2_4, 512) + x = decoder_block(x, x2_3, 256) + x = decoder_block(x, x2_2, 128) + x = decoder_block(x, x2_1, 64) + + outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(x) + + return Model(inputs=[inputs1, inputs2], outputs=[outputs, [outputs, z_mean, z_log_var]]) + +def vae_loss(y_true, output, beta=1.0): + y_pred, z_mean, z_log_var = output[0], output[1], output[2] + # Reconstruction loss + recon_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) + recon_loss = tf.reduce_mean(recon_loss) + + # KL divergence loss + kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + kl_loss = -0.5 * tf.reduce_sum(kl_loss, axis=-1) + + # Total loss + return recon_loss + beta * kl_loss + +model = variational_unet((256, 256, 3), (256, 256, 3)) +model.summary() +model.compile(optimizer='adam', loss=[None, vae_loss]) + + +def load_img(filename, map_dir, legend_dir, seg_dir): + mapName = tf.strings.join([map_dir, filename[0]], separator='/') + legendName = tf.strings.join([legend_dir, filename[1]], separator='/') + + # Load and preprocess map_img + map_img = tf.io.read_file(mapName) + map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 + map_img = tf.image.resize(map_img, [256, 256]) + + # Load and preprocess legend_img + legend_img = tf.io.read_file(legendName) + legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 + legend_img = tf.image.resize(legend_img, [256, 256]) + + # Load and preprocess seg_img + segName = tf.strings.join([seg_dir, filename[0]], separator='/') + seg_img = tf.io.read_file(segName) + seg_img = tf.io.decode_png(seg_img) + seg_img = tf.image.resize(seg_img, [256, 256]) + + return (legend_img, map_img), seg_img + +def load_train_img(filename): + return load_img(filename, + '/projects/bbym/shared/all_patched_data/training/poly/map_patches', + '/projects/bbym/shared/all_patched_data/training/poly/legend', + '/projects/bbym/shared/all_patched_data/training/poly/seg_patches') + +def load_validation_img(filename): + return load_img(filename, + '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', + '/projects/bbym/shared/all_patched_data/validation/poly/legend', + '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches') + + +train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches') +random.shuffle(train_map_file) + +# Pre-filter map files based on existence of corresponding legend files +train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') + for x in train_map_file + if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', + '_'.join(x.split('_')[0:-2])+'.png'))] + +train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names) +train_dataset = train_dataset.map(load_train_img) +train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128) + +validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') + +# Pre-filter map files based on existence of corresponding legend files +validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') + for x in validate_map_file + if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', + '_'.join(x.split('_')[0:-2])+'.png'))] + +validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names) +validate_dataset = validate_dataset.map(load_validation_img) +validate_dataset = validate_dataset.batch(50) + +################################################################ +##### Prepare the model configurations ######################### +################################################################ +#You can change the id for each run so that all models and stats are saved separately. +name_id = "VAE-unet" +prediction_path = './predicts_'+name_id+'/' +log_path = './logs_'+name_id+'/' +model_path = './models_'+name_id+'/' +save_model_path = './models_'+name_id+'/' + +# Create the folder if it does not exist +os.makedirs(model_path, exist_ok=True) +os.makedirs(prediction_path, exist_ok=True) + +name = 'VAE-unet' + +logdir = log_path + name + +if(os.path.isdir(logdir)): + shutil.rmtree(logdir) + +os.makedirs(logdir, exist_ok=True) +tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) + +print('model location: '+ model_path+name+'.h5') +# define hyperparameters and callback modules +patience = 10 +maxepoch = 500 +callbacks = [ReduceLROnPlateau(monitor='val_loss', factor=0.7, patience=patience, min_lr=1e-9, verbose=1, mode='min'), + EarlyStopping(monitor='val_loss', patience=patience, verbose=0), + ModelCheckpoint(model_path+name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), + TensorBoard(log_dir=logdir)] + +train_history = model.fit(train_dataset, validation_data = validate_dataset, + batch_size = 16, epochs = maxepoch, verbose=1, + callbacks = callbacks) \ No newline at end of file diff --git a/data_util.py b/data_util.py index bcc7243..5adf736 100644 --- a/data_util.py +++ b/data_util.py @@ -31,6 +31,7 @@ class DataLoader: process_legends(label_suffix, resize_to=(256, 256)): Processes legends based on label suffix. process_data(): Extracts and processes data from the loaded TIFF and JSON. get_processed_data(): Returns the processed data. + reconstruct_data(self, patches): unpatchify the prediction data. """ def __init__(self, tiff_path, json_path, patch_size=(256, 256, 3), overlap=30): """ diff --git a/eval_gan.py b/eval_gan.py new file mode 100644 index 0000000..e105789 --- /dev/null +++ b/eval_gan.py @@ -0,0 +1,162 @@ +import matplotlib.pyplot as plt +import matplotlib.image as mpimg + +import os +import shutil +import random +import numpy as np +import tensorflow as tf +from datetime import datetime +from keras import backend as K +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +import segmentation_models as sm +from tensorflow.keras import layers +from keras.models import load_model +from tensorflow.keras.models import Model, load_model +from tensorflow.keras.layers import Input, Conv2D, RandomFlip, RandomRotation +from tensorflow.keras.optimizers import Adam, SGD, RMSprop +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard +from unet_util import dice_coef_loss, dice_coef, jacard_coef, dice_coef_loss, Residual_CNN_block, multiplication, attention_up_and_concatenate, multiplication2, attention_up_and_concatenate2, UNET_224, evaluate_prediction_result + +def f1_score(y_true, y_pred): # Dice coefficient + smooth = 1. + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + +################################################################ +##### Prepare the model configurations ######################### +################################################################ +#You can change the id for each run so that all models and stats are saved separately. +name_id = "GAN_test" +prediction_path = './predicts_'+name_id+'/' +model_path = './models_'+name_id+'/' +save_model_path = './models_'+name_id+'/' + +os.makedirs(model_path, exist_ok=True) +os.makedirs(prediction_path, exist_ok=True) + +print('model location: '+ model_path + name_id + '_generator.h5') + +generator = load_model(os.path.join(model_path, name_id + '_generator.h5'), + custom_objects={'dice_coef_loss': dice_coef_loss}) + +generator.compile(optimizer = Adam(), + loss = dice_coef_loss, + metrics=[dice_coef,'accuracy', f1_score]) + +z_dim = 100 # global variable + +def load_img_with_noise(filename, map_dir, legend_dir): + mapName = tf.strings.join([map_dir, filename[0]], separator='/') + legendName = tf.strings.join([legend_dir, filename[1]], separator='/') + + map_img = tf.io.read_file(mapName) + map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 + + legend_img = tf.io.read_file(legendName) + legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 + + map_img = tf.concat(axis=2, values=[map_img, legend_img]) + map_img = map_img*2.0 - 1.0 + map_img = tf.image.resize(map_img, [256, 256]) + + noise = tf.random.normal([1, z_dim]) # generating noise for one image + + return map_img, noise + +def generate_image(filename, generator, map_dir, legend_dir): + map_img, noise = load_img_with_noise(filename, map_dir, legend_dir) + + # Ensure that map_img has batch dimension + map_img = tf.expand_dims(map_img, 0) + + generated_img = generator([map_img, noise], training=False) + + # Removing the batch dimension for visualization + generated_img = tf.squeeze(generated_img) + + return generated_img.numpy() + +# def load_validation_img(filename): +# return generate_image(filename, +# generator, # Add the generator here +# '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', +# '/projects/bbym/shared/all_patched_data/validation/poly/legend') + +validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') +validate_map_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') for x in validate_map_file] +# validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_names) +# validate_dataset = validate_dataset.map(load_validation_img) +# validate_dataset = validate_dataset.batch(50) + +print("Load Data Done!") + + +print("Load Model Done!") + +def dice_coef(y_true, y_pred, smooth=1.0): + intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3]) + union = tf.reduce_sum(y_true + y_pred, axis=[1,2,3]) + return tf.reduce_mean((2. * intersection + smooth) / (union + smooth), axis=0) + +def dice_coef_loss(y_true, y_pred): + return 1 - dice_coef(y_true, y_pred) + +# model.summary() +# eval_results = model.evaluate(validate_dataset, verbose=1) +# print(eval_results) +# print(f'Validation F1 score: {f1}') + +# If validate_dataset is a tf.data.Dataset instance +def plotResult(fileName, save_dir, generator): + map_img, noise = load_img_with_noise(fileName, + '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', + '/projects/bbym/shared/all_patched_data/validation/poly/legend') + generated_img = generator([tf.expand_dims(map_img, 0), noise], training=False) + predicted_binary = (generated_img > 0.5).numpy().astype(np.uint8) # thresholding + + mapName = '/projects/bbym/shared/all_patched_data/validation/poly/map_patches/' + fileName[0] + segName = '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches/' + fileName[0] + legendName = '/projects/bbym/shared/all_patched_data/validation/poly/legend/' + fileName[1] + + map_img = mpimg.imread(mapName) + seg_img = mpimg.imread(segName) + legend_img = mpimg.imread(legendName) + + # Visualization + plt.figure(figsize=(10, 2)) + + plt.subplot(1, 5, 1) + plt.title("map") + plt.imshow(map_img) + + plt.subplot(1, 5, 2) + plt.title("legend") + plt.imshow(legend_img) + + plt.subplot(1, 5, 3) + plt.title("true segmentation") + plt.imshow(seg_img, cmap='gray') + + plt.subplot(1, 5, 4) + plt.title("predicted segmentation") + plt.imshow(predicted_binary[0, :, :, 0]*255, cmap='gray') + + plt.subplot(1, 5, 5) + plt.title("error") + error_img = np.logical_xor(predicted_binary[0, :, :, 0], seg_img/255.0) # assuming seg_img is in [0, 255] + plt.imshow(error_img, cmap='gray') + + plt.savefig(save_dir + fileName[0] + '.png') + plt.close() + +n=20 + +for fileName in random.sample(validate_map_names, n): + print(fileName) + plotResult(fileName, prediction_path, generator) # Assuming generator is available in this context +print("Save Images Done!") + diff --git a/two-h-model.py b/two-h-model.py index e5d727d..e9a3271 100644 --- a/two-h-model.py +++ b/two-h-model.py @@ -73,8 +73,8 @@ def latent_sample(p): return mean + stddev * eps def variational_unet(input_shape): - input1 = tf.keras.Input(shape=input_shape) # legend patch + map patch - input2 = tf.keras.Input(shape=input_shape) # map patch + input1 = tf.keras.Input(shape=input_shape, name="legend_patch") # legend patch + map patch + input2 = tf.keras.Input(shape=input_shape, name="map_patch") # map patch # Encoder for input1 x1 = layers.Conv2D(32, (1, 1), padding="same", activation="relu")(input1) # Adjusting input depth to 3 @@ -107,11 +107,13 @@ def variational_unet(input_shape): bottleneck = decoder_block(bottleneck, skip, filters) # Final convolution to get the segmentation result - output = layers.Conv2D(1, (1, 1), activation="sigmoid")(bottleneck) + output = layers.Conv2D(1, (1, 1), activation="sigmoid", name="output")(bottleneck) - print("output-shape:", output.shape, qs.shape, ps.shape) + # combined_outputs = concatenate([output, qs, ps], name="combined_output") - return models.Model(inputs=[input1, input2], outputs=[output, qs, ps]) + # print("output-shape:", output.shape, qs.shape, ps.shape) + + return models.Model(inputs=[input1, input2], outputs=[output, [output, qs, ps]]) # Instantiate the model with the input shape model = variational_unet((256, 256, 3)) @@ -294,8 +296,12 @@ tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) print('model location: '+ model_path+name+'.h5') +def dumm_loss(*args, **kwargs): + return 0.0 + # Compile the model with 'Adam' optimizer (0.001 is the default learning rate) and define the loss and metrics -model.compile(optimizer = Adam(), loss=CombinedLoss(), metrics=[DiceCoefficientMetric(), SegmentationAccuracy()]) +model.compile(optimizer = Adam(), + loss=[dice_coef_loss, CombinedLoss()]) # define hyperparameters and callback modules patience = 10 -- GitLab