-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathverify.py
139 lines (121 loc) · 5.6 KB
/
verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: [email protected]
## Copyright (c) 2020
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import argparse
import importlib
from tqdm import tqdm
import torch
import torch.nn as nn
import encoding
from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler)
class Options():
def __init__(self):
# data settings
parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='imagenet',
help='training dataset (default: imagenet)')
parser.add_argument('--base-size', type=int, default=None,
help='base image size')
parser.add_argument('--crop-size', type=int, default=224,
help='crop image size')
# model params
#parser.add_argument('--model', type=str, default='densenet',
# help='network model type (default: densenet)')
parser.add_argument('--arch', type=str, default='regnet',
help='network type (default: regnet)')
parser.add_argument('--config-file', type=str, required=True,
help='network node config file')
parser.add_argument('--rectify', action='store_true',
default=False, help='rectify convolution')
parser.add_argument('--rectify-avg', action='store_true',
default=False, help='rectify convolution')
# training hyper params
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='batch size for training (default: 128)')
parser.add_argument('--workers', type=int, default=32,
metavar='N', help='dataloader threads')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true',
default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/.encoding/data'),
help='data location for training')
# checking point
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--verify', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--export', type=str, default=None,
help='put the path to resuming file if needed')
self.parser = parser
def parse(self):
args = self.parser.parse_args()
return args
def main():
# init the args
args = Options().parse()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# init dataloader
_, transform_val = encoding.transforms.get_transform(args.dataset, args.base_size, args.crop_size)
valset = encoding.datasets.get_dataset(args.dataset, root=args.data_dir,
transform=transform_val, train=False, download=True)
val_loader = torch.utils.data.DataLoader(
valset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True if args.cuda else False)
if args.rectify:
model_kwargs['rectified_conv'] = True
model_kwargs['rectify_avg'] = args.rectify_avg
arch = importlib.import_module('arch.' + args.arch)
model = arch.config_network(args.config_file)
print(model)
if args.cuda:
model.cuda()
# Please use CUDA_VISIBLE_DEVICES to control the number of gpus
model = nn.DataParallel(model)
# checkpoint
if args.verify:
if os.path.isfile(args.verify):
print("=> loading checkpoint '{}'".format(args.verify))
model.module.load_state_dict(torch.load(args.verify))
else:
raise RuntimeError ("=> no verify checkpoint found at '{}'".\
format(args.verify))
elif args.resume is not None:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
model.module.load_state_dict(checkpoint['state_dict'])
else:
raise RuntimeError ("=> no resume checkpoint found at '{}'".\
format(args.resume))
if args.export:
torch.save(model.module.state_dict(), args.export + '.pth')
return
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
is_best = False
tbar = tqdm(val_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
if args.cuda:
data, target = data.cuda(), target.cuda()
with torch.no_grad():
output = model(data)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], data.size(0))
top5.update(acc5[0], data.size(0))
tbar.set_description('Top1: %.3f | Top5: %.3f'%(top1.avg, top5.avg))
print('Top1 Acc: %.3f | Top5 Acc: %.3f '%(top1.avg, top5.avg))
if __name__ == "__main__":
main()