forked from gmum/ProtoPool
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
71 lines (53 loc) · 2.2 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from abc import ABC
from typing import Any, Tuple
import pandas as pd
import torch
from torch.utils.data import Dataset
class Connect4(Dataset, ABC):
convert_data = dict(zip(*pd.factorize(['b', 'x', 'o'], sort=True)))
classes = dict(zip(*pd.factorize(['win', 'loss', 'draw'], sort=True)))
def __init__(self, root: str, train: bool = True) -> None:
super(Connect4, self).__init__()
self.num_classes = len(self.classes)
self.train = train
data = pd.read_csv(f'{root}/Connect-4/connect-4.data', sep=',', header=None)
data = data.apply(lambda col: pd.factorize(col, sort=True)[0]).to_numpy()
size = data.shape[0]
cut = int(0.9 * size)
if self.train:
self.data = data[:cut, :-1]
self.targets = data[:cut, -1]
else:
self.data = data[cut:, :-1]
self.targets = data[cut:, -1]
def __getitem__(self, index: int) -> Tuple[Any, Any]:
vector = torch.tensor(self.data[index])
target = torch.tensor(self.targets[index])
return vector.float(), target
def __len__(self):
return self.targets.size
class Letter(Dataset, ABC):
classes = dict(zip(*pd.factorize(
['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N',
'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'],
sort=True)))
def __init__(self, root: str, train: bool = True) -> None:
super(Letter, self).__init__()
self.num_classes = len(self.classes)
self.train = train
data = pd.read_csv(f'{root}/Letter/letter.data', sep=',', header=None)
data = data.apply(lambda col: pd.factorize(col, sort=True)[0]).to_numpy()
size = data.shape[0]
cut = int(0.9 * size)
if self.train:
self.data = data[:cut, :-1]
self.targets = data[:cut, -1]
else:
self.data = data[cut:, :-1]
self.targets = data[cut:, -1]
def __getitem__(self, index: int) -> Tuple[Any, Any]:
vector = torch.tensor(self.data[index])
target = torch.tensor(self.targets[index])
return vector.float(), target
def __len__(self):
return self.targets.size