-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
30 lines (22 loc) · 938 Bytes
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import numpy
import torch
from torch.utils import data
class GTZANDataset(data.dataset):
def __init__(self, data_path:str, split: str, num_samples: int, num_chunks: int, is_augumentation: bool):
self.data_path = data_path if data_path else '/home/parker/data/Data'
self.split = split
self.num_samples = num_samples
self.num_chunks = num_chunks
self.genres = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
self.song_list = self._get_song_list()
def _get_song_list(self):
# Train model 30 sec music
feature_file_name = os.path.join(self.data_path, 'features_30_sec.csv ')
with open(feature_file_name) as f:
lines = f.readlines()
return [line.strip() for line in lines]
def __getitem__():
pass
def __len__():
pass