-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtrain.py
364 lines (333 loc) · 15.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: [email protected]
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import time
import argparse
import importlib
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel
import autotorch as at
import encoding
from encoding.nn import LabelSmoothing, NLLMultiLabelSmooth
from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler, torch_dist_sum)
try:
import apex
from apex import amp
except ModuleNotFoundError:
print('please install amp if using float16 training')
class Options():
def __init__(self):
# data settings
parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='imagenet',
help='training dataset (default: imagenet)')
parser.add_argument('--base-size', type=int, default=None,
help='base image size')
parser.add_argument('--crop-size', type=int, default=224,
help='crop image size')
parser.add_argument('--label-smoothing', type=float, default=0.0,
help='label-smoothing (default eta: 0.0)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup (default eta: 0.0)')
parser.add_argument('--auto-policy', type=str, default=None,
help='path to auto augment policy')
parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/.encoding/data'),
help='data location for training')
# model params
#parser.add_argument('--model', type=str, default='resnet50',
# help='network model type (default: densenet)')
parser.add_argument('--arch', type=str, default='regnet',
help='network type (default: regnet)')
parser.add_argument('--config-file', type=str, required=True,
help='network node config file')
parser.add_argument('--last-gamma', action='store_true', default=False,
help='whether to init gamma of the last BN layer in \
each bottleneck to 0 (default: False)')
# training params
parser.add_argument('--amp', action='store_true',
default=False, help='using amp')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
help='batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=120, metavar='N',
help='number of epochs to train (default: 600)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='the epoch number to start (default: 1)')
parser.add_argument('--workers', type=int, default=8,
metavar='N', help='dataloader threads')
# optimizer
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
parser.add_argument('--lr-scheduler', type=str, default='cos',
help='learning rate scheduler (default: cos)')
parser.add_argument('--warmup-epochs', type=int, default=0,
help='number of warmup epochs (default: 0)')
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4,
metavar ='M', help='SGD weight decay (default: 1e-4)')
parser.add_argument('--no-bn-wd', action='store_true',
default=False, help='no bias decay')
# seed
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# checking point
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default='default',
help='set the checkpoint name')
# distributed
parser.add_argument('--world-size', default=1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://localhost:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
# evaluation option
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating')
parser.add_argument('--export', type=str, default=None,
help='put the path to resuming file if needed')
self.parser = parser
def parse(self):
args = self.parser.parse_args()
return args
def main():
args = Options().parse()
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node * args.world_size
args.lr = args.lr * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
# global variable
best_pred = 0.0
acclist_train = []
acclist_val = []
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
args.rank = args.rank * ngpus_per_node + gpu
# model name for checkpoint
args.model = "{}-{}".format(args.arch, os.path.splitext(os.path.basename(args.config_file))[0])
if args.gpu == 0:
print('model:', args.model)
print('rank: {} / {}'.format(args.rank, args.world_size))
dist.init_process_group(backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
torch.cuda.set_device(args.gpu)
# init the args
global best_pred, acclist_train, acclist_val
if args.gpu == 0:
print(args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
cudnn.benchmark = True
# init dataloader
transform_train, transform_val = encoding.transforms.get_transform(
args.dataset, args.base_size, args.crop_size)
if args.auto_policy is not None:
print(f'Using auto_policy: {args.auto_policy}')
from augment import Augmentation
auto_policy = Augmentation(at.load(args.auto_policy))
transform_train.transforms.insert(0, auto_policy)
trainset = encoding.datasets.get_dataset(args.dataset, root=args.data_dir,
transform=transform_train, train=True, download=True)
valset = encoding.datasets.get_dataset(args.dataset, root=args.data_dir,
transform=transform_val, train=False, download=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=False)
val_loader = torch.utils.data.DataLoader(
valset, batch_size=args.test_batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
sampler=val_sampler)
# init the model
arch = importlib.import_module('arch.' + args.arch)
model = arch.config_network(args.config_file)
if args.gpu == 0:
print(model)
if args.mixup > 0:
train_loader = MixUpWrapper(args.mixup, 1000, train_loader, args.gpu)
criterion = NLLMultiLabelSmooth(args.label_smoothing)
elif args.label_smoothing > 0.0:
criterion = LabelSmoothing(args.label_smoothing)
else:
criterion = nn.CrossEntropyLoss()
model.cuda(args.gpu)
criterion.cuda(args.gpu)
# criterion and optimizer
if args.no_bn_wd:
parameters = model.named_parameters()
param_dict = {}
for k, v in parameters:
param_dict[k] = v
bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)]
rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)]
if args.gpu == 0:
print(" Weight decay NOT applied to BN parameters ")
print(f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}')
optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0 },
{'params': rest_params, 'weight_decay': args.wd}],
lr=args.lr,
momentum=args.momentum,
weight_decay=args.wd)
else:
optimizer = torch.optim.SGD(model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.wd)
if args.amp:
#optimizer = amp_handle.wrap_optimizer(optimizer)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
#from apex import amp
DDP = apex.parallel.DistributedDataParallel
model = DDP(model, delay_allreduce=True)
else:
DDP = DistributedDataParallel
model = DDP(model, device_ids=[args.gpu])
# check point
if args.resume is not None:
if os.path.isfile(args.resume):
if args.gpu == 0:
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch'] + 1 if args.start_epoch == 0 else args.start_epoch
best_pred = checkpoint['best_pred']
acclist_train = checkpoint['acclist_train']
acclist_val = checkpoint['acclist_val']
model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if args.amp:
amp.load_state_dict(checkpoint['amp'])
if args.gpu == 0:
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
raise RuntimeError ("=> no resume checkpoint found at '{}'".\
format(args.resume))
scheduler = LR_Scheduler(args.lr_scheduler,
base_lr=args.lr,
num_epochs=args.epochs,
iters_per_epoch=len(train_loader),
warmup_epochs=args.warmup_epochs)
def train(epoch):
train_sampler.set_epoch(epoch)
model.train()
losses = AverageMeter()
top1 = AverageMeter()
global best_pred, acclist_train
tic = time.time()
for batch_idx, (data, target) in enumerate(train_loader):
scheduler(optimizer, batch_idx, epoch, best_pred)
if not args.mixup:
data, target = data.cuda(args.gpu), target.cuda(args.gpu)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if not args.mixup:
acc1 = accuracy(output, target, topk=(1,))
top1.update(acc1[0], data.size(0))
losses.update(loss.item(), data.size(0))
if batch_idx % 100 == 0 and args.gpu == 0:
iter_per_sec = 100.0 / (time.time() - tic) if batch_idx != 0 else 1.0 / (time.time() - tic)
tic = time.time()
if args.mixup:
#print('Batch: %d| Loss: %.3f'%(batch_idx, losses.avg))
print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Train loss: {:.3f}'. \
format(epoch, batch_idx, iter_per_sec, losses.avg.item()))
else:
#print('Batch: %d| Loss: %.3f | Top1: %.3f'%(batch_idx, losses.avg, top1.avg))
print('Epoch: {}, Iter: {}, Speed: {:.3f} iter/sec, Top1: {:.3f}'. \
format(epoch, batch_idx, iter_per_sec, top1.avg.item()))
acclist_train += [top1.avg]
def validate(epoch):
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
global best_pred, acclist_train, acclist_val
is_best = False
for batch_idx, (data, target) in enumerate(val_loader):
data, target = data.cuda(args.gpu), target.cuda(args.gpu)
with torch.no_grad():
output = model(data)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], data.size(0))
top5.update(acc5[0], data.size(0))
# sum all
sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)
if args.eval:
if args.gpu == 0:
top1_acc = sum(sum1) / sum(cnt1)
top5_acc = sum(sum5) / sum(cnt5)
print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
return
if args.gpu == 0:
top1_acc = sum(sum1) / sum(cnt1)
top5_acc = sum(sum5) / sum(cnt5)
print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
# save checkpoint
acclist_val += [top1_acc]
if top1_acc > best_pred:
best_pred = top1_acc
is_best = True
state_dict = {
'epoch': epoch,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_pred': best_pred,
'acclist_train':acclist_train,
'acclist_val':acclist_val,
}
if args.amp:
state_dict['amp'] = amp.state_dict()
encoding.utils.save_checkpoint(state_dict, args=args, is_best=is_best)
if args.export:
if args.gpu == 0:
torch.save(model.module.state_dict(), args.export + '.pth')
return
if args.eval:
validate(args.start_epoch)
return
for epoch in range(args.start_epoch, args.epochs):
tic = time.time()
train(epoch)
if epoch % 10 == 0:# or epoch == args.epochs-1:
validate(epoch)
elapsed = time.time() - tic
if args.gpu == 0:
print(f'Epoch: {epoch}, Time cost: {elapsed}')
if args.gpu == 0:
encoding.utils.save_checkpoint({
'epoch': args.epochs-1,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_pred': best_pred,
'acclist_train':acclist_train,
'acclist_val':acclist_val,
}, args=args, is_best=False)
if __name__ == "__main__":
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
main()