aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/data/shape_dataset.py
blob: 5a687361ee25e55c731ed3889f23c450635da372 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from .dataset import Dataset


class ShapeDataset(Dataset):
    """# Shape Dataset"""

    def __init__(
        self,
        dataset_filename,
        data_columns=[
            "Area",
            "MjrAxisLength",
            "MnrAxisLength",
            "Eccentricity",
            "ConvexArea",
            "EquivDiameter",
            "Solidity",
            "Extent",
            "Perimeter",
            "ConvexPerim",
            "FibLen",
            "InscribeR",
            "BlebLen",
        ],
        label_columns=["Class"],
        batch_size=None,
        shuffle=False,
        val_split=0.0,
        scaler=None,
        sort_columns=None,
        random_seed=4332,
        pre_scale=10,
        **kwargs
    ):
        """# Initialize Dataset
        self.dataset = dataset (N, ...)
        self.labels = labels (N, ...)

        Optional Arguments:
        - prescale_function: The function that takes the ratio and transforms
        the dataset by multiplying the prescale_function output
        - sort_columns: The columns to sort the data by initially
        - equal_split: If the classifications should be equally split in
        training"""
        super().__init__(
            dataset_filename,
            data_columns=data_columns,
            label_columns=label_columns,
            batch_size=batch_size,
            shuffle=shuffle,
            val_split=val_split,
            scaler=scaler,
            sort_columns=sort_columns,
            random_seed=random_seed,
            pre_scale=pre_scale,
            **kwargs
        )