-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbenchmark.py
122 lines (90 loc) · 3.58 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import csv
import multiprocessing as mp
from os import listdir, path
from itertools import repeat
from dataclasses import dataclass, fields
import pandas as pd
from tartarus import pce, tadf, docking, reactivity
@dataclass
class ResultBase():
smile: str
@dataclass
class PCEResult(ResultBase):
pce_pcbm_sas: float
pce_pcdtbt_sas: float
@dataclass
class TADFResult(ResultBase):
st: float
osc: float
combined: float
@dataclass
class DockingResult(ResultBase):
score_1syh: float
score_6y2f: float
score_4lde: float
@dataclass
class ReactivityResult(ResultBase):
Ea: float
Er: float
sum_Ea_Er: float
diff_Ea_Er: float
class BenchmarkResults():
"""Benchmark results"""
def __init__(self, mode, results):
self.mode = mode
self.results = results
def save(self, output_filename):
with open(f'/data/{output_filename}', 'w', newline='') as output_file:
wr = csv.writer(output_file)
for idx, result in enumerate(self.results):
if idx == 0:
wr.writerow([field.name for field in fields(result)])
wr.writerow([getattr(result, field.name) for field in fields(result)])
def benchmark_smile(smile, mode, verbose):
"""Benchmark a single smile
Args:
smile (str): SMILE string
mode (str): Benchmark mode (pce, tadf, docking, reactivity)
"""
if mode == 'pce':
result = PCEResult(smile, *pce.get_properties(smile, verbose=verbose))
elif mode == 'tadf':
result = TADFResult(smile, *tadf.get_properties(smile, verbose=verbose))
elif mode == 'docking':
result = DockingResult(smile, docking.perform_calc_single(smile, '1syh', docking_program='qvina'), docking.perform_calc_single(smile, '6y2f', docking_program='qvina'), docking.perform_calc_single(smile, '4lde', docking_program='qvina'))
elif mode == 'reactivity':
result = ReactivityResult(smile, *reactivity.get_properties(smile, verbose=verbose))
else:
raise ValueError('Invalid mode')
return result
def benchmark_smiles(smiles, mode, parallel=True, verbose=False):
"""Benchmark a list of smiles
Args:
smiles (list): List of SMILE strings
mode (str): Benchmark mode (pce, tadf, docking, reactivity)
parallel (bool): Run in parallel
"""
if parallel:
n_procs = mp.cpu_count()
with mp.Pool(n_procs) as p:
results = p.starmap(benchmark_smile, zip(smiles, repeat(mode), repeat(verbose)))
else:
results = [benchmark_smile(smile, mode, verbose) for smile in smiles]
return BenchmarkResults(mode, results)
def get_args():
"""Get command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument('--input_filename', type=str, required=True, help='Input Filename')
parser.add_argument('--output_filename', type=str, default="output.csv", help='Output Filename')
parser.add_argument('--mode', type=str, required=True, help='Benchmark mode (pce, tadf, docking, reactivity)')
parser.add_argument('--verbose', action='store_true', help='Verbose mode (default: False)')
parser.add_argument('--parallel', action='store_true', help='Parallel evaluation (default: False)')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
df = pd.read_csv(f'/data/{args.input_filename}')
smiles = df['smiles'].tolist()
results = benchmark_smiles(smiles, args.mode, verbose=args.verbose, parallel=args.parallel)
results.save(args.output_filename)