aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/sunflow/models/utilities.py
blob: ab0c2a6fbfefdead337cbdb558c638f150841764 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Higher-level functions

from sunlab.common.distribution.adversarial_distribution import AdversarialDistribution
from sunlab.common.scaler.adversarial_scaler import AdversarialScaler
from sunlab.common.data.utilities import import_dataset
from .adversarial_autoencoder import AdversarialAutoencoder


def create_aae(
    dataset_file_name,
    model_directory,
    normalization_scaler: AdversarialScaler,
    distribution: AdversarialDistribution or None,
    magnification=10,
    latent_size=2,
):
    """# Create Adversarial Autoencoder

    - dataset_file_name: str = Path to the dataset file
    - model_directory: str = Path to save the model in
    - normalization_scaler: AdversarialScaler = Data normalization Scaler Model
    - distribution: AdversarialDistribution = Distribution for the Adversary
    - magnification: int = The Magnification of the Dataset"""
    dataset = import_dataset(dataset_file_name, magnification)
    model = AdversarialAutoencoder(
        model_directory, distribution, normalization_scaler
    ).init(dataset.dataset, latent_size=latent_size)
    return model


def create_aae_and_dataset(
    dataset_file_name,
    model_directory,
    normalization_scaler: AdversarialScaler,
    distribution: AdversarialDistribution or None,
    magnification=10,
    batch_size=1024,
    shuffle=True,
    val_split=0.1,
    latent_size=2,
):
    """# Create Adversarial Autoencoder and Load the Dataset

    - dataset_file_name: str = Path to the dataset file
    - model_directory: str = Path to save the model in
    - normalization_scaler: AdversarialScaler = Data normalization Scaler Model
    - distribution: AdversarialDistribution = Distribution for the Adversary
    - magnification: int = The Magnification of the Dataset"""
    model = create_aae(
        dataset_file_name,
        model_directory,
        normalization_scaler,
        distribution,
        magnification=magnification,
        latent_size=latent_size,
    )
    dataset = import_dataset(
        dataset_file_name,
        magnification,
        batch_size=batch_size,
        shuffle=shuffle,
        val_split=val_split,
        scaler=model.scaler,
    )
    return model, dataset


def load_aae(model_directory, normalization_scaler: AdversarialScaler):
    """# Load Adversarial Autoencoder

    - model_directory: str = Path to save the model in
    - normalization_scaler: AdversarialScaler = Data normalization Scaler Model
    """
    return AdversarialAutoencoder(model_directory, None, normalization_scaler).load()


def load_aae_and_dataset(
    dataset_file_name,
    model_directory,
    normalization_scaler: AdversarialScaler,
    magnification=10,
):
    """# Load Adversarial Autoencoder

    - dataset_file_name: str = Path to the dataset file
    - model_directory: str = Path to save the model in
    - normalization_scaler: AdversarialScaler = Data normalization Scaler Model
    - magnification: int = The Magnification of the Dataset"""
    model = load_aae(model_directory, normalization_scaler)
    dataset = import_dataset(
        dataset_file_name, magnification=magnification, scaler=model.scaler
    )
    return model, dataset