-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_w2d.py
41 lines (32 loc) · 1.09 KB
/
run_w2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from itertools import product as it_product
import wandb
from gene.utils import fail_if_not_device, validate_json, load_config
from gene.core.distances import get_df
from gene.experiment import Experiment
if __name__ == "__main__":
fail_if_not_device()
policy_layer_dimensions = (
[16, 16],
[32, 32],
[64, 64],
[128, 128],
[32, 32, 32, 32],
)
policy_architecture = ("relu_tanh_linear", "tanh_linear")
experiment_settings = list(
it_product(
policy_architecture,
policy_layer_dimensions,
)
)
for arch, l_dimensions in experiment_settings:
config = load_config("config/brax.json")
assert config["task"]["environnment"] == "walker2d"
config["net"]["architecture"] = arch
config["net"]["layer_dimensions"] = [17] + l_dimensions + [6]
validate_json(config)
with wandb.init(
project="devnull w2d", name="direct-w2d", config=config
) as wandb_run:
df = get_df(config)()
Experiment(config, wandb_run, df).run()