diff options
Diffstat (limited to 'code/sunlab/suntorch/models/variational')
-rw-r--r-- | code/sunlab/suntorch/models/variational/autoencoder.py | 128 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/common.py | 12 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/decoder.py | 33 | ||||
-rw-r--r-- | code/sunlab/suntorch/models/variational/encoder.py | 34 |
4 files changed, 207 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/variational/autoencoder.py b/code/sunlab/suntorch/models/variational/autoencoder.py new file mode 100644 index 0000000..e335704 --- /dev/null +++ b/code/sunlab/suntorch/models/variational/autoencoder.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from .encoder import Encoder +from .decoder import Decoder +from .common import * + + +class VariationalAutoencoder: + """# Variational Autoencoder Model""" + + def __init__( + self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0 + ): + self.encoder = Encoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.decoder = Decoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.data_dim = data_dim + self.latent_dim = latent_dim + self.p = dropout + self.negative_slope = negative_slope + return + + def parameters(self): + return (*self.encoder.parameters(), *self.decoder.parameters()) + + def train(self): + self.encoder.train(True) + self.decoder.train(True) + return self + + def eval(self): + self.encoder.train(False) + self.decoder.train(False) + return self + + def encode(self, data): + return self.encoder(data) + + def decode(self, data): + return self.decoder(data) + + def reparameterization(self, mean, var): + epsilon = torch.randn_like(var) + if torch.cuda.is_available(): + epsilon = epsilon.cuda() + z = mean + var * epsilon + return z + + def forward(self, x): + mean, log_var = self.encoder(x) + z = self.reparameterization(mean, torch.exp(0.5 * log_var)) + X = self.decoder(z) + return X, mean, log_var + + def __call__(self, data): + return self.forward(data) + + def save(self, base="./"): + torch.save(self.encoder.state_dict(), base + "weights_encoder.pt") + torch.save(self.decoder.state_dict(), base + "weights_decoder.pt") + return self + + def load(self, base="./"): + self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt")) + self.encoder.eval() + self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt")) + self.decoder.eval() + return self + + def to(self, device): + self.encoder.to(device) + self.decoder.to(device) + return self + + def cuda(self): + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + return self + + def cpu(self): + self.encoder = self.encoder.cpu() + self.decoder = self.decoder.cpu() + return self + + def init_optimizers(self, recon_lr=1e-4): + self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr) + self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr) + return self + + def init_losses(self, recon_loss_fn=F.binary_cross_entropy): + self.recon_loss_fn = recon_loss_fn + return self + + def train_step(self, raw_data): + data = to_var(raw_data.view(raw_data.size(0), -1)) + + self.encoder.zero_grad() + self.decoder.zero_grad() + X, _, _ = self.forward(data) + # mean, log_var = self.encoder(data) + # z = self.reparameterization(mean, torch.exp(0.5 * log_var)) + # X = self.decoder(z) + self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS) + self.recon_loss.backward() + self.optim_E_enc.step() + self.optim_D.step() + return + + def losses(self): + try: + return self.recon_loss + except: + ... + return diff --git a/code/sunlab/suntorch/models/variational/common.py b/code/sunlab/suntorch/models/variational/common.py new file mode 100644 index 0000000..f10930e --- /dev/null +++ b/code/sunlab/suntorch/models/variational/common.py @@ -0,0 +1,12 @@ +from torch.autograd import Variable + +EPS = 1e-15 + + +def to_var(x): + """# Convert to variable""" + import torch + + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) diff --git a/code/sunlab/suntorch/models/variational/decoder.py b/code/sunlab/suntorch/models/variational/decoder.py new file mode 100644 index 0000000..2eeb7a4 --- /dev/null +++ b/code/sunlab/suntorch/models/variational/decoder.py @@ -0,0 +1,33 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import sigmoid + + +class Decoder(nn.Module): + """# Decoder Neural Network + X_dim: Output dimension shape + N: Inner neuronal layer size + z_dim: Input dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Decoder, self).__init__() + self.lin1 = nn.Linear(z_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3 = nn.Linear(N, X_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin3(x) + return sigmoid(x) diff --git a/code/sunlab/suntorch/models/variational/encoder.py b/code/sunlab/suntorch/models/variational/encoder.py new file mode 100644 index 0000000..b08202f --- /dev/null +++ b/code/sunlab/suntorch/models/variational/encoder.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + """# Encoder Neural Network + X_dim: Input dimension shape + N: Inner neuronal layer size + z_dim: Output dimension shape + """ + + def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3): + super(Encoder, self).__init__() + self.lin1 = nn.Linear(X_dim, N) + self.lin2 = nn.Linear(N, N) + self.lin3mu = nn.Linear(N, z_dim) + self.lin3sigma = nn.Linear(N, z_dim) + self.p = dropout + self.negative_slope = negative_slope + + def forward(self, x): + x = self.lin1(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + x = self.lin2(x) + if self.p > 0.0: + x = F.dropout(x, p=self.p, training=self.training) + x = F.leaky_relu(x, negative_slope=self.negative_slope) + + mu = self.lin3mu(x) + sigma = self.lin3sigma(x) + return mu, sigma |