diff options
Diffstat (limited to 'code/sunlab/suntorch/models')
-rw-r--r-- | code/sunlab/suntorch/models/__init__.py | 3 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/adversarial_autoencoder.py | 165 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/autoencoder.py | 114 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/common.py | 12 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/convolutional/variational/autoencoder.py | 190 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/decoder.py | 33 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/discriminator.py | 32 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/encoder.py | 32 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/autoencoder.py | 128 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/common.py | 12 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/decoder.py | 33 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/encoder.py | 34 |
12 files changed, 788 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/__init__.py b/code/sunlab/suntorch/models/__init__.py new file mode 100644 index 0000000..03d6a45 --- /dev/null +++ b/code/sunlab/suntorch/models/__init__.py @@ -0,0 +1,3 @@ +from .adversarial_autoencoder import AdversarialAutoencoder +from .autoencoder import Autoencoder +from .variational.autoencoder import VariationalAutoencoder diff --git a/code/sunlab/suntorch/models/adversarial_autoencoder.py b/code/sunlab/suntorch/models/adversarial_autoencoder.py new file mode 100644 index 0000000..e3e8a06 --- /dev/null +++ b/code/sunlab/suntorch/models/adversarial_autoencoder.py @@ -0,0 +1,165 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from .encoder import Encoder +from .decoder import Decoder +from .discriminator import Discriminator +from .common import * + + +def dummy_distribution(*args): + raise NotImplementedError("Give a distribution") + + +class AdversarialAutoencoder: + """# Adversarial Autoencoder Model""" + + def __init__( + self, + data_dim, + latent_dim, + enc_dec_size, + disc_size, + negative_slope=0.3, + dropout=0.0, + distribution=dummy_distribution, + ): + self.encoder = Encoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.decoder = Decoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.discriminator = Discriminator( + disc_size, latent_dim, negative_slope=negative_slope, dropout=dropout + ) + self.data_dim = data_dim + self.latent_dim = latent_dim + self.p = dropout + self.negative_slope = negative_slope + self.distribution = distribution + return + + def parameters(self): + return ( + *self.encoder.parameters(), + *self.decoder.parameters(), + *self.discriminator.parameters(), + ) + + def train(self): + self.encoder.train(True) + self.decoder.train(True) + self.discriminator.train(True) + return self + + def eval(self): + self.encoder.train(False) + self.decoder.train(False) + self.discriminator.train(False) + return self + + def encode(self, data): + return self.encoder(data) + + def decode(self, data): + return self.decoder(data) + + def __call__(self, data): + return self.decode(self.encode(data)) + + def save(self, base="./"): + torch.save(self.encoder.state_dict(), base + "weights_encoder.pt") + torch.save(self.decoder.state_dict(), base + "weights_decoder.pt") + torch.save(self.discriminator.state_dict(), base + "weights_discriminator.pt") + return self + + def load(self, base="./"): + self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt")) + self.encoder.eval() + self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt")) + self.decoder.eval() + self.discriminator.load_state_dict( + torch.load(base + "weights_discriminator.pt") + ) + self.discriminator.eval() + return self + + def to(self, device): + self.encoder.to(device) + self.decoder.to(device) + self.discriminator.to(device) + return self + + def cuda(self): + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + self.discriminator = self.discriminator.cuda() + return self + + def cpu(self): + self.encoder = self.encoder.cpu() + self.decoder = self.decoder.cpu() + self.discriminator = self.discriminator.cpu() + return self + + def init_optimizers(self, recon_lr=1e-4, adv_lr=5e-5): + self.optim_E_gen = torch.optim.Adam(self.encoder.parameters(), lr=adv_lr) + self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr) + self.optim_D_dec = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr) + self.optim_D_dis = torch.optim.Adam(self.discriminator.parameters(), lr=adv_lr) + return self + + def init_losses(self, recon_loss_fn=F.binary_cross_entropy): + self.recon_loss_fn = recon_loss_fn + return self + + def train_step(self, raw_data, scale=1.0): + data = to_var(raw_data.view(raw_data.size(0), -1)) + + self.encoder.zero_grad() + self.decoder.zero_grad() + self.discriminator.zero_grad() + + z = self.encoder(data) + X = self.decoder(z) + self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS) + self.recon_loss.backward() + self.optim_E_enc.step() + self.optim_D_dec.step() + + self.encoder.eval() + z_gaussian = to_var(self.distribution(data.size(0), self.latent_dim) * scale) + z_gaussian_fake = self.encoder(data) + y_gaussian = self.discriminator(z_gaussian) + y_gaussian_fake = self.discriminator(z_gaussian_fake) + self.D_loss = -torch.mean( + torch.log(y_gaussian + EPS) + torch.log(1 - y_gaussian_fake + EPS) + ) + self.D_loss.backward() + self.optim_D_dis.step() + + self.encoder.train() + z_gaussian = self.encoder(data) + y_gaussian = self.discriminator(z_gaussian) + self.G_loss = -torch.mean(torch.log(y_gaussian + EPS)) + self.G_loss.backward() + self.optim_E_gen.step() + return + + def losses(self): + try: + return self.recon_loss, self.D_loss, self.G_loss + except: + ... + return diff --git a/code/sunlab/suntorch/models/autoencoder.py b/code/sunlab/suntorch/models/autoencoder.py new file mode 100644 index 0000000..232f180 --- /dev/null +++ b/code/sunlab/suntorch/models/autoencoder.py @@ -0,0 +1,114 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from .encoder import Encoder +from .decoder import Decoder +from .common import * + + +class Autoencoder: + """# Autoencoder Model""" + + def __init__( + self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0 + ): + self.encoder = Encoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.decoder = Decoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.data_dim = data_dim + self.latent_dim = latent_dim + self.p = dropout + self.negative_slope = negative_slope + return + + def parameters(self): + return (*self.encoder.parameters(), *self.decoder.parameters()) + + def train(self): + self.encoder.train(True) + self.decoder.train(True) + return self + + def eval(self): + self.encoder.train(False) + self.decoder.train(False) + return self + + def encode(self, data): + return self.encoder(data) + + def decode(self, data): + return self.decoder(data) + + def __call__(self, data): + return self.decode(self.encode(data)) + + def save(self, base="./"): + torch.save(self.encoder.state_dict(), base + "weights_encoder.pt") + torch.save(self.decoder.state_dict(), base + "weights_decoder.pt") + return self + + def load(self, base="./"): + self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt")) + self.encoder.eval() + self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt")) + self.decoder.eval() + return self + + def to(self, device): + self.encoder.to(device) + self.decoder.to(device) + self.discriminator.to(device) + return self + + def cuda(self): + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + return self + + def cpu(self): + self.encoder = self.encoder.cpu() + self.decoder = self.decoder.cpu() + return self + + def init_optimizers(self, recon_lr=1e-4): + self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr) + self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr) + return self + + def init_losses(self, recon_loss_fn=F.binary_cross_entropy): + self.recon_loss_fn = recon_loss_fn + return self + + def train_step(self, raw_data): + data = to_var(raw_data.view(raw_data.size(0), -1)) + + self.encoder.zero_grad() + self.decoder.zero_grad() + z = self.encoder(data) + X = self.decoder(z) + self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS) + self.recon_loss.backward() + self.optim_E_enc.step() + self.optim_D.step() + return + + def losses(self): + try: + return self.recon_loss + except: + ... + return diff --git a/code/sunlab/suntorch/models/common.py b/code/sunlab/suntorch/models/common.py new file mode 100644 index 0000000..f10930e --- /dev/null +++ b/code/sunlab/suntorch/models/common.py @@ -0,0 +1,12 @@ +from torch.autograd import Variable + +EPS = 1e-15 + + +def to_var(x): + """# Convert to variable""" + import torch + + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) diff --git a/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py new file mode 100644 index 0000000..970f717 --- /dev/null +++ b/code/sunlab/suntorch/models/convolutional/variational/autoencoder.py @@ -0,0 +1,190 @@ +import torch +from torch import nn + + +class ConvolutionalVariationalAutoencoder(nn.Module): + def __init__(self, latent_dims, hidden_dims, image_shape, dropout=0.0): + super(ConvolutionalVariationalAutoencoder, self).__init__() + + self.latent_dims = latent_dims # Size of the latent space layer + self.hidden_dims = ( + hidden_dims # List of hidden layers number of filters/channels + ) + self.image_shape = image_shape # Input image shape + + self.last_channels = self.hidden_dims[-1] + self.in_channels = self.image_shape[0] + # Simple formula to get the number of neurons after the last convolution layer is flattened + self.flattened_channels = int( + self.last_channels + * (self.image_shape[1] / (2 ** len(self.hidden_dims))) ** 2 + ) + + # For each hidden layer we will create a Convolution Block + modules = [] + for h_dim in self.hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=h_dim, + kernel_size=3, + stride=2, + padding=1, + ), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU(), + nn.Dropout(p=dropout), + ) + ) + + self.in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + + # Here are our layers for our latent space distribution + self.fc_mu = nn.Linear(self.flattened_channels, latent_dims) + self.fc_var = nn.Linear(self.flattened_channels, latent_dims) + + # Decoder input layer + self.decoder_input = nn.Linear(latent_dims, self.flattened_channels) + + # For each Convolution Block created on the Encoder we will do a symmetric Decoder with the same Blocks, but using ConvTranspose + self.hidden_dims.reverse() + modules = [] + for h_dim in self.hidden_dims: + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels=self.in_channels, + out_channels=h_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU(), + nn.Dropout(p=dropout), + ) + ) + + self.in_channels = h_dim + + self.decoder = nn.Sequential(*modules) + + # The final layer the reconstructed image have the same dimensions as the input image + self.final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.image_shape[0], + kernel_size=3, + padding=1, + ), + nn.Sigmoid(), + ) + + def get_latent_dims(self): + + return self.latent_dims + + def encode(self, input): + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + # Split the result into mu and var componentsbof the latent Gaussian distribution + mu = self.fc_mu(result) + log_var = self.fc_var(result) + + return [mu, log_var] + + def decode(self, z): + """ + Maps the given latent codes onto the image space. + """ + result = self.decoder_input(z) + result = result.view( + -1, + self.last_channels, + int(self.image_shape[1] / (2 ** len(self.hidden_dims))), + int(self.image_shape[1] / (2 ** len(self.hidden_dims))), + ) + result = self.decoder(result) + result = self.final_layer(result) + + return result + + def reparameterize(self, mu, log_var): + """ + Reparameterization trick to sample from N(mu, var) from N(0,1). + """ + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + + return mu + eps * std + + def forward(self, input): + """ + Forward method which will encode and decode our image. + """ + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + + return [self.decode(z), input, mu, log_var, z] + + def loss_function(self, recons, input, mu, log_var): + """ + Computes VAE loss function + """ + recons_loss = nn.functional.binary_cross_entropy( + recons.reshape(recons.shape[0], -1), + input.reshape(input.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) + + loss = (recons_loss + kld_loss).mean(dim=0) + + return loss + + def sample(self, num_samples, device): + """ + Samples from the latent space and return the corresponding + image space map. + """ + z = torch.randn(num_samples, self.latent_dims) + z = z.to(device) + samples = self.decode(z) + + return samples + + def generate(self, x): + """ + Given an input image x, returns the reconstructed image + """ + return self.forward(x)[0] + + def interpolate(self, starting_inputs, ending_inputs, device, granularity=10): + """This function performs a linear interpolation in the latent space of the autoencoder + from starting inputs to ending inputs. It returns the interpolation trajectories. + """ + mu, log_var = self.encode(starting_inputs.to(device)) + starting_z = self.reparameterize(mu, log_var) + + mu, log_var = self.encode(ending_inputs.to(device)) + ending_z = self.reparameterize(mu, log_var) + + t = torch.linspace(0, 1, granularity).to(device) + + intep_line = torch.kron( + starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1) + ) + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1)) + + decoded_line = self.decode(intep_line).reshape( + (starting_inputs.shape[0], t.shape[0]) + (starting_inputs.shape[1:]) + ) + return decoded_line diff --git a/code/sunlab/suntorch/models/decoder.py b/code/sunlab/suntorch/models/decoder.py new file mode 100644 index 0000000..2eeb7a4 --- /dev/null +++ b/code/sunlab/suntorch/models/decoder.py @@ -0,0 +1,33 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import sigmoid + + +class Decoder(nn.Module): + """# Decoder Neural Network + X_dim: Output dimension shape + N: Inner neuronal layer size + z_dim: Input dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Decoder, self).__init__() + self.lin1 = nn.Linear(z_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3 = nn.Linear(N, X_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin3(x) + return sigmoid(x) diff --git a/code/sunlab/suntorch/models/discriminator.py b/code/sunlab/suntorch/models/discriminator.py new file mode 100644 index 0000000..9249095 --- /dev/null +++ b/code/sunlab/suntorch/models/discriminator.py @@ -0,0 +1,32 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import sigmoid + + +class Discriminator(nn.Module): + """# Discriminator Neural Network + N: Inner neuronal layer size + z_dim: Input dimension shape + """ + + def __init__(self, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Discriminator, self).__init__() + self.lin1 = nn.Linear(z_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3 = nn.Linear(N, 1) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin3(x) + return sigmoid(x) diff --git a/code/sunlab/suntorch/models/encoder.py b/code/sunlab/suntorch/models/encoder.py new file mode 100644 index 0000000..e6f88c7 --- /dev/null +++ b/code/sunlab/suntorch/models/encoder.py @@ -0,0 +1,32 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + """# Encoder Neural Network + X_dim: Input dimension shape + N: Inner neuronal layer size + z_dim: Output dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Encoder, self).__init__() + self.lin1 = nn.Linear(X_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3gauss = nn.Linear(N, z_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + xgauss = self.lin3gauss(x) + return xgauss diff --git a/code/sunlab/suntorch/models/variational/autoencoder.py b/code/sunlab/suntorch/models/variational/autoencoder.py new file mode 100644 index 0000000..e335704 --- /dev/null +++ b/code/sunlab/suntorch/models/variational/autoencoder.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from .encoder import Encoder +from .decoder import Decoder +from .common import * + + +class VariationalAutoencoder: + """# Variational Autoencoder Model""" + + def __init__( + self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0 + ): + self.encoder = Encoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.decoder = Decoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.data_dim = data_dim + self.latent_dim = latent_dim + self.p = dropout + self.negative_slope = negative_slope + return + + def parameters(self): + return (*self.encoder.parameters(), *self.decoder.parameters()) + + def train(self): + self.encoder.train(True) + self.decoder.train(True) + return self + + def eval(self): + self.encoder.train(False) + self.decoder.train(False) + return self + + def encode(self, data): + return self.encoder(data) + + def decode(self, data): + return self.decoder(data) + + def reparameterization(self, mean, var): + epsilon = torch.randn_like(var) + if torch.cuda.is_available(): + epsilon = epsilon.cuda() + z = mean + var * epsilon + return z + + def forward(self, x): + mean, log_var = self.encoder(x) + z = self.reparameterization(mean, torch.exp(0.5 * log_var)) + X = self.decoder(z) + return X, mean, log_var + + def __call__(self, data): + return self.forward(data) + + def save(self, base="./"): + torch.save(self.encoder.state_dict(), base + "weights_encoder.pt") + torch.save(self.decoder.state_dict(), base + "weights_decoder.pt") + return self + + def load(self, base="./"): + self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt")) + self.encoder.eval() + self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt")) + self.decoder.eval() + return self + + def to(self, device): + self.encoder.to(device) + self.decoder.to(device) + return self + + def cuda(self): + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + return self + + def cpu(self): + self.encoder = self.encoder.cpu() + self.decoder = self.decoder.cpu() + return self + + def init_optimizers(self, recon_lr=1e-4): + self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr) + self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr) + return self + + def init_losses(self, recon_loss_fn=F.binary_cross_entropy): + self.recon_loss_fn = recon_loss_fn + return self + + def train_step(self, raw_data): + data = to_var(raw_data.view(raw_data.size(0), -1)) + + self.encoder.zero_grad() + self.decoder.zero_grad() + X, _, _ = self.forward(data) + # mean, log_var = self.encoder(data) + # z = self.reparameterization(mean, torch.exp(0.5 * log_var)) + # X = self.decoder(z) + self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS) + self.recon_loss.backward() + self.optim_E_enc.step() + self.optim_D.step() + return + + def losses(self): + try: + return self.recon_loss + except: + ... + return diff --git a/code/sunlab/suntorch/models/variational/common.py b/code/sunlab/suntorch/models/variational/common.py new file mode 100644 index 0000000..f10930e --- /dev/null +++ b/code/sunlab/suntorch/models/variational/common.py @@ -0,0 +1,12 @@ +from torch.autograd import Variable + +EPS = 1e-15 + + +def to_var(x): + """# Convert to variable""" + import torch + + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) diff --git a/code/sunlab/suntorch/models/variational/decoder.py b/code/sunlab/suntorch/models/variational/decoder.py new file mode 100644 index 0000000..2eeb7a4 --- /dev/null +++ b/code/sunlab/suntorch/models/variational/decoder.py @@ -0,0 +1,33 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import sigmoid + + +class Decoder(nn.Module): + """# Decoder Neural Network + X_dim: Output dimension shape + N: Inner neuronal layer size + z_dim: Input dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Decoder, self).__init__() + self.lin1 = nn.Linear(z_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3 = nn.Linear(N, X_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin3(x) + return sigmoid(x) diff --git a/code/sunlab/suntorch/models/variational/encoder.py b/code/sunlab/suntorch/models/variational/encoder.py new file mode 100644 index 0000000..b08202f --- /dev/null +++ b/code/sunlab/suntorch/models/variational/encoder.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + """# Encoder Neural Network + X_dim: Input dimension shape + N: Inner neuronal layer size + z_dim: Output dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Encoder, self).__init__() + self.lin1 = nn.Linear(X_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3mu = nn.Linear(N, z_dim) + self.lin3sigma = nn.Linear(N, z_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + mu = self.lin3mu(x) + sigma = self.lin3sigma(x) + return mu, sigma |