diff --git a/GAN_model.py b/GAN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4eff11c8c86c2d3063de38e99875879cd63e05f
--- /dev/null
+++ b/GAN_model.py
@@ -0,0 +1,234 @@
+import tensorflow as tf
+import random
+import os
+import numpy as np
+from tensorflow.keras import layers
+from tqdm import tqdm
+from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, TensorBoard
+
+
+#You can change the id for each run so that all models and stats are saved separately.
+name_id = "GAN_test"
+prediction_path = './predicts_'+name_id+'/'
+log_path = './logs_'+name_id+'/'
+model_path = './models_'+name_id+'/'
+save_model_path = './models_'+name_id+'/'
+
+# Create the folder if it does not exist
+os.makedirs(model_path, exist_ok=True)
+os.makedirs(prediction_path, exist_ok=True)
+
+# Hyperparameters
+z_dim = 100
+input_shape = (256, 256, 6)
+
+def build_generator():
+    input_image = layers.Input(shape=input_shape, name='input_image')
+    z = layers.Input(shape=(z_dim,), name='z')
+    
+    # Transform the noise
+    z_transformed = layers.Dense(256*256*6, activation='relu')(z)
+    z_transformed = layers.Reshape((256, 256, 6))(z_transformed)
+
+    # Concatenate the transformed noise and input image
+    combined = layers.Concatenate(axis=-1)([input_image, z_transformed])
+    
+    x = layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')(combined)
+    x = layers.BatchNormalization()(x)
+    
+    x = layers.Conv2D(32, (3, 3), strides=(1, 1), padding='same', activation='relu')(x)
+    x = layers.BatchNormalization()(x)
+
+    # Generate a segmentation map
+    output = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='output')(x)
+    
+    model = tf.keras.Model(inputs=[input_image, z], outputs=output)
+    return model
+
+# Discriminator
+def build_discriminator():
+    input_image = layers.Input(shape=input_shape, name='input_image')
+    segmentation_map = layers.Input(shape=(256, 256, 1), name='segmentation_map')
+
+    combined = layers.Concatenate()([input_image, segmentation_map])
+
+    x = layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same', activation='relu')(combined)
+    x = layers.BatchNormalization()(x)
+    
+    x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', activation='relu')(x)
+    x = layers.BatchNormalization()(x)
+
+    x = layers.Flatten()(x)
+    output = layers.Dense(1, activation='sigmoid')(x)
+    
+    model = tf.keras.Model(inputs=[input_image, segmentation_map], outputs=output)
+    return model
+
+generator = build_generator()
+discriminator = build_discriminator()
+
+# Losses and optimizers
+bce_loss = tf.keras.losses.BinaryCrossentropy()
+optimizer_g = tf.keras.optimizers.Adam(1e-4)
+optimizer_d = tf.keras.optimizers.Adam(1e-4)
+
+# Define training steps
+@tf.function
+def train_step(images, masks, batch_size):
+    # Generate noise for generator
+    noise = tf.random.normal([tf.shape(images)[0], z_dim])
+    
+    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
+        # Generate fake segmentation maps using generator
+        generated_masks = generator([images, noise], training=True)
+        
+        # Discriminator output for real and fake images
+        real_output = discriminator([images, masks], training=True)
+        fake_output = discriminator([images, generated_masks], training=True)
+        
+        # Generator loss: Adversarial loss + L1 loss for generated mask
+        gen_loss = bce_loss(tf.ones_like(fake_output), fake_output)
+        l1_loss = tf.reduce_mean(tf.abs(masks - generated_masks))
+        total_gen_loss = gen_loss + (100.0 * l1_loss)
+        
+        # Discriminator loss
+        real_loss = bce_loss(tf.ones_like(real_output), real_output)
+        fake_loss = bce_loss(tf.zeros_like(fake_output), fake_output)
+        disc_loss = (real_loss + fake_loss) * 0.5
+        
+    # Calculate gradients and apply updates
+    gradients_of_generator = gen_tape.gradient(total_gen_loss, generator.trainable_variables)
+    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
+    
+    optimizer_g.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
+    optimizer_d.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
+    
+    return total_gen_loss, disc_loss
+
+# Helper function to compute F1 score
+def compute_f1_score(true_masks, pred_masks):
+    pred_masks_bin = tf.cast(pred_masks > 0.5, tf.float32)
+    
+    TP = tf.reduce_sum(true_masks * pred_masks_bin)
+    FP = tf.reduce_sum((1 - true_masks) * pred_masks_bin)
+    FN = tf.reduce_sum(true_masks * (1 - pred_masks_bin))
+    F1 = (2 * TP) / (2 * TP + FP + FN + 1e-7) 
+    return F1.numpy()
+
+def dice_coef(y_true, y_pred, smooth=1.0):
+    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
+    union = tf.reduce_sum(y_true + y_pred, axis=[1,2,3])
+    return tf.reduce_mean((2. * intersection + smooth) / (union + smooth), axis=0)
+
+def dice_coef_loss(y_true, y_pred):
+    return 1 - dice_coef(y_true, y_pred)
+
+
+generator.compile(optimizer=optimizer_g, loss=dice_coef_loss)
+discriminator.compile(optimizer=optimizer_d, loss=bce_loss)
+
+def train(dataset, val_dataset, epochs, batch_size):
+    history = {'train_loss': [], 'val_loss': [], 'f1_score': []}
+    
+    # Initialize the best_val_loss with a high value
+    best_val_loss = float('inf')
+    
+    for epoch in tqdm(range(epochs), desc="Training"):
+        train_losses = []  # store training losses for this epoch
+        
+        # Training
+        for image_batch, mask_batch in dataset:
+            gen_loss, disc_loss = train_step(image_batch, mask_batch, batch_size)  # Assuming you already have the `train_step` function
+            train_losses.append(gen_loss)
+            
+        # Validation
+        total_f1_score = 0
+        val_losses = []  # store validation losses for this epoch
+        
+        for val_image_batch, val_mask_batch in val_dataset:
+            # Generate segmentation masks using the generator
+            noise = tf.random.normal([val_image_batch.shape[0], z_dim])
+            val_pred_masks = generator([val_image_batch, noise], training=False)
+            
+            # Compute validation loss and F1 score
+            val_loss = dice_coef_loss(val_mask_batch, val_pred_masks)  # Implement your own loss_function or use a suitable one from TF
+            total_f1_score += compute_f1_score(val_mask_batch, val_pred_masks)
+            
+            val_losses.append(val_loss.numpy())
+
+        # Averaging metrics
+        avg_train_loss = np.mean(train_losses)
+        avg_val_loss = np.mean(val_losses)
+        avg_f1_score = total_f1_score / len(val_dataset)
+        
+        history['train_loss'].append(avg_train_loss)
+        history['val_loss'].append(avg_val_loss)
+        history['f1_score'].append(avg_f1_score)
+
+        # Check if this epoch's val_loss is better than the best so far
+        if avg_val_loss < best_val_loss:
+            best_val_loss = avg_val_loss
+            generator.save(os.path.join(model_path, name_id + '_generator.h5'))
+
+    return history
+
+
+def load_img(filename, map_dir, legend_dir, seg_dir):
+    mapName = tf.strings.join([map_dir, filename[0]], separator='/')
+    legendName = tf.strings.join([legend_dir, filename[1]], separator='/')
+    
+    map_img = tf.io.read_file(mapName)
+    map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0
+
+    legend_img = tf.io.read_file(legendName)
+    legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0
+
+    map_img = tf.concat(axis=2, values=[map_img, legend_img])
+    map_img = map_img*2.0 - 1.0
+    map_img = tf.image.resize(map_img, [256, 256])
+
+    segName = tf.strings.join([seg_dir, filename[0]], separator='/')
+    seg_img = tf.io.read_file(segName)
+    seg_img = tf.io.decode_png(seg_img)
+    seg_img = tf.image.resize(seg_img, [256, 256])
+
+    return map_img, seg_img
+
+def load_train_img(filename):
+    return load_img(filename, 
+                    '/projects/bbym/shared/all_patched_data/training/poly/map_patches',
+                    '/projects/bbym/shared/all_patched_data/training/poly/legend',
+                    '/projects/bbym/shared/all_patched_data/training/poly/seg_patches')
+
+def load_validation_img(filename):
+    return load_img(filename, 
+                    '/projects/bbym/shared/all_patched_data/validation/poly/map_patches',
+                    '/projects/bbym/shared/all_patched_data/validation/poly/legend',
+                    '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches')
+
+train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches')
+random.shuffle(train_map_file)
+
+# Pre-filter map files based on existence of corresponding legend files
+train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') 
+                            for x in train_map_file 
+                            if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', 
+                                                        '_'.join(x.split('_')[0:-2])+'.png'))]
+
+train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names)
+train_dataset = train_dataset.map(load_train_img)
+train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128)
+
+validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches')
+
+# Pre-filter map files based on existence of corresponding legend files
+validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') 
+                            for x in validate_map_file 
+                            if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', 
+                                                        '_'.join(x.split('_')[0:-2])+'.png'))]
+
+validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names)
+validate_dataset = validate_dataset.map(load_validation_img)
+validate_dataset = validate_dataset.batch(50)
+
+train(train_dataset, validate_dataset, epochs=100, batch_size=128)
diff --git a/VAE-unet.py b/VAE-unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..81491159831901ed548e264f1ebcf660941fce13
--- /dev/null
+++ b/VAE-unet.py
@@ -0,0 +1,201 @@
+import tensorflow as tf
+import random
+import os
+import shutil
+from tensorflow.keras import layers, Model
+from tensorflow.keras.optimizers import Adam
+from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
+
+def sampling(args):
+    """Reparameterization trick."""
+    z_mean, z_log_var = args
+    batch, height, width, channels = tf.shape(z_mean)[0], tf.shape(z_mean)[1], tf.shape(z_mean)[2], tf.shape(z_mean)[3]
+    epsilon = tf.keras.backend.random_normal(shape=(batch, height, width, channels))
+    return z_mean + tf.exp(0.5 * z_log_var) * epsilon
+
+def attention_block(x, g, inter_channel):
+    """Attention block. `x` is the local feature and `g` is the wider context."""
+    theta_x = layers.Conv2D(inter_channel, (1, 1), strides=(1, 1))(x)
+    
+    phi_g = layers.Conv2D(inter_channel, (1, 1), strides=(1, 1))(g)
+    phi_g = layers.UpSampling2D(size=(2, 2))(phi_g)
+    
+    f = layers.Add()([theta_x, phi_g])
+    f = layers.Activation('relu')(f)
+    
+    psi_f = layers.Conv2D(1, (1, 1), strides=(1, 1))(f)
+    psi_f = layers.Activation('sigmoid')(psi_f)
+    
+    return layers.Multiply()([x, psi_f])
+
+def encoder_block(inputs, filters, attention=False, pool=True):
+    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs)
+    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
+    if pool:
+        if attention:
+            g = layers.MaxPooling2D(pool_size=(2, 2))(x)
+            x = attention_block(x, g, filters//2)
+            return x, layers.MaxPooling2D(pool_size=(2, 2))(x)
+        else:
+            return x, layers.MaxPooling2D(pool_size=(2, 2))(x)
+    else:
+        return x
+
+def decoder_block(inputs, skip_features, filters):
+    x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(inputs)
+    x = layers.Concatenate()([x, skip_features])
+    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
+    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
+    return x
+
+def variational_unet(input_shape1, input_shape2):
+    # Encoder 1
+    inputs1 = layers.Input(shape=input_shape1, name="legend_patch")
+    x1, p1 = encoder_block(inputs1, 64, attention=False)
+    x1, p1 = encoder_block(p1, 128)
+    x1, p1 = encoder_block(p1, 256)
+    x1, p1 = encoder_block(p1, 512)
+    x1 = encoder_block(p1, 1024, pool=False)
+    
+    # Latent Space for Encoder 1
+    z_mean = layers.Conv2D(1024, (1, 1))(x1)
+    z_log_var = layers.Conv2D(1024, (1, 1))(x1)
+    z = layers.Lambda(sampling)([z_mean, z_log_var])
+    
+    # Encoder 2 with attention
+    inputs2 = layers.Input(shape=input_shape2, name="map_patch")
+    x2_1, p2 = encoder_block(inputs2, 64, attention=True)
+    x2_2, p2 = encoder_block(p2, 128, attention=True)
+    x2_3, p2 = encoder_block(p2, 256, attention=True)
+    x2_4, p2 = encoder_block(p2, 512, attention=True)
+    x2 = encoder_block(p2, 1024, attention=True, pool=False)
+    
+    # Concatenate at the bottleneck
+    bottleneck = layers.Concatenate()([z, x2])
+
+    print(bottleneck.shape, z.shape, x2.shape)
+
+    # Decoder
+    x = decoder_block(bottleneck, x2_4, 512)
+    x = decoder_block(x, x2_3, 256)
+    x = decoder_block(x, x2_2, 128)
+    x = decoder_block(x, x2_1, 64)
+    
+    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(x)
+    
+    return Model(inputs=[inputs1, inputs2], outputs=[outputs, [outputs, z_mean, z_log_var]])
+
+def vae_loss(y_true, output, beta=1.0):
+    y_pred, z_mean, z_log_var = output[0], output[1], output[2]
+    # Reconstruction loss
+    recon_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
+    recon_loss = tf.reduce_mean(recon_loss)
+
+    # KL divergence loss
+    kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
+    kl_loss = -0.5 * tf.reduce_sum(kl_loss, axis=-1)
+
+    # Total loss
+    return recon_loss + beta * kl_loss
+
+model = variational_unet((256, 256, 3), (256, 256, 3))
+model.summary()
+model.compile(optimizer='adam', loss=[None, vae_loss])
+
+
+def load_img(filename, map_dir, legend_dir, seg_dir):
+    mapName = tf.strings.join([map_dir, filename[0]], separator='/')
+    legendName = tf.strings.join([legend_dir, filename[1]], separator='/')
+    
+    # Load and preprocess map_img
+    map_img = tf.io.read_file(mapName)
+    map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0
+    map_img = tf.image.resize(map_img, [256, 256])
+    
+    # Load and preprocess legend_img
+    legend_img = tf.io.read_file(legendName)
+    legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0
+    legend_img = tf.image.resize(legend_img, [256, 256])
+
+    # Load and preprocess seg_img
+    segName = tf.strings.join([seg_dir, filename[0]], separator='/')
+    seg_img = tf.io.read_file(segName)
+    seg_img = tf.io.decode_png(seg_img)
+    seg_img = tf.image.resize(seg_img, [256, 256])
+
+    return (legend_img, map_img), seg_img
+
+def load_train_img(filename):
+    return load_img(filename, 
+                    '/projects/bbym/shared/all_patched_data/training/poly/map_patches',
+                    '/projects/bbym/shared/all_patched_data/training/poly/legend',
+                    '/projects/bbym/shared/all_patched_data/training/poly/seg_patches')
+
+def load_validation_img(filename):
+    return load_img(filename, 
+                    '/projects/bbym/shared/all_patched_data/validation/poly/map_patches',
+                    '/projects/bbym/shared/all_patched_data/validation/poly/legend',
+                    '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches')
+
+                    
+train_map_file = os.listdir('/projects/bbym/shared/all_patched_data/training/poly/map_patches')
+random.shuffle(train_map_file)
+
+# Pre-filter map files based on existence of corresponding legend files
+train_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') 
+                            for x in train_map_file 
+                            if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/training/poly/legend', 
+                                                        '_'.join(x.split('_')[0:-2])+'.png'))]
+
+train_dataset = tf.data.Dataset.from_tensor_slices(train_map_legend_names)
+train_dataset = train_dataset.map(load_train_img)
+train_dataset = train_dataset.shuffle(5000, reshuffle_each_iteration=False).batch(128)
+
+validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches')
+
+# Pre-filter map files based on existence of corresponding legend files
+validate_map_legend_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') 
+                            for x in validate_map_file 
+                            if os.path.exists(os.path.join('/projects/bbym/shared/all_patched_data/validation/poly/legend', 
+                                                        '_'.join(x.split('_')[0:-2])+'.png'))]
+
+validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_legend_names)
+validate_dataset = validate_dataset.map(load_validation_img)
+validate_dataset = validate_dataset.batch(50)
+
+################################################################
+##### Prepare the model configurations #########################
+################################################################
+#You can change the id for each run so that all models and stats are saved separately.
+name_id = "VAE-unet"
+prediction_path = './predicts_'+name_id+'/'
+log_path = './logs_'+name_id+'/'
+model_path = './models_'+name_id+'/'
+save_model_path = './models_'+name_id+'/'
+
+# Create the folder if it does not exist
+os.makedirs(model_path, exist_ok=True)
+os.makedirs(prediction_path, exist_ok=True)
+
+name = 'VAE-unet'
+
+logdir = log_path + name
+
+if(os.path.isdir(logdir)):
+    shutil.rmtree(logdir)
+
+os.makedirs(logdir, exist_ok=True)
+tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
+
+print('model location: '+ model_path+name+'.h5')
+# define hyperparameters and callback modules
+patience = 10
+maxepoch = 500
+callbacks = [ReduceLROnPlateau(monitor='val_loss', factor=0.7, patience=patience, min_lr=1e-9, verbose=1, mode='min'),
+            EarlyStopping(monitor='val_loss', patience=patience, verbose=0),
+            ModelCheckpoint(model_path+name+'.h5', monitor='val_loss', save_best_only=True, verbose=0),
+            TensorBoard(log_dir=logdir)]
+
+train_history = model.fit(train_dataset, validation_data = validate_dataset, 
+                            batch_size = 16, epochs = maxepoch, verbose=1, 
+                            callbacks = callbacks)
\ No newline at end of file
diff --git a/data_util.py b/data_util.py
index bcc7243721ab780772ff807556c763e90b175c54..5adf7366a567d4a1caa834497752a26238ec2a4e 100644
--- a/data_util.py
+++ b/data_util.py
@@ -31,6 +31,7 @@ class DataLoader:
     process_legends(label_suffix, resize_to=(256, 256)): Processes legends based on label suffix.
     process_data(): Extracts and processes data from the loaded TIFF and JSON.
     get_processed_data(): Returns the processed data.
+    reconstruct_data(self, patches): unpatchify the prediction data. 
     """
     def __init__(self, tiff_path, json_path, patch_size=(256, 256, 3), overlap=30):
         """
diff --git a/eval_gan.py b/eval_gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10578914df1bc568869564e7c76f2701e3ce63d
--- /dev/null
+++ b/eval_gan.py
@@ -0,0 +1,162 @@
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+
+import os
+import shutil
+import random
+import numpy as np 
+import tensorflow as tf
+from datetime import datetime
+from keras import backend as K
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+import segmentation_models as sm
+from tensorflow.keras import layers
+from keras.models import load_model
+from tensorflow.keras.models import Model, load_model
+from tensorflow.keras.layers import Input, Conv2D, RandomFlip, RandomRotation
+from tensorflow.keras.optimizers import Adam, SGD, RMSprop
+from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
+from unet_util import dice_coef_loss, dice_coef, jacard_coef, dice_coef_loss, Residual_CNN_block, multiplication, attention_up_and_concatenate, multiplication2, attention_up_and_concatenate2, UNET_224, evaluate_prediction_result
+
+def f1_score(y_true, y_pred): # Dice coefficient
+    smooth = 1.
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+
+################################################################
+##### Prepare the model configurations #########################
+################################################################
+#You can change the id for each run so that all models and stats are saved separately.
+name_id = "GAN_test"
+prediction_path = './predicts_'+name_id+'/'
+model_path = './models_'+name_id+'/'
+save_model_path = './models_'+name_id+'/'
+
+os.makedirs(model_path, exist_ok=True)
+os.makedirs(prediction_path, exist_ok=True)
+
+print('model location: '+ model_path + name_id + '_generator.h5')
+
+generator = load_model(os.path.join(model_path, name_id + '_generator.h5'),  
+                    custom_objects={'dice_coef_loss': dice_coef_loss})
+
+generator.compile(optimizer = Adam(),
+                    loss = dice_coef_loss,
+                    metrics=[dice_coef,'accuracy', f1_score])
+
+z_dim = 100  # global variable
+
+def load_img_with_noise(filename, map_dir, legend_dir):
+    mapName = tf.strings.join([map_dir, filename[0]], separator='/')
+    legendName = tf.strings.join([legend_dir, filename[1]], separator='/')
+    
+    map_img = tf.io.read_file(mapName)
+    map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0
+
+    legend_img = tf.io.read_file(legendName)
+    legend_img = tf.cast(tf.io.decode_png(legend_img), dtype=tf.float32) / 255.0
+
+    map_img = tf.concat(axis=2, values=[map_img, legend_img])
+    map_img = map_img*2.0 - 1.0
+    map_img = tf.image.resize(map_img, [256, 256])
+
+    noise = tf.random.normal([1, z_dim])  # generating noise for one image
+
+    return map_img, noise
+
+def generate_image(filename, generator, map_dir, legend_dir):
+    map_img, noise = load_img_with_noise(filename, map_dir, legend_dir)
+    
+    # Ensure that map_img has batch dimension
+    map_img = tf.expand_dims(map_img, 0)
+    
+    generated_img = generator([map_img, noise], training=False)
+
+    # Removing the batch dimension for visualization
+    generated_img = tf.squeeze(generated_img)
+
+    return generated_img.numpy()
+
+# def load_validation_img(filename):
+#     return generate_image(filename, 
+#                             generator,  # Add the generator here
+#                             '/projects/bbym/shared/all_patched_data/validation/poly/map_patches',
+#                             '/projects/bbym/shared/all_patched_data/validation/poly/legend')
+
+validate_map_file = os.listdir('/projects/bbym/shared/all_patched_data/validation/poly/map_patches')
+validate_map_names = [(x, '_'.join(x.split('_')[0:-2])+'.png') for x in validate_map_file]
+# validate_dataset = tf.data.Dataset.from_tensor_slices(validate_map_names)
+# validate_dataset = validate_dataset.map(load_validation_img)
+# validate_dataset = validate_dataset.batch(50)
+
+print("Load Data Done!")
+
+
+print("Load Model Done!")
+
+def dice_coef(y_true, y_pred, smooth=1.0):
+    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
+    union = tf.reduce_sum(y_true + y_pred, axis=[1,2,3])
+    return tf.reduce_mean((2. * intersection + smooth) / (union + smooth), axis=0)
+
+def dice_coef_loss(y_true, y_pred):
+    return 1 - dice_coef(y_true, y_pred)
+
+# model.summary()
+# eval_results = model.evaluate(validate_dataset, verbose=1)
+# print(eval_results)
+# print(f'Validation F1 score: {f1}')
+
+# If validate_dataset is a tf.data.Dataset instance
+def plotResult(fileName, save_dir, generator):
+    map_img, noise = load_img_with_noise(fileName, 
+                                            '/projects/bbym/shared/all_patched_data/validation/poly/map_patches',
+                                            '/projects/bbym/shared/all_patched_data/validation/poly/legend')
+    generated_img = generator([tf.expand_dims(map_img, 0), noise], training=False)
+    predicted_binary = (generated_img > 0.5).numpy().astype(np.uint8)  # thresholding
+    
+    mapName = '/projects/bbym/shared/all_patched_data/validation/poly/map_patches/' + fileName[0]
+    segName = '/projects/bbym/shared/all_patched_data/validation/poly/seg_patches/' + fileName[0]
+    legendName = '/projects/bbym/shared/all_patched_data/validation/poly/legend/' + fileName[1]
+
+    map_img = mpimg.imread(mapName)
+    seg_img = mpimg.imread(segName)
+    legend_img = mpimg.imread(legendName)
+    
+    # Visualization
+    plt.figure(figsize=(10, 2))
+
+    plt.subplot(1, 5, 1)
+    plt.title("map")
+    plt.imshow(map_img)
+
+    plt.subplot(1, 5, 2)
+    plt.title("legend")
+    plt.imshow(legend_img)
+
+    plt.subplot(1, 5, 3)
+    plt.title("true segmentation")
+    plt.imshow(seg_img, cmap='gray')
+
+    plt.subplot(1, 5, 4)
+    plt.title("predicted segmentation")
+    plt.imshow(predicted_binary[0, :, :, 0]*255, cmap='gray') 
+
+    plt.subplot(1, 5, 5)
+    plt.title("error")
+    error_img = np.logical_xor(predicted_binary[0, :, :, 0], seg_img/255.0)  # assuming seg_img is in [0, 255]
+    plt.imshow(error_img, cmap='gray')
+
+    plt.savefig(save_dir + fileName[0] + '.png')
+    plt.close()
+
+n=20
+
+for fileName in random.sample(validate_map_names, n):
+    print(fileName)
+    plotResult(fileName, prediction_path, generator)  # Assuming generator is available in this context
+print("Save Images Done!")
+
diff --git a/two-h-model.py b/two-h-model.py
index e5d727d2d8545883bfcf8da7aad2939faee9c98a..e9a327191384f6e355067629a7601a853d416f56 100644
--- a/two-h-model.py
+++ b/two-h-model.py
@@ -73,8 +73,8 @@ def latent_sample(p):
     return mean + stddev * eps
 
 def variational_unet(input_shape):
-    input1 = tf.keras.Input(shape=input_shape) # legend patch + map patch
-    input2 = tf.keras.Input(shape=input_shape) # map patch
+    input1 = tf.keras.Input(shape=input_shape, name="legend_patch") # legend patch + map patch
+    input2 = tf.keras.Input(shape=input_shape, name="map_patch") # map patch
 
     # Encoder for input1
     x1 = layers.Conv2D(32, (1, 1), padding="same", activation="relu")(input1)  # Adjusting input depth to 3
@@ -107,11 +107,13 @@ def variational_unet(input_shape):
         bottleneck = decoder_block(bottleneck, skip, filters)
 
     # Final convolution to get the segmentation result
-    output = layers.Conv2D(1, (1, 1), activation="sigmoid")(bottleneck)
+    output = layers.Conv2D(1, (1, 1), activation="sigmoid", name="output")(bottleneck)
 
-    print("output-shape:", output.shape, qs.shape, ps.shape)
+    # combined_outputs = concatenate([output, qs, ps], name="combined_output")
 
-    return models.Model(inputs=[input1, input2], outputs=[output, qs, ps])
+    # print("output-shape:", output.shape, qs.shape, ps.shape)
+
+    return models.Model(inputs=[input1, input2], outputs=[output, [output, qs, ps]])
 
 # Instantiate the model with the input shape
 model = variational_unet((256, 256, 3))
@@ -294,8 +296,12 @@ tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
 
 print('model location: '+ model_path+name+'.h5')
 
+def dumm_loss(*args, **kwargs):
+    return 0.0
+
 # Compile the model with 'Adam' optimizer (0.001 is the default learning rate) and define the loss and metrics
-model.compile(optimizer = Adam(),  loss=CombinedLoss(),  metrics=[DiceCoefficientMetric(), SegmentationAccuracy()])
+model.compile(optimizer = Adam(),  
+            loss=[dice_coef_loss, CombinedLoss()])
 
 # define hyperparameters and callback modules
 patience = 10