diff --git a/VAE-unet.py b/VAE-unet.py index 81491159831901ed548e264f1ebcf660941fce13..594a900a4f5a04d79e7ba135e04ee7acca2d9027 100644 --- a/VAE-unet.py +++ b/VAE-unet.py @@ -2,6 +2,7 @@ import tensorflow as tf import random import os import shutil +from tensorflow.keras.utils import plot_model from tensorflow.keras import layers, Model from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard @@ -51,24 +52,24 @@ def decoder_block(inputs, skip_features, filters): def variational_unet(input_shape1, input_shape2): # Encoder 1 inputs1 = layers.Input(shape=input_shape1, name="legend_patch") - x1, p1 = encoder_block(inputs1, 64, attention=False) - x1, p1 = encoder_block(p1, 128) - x1, p1 = encoder_block(p1, 256) - x1, p1 = encoder_block(p1, 512) - x1 = encoder_block(p1, 1024, pool=False) + x1_1, p1_1 = encoder_block(inputs1, 64) + x1_2, p1_2 = encoder_block(p1_1, 128) + x1_3, p1_3 = encoder_block(p1_2, 256) + x1_4, p1_4 = encoder_block(p1_3, 512) + x1_5 = encoder_block(p1_4, 1024, pool=False) # Latent Space for Encoder 1 - z_mean = layers.Conv2D(1024, (1, 1))(x1) - z_log_var = layers.Conv2D(1024, (1, 1))(x1) + z_mean = layers.Conv2D(1024, (1, 1))(x1_5) + z_log_var = layers.Conv2D(1024, (1, 1))(x1_5) z = layers.Lambda(sampling)([z_mean, z_log_var]) # Encoder 2 with attention inputs2 = layers.Input(shape=input_shape2, name="map_patch") - x2_1, p2 = encoder_block(inputs2, 64, attention=True) - x2_2, p2 = encoder_block(p2, 128, attention=True) - x2_3, p2 = encoder_block(p2, 256, attention=True) - x2_4, p2 = encoder_block(p2, 512, attention=True) - x2 = encoder_block(p2, 1024, attention=True, pool=False) + x2_1, p2_1 = encoder_block(inputs2, 64, attention=True) + x2_2, p2_2 = encoder_block(p2_1, 128, attention=True) + x2_3, p2_3 = encoder_block(p2_2, 256, attention=True) + x2_4, p2_4 = encoder_block(p2_3, 512, attention=True) + x2 = encoder_block(p2_4, 1024, attention=True, pool=False) # Concatenate at the bottleneck bottleneck = layers.Concatenate()([z, x2]) @@ -85,10 +86,18 @@ def variational_unet(input_shape1, input_shape2): return Model(inputs=[inputs1, inputs2], outputs=[outputs, [outputs, z_mean, z_log_var]]) +def f1_score(y_true, y_pred): # Dice coefficient + smooth = 1. + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + + def vae_loss(y_true, output, beta=1.0): y_pred, z_mean, z_log_var = output[0], output[1], output[2] # Reconstruction loss - recon_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) + recon_loss = f1_score(y_true, y_pred) recon_loss = tf.reduce_mean(recon_loss) # KL divergence loss @@ -103,99 +112,102 @@ model.summary() model.compile(optimizer='adam', loss=[None, vae_loss]) -def load_img(filename, map_dir, legend_dir, seg_dir): - mapName = tf.strings.join([map_dir, filename[0]], separator='/') - legendName = tf.strings.join([legend_dir, filename[1]], separator='/') +# Assuming 'model' is your instantiated model +plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True) + +# def load_img(filename, map_dir, legend_dir, seg_dir): +# mapName = tf.strings.join([map_dir, filename[0]], separator='/') +# legendName = tf.strings.join([legend_dir, filename[1]], separator='/') - # Load and preprocess map_img - map_img = tf.io.read_file(mapName) - map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 - map_img = tf.image.resize(map_img, [256, 256]) +# # Load and preprocess map_img +# map_img = tf.io.read_file(mapName) +# map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 +# map_img = tf.image.resize(map_img, [256, 256]) - # Load and preprocess legend_img - legend_img = tf.io.read_file(legendName) - legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 - legend_img = tf.image.resize(legend_img, [256, 256]) - - # Load and preprocess seg_img - segName = tf.strings.join([seg_dir, filename[0]], separator='/') - seg_img = tf.io.read_file(segName) - seg_img = tf.io.decode_png(seg_img) - seg_img = tf.image.resize(seg_img, [256, 256]) - - return (legend_img, map_img), seg_img - -def load_train_img(filename): - return load_img(filename, - '/projects/bbym/shared/all_patched_data/training/poly/map_patches', - '/projects/bbym/shared/all_patched_data/training/poly/legend', - '/projects/bbym/shared/all_patched_data/training/poly/seg_patches') - -def load_validation_img(filename): - return load_img(filename, - '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', - '/projects/bbym/shared/all_patched_data/validation/poly/legend', - '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches') +# # Load and preprocess legend_img +# legend_img = tf.io.read_file(legendName) +# legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 +# legend_img = tf.image.resize(legend_img, [256, 256]) + +# # Load and preprocess seg_img +# segName = tf.strings.join([seg_dir, filename[0]], separator='/') +# seg_img = tf.io.read_file(segName) +# seg_img = tf.io.decode_png(seg_img) +# seg_img = tf.image.resize(seg_img, [256, 256]) + +# return (legend_img, map_img), seg_img + +# def load_train_img(filename): +# return load_img(filename, +# '/projects/bbym/shared/all_patched_data/training/poly/map_patches', +# '/projects/bbym/shared/all_patched_data/training/poly/legend', +# '/projects/bbym/shared/all_patched_data/training/poly/seg_patches') + +# def load_validation_img(filename): +# return load_img(filename, +# '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', +# '/projects/bbym/shared/all_patched_data/validation/poly/legend', +# '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches') -train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches') -random.shuffle(train_map_file) - -# Pre-filter map files based on existence of corresponding legend files -train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') - for x in train_map_file - if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', - '_'.join(x.split('_')[0:-2])+'.png'))] - -train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names) -train_dataset = train_dataset.map(load_train_img) -train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128) - -validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') - -# Pre-filter map files based on existence of corresponding legend files -validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') - for x in validate_map_file - if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', - '_'.join(x.split('_')[0:-2])+'.png'))] - -validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names) -validate_dataset = validate_dataset.map(load_validation_img) -validate_dataset = validate_dataset.batch(50) - -################################################################ -##### Prepare the model configurations ######################### -################################################################ -#You can change the id for each run so that all models and stats are saved separately. -name_id = "VAE-unet" -prediction_path = './predicts_'+name_id+'/' -log_path = './logs_'+name_id+'/' -model_path = './models_'+name_id+'/' -save_model_path = './models_'+name_id+'/' - -# Create the folder if it does not exist -os.makedirs(model_path, exist_ok=True) -os.makedirs(prediction_path, exist_ok=True) - -name = 'VAE-unet' - -logdir = log_path + name - -if(os.path.isdir(logdir)): - shutil.rmtree(logdir) - -os.makedirs(logdir, exist_ok=True) -tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) - -print('model location: '+ model_path+name+'.h5') -# define hyperparameters and callback modules -patience = 10 -maxepoch = 500 -callbacks = [ReduceLROnPlateau(monitor='val_loss', factor=0.7, patience=patience, min_lr=1e-9, verbose=1, mode='min'), - EarlyStopping(monitor='val_loss', patience=patience, verbose=0), - ModelCheckpoint(model_path+name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), - TensorBoard(log_dir=logdir)] - -train_history = model.fit(train_dataset, validation_data = validate_dataset, - batch_size = 16, epochs = maxepoch, verbose=1, - callbacks = callbacks) \ No newline at end of file +# train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches') +# random.shuffle(train_map_file) + +# # Pre-filter map files based on existence of corresponding legend files +# train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') +# for x in train_map_file +# if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', +# '_'.join(x.split('_')[0:-2])+'.png'))] + +# train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names) +# train_dataset = train_dataset.map(load_train_img) +# train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128) + +# validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') + +# # Pre-filter map files based on existence of corresponding legend files +# validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') +# for x in validate_map_file +# if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', +# '_'.join(x.split('_')[0:-2])+'.png'))] + +# validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names) +# validate_dataset = validate_dataset.map(load_validation_img) +# validate_dataset = validate_dataset.batch(50) + +# ################################################################ +# ##### Prepare the model configurations ######################### +# ################################################################ +# #You can change the id for each run so that all models and stats are saved separately. +# name_id = "VAE-unet" +# prediction_path = './predicts_'+name_id+'/' +# log_path = './logs_'+name_id+'/' +# model_path = './models_'+name_id+'/' +# save_model_path = './models_'+name_id+'/' + +# # Create the folder if it does not exist +# os.makedirs(model_path, exist_ok=True) +# os.makedirs(prediction_path, exist_ok=True) + +# name = 'VAE-unet' + +# logdir = log_path + name + +# if(os.path.isdir(logdir)): +# shutil.rmtree(logdir) + +# os.makedirs(logdir, exist_ok=True) +# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) + +# print('model location: '+ model_path+name+'.h5') +# # define hyperparameters and callback modules +# patience = 10 +# maxepoch = 500 +# callbacks = [ReduceLROnPlateau(monitor='val_loss', factor=0.7, patience=patience, min_lr=1e-9, verbose=1, mode='min'), +# EarlyStopping(monitor='val_loss', patience=patience, verbose=0), +# ModelCheckpoint(model_path+name+'.h5', monitor='val_loss', save_best_only=True, verbose=0), +# TensorBoard(log_dir=logdir)] + +# train_history = model.fit(train_dataset, validation_data = validate_dataset, +# batch_size = 16, epochs = maxepoch, verbose=1, +# callbacks = callbacks) \ No newline at end of file diff --git a/eval_gan.py b/eval_gan.py index e10578914df1bc568869564e7c76f2701e3ce63d..371f60e2343534af71a598164600ff194d108e9f 100644 --- a/eval_gan.py +++ b/eval_gan.py @@ -12,7 +12,6 @@ import matplotlib.pyplot as plt import matplotlib.image as mpimg import segmentation_models as sm from tensorflow.keras import layers -from keras.models import load_model from tensorflow.keras.models import Model, load_model from tensorflow.keras.layers import Input, Conv2D, RandomFlip, RandomRotation from tensorflow.keras.optimizers import Adam, SGD, RMSprop diff --git a/eval_vauner.py b/eval_vauner.py new file mode 100644 index 0000000000000000000000000000000000000000..659cffd632113428aabc90ed7b613b040280a25f --- /dev/null +++ b/eval_vauner.py @@ -0,0 +1,180 @@ +import tensorflow as tf +import random +import os +import shutil +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +from keras import backend as K +from tensorflow.keras import layers, Model +from tensorflow.keras.models import load_model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard + +def vae_loss(y_true, output, beta=1.0): + y_pred, z_mean, z_log_var = output[0], output[1], output[2] + # Reconstruction loss + recon_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) + recon_loss = tf.reduce_mean(recon_loss) + + # KL divergence loss + kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + kl_loss = -0.5 * tf.reduce_sum(kl_loss, axis=-1) + + # Total loss + return recon_loss + beta * kl_loss + +def f1_score(y_true, y_pred): # Dice coefficient + smooth = 1. + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + +def load_img(filename, map_dir, legend_dir, seg_dir): + mapName = tf.strings.join([map_dir, filename[0]], separator='/') + legendName = tf.strings.join([legend_dir, filename[1]], separator='/') + + # Load and preprocess map_img + map_img = tf.io.read_file(mapName) + map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 + map_img = tf.image.resize(map_img, [256, 256]) + + # Load and preprocess legend_img + legend_img = tf.io.read_file(legendName) + legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0 + legend_img = tf.image.resize(legend_img, [256, 256]) + + # Load and preprocess seg_img + segName = tf.strings.join([seg_dir, filename[0]], separator='/') + seg_img = tf.io.read_file(segName) + seg_img = tf.io.decode_png(seg_img) + seg_img = tf.image.resize(seg_img, [256, 256]) + + return (legend_img, map_img), seg_img + +def load_validation_img(filename): + return load_img(filename, + '/projects/bbym/shared/all_patched_data/validation/poly/map_patches', + '/projects/bbym/shared/all_patched_data/validation/poly/legend', + '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches') + +validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches') + +# Pre-filter map files based on existence of corresponding legend files +validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') + for x in validate_map_file + if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', + '_'.join(x.split('_')[0:-2])+'.png'))] + +validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names) +validate_dataset = validate_dataset.map(load_validation_img) +validate_dataset = validate_dataset.batch(50) + +################################################################ +##### Prepare the model configurations ######################### +################################################################ +#You can change the id for each run so that all models and stats are saved separately. +name_id = "VAE-unet" +prediction_path = './predicts_'+name_id+'/' +log_path = './logs_'+name_id+'/' +model_path = './models_'+name_id+'/' +save_model_path = './models_'+name_id+'/' + +# Create the folder if it does not exist +os.makedirs(model_path, exist_ok=True) +os.makedirs(prediction_path, exist_ok=True) + +print('model location: '+ model_path + name_id + '.h5') + +model = load_model(os.path.join(model_path, name_id + '.h5'), + custom_objects={'vae_loss': vae_loss}) + +model.compile(optimizer = Adam(), + loss = vae_loss, + metrics=[f1_score, vae_loss]) + + +# If validate_dataset is a tf.data.Dataset instance +def plotResult(fileName, save_dir): + + test_dataset = tf.data.Dataset.from_tensor_slices([fileName]) + test_dataset = test_dataset.map(load_validation_img) + test_dataset = test_dataset.batch(1) + + predicted = model.predict(test_dataset) + print(predicted[0].shape) + + # Thresholding the predicted result to get binary values + threshold = 0.5 # you can adjust this value based on your requirement + predicted_binary = (predicted[0] > threshold).astype(np.uint8) # convert boolean to integer (1 or 0) + + mapName = '/projects/bbym/shared/all_patched_data/validation/poly/map_patches/' + fileName[0] + segName = '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches/' + fileName[0] + legendName = '/projects/bbym/shared/all_patched_data/validation/poly/legend/' + fileName[1] + # legendName = '/projects/bbym/nathanj/validation/legend/' + fileName[1] + + map_img = mpimg.imread(mapName) + seg_img = mpimg.imread(segName) + label_img = mpimg.imread(legendName) + + # Set the figure size + plt.figure(figsize=(10, 2)) + + # Plot map image + plt.subplot(1, 5, 1) + plt.title("map") + plt.imshow(map_img) + # Plot legend image + plt.subplot(1, 5, 2) + plt.title("legend") + plt.imshow(label_img) + + # Plot true segmentation image + plt.subplot(1, 5, 3) + plt.title("label") + plt.imshow(seg_img, cmap='gray') + + # Plot predicted segmentation image + plt.subplot(1, 5, 4) + plt.title("prediction") + plt.imshow(predicted_binary[0, :, :, 0]*255, cmap='gray') + + # Plot error image + plt.subplot(1, 5, 5) + plt.title("error") + + # Normalize both images to the range [0, 1] if they aren't already + seg_img_normalized = seg_img / 255.0 if seg_img.max() > 1 else seg_img + predicted_normalized = predicted_binary[0, :, :, 0] if predicted_binary.max() <= 1 else predicted_binary[0, :, :, 0] / 255.0 + + # Calculate the error image + # error_img = seg_img_normalized - predicted_normalized # simple difference + error_img = np.logical_xor(predicted_binary[0, :, :, 0], seg_img) + + # Alternatively, for absolute difference: + # error_img = np.abs(seg_img_normalized - predicted_normalized) + + # Display the error image + cax = plt.imshow(error_img, cmap='gray') + + # Set the color scale limits if necessary + # cax.set_clim(vmin=-1, vmax=1) # adjust as needed + + # # Add color bar to help interpret the error image + # cbar = plt.colorbar(cax, orientation='vertical', shrink=0.75) + # cbar.set_label('Error Magnitude', rotation=270, labelpad=15) + + # Save the entire figure + plt.savefig(prediction_path + fileName[0] + '.png') + + # Close the figure to release resources + plt.close() + + +n=20 + +for fileName in random.sample(validate_map_legend_names, n): + print(fileName) + plotResult(fileName, prediction_path) +print("Save Images Done!") \ No newline at end of file