diff options
author | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
---|---|---|
committer | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
commit | b85ee9d64a536937912544c7bbd5b98b635b7e8d (patch) | |
tree | cef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/sunflow/models/adversarial_autoencoder.py |
Initial commit
Diffstat (limited to 'code/sunlab/sunflow/models/adversarial_autoencoder.py')
-rw-r--r-- | code/sunlab/sunflow/models/adversarial_autoencoder.py | 344 |
1 files changed, 344 insertions, 0 deletions
diff --git a/code/sunlab/sunflow/models/adversarial_autoencoder.py b/code/sunlab/sunflow/models/adversarial_autoencoder.py new file mode 100644 index 0000000..4cbb2f8 --- /dev/null +++ b/code/sunlab/sunflow/models/adversarial_autoencoder.py @@ -0,0 +1,344 @@ +from sunlab.common.data.dataset import Dataset +from sunlab.common.scaler.adversarial_scaler import AdversarialScaler +from sunlab.common.distribution.adversarial_distribution import AdversarialDistribution +from .encoder import Encoder +from .decoder import Decoder +from .discriminator import Discriminator +from .encoder_discriminator import EncoderDiscriminator +from .autoencoder import Autoencoder +from tensorflow.keras import optimizers, metrics, losses +import tensorflow as tf +from numpy import ones, zeros, float32, NaN + + +class AdversarialAutoencoder: + """# Adversarial Autoencoder + - distribution: The distribution used by the adversary to learn on""" + + def __init__( + self, + model_base_directory, + distribution: AdversarialDistribution or None = None, + scaler: AdversarialScaler or None = None, + ): + """# Adversarial Autoencoder Model Initialization + + - model_base_directory: The base folder directory where the model will + be saved/ loaded + - distribution: The distribution the adversary will use + - scaler: The scaling function the model will assume on the data""" + self.model_base_directory = model_base_directory + if distribution is not None: + self.distribution = distribution + else: + self.distribution = None + if scaler is not None: + self.scaler = scaler(self.model_base_directory) + else: + self.scaler = None + + def init( + self, + data=None, + data_size=13, + autoencoder_layer_size=16, + adversary_layer_size=8, + latent_size=2, + autoencoder_depth=2, + dropout=0.0, + use_leaky_relu=False, + **kwargs, + ): + """# Initialize AAE model parameters + - data_size: int + - autoencoder_layer_size: int + - adversary_layer_size: int + - latent_size: int + - autoencoder_depth: int + - dropout: float + - use_leaky_relu: boolean""" + self.data_size = data_size + self.autoencoder_layer_size = autoencoder_layer_size + self.adversary_layer_size = adversary_layer_size + self.latent_size = latent_size + self.autoencoder_depth = autoencoder_depth + self.dropout = dropout + self.use_leaky_relu = use_leaky_relu + self.save_parameters() + self.encoder = Encoder(self.model_base_directory).init() + self.decoder = Decoder(self.model_base_directory).init() + self.autoencoder = Autoencoder(self.model_base_directory).init( + self.encoder, self.decoder + ) + self.discriminator = Discriminator(self.model_base_directory).init() + self.encoder_discriminator = EncoderDiscriminator( + self.model_base_directory + ).init(self.encoder, self.discriminator) + if self.distribution is not None: + self.distribution = self.distribution(self.latent_size) + if (data is not None) and (self.scaler is not None): + self.scaler = self.scaler.init(data) + self.init_optimizers_and_metrics(**kwargs) + return self + + def init_optimizers_and_metrics( + self, + optimizer=optimizers.Adam, + ae_metric=metrics.MeanAbsoluteError, + adv_metric=metrics.BinaryCrossentropy, + ae_lr=7e-4, + adv_lr=3e-4, + loss_fn=losses.BinaryCrossentropy, + **kwargs, + ): + """# Set the optimizer, loss function, and metrics""" + self.ae_optimizer = optimizer(learning_rate=ae_lr) + self.adv_optimizer = optimizer(learning_rate=adv_lr) + self.gan_optimizer = optimizer(learning_rate=adv_lr) + self.train_ae_metric = ae_metric() + self.val_ae_metric = ae_metric() + self.train_adv_metric = adv_metric() + self.val_adv_metric = adv_metric() + self.train_gan_metric = adv_metric() + self.val_gan_metric = adv_metric() + self.loss_fn = loss_fn() + + def load(self): + """# Load the models from their respective files""" + self.load_parameters() + self.encoder = Encoder(self.model_base_directory).load() + self.decoder = Decoder(self.model_base_directory).load() + self.autoencoder = Autoencoder(self.model_base_directory).load() + self.discriminator = Discriminator(self.model_base_directory).load() + self.encoder_discriminator = EncoderDiscriminator( + self.model_base_directory + ).load() + if self.scaler is not None: + self.scaler = self.scaler.load() + return self + + def save(self, overwrite=False): + """# Save each model in the AAE""" + self.encoder.save(overwrite=overwrite) + self.decoder.save(overwrite=overwrite) + self.autoencoder.save(overwrite=overwrite) + self.discriminator.save(overwrite=overwrite) + self.encoder_discriminator.save(overwrite=overwrite) + if self.scaler is not None: + self.scaler.save() + + def save_parameters(self): + """# Save the AAE parameters in a file""" + from pickle import dump + from os import makedirs + + makedirs(self.model_base_directory + "/portable/", exist_ok=True) + parameters = { + "data_size": self.data_size, + "autoencoder_layer_size": self.autoencoder_layer_size, + "adversary_layer_size": self.adversary_layer_size, + "latent_size": self.latent_size, + "autoencoder_depth": self.autoencoder_depth, + "dropout": self.dropout, + "use_leaky_relu": self.use_leaky_relu, + } + with open( + f"{self.model_base_directory}/portable/model_parameters.pkl", "wb" + ) as phandle: + dump(parameters, phandle) + + def load_parameters(self): + """# Load the AAE parameters from a file""" + from pickle import load + + with open( + f"{self.model_base_directory}/portable/model_parameters.pkl", "rb" + ) as phandle: + parameters = load(phandle) + self.data_size = parameters["data_size"] + self.autoencoder_layer_size = parameters["autoencoder_layer_size"] + self.adversary_layer_size = parameters["adversary_layer_size"] + self.latent_size = parameters["latent_size"] + self.autoencoder_depth = parameters["autoencoder_depth"] + self.dropout = parameters["dropout"] + self.use_leaky_relu = parameters["use_leaky_relu"] + return parameters + + def summary(self): + """# Summarize each model in the AAE""" + self.encoder.summary() + self.decoder.summary() + self.autoencoder.summary() + self.discriminator.summary() + self.encoder_discriminator.summary() + + @tf.function + def train_step(self, x, y): + """# Training Step + + 1. Train the Autoencoder + 2. (If distribution is given) Train the discriminator + 3. (If the distribution is given) Train the encoder_discriminator""" + # Autoencoder Training + with tf.GradientTape() as tape: + decoded_vector = self.autoencoder(x, training=True) + ae_loss_value = self.loss_fn(y, decoded_vector) + grads = tape.gradient(ae_loss_value, self.autoencoder.model.trainable_weights) + self.ae_optimizer.apply_gradients( + zip(grads, self.autoencoder.model.trainable_weights) + ) + self.train_ae_metric.update_state(y, decoded_vector) + if self.distribution is not None: + # Adversary Trainig + with tf.GradientTape() as tape: + latent_vector = self.encoder(x) + fakepred = self.distribution(x.shape[0]) + discbatch_x = tf.concat([latent_vector, fakepred], axis=0) + discbatch_y = tf.concat([zeros(x.shape[0]), ones(x.shape[0])], axis=0) + adversary_vector = self.discriminator(discbatch_x, training=True) + adv_loss_value = self.loss_fn(discbatch_y, adversary_vector) + grads = tape.gradient( + adv_loss_value, self.discriminator.model.trainable_weights + ) + self.adv_optimizer.apply_gradients( + zip(grads, self.discriminator.model.trainable_weights) + ) + self.train_adv_metric.update_state(discbatch_y, adversary_vector) + # Gan Training + with tf.GradientTape() as tape: + gan_vector = self.encoder_discriminator(x, training=True) + adv_vector = tf.convert_to_tensor(ones((x.shape[0], 1), dtype=float32)) + gan_loss_value = self.loss_fn(gan_vector, adv_vector) + grads = tape.gradient(gan_loss_value, self.encoder.model.trainable_weights) + self.gan_optimizer.apply_gradients( + zip(grads, self.encoder.model.trainable_weights) + ) + self.train_gan_metric.update_state(adv_vector, gan_vector) + return (ae_loss_value, adv_loss_value, gan_loss_value) + return (ae_loss_value, None, None) + + @tf.function + def test_step(self, x, y): + """# Test Step - On validation data + + 1. Evaluate the Autoencoder + 2. (If distribution is given) Evaluate the discriminator + 3. (If the distribution is given) Evaluate the encoder_discriminator""" + val_decoded_vector = self.autoencoder(x, training=False) + self.val_ae_metric.update_state(y, val_decoded_vector) + + if self.distribution is not None: + latent_vector = self.encoder(x) + fakepred = self.distribution(x.shape[0]) + discbatch_x = tf.concat([latent_vector, fakepred], axis=0) + discbatch_y = tf.concat([zeros(x.shape[0]), ones(x.shape[0])], axis=0) + adversary_vector = self.discriminator(discbatch_x, training=False) + self.val_adv_metric.update_state(discbatch_y, adversary_vector) + + gan_vector = self.encoder_discriminator(x, training=False) + self.val_gan_metric.update_state(ones(x.shape[0]), gan_vector) + + # Garbage Collect at the end of each epoch + def on_epoch_end(self, _epoch, logs=None): + """# Cleanup environment to prevent memory leaks each epoch""" + import gc + from tensorflow.keras import backend as k + + gc.collect() + k.clear_session() + + def train( + self, + dataset: Dataset, + epoch_count: int = 1, + output=False, + output_freq=1, + fmt="%i[%.3f]: %.2e %.2e %.2e %.2e %.2e %.2e", + ): + """# Train the model on a dataset + + - dataset: ataset = Dataset to train the model on, which as the + training and validation iterators set up + - epoch_count: int = The number of epochs to train + - output: boolean = Whether or not to output training information + - output_freq: int = The number of epochs between each output""" + from time import time + from numpy import array as narray + + def fmtter(x): + return x if x is not None else -1 + + epoch_data = [] + dataset.reset_iterators() + + self.test_step(dataset.dataset, dataset.dataset) + val_ae = self.val_ae_metric.result() + val_adv = self.val_adv_metric.result() + val_gan = self.val_gan_metric.result() + self.val_ae_metric.reset_states() + self.val_adv_metric.reset_states() + self.val_gan_metric.reset_states() + print( + fmt + % ( + 0, + NaN, + val_ae, + fmtter(val_adv), + fmtter(val_gan), + NaN, + NaN, + NaN, + ) + ) + for epoch in range(epoch_count): + start_time = time() + + for step, (x_batch_train, y_batch_train) in enumerate(dataset.training): + ae_lv, adv_lv, gan_lv = self.train_step(x_batch_train, x_batch_train) + + train_ae = self.train_ae_metric.result() + train_adv = self.train_adv_metric.result() + train_gan = self.train_gan_metric.result() + self.train_ae_metric.reset_states() + self.train_adv_metric.reset_states() + self.train_gan_metric.reset_states() + + for step, (x_batch_val, y_batch_val) in enumerate(dataset.validation): + self.test_step(x_batch_val, x_batch_val) + + val_ae = self.val_ae_metric.result() + val_adv = self.val_adv_metric.result() + val_gan = self.val_gan_metric.result() + self.val_ae_metric.reset_states() + self.val_adv_metric.reset_states() + self.val_gan_metric.reset_states() + + epoch_data.append( + ( + epoch, + train_ae, + val_ae, + fmtter(train_adv), + fmtter(val_adv), + fmtter(train_gan), + fmtter(val_gan), + ) + ) + if output and (epoch + 1) % output_freq == 0: + print( + fmt + % ( + epoch + 1, + time() - start_time, + train_ae, + fmtter(train_adv), + fmtter(train_gan), + val_ae, + fmtter(val_adv), + fmtter(val_gan), + ) + ) + self.on_epoch_end(epoch) + dataset.reset_iterators() + return narray(epoch_data) |