aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/suntorch/models/convolutional/variational/autoencoder.py')
-rw-r--r--code/sunlab/suntorch/models/convolutional/variational/autoencoder.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py
new file mode 100644
index 0000000..970f717
--- /dev/null
+++ b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py
@@ -0,0 +1,190 @@
+import torch
+from torch import nn
+
+
+class ConvolutionalVariationalAutoencoder(nn.Module):
+ def __init__(self, latent_dims, hidden_dims, image_shape, dropout=0.0):
+ super(ConvolutionalVariationalAutoencoder, self).__init__()
+
+ self.latent_dims = latent_dims # Size of the latent space layer
+ self.hidden_dims = (
+ hidden_dims # List of hidden layers number of filters/channels
+ )
+ self.image_shape = image_shape # Input image shape
+
+ self.last_channels = self.hidden_dims[-1]
+ self.in_channels = self.image_shape[0]
+ # Simple formula to get the number of neurons after the last convolution layer is flattened
+ self.flattened_channels = int(
+ self.last_channels
+ * (self.image_shape[1] / (2 ** len(self.hidden_dims))) ** 2
+ )
+
+ # For each hidden layer we will create a Convolution Block
+ modules = []
+ for h_dim in self.hidden_dims:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=h_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ nn.BatchNorm2d(h_dim),
+ nn.LeakyReLU(),
+ nn.Dropout(p=dropout),
+ )
+ )
+
+ self.in_channels = h_dim
+
+ self.encoder = nn.Sequential(*modules)
+
+ # Here are our layers for our latent space distribution
+ self.fc_mu = nn.Linear(self.flattened_channels, latent_dims)
+ self.fc_var = nn.Linear(self.flattened_channels, latent_dims)
+
+ # Decoder input layer
+ self.decoder_input = nn.Linear(latent_dims, self.flattened_channels)
+
+ # For each Convolution Block created on the Encoder we will do a symmetric Decoder with the same Blocks, but using ConvTranspose
+ self.hidden_dims.reverse()
+ modules = []
+ for h_dim in self.hidden_dims:
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=self.in_channels,
+ out_channels=h_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ ),
+ nn.BatchNorm2d(h_dim),
+ nn.LeakyReLU(),
+ nn.Dropout(p=dropout),
+ )
+ )
+
+ self.in_channels = h_dim
+
+ self.decoder = nn.Sequential(*modules)
+
+ # The final layer the reconstructed image have the same dimensions as the input image
+ self.final_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.image_shape[0],
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.Sigmoid(),
+ )
+
+ def get_latent_dims(self):
+
+ return self.latent_dims
+
+ def encode(self, input):
+ """
+ Encodes the input by passing through the encoder network
+ and returns the latent codes.
+ """
+ result = self.encoder(input)
+ result = torch.flatten(result, start_dim=1)
+ # Split the result into mu and var componentsbof the latent Gaussian distribution
+ mu = self.fc_mu(result)
+ log_var = self.fc_var(result)
+
+ return [mu, log_var]
+
+ def decode(self, z):
+ """
+ Maps the given latent codes onto the image space.
+ """
+ result = self.decoder_input(z)
+ result = result.view(
+ -1,
+ self.last_channels,
+ int(self.image_shape[1] / (2 ** len(self.hidden_dims))),
+ int(self.image_shape[1] / (2 ** len(self.hidden_dims))),
+ )
+ result = self.decoder(result)
+ result = self.final_layer(result)
+
+ return result
+
+ def reparameterize(self, mu, log_var):
+ """
+ Reparameterization trick to sample from N(mu, var) from N(0,1).
+ """
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+
+ return mu + eps * std
+
+ def forward(self, input):
+ """
+ Forward method which will encode and decode our image.
+ """
+ mu, log_var = self.encode(input)
+ z = self.reparameterize(mu, log_var)
+
+ return [self.decode(z), input, mu, log_var, z]
+
+ def loss_function(self, recons, input, mu, log_var):
+ """
+ Computes VAE loss function
+ """
+ recons_loss = nn.functional.binary_cross_entropy(
+ recons.reshape(recons.shape[0], -1),
+ input.reshape(input.shape[0], -1),
+ reduction="none",
+ ).sum(dim=-1)
+
+ kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
+
+ loss = (recons_loss + kld_loss).mean(dim=0)
+
+ return loss
+
+ def sample(self, num_samples, device):
+ """
+ Samples from the latent space and return the corresponding
+ image space map.
+ """
+ z = torch.randn(num_samples, self.latent_dims)
+ z = z.to(device)
+ samples = self.decode(z)
+
+ return samples
+
+ def generate(self, x):
+ """
+ Given an input image x, returns the reconstructed image
+ """
+ return self.forward(x)[0]
+
+ def interpolate(self, starting_inputs, ending_inputs, device, granularity=10):
+ """This function performs a linear interpolation in the latent space of the autoencoder
+ from starting inputs to ending inputs. It returns the interpolation trajectories.
+ """
+ mu, log_var = self.encode(starting_inputs.to(device))
+ starting_z = self.reparameterize(mu, log_var)
+
+ mu, log_var = self.encode(ending_inputs.to(device))
+ ending_z = self.reparameterize(mu, log_var)
+
+ t = torch.linspace(0, 1, granularity).to(device)
+
+ intep_line = torch.kron(
+ starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1)
+ ) + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1))
+
+ decoded_line = self.decode(intep_line).reshape(
+ (starting_inputs.shape[0], t.shape[0]) + (starting_inputs.shape[1:])
+ )
+ return decoded_line