-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
109 lines (95 loc) · 3.39 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# ------------------------------------------------------------------------------
# @file: run.py
# @brief: This script is used for running trajectory prediction trainers.
# Example usage:
# python run.py \
# --exp path/to/config.json \
# --run [ trainval | train | eval | test ]
# ------------------------------------------------------------------------------
import argparse
import json
import os
def run(
exp: str, run: str, ckpt_num: int, ckpt_path: str, resume: bool, best: bool,
visualize: bool) -> None:
assert os.path.exists(exp), f"File {exp} does not exist!"
# load the configuration files
exp_file = open(exp)
exp = json.load(exp_file)
config_file = open(exp["base_config"])
config = json.load(config_file)
config.update(exp)
config["log_file"] = f"{run}.log"
trainer_type = config['trainer']
if ckpt_num:
config['load_ckpt'] = True
config['ckpt_name'] = f"ckpt_{ckpt_num}.pth"
if ckpt_path:
config['load_ckpts_from_path'] = True
config['ckpt_path'] = ckpt_path
config['visualization']['visualize'] = visualize
# choose trainer
if trainer_type == "vrnn":
from sprnn.trajpred_trainers.vrnn import VRNNTrainer as Trainer
elif trainer_type == "patternn":
from sprnn.trajpred_trainers.patternn import PatteRNNTrainer as Trainer
elif trainer_type == "socpatternn-mlp" or trainer_type == "socpatternn-mha":
from sprnn.trajpred_trainers.socpatternn import (
SocialPatteRNNTrainer as Trainer)
else:
raise NotImplementedError(f"Trainer {trainer_type} not supported!")
trainer = Trainer(config=config)
if run == "trainval":
trainer.train(do_eval=True, resume=resume)
elif run == "train":
trainer.train(resume=resume)
elif run == "eval":
trainer.eval(do_best=best)
elif run == "test":
trainer.eval(do_eval=False, do_best=best)
# this is just to have a copy of the original config
config['load_ckpt'] = False
config['load_ckpts_from_path'] = False
trainer.save_config(config, filename=f'config_{run}.json', log_dump=False)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
'--exp',
default='./config/test.json',
type=str,
help='path to experiment configuration file')
parser.add_argument(
'--run',
default='trainval',
type=str,
choices=['trainval', 'train', 'eval', 'test'],
help='type of experiment [trainval | train | eval | test]')
parser.add_argument(
'--ckpt-num',
required=False,
type=int,
help='checkpoint number to evaluate')
parser.add_argument(
'--ckpt-path',
required=False,
type=str,
help='path to checkpoint to run')
parser.add_argument(
'--resume',
required=False,
action='store_true',
help='resume training process')
parser.add_argument(
'--best',
required=False,
action='store_true',
help='enable visualizations')
parser.add_argument(
'--visualize',
required=False,
action='store_true',
help='enable visualizations')
args = parser.parse_args()
run(**vars(args))
if __name__ == "__main__":
main()