-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathspoco_predict.py
68 lines (53 loc) · 2.65 KB
/
spoco_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import os
import torch
from torch import nn
from spoco.datasets.utils import create_test_loader
from spoco.model import create_model
from spoco.predictor import EmbeddingsPredictor
from spoco.utils import SUPPORTED_DATASETS, load_checkpoint
parser = argparse.ArgumentParser(description='SPOCO predict')
# dataset config
parser.add_argument('--spoco', action='store_true', default=False, help="Indicate SPOCO prediction to the loaders")
parser.add_argument('--ds-name', type=str, default='cvppp', choices=SUPPORTED_DATASETS,
help=f'Name of the dataset from: {SUPPORTED_DATASETS}')
parser.add_argument('--ds-path', type=str, required=True, help='Path to the dataset root directory')
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--output-dir', type=str, default='.', help='Directory where prediction are to be saved')
# model config
parser.add_argument('--model-name', type=str, default="UNet2D", help="UNet2D or UNet3D")
parser.add_argument('--model-path', type=str, required=True, help="Path to the model's checkpoint")
parser.add_argument('--model-in-channels', type=int, default=3)
parser.add_argument('--model-out-channels', type=int, default=16, help="Embedding space dimension")
parser.add_argument('--model-feature-maps', type=int, nargs="+", default=[16, 32, 64, 128, 256, 512],
help="Number of features at each level on the encoder path")
parser.add_argument('--model-layer-order', type=str, default="bcr",
help="Determines the order of operations for SingleConv layer; 'bcr' means Batchnorm+Conv+ReLU")
def main():
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError('Only GPU mode is supported')
# load model from checkpoint
model = create_model(args)
# use DataParallel
model = nn.DataParallel(model)
model.cuda()
print(f'Using {torch.cuda.device_count()} GPUs for prediction')
if torch.cuda.device_count() > 1:
args.batch_size = args.batch_size * torch.cuda.device_count()
print(f'Loading model from {args.model_path}')
load_checkpoint(args.model_path, model)
# create output dir if necessary
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
print(f'Saving predictions to: {output_dir}')
# create test loader
test_loader = create_test_loader(args)
# crete predictor
predictor = EmbeddingsPredictor(model, test_loader, output_dir, args.spoco)
print(f'Running inference on {len(test_loader)} batches')
# run inference
predictor.predict()
if __name__ == '__main__':
main()