aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/suntorch
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/suntorch')
-rw-r--r--code/sunlab/suntorch/__init__.py3
-rw-r--r--code/sunlab/suntorch/data/__init__.py1
-rw-r--r--code/sunlab/suntorch/data/utilities.py55
-rw-r--r--code/sunlab/suntorch/models/__init__.py3
-rw-r--r--code/sunlab/suntorch/models/adversarial_autoencoder.py165
-rw-r--r--code/sunlab/suntorch/models/autoencoder.py114
-rw-r--r--code/sunlab/suntorch/models/common.py12
-rw-r--r--code/sunlab/suntorch/models/convolutional/variational/autoencoder.py190
-rw-r--r--code/sunlab/suntorch/models/decoder.py33
-rw-r--r--code/sunlab/suntorch/models/discriminator.py32
-rw-r--r--code/sunlab/suntorch/models/encoder.py32
-rw-r--r--code/sunlab/suntorch/models/variational/autoencoder.py128
-rw-r--r--code/sunlab/suntorch/models/variational/common.py12
-rw-r--r--code/sunlab/suntorch/models/variational/decoder.py33
-rw-r--r--code/sunlab/suntorch/models/variational/encoder.py34
-rw-r--r--code/sunlab/suntorch/plotting/__init__.py1
-rw-r--r--code/sunlab/suntorch/plotting/model_extensions.py34
17 files changed, 882 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/__init__.py b/code/sunlab/suntorch/__init__.py
new file mode 100644
index 0000000..d394e27
--- /dev/null
+++ b/code/sunlab/suntorch/__init__.py
@@ -0,0 +1,3 @@
+from ..common import *
+from .models import *
+from .plotting import *
diff --git a/code/sunlab/suntorch/data/__init__.py b/code/sunlab/suntorch/data/__init__.py
new file mode 100644
index 0000000..b9a32c0
--- /dev/null
+++ b/code/sunlab/suntorch/data/__init__.py
@@ -0,0 +1 @@
+from .utilities import *
diff --git a/code/sunlab/suntorch/data/utilities.py b/code/sunlab/suntorch/data/utilities.py
new file mode 100644
index 0000000..05318f5
--- /dev/null
+++ b/code/sunlab/suntorch/data/utilities.py
@@ -0,0 +1,55 @@
+from sunlab.common import ShapeDataset
+from sunlab.common import MaxAbsScaler
+
+
+def process_and_load_dataset(
+ dataset_file, model_folder, magnification=10, scaler=MaxAbsScaler
+):
+ """# Load a dataset and process a models' Latent Space on the Dataset"""
+ raise NotImplemented("Still Implementing for PyTorch")
+ from ..models import load_aae
+ from sunlab.common import import_full_dataset
+
+ model = load_aae(model_folder, normalization_scaler=scaler)
+ dataset = import_full_dataset(
+ dataset_file, magnification=magnification, scaler=model.scaler
+ )
+ latent = model.encoder(dataset.dataset).numpy()
+ assert len(latent.shape) == 2, "Only 1D Latent Vectors Supported"
+ for dim in range(latent.shape[1]):
+ dataset.dataframe[f"Latent-{dim}"] = latent[:, dim]
+ return dataset
+
+
+def process_and_load_datasets(
+ dataset_file_list, model_folder, magnification=10, scaler=MaxAbsScaler
+):
+ from pandas import concat
+ from ..models import load_aae
+
+ raise NotImplemented("Still Implementing for PyTorch")
+ dataframes = []
+ datasets = []
+ for dataset_file in dataset_file_list:
+ dataset = process_and_load_dataset(
+ dataset_file, model_folder, magnification, scaler
+ )
+ model = load_aae(model_folder, normalization_scaler=scaler)
+ dataframe = dataset.dataframe
+ for label in ["ActinEdge", "Filopodia", "Bleb", "Lamellipodia"]:
+ if label in dataframe.columns:
+ dataframe[label.lower()] = dataframe[label]
+ if label.lower() not in dataframe.columns:
+ dataframe[label.lower()] = 0
+ latent_columns = [f"Latent-{dim}" for dim in range(model.latent_size)]
+ datasets.append(dataset)
+ dataframes.append(
+ dataframe[
+ dataset.data_columns
+ + dataset.label_columns
+ + latent_columns
+ + ["Frames", "CellNum"]
+ + ["actinedge", "filopodia", "bleb", "lamellipodia"]
+ ]
+ )
+ return datasets, concat(dataframes)
diff --git a/code/sunlab/suntorch/models/__init__.py b/code/sunlab/suntorch/models/__init__.py
new file mode 100644
index 0000000..03d6a45
--- /dev/null
+++ b/code/sunlab/suntorch/models/__init__.py
@@ -0,0 +1,3 @@
+from .adversarial_autoencoder import AdversarialAutoencoder
+from .autoencoder import Autoencoder
+from .variational.autoencoder import VariationalAutoencoder
diff --git a/code/sunlab/suntorch/models/adversarial_autoencoder.py b/code/sunlab/suntorch/models/adversarial_autoencoder.py
new file mode 100644
index 0000000..e3e8a06
--- /dev/null
+++ b/code/sunlab/suntorch/models/adversarial_autoencoder.py
@@ -0,0 +1,165 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .encoder import Encoder
+from .decoder import Decoder
+from .discriminator import Discriminator
+from .common import *
+
+
+def dummy_distribution(*args):
+ raise NotImplementedError("Give a distribution")
+
+
+class AdversarialAutoencoder:
+ """# Adversarial Autoencoder Model"""
+
+ def __init__(
+ self,
+ data_dim,
+ latent_dim,
+ enc_dec_size,
+ disc_size,
+ negative_slope=0.3,
+ dropout=0.0,
+ distribution=dummy_distribution,
+ ):
+ self.encoder = Encoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.decoder = Decoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.discriminator = Discriminator(
+ disc_size, latent_dim, negative_slope=negative_slope, dropout=dropout
+ )
+ self.data_dim = data_dim
+ self.latent_dim = latent_dim
+ self.p = dropout
+ self.negative_slope = negative_slope
+ self.distribution = distribution
+ return
+
+ def parameters(self):
+ return (
+ *self.encoder.parameters(),
+ *self.decoder.parameters(),
+ *self.discriminator.parameters(),
+ )
+
+ def train(self):
+ self.encoder.train(True)
+ self.decoder.train(True)
+ self.discriminator.train(True)
+ return self
+
+ def eval(self):
+ self.encoder.train(False)
+ self.decoder.train(False)
+ self.discriminator.train(False)
+ return self
+
+ def encode(self, data):
+ return self.encoder(data)
+
+ def decode(self, data):
+ return self.decoder(data)
+
+ def __call__(self, data):
+ return self.decode(self.encode(data))
+
+ def save(self, base="./"):
+ torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
+ torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
+ torch.save(self.discriminator.state_dict(), base + "weights_discriminator.pt")
+ return self
+
+ def load(self, base="./"):
+ self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
+ self.encoder.eval()
+ self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
+ self.decoder.eval()
+ self.discriminator.load_state_dict(
+ torch.load(base + "weights_discriminator.pt")
+ )
+ self.discriminator.eval()
+ return self
+
+ def to(self, device):
+ self.encoder.to(device)
+ self.decoder.to(device)
+ self.discriminator.to(device)
+ return self
+
+ def cuda(self):
+ if torch.cuda.is_available():
+ self.encoder = self.encoder.cuda()
+ self.decoder = self.decoder.cuda()
+ self.discriminator = self.discriminator.cuda()
+ return self
+
+ def cpu(self):
+ self.encoder = self.encoder.cpu()
+ self.decoder = self.decoder.cpu()
+ self.discriminator = self.discriminator.cpu()
+ return self
+
+ def init_optimizers(self, recon_lr=1e-4, adv_lr=5e-5):
+ self.optim_E_gen = torch.optim.Adam(self.encoder.parameters(), lr=adv_lr)
+ self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
+ self.optim_D_dec = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
+ self.optim_D_dis = torch.optim.Adam(self.discriminator.parameters(), lr=adv_lr)
+ return self
+
+ def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
+ self.recon_loss_fn = recon_loss_fn
+ return self
+
+ def train_step(self, raw_data, scale=1.0):
+ data = to_var(raw_data.view(raw_data.size(0), -1))
+
+ self.encoder.zero_grad()
+ self.decoder.zero_grad()
+ self.discriminator.zero_grad()
+
+ z = self.encoder(data)
+ X = self.decoder(z)
+ self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
+ self.recon_loss.backward()
+ self.optim_E_enc.step()
+ self.optim_D_dec.step()
+
+ self.encoder.eval()
+ z_gaussian = to_var(self.distribution(data.size(0), self.latent_dim) * scale)
+ z_gaussian_fake = self.encoder(data)
+ y_gaussian = self.discriminator(z_gaussian)
+ y_gaussian_fake = self.discriminator(z_gaussian_fake)
+ self.D_loss = -torch.mean(
+ torch.log(y_gaussian + EPS) + torch.log(1 - y_gaussian_fake + EPS)
+ )
+ self.D_loss.backward()
+ self.optim_D_dis.step()
+
+ self.encoder.train()
+ z_gaussian = self.encoder(data)
+ y_gaussian = self.discriminator(z_gaussian)
+ self.G_loss = -torch.mean(torch.log(y_gaussian + EPS))
+ self.G_loss.backward()
+ self.optim_E_gen.step()
+ return
+
+ def losses(self):
+ try:
+ return self.recon_loss, self.D_loss, self.G_loss
+ except:
+ ...
+ return
diff --git a/code/sunlab/suntorch/models/autoencoder.py b/code/sunlab/suntorch/models/autoencoder.py
new file mode 100644
index 0000000..232f180
--- /dev/null
+++ b/code/sunlab/suntorch/models/autoencoder.py
@@ -0,0 +1,114 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .encoder import Encoder
+from .decoder import Decoder
+from .common import *
+
+
+class Autoencoder:
+ """# Autoencoder Model"""
+
+ def __init__(
+ self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0
+ ):
+ self.encoder = Encoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.decoder = Decoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.data_dim = data_dim
+ self.latent_dim = latent_dim
+ self.p = dropout
+ self.negative_slope = negative_slope
+ return
+
+ def parameters(self):
+ return (*self.encoder.parameters(), *self.decoder.parameters())
+
+ def train(self):
+ self.encoder.train(True)
+ self.decoder.train(True)
+ return self
+
+ def eval(self):
+ self.encoder.train(False)
+ self.decoder.train(False)
+ return self
+
+ def encode(self, data):
+ return self.encoder(data)
+
+ def decode(self, data):
+ return self.decoder(data)
+
+ def __call__(self, data):
+ return self.decode(self.encode(data))
+
+ def save(self, base="./"):
+ torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
+ torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
+ return self
+
+ def load(self, base="./"):
+ self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
+ self.encoder.eval()
+ self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
+ self.decoder.eval()
+ return self
+
+ def to(self, device):
+ self.encoder.to(device)
+ self.decoder.to(device)
+ self.discriminator.to(device)
+ return self
+
+ def cuda(self):
+ if torch.cuda.is_available():
+ self.encoder = self.encoder.cuda()
+ self.decoder = self.decoder.cuda()
+ return self
+
+ def cpu(self):
+ self.encoder = self.encoder.cpu()
+ self.decoder = self.decoder.cpu()
+ return self
+
+ def init_optimizers(self, recon_lr=1e-4):
+ self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
+ self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
+ return self
+
+ def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
+ self.recon_loss_fn = recon_loss_fn
+ return self
+
+ def train_step(self, raw_data):
+ data = to_var(raw_data.view(raw_data.size(0), -1))
+
+ self.encoder.zero_grad()
+ self.decoder.zero_grad()
+ z = self.encoder(data)
+ X = self.decoder(z)
+ self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
+ self.recon_loss.backward()
+ self.optim_E_enc.step()
+ self.optim_D.step()
+ return
+
+ def losses(self):
+ try:
+ return self.recon_loss
+ except:
+ ...
+ return
diff --git a/code/sunlab/suntorch/models/common.py b/code/sunlab/suntorch/models/common.py
new file mode 100644
index 0000000..f10930e
--- /dev/null
+++ b/code/sunlab/suntorch/models/common.py
@@ -0,0 +1,12 @@
+from torch.autograd import Variable
+
+EPS = 1e-15
+
+
+def to_var(x):
+ """# Convert to variable"""
+ import torch
+
+ if torch.cuda.is_available():
+ x = x.cuda()
+ return Variable(x)
diff --git a/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py
new file mode 100644
index 0000000..970f717
--- /dev/null
+++ b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py
@@ -0,0 +1,190 @@
+import torch
+from torch import nn
+
+
+class ConvolutionalVariationalAutoencoder(nn.Module):
+ def __init__(self, latent_dims, hidden_dims, image_shape, dropout=0.0):
+ super(ConvolutionalVariationalAutoencoder, self).__init__()
+
+ self.latent_dims = latent_dims # Size of the latent space layer
+ self.hidden_dims = (
+ hidden_dims # List of hidden layers number of filters/channels
+ )
+ self.image_shape = image_shape # Input image shape
+
+ self.last_channels = self.hidden_dims[-1]
+ self.in_channels = self.image_shape[0]
+ # Simple formula to get the number of neurons after the last convolution layer is flattened
+ self.flattened_channels = int(
+ self.last_channels
+ * (self.image_shape[1] / (2 ** len(self.hidden_dims))) ** 2
+ )
+
+ # For each hidden layer we will create a Convolution Block
+ modules = []
+ for h_dim in self.hidden_dims:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=h_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ nn.BatchNorm2d(h_dim),
+ nn.LeakyReLU(),
+ nn.Dropout(p=dropout),
+ )
+ )
+
+ self.in_channels = h_dim
+
+ self.encoder = nn.Sequential(*modules)
+
+ # Here are our layers for our latent space distribution
+ self.fc_mu = nn.Linear(self.flattened_channels, latent_dims)
+ self.fc_var = nn.Linear(self.flattened_channels, latent_dims)
+
+ # Decoder input layer
+ self.decoder_input = nn.Linear(latent_dims, self.flattened_channels)
+
+ # For each Convolution Block created on the Encoder we will do a symmetric Decoder with the same Blocks, but using ConvTranspose
+ self.hidden_dims.reverse()
+ modules = []
+ for h_dim in self.hidden_dims:
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=self.in_channels,
+ out_channels=h_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ ),
+ nn.BatchNorm2d(h_dim),
+ nn.LeakyReLU(),
+ nn.Dropout(p=dropout),
+ )
+ )
+
+ self.in_channels = h_dim
+
+ self.decoder = nn.Sequential(*modules)
+
+ # The final layer the reconstructed image have the same dimensions as the input image
+ self.final_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.image_shape[0],
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.Sigmoid(),
+ )
+
+ def get_latent_dims(self):
+
+ return self.latent_dims
+
+ def encode(self, input):
+ """
+ Encodes the input by passing through the encoder network
+ and returns the latent codes.
+ """
+ result = self.encoder(input)
+ result = torch.flatten(result, start_dim=1)
+ # Split the result into mu and var componentsbof the latent Gaussian distribution
+ mu = self.fc_mu(result)
+ log_var = self.fc_var(result)
+
+ return [mu, log_var]
+
+ def decode(self, z):
+ """
+ Maps the given latent codes onto the image space.
+ """
+ result = self.decoder_input(z)
+ result = result.view(
+ -1,
+ self.last_channels,
+ int(self.image_shape[1] / (2 ** len(self.hidden_dims))),
+ int(self.image_shape[1] / (2 ** len(self.hidden_dims))),
+ )
+ result = self.decoder(result)
+ result = self.final_layer(result)
+
+ return result
+
+ def reparameterize(self, mu, log_var):
+ """
+ Reparameterization trick to sample from N(mu, var) from N(0,1).
+ """
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+
+ return mu + eps * std
+
+ def forward(self, input):
+ """
+ Forward method which will encode and decode our image.
+ """
+ mu, log_var = self.encode(input)
+ z = self.reparameterize(mu, log_var)
+
+ return [self.decode(z), input, mu, log_var, z]
+
+ def loss_function(self, recons, input, mu, log_var):
+ """
+ Computes VAE loss function
+ """
+ recons_loss = nn.functional.binary_cross_entropy(
+ recons.reshape(recons.shape[0], -1),
+ input.reshape(input.shape[0], -1),
+ reduction="none",
+ ).sum(dim=-1)
+
+ kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
+
+ loss = (recons_loss + kld_loss).mean(dim=0)
+
+ return loss
+
+ def sample(self, num_samples, device):
+ """
+ Samples from the latent space and return the corresponding
+ image space map.
+ """
+ z = torch.randn(num_samples, self.latent_dims)
+ z = z.to(device)
+ samples = self.decode(z)
+
+ return samples
+
+ def generate(self, x):
+ """
+ Given an input image x, returns the reconstructed image
+ """
+ return self.forward(x)[0]
+
+ def interpolate(self, starting_inputs, ending_inputs, device, granularity=10):
+ """This function performs a linear interpolation in the latent space of the autoencoder
+ from starting inputs to ending inputs. It returns the interpolation trajectories.
+ """
+ mu, log_var = self.encode(starting_inputs.to(device))
+ starting_z = self.reparameterize(mu, log_var)
+
+ mu, log_var = self.encode(ending_inputs.to(device))
+ ending_z = self.reparameterize(mu, log_var)
+
+ t = torch.linspace(0, 1, granularity).to(device)
+
+ intep_line = torch.kron(
+ starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1)
+ ) + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1))
+
+ decoded_line = self.decode(intep_line).reshape(
+ (starting_inputs.shape[0], t.shape[0]) + (starting_inputs.shape[1:])
+ )
+ return decoded_line
diff --git a/code/sunlab/suntorch/models/decoder.py b/code/sunlab/suntorch/models/decoder.py
new file mode 100644
index 0000000..2eeb7a4
--- /dev/null
+++ b/code/sunlab/suntorch/models/decoder.py
@@ -0,0 +1,33 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import sigmoid
+
+
+class Decoder(nn.Module):
+ """# Decoder Neural Network
+ X_dim: Output dimension shape
+ N: Inner neuronal layer size
+ z_dim: Input dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Decoder, self).__init__()
+ self.lin1 = nn.Linear(z_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3 = nn.Linear(N, X_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin3(x)
+ return sigmoid(x)
diff --git a/code/sunlab/suntorch/models/discriminator.py b/code/sunlab/suntorch/models/discriminator.py
new file mode 100644
index 0000000..9249095
--- /dev/null
+++ b/code/sunlab/suntorch/models/discriminator.py
@@ -0,0 +1,32 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import sigmoid
+
+
+class Discriminator(nn.Module):
+ """# Discriminator Neural Network
+ N: Inner neuronal layer size
+ z_dim: Input dimension shape
+ """
+
+ def __init__(self, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Discriminator, self).__init__()
+ self.lin1 = nn.Linear(z_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3 = nn.Linear(N, 1)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin3(x)
+ return sigmoid(x)
diff --git a/code/sunlab/suntorch/models/encoder.py b/code/sunlab/suntorch/models/encoder.py
new file mode 100644
index 0000000..e6f88c7
--- /dev/null
+++ b/code/sunlab/suntorch/models/encoder.py
@@ -0,0 +1,32 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ """# Encoder Neural Network
+ X_dim: Input dimension shape
+ N: Inner neuronal layer size
+ z_dim: Output dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Encoder, self).__init__()
+ self.lin1 = nn.Linear(X_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3gauss = nn.Linear(N, z_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ xgauss = self.lin3gauss(x)
+ return xgauss
diff --git a/code/sunlab/suntorch/models/variational/autoencoder.py b/code/sunlab/suntorch/models/variational/autoencoder.py
new file mode 100644
index 0000000..e335704
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/autoencoder.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .encoder import Encoder
+from .decoder import Decoder
+from .common import *
+
+
+class VariationalAutoencoder:
+ """# Variational Autoencoder Model"""
+
+ def __init__(
+ self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0
+ ):
+ self.encoder = Encoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.decoder = Decoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.data_dim = data_dim
+ self.latent_dim = latent_dim
+ self.p = dropout
+ self.negative_slope = negative_slope
+ return
+
+ def parameters(self):
+ return (*self.encoder.parameters(), *self.decoder.parameters())
+
+ def train(self):
+ self.encoder.train(True)
+ self.decoder.train(True)
+ return self
+
+ def eval(self):
+ self.encoder.train(False)
+ self.decoder.train(False)
+ return self
+
+ def encode(self, data):
+ return self.encoder(data)
+
+ def decode(self, data):
+ return self.decoder(data)
+
+ def reparameterization(self, mean, var):
+ epsilon = torch.randn_like(var)
+ if torch.cuda.is_available():
+ epsilon = epsilon.cuda()
+ z = mean + var * epsilon
+ return z
+
+ def forward(self, x):
+ mean, log_var = self.encoder(x)
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ X = self.decoder(z)
+ return X, mean, log_var
+
+ def __call__(self, data):
+ return self.forward(data)
+
+ def save(self, base="./"):
+ torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
+ torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
+ return self
+
+ def load(self, base="./"):
+ self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
+ self.encoder.eval()
+ self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
+ self.decoder.eval()
+ return self
+
+ def to(self, device):
+ self.encoder.to(device)
+ self.decoder.to(device)
+ return self
+
+ def cuda(self):
+ if torch.cuda.is_available():
+ self.encoder = self.encoder.cuda()
+ self.decoder = self.decoder.cuda()
+ return self
+
+ def cpu(self):
+ self.encoder = self.encoder.cpu()
+ self.decoder = self.decoder.cpu()
+ return self
+
+ def init_optimizers(self, recon_lr=1e-4):
+ self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
+ self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
+ return self
+
+ def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
+ self.recon_loss_fn = recon_loss_fn
+ return self
+
+ def train_step(self, raw_data):
+ data = to_var(raw_data.view(raw_data.size(0), -1))
+
+ self.encoder.zero_grad()
+ self.decoder.zero_grad()
+ X, _, _ = self.forward(data)
+ # mean, log_var = self.encoder(data)
+ # z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ # X = self.decoder(z)
+ self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
+ self.recon_loss.backward()
+ self.optim_E_enc.step()
+ self.optim_D.step()
+ return
+
+ def losses(self):
+ try:
+ return self.recon_loss
+ except:
+ ...
+ return
diff --git a/code/sunlab/suntorch/models/variational/common.py b/code/sunlab/suntorch/models/variational/common.py
new file mode 100644
index 0000000..f10930e
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/common.py
@@ -0,0 +1,12 @@
+from torch.autograd import Variable
+
+EPS = 1e-15
+
+
+def to_var(x):
+ """# Convert to variable"""
+ import torch
+
+ if torch.cuda.is_available():
+ x = x.cuda()
+ return Variable(x)
diff --git a/code/sunlab/suntorch/models/variational/decoder.py b/code/sunlab/suntorch/models/variational/decoder.py
new file mode 100644
index 0000000..2eeb7a4
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/decoder.py
@@ -0,0 +1,33 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import sigmoid
+
+
+class Decoder(nn.Module):
+ """# Decoder Neural Network
+ X_dim: Output dimension shape
+ N: Inner neuronal layer size
+ z_dim: Input dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Decoder, self).__init__()
+ self.lin1 = nn.Linear(z_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3 = nn.Linear(N, X_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin3(x)
+ return sigmoid(x)
diff --git a/code/sunlab/suntorch/models/variational/encoder.py b/code/sunlab/suntorch/models/variational/encoder.py
new file mode 100644
index 0000000..b08202f
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/encoder.py
@@ -0,0 +1,34 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ """# Encoder Neural Network
+ X_dim: Input dimension shape
+ N: Inner neuronal layer size
+ z_dim: Output dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Encoder, self).__init__()
+ self.lin1 = nn.Linear(X_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3mu = nn.Linear(N, z_dim)
+ self.lin3sigma = nn.Linear(N, z_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ mu = self.lin3mu(x)
+ sigma = self.lin3sigma(x)
+ return mu, sigma
diff --git a/code/sunlab/suntorch/plotting/__init__.py b/code/sunlab/suntorch/plotting/__init__.py
new file mode 100644
index 0000000..36e00e6
--- /dev/null
+++ b/code/sunlab/suntorch/plotting/__init__.py
@@ -0,0 +1 @@
+from .model_extensions import *
diff --git a/code/sunlab/suntorch/plotting/model_extensions.py b/code/sunlab/suntorch/plotting/model_extensions.py
new file mode 100644
index 0000000..33f0191
--- /dev/null
+++ b/code/sunlab/suntorch/plotting/model_extensions.py
@@ -0,0 +1,34 @@
+from matplotlib import pyplot as plt
+from sunlab.common.data.shape_dataset import ShapeDataset
+from sunlab.globals import DIR_ROOT
+
+
+def apply_boundary(
+ model_loc=DIR_ROOT + "models/current_model/",
+ border_thickness=3,
+ include_transition_regions=False,
+ threshold=0.7,
+ alpha=1,
+ _plt=None,
+):
+ """# Apply Boundary to Plot
+
+ Use Pregenerated Boundary by Default for Speed"""
+ from sunlab.common.scaler import MaxAbsScaler
+ import numpy as np
+
+ if _plt is None:
+ _plt = plt
+ if (model_loc == model_loc) and (border_thickness == 3) and (threshold == 0.7):
+ XYM = np.load(DIR_ROOT + "extra_data/OutlineXYM.npy")
+ XY = XYM[:2, :, :]
+ if include_transition_regions:
+ outline = XYM[3, :, :]
+ else:
+ outline = XYM[2, :, :]
+ _plt.pcolor(XY[0, :, :], XY[1, :, :], outline, cmap="gray", alpha=alpha)
+ return
+ raise NotImplemented("Not Yet Implemented for PyTorch!")
+
+
+plt.apply_boundary = apply_boundary