-
Notifications
You must be signed in to change notification settings - Fork 116
/
Copy pathclassifier_utils.py
28 lines (24 loc) · 1.06 KB
/
classifier_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from fastai.learner import *
from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *
from torchtext import vocab, data
from torchtext.datasets import language_modeling
class ComposerDataset(torchtext.data.Dataset):
def __init__(self, path, text_field, label_field, **kwargs):
fields = [('text', text_field), ('label', label_field)]
examples = []
for label in os.listdir(path):
for fname in glob(os.path.join(path, label, '*.txt')):
with open(fname, 'r') as f: text = f.readline()
examples.append(data.Example.fromlist([text, label], fields))
super().__init__(examples, fields, **kwargs)
@staticmethod
def sort_key(ex): return len(ex.text)
@classmethod
def splits(cls, text_field, label_field, root='.data',
train='train', test='test', **kwargs):
return super().splits(
root, text_field=text_field, label_field=label_field,
train=train, validation=None, test=test, **kwargs)