aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/suntorch/models/autoencoder.py
blob: 232f1803b32971190045259706c0158797c62a80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from .encoder import Encoder
from .decoder import Decoder
from .common import *


class Autoencoder:
    """# Autoencoder Model"""

    def __init__(
        self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0
    ):
        self.encoder = Encoder(
            data_dim,
            enc_dec_size,
            latent_dim,
            negative_slope=negative_slope,
            dropout=dropout,
        )
        self.decoder = Decoder(
            data_dim,
            enc_dec_size,
            latent_dim,
            negative_slope=negative_slope,
            dropout=dropout,
        )
        self.data_dim = data_dim
        self.latent_dim = latent_dim
        self.p = dropout
        self.negative_slope = negative_slope
        return

    def parameters(self):
        return (*self.encoder.parameters(), *self.decoder.parameters())

    def train(self):
        self.encoder.train(True)
        self.decoder.train(True)
        return self

    def eval(self):
        self.encoder.train(False)
        self.decoder.train(False)
        return self

    def encode(self, data):
        return self.encoder(data)

    def decode(self, data):
        return self.decoder(data)

    def __call__(self, data):
        return self.decode(self.encode(data))

    def save(self, base="./"):
        torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
        torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
        return self

    def load(self, base="./"):
        self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
        self.encoder.eval()
        self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
        self.decoder.eval()
        return self

    def to(self, device):
        self.encoder.to(device)
        self.decoder.to(device)
        self.discriminator.to(device)
        return self

    def cuda(self):
        if torch.cuda.is_available():
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
        return self

    def cpu(self):
        self.encoder = self.encoder.cpu()
        self.decoder = self.decoder.cpu()
        return self

    def init_optimizers(self, recon_lr=1e-4):
        self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
        self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
        return self

    def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
        self.recon_loss_fn = recon_loss_fn
        return self

    def train_step(self, raw_data):
        data = to_var(raw_data.view(raw_data.size(0), -1))

        self.encoder.zero_grad()
        self.decoder.zero_grad()
        z = self.encoder(data)
        X = self.decoder(z)
        self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
        self.recon_loss.backward()
        self.optim_E_enc.step()
        self.optim_D.step()
        return

    def losses(self):
        try:
            return self.recon_loss
        except:
            ...
        return