forked from echen4096/CT-Disease-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_consts.py
69 lines (54 loc) · 1.38 KB
/
train_consts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
import os
import torch
EXPERIMENT_NAME = "test"
CUDA_DEVICE = 1
SEED = 43
TRAIN_SERVICE_PATH = os.environ.get('TRAIN_SERVICE_PATH')
CACHE_DATASET = True
if TRAIN_SERVICE_PATH is None:
TRAIN_SERVICE_PATH = os.path.abspath(os.getcwd())
DATA_FILES_PATH = os.environ.get('DATA_FILES_PATH', os.path.join(TRAIN_SERVICE_PATH, 'data_files'))
CHECKPOINT_DIR = "checkpoints"
NUM_WORKERS = 8
DEVICE = torch.device(f'cuda:{CUDA_DEVICE}' if torch.cuda.is_available() else 'cpu')
DATA_PARALLEL = False
TWO_VIEWS = True
SIZE = 256
AGE_NORM = 100
RAF_NORM = 10
IMAGE_PATH_COL = 'img_path'
ORIENT_COL = 'img_orient'
CLASS_TASK = 'class'
REG_TASK = 'reg'
HCC_GROUPS = ['HCC18', 'HCC22', 'HCC85', 'HCC96', 'HCC108', 'HCC111']
CONDITIONS = {
# 'GENDER': CLASS_TASK,
**{hcc: CLASS_TASK for hcc in HCC_GROUPS},
'AGE': REG_TASK,
'RAF': REG_TASK,
'bmi': REG_TASK,
# 'sdi': REG_TASK,
'a1c': REG_TASK,
# 'RACE_WHITE': CLASS_TASK,
# 'RACE_BLACK': CLASS_TASK,
# 'RACE_ASIAN': CLASS_TASK,
# 'LANG_ENG': CLASS_TASK
}
COND_WEIGHTS = {
# 'bmi': 0.3,
# 'a1c': 0.3,
# 'sdi': 0.3,
# 'RACE_WHITE': 0.25,
# 'RACE_BLACK': 0.25,
# 'RACE_ASIAN': 0.25,
# 'LANG_ENG': 0.25,
}
CONDITIONS_FOR_METRIC_AGG = HCC_GROUPS
THRESHOLD = 0.5
PRETRAIN = "/storage1/data/model_best.pth"
EPOCHS = 100
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 64
SPLIT_DATE = "2021.01.01"
"""