From 73093976b0772f75128d38c89174869b7ce4271b Mon Sep 17 00:00:00 2001 From: Mikael Dallaire Cote <110583667+0mdc@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:55:35 -0500 Subject: [PATCH 1/3] Move BaselinesController to its own file. --- .../controllers/baselines_controller.py | 217 ++++++++++++++++-- .../environment/controllers/controller_abc.py | 191 --------------- 2 files changed, 197 insertions(+), 211 deletions(-) diff --git a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py index 5f740fb5c3..1b911af9e5 100644 --- a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py +++ b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py @@ -4,30 +4,41 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING - -import gym.spaces as spaces - -import habitat.gym.gym_wrapper as gym_wrapper -from habitat_baselines.common.env_spec import EnvironmentSpec -from habitat_baselines.common.obs_transformers import get_active_obs_transforms -from habitat_baselines.rl.multi_agent.multi_agent_access_mgr import ( - MultiAgentAccessMgr, -) -from habitat_baselines.rl.multi_agent.utils import ( - update_dict_with_agent_prefix, -) -from habitat_baselines.rl.ppo.single_agent_access_mgr import ( - SingleAgentAccessMgr, -) -from habitat_hitl.environment.controllers.controller_abc import ( - BaselinesController, -) +from abc import abstractmethod +from typing import TYPE_CHECKING, Dict, List, Tuple if TYPE_CHECKING: + import gym.spaces as spaces + import numpy as np + import torch from omegaconf import DictConfig + import habitat + import habitat.gym.gym_wrapper as gym_wrapper from habitat.core.environments import GymHabitatEnv + from habitat_baselines.common.env_spec import EnvironmentSpec + from habitat_baselines.common.obs_transformers import ( + ObservationTransformer, + apply_obs_transforms_batch, + apply_obs_transforms_obs_space, + get_active_obs_transforms, + ) + from habitat_baselines.rl.multi_agent.multi_agent_access_mgr import ( + MultiAgentAccessMgr, + ) + from habitat_baselines.rl.multi_agent.utils import ( + update_dict_with_agent_prefix, + ) + from habitat_baselines.rl.ppo.agent_access_mgr import AgentAccessMgr + from habitat_baselines.rl.ppo.single_agent_access_mgr import ( + SingleAgentAccessMgr, + ) + from habitat_baselines.utils.common import ( + batch_obs, + get_action_space_info, + is_continuous_action_space, + ) + from habitat_hitl.environment.controllers.controller_abc import Controller def clean_dict(d, remove_prefix): @@ -44,6 +55,172 @@ def clean_dict(d, remove_prefix): return spaces.Dict(ret_d) +class BaselinesController(Controller): + """Abstract controller for baselines agents.""" + + def __init__( + self, + is_multi_agent: bool, + config: "DictConfig", + gym_habitat_env: "GymHabitatEnv", + ): + super().__init__(is_multi_agent) + self._config: DictConfig = config + + self._gym_habitat_env: GymHabitatEnv = gym_habitat_env + self._habitat_env: habitat.Env = gym_habitat_env.unwrapped.habitat_env + self._num_envs: int = 1 + + self.device: torch.device = ( + torch.device("cuda", config.habitat_baselines.torch_gpu_id) + if torch.cuda.is_available() + else torch.device("cpu") + ) + + # create env spec + self._env_spec: EnvironmentSpec = self._create_env_spec() + + # create observations transforms + self._obs_transforms: List[ + "ObservationTransformer" + ] = self._get_active_obs_transforms() + + # apply observations transforms + self._env_spec.observation_space = apply_obs_transforms_obs_space( + self._env_spec.observation_space, self._obs_transforms + ) + + # create agent + self._agent: "AgentAccessMgr" = self._create_agent() + if ( + self._agent.actor_critic.should_load_agent_state + and self._config.habitat_baselines.eval.should_load_ckpt + ): + self._load_agent_checkpoint() + + self._agent.eval() + + self._action_shape: Tuple[int] + self._discrete_actions: bool + self._action_shape, self._discrete_actions = get_action_space_info( + self._agent.actor_critic.policy_action_space + ) + + hidden_state_lens = self._agent.actor_critic.hidden_state_shape_lens + action_space_lens = ( + self._agent.actor_critic.policy_action_space_shape_lens + ) + + self._space_lengths: Dict = {} + n_agents = len(self._config.habitat.simulator.agents) + if n_agents > 1: + self._space_lengths = { + "index_len_recurrent_hidden_states": hidden_state_lens, + "index_len_prev_actions": action_space_lens, + } + + # these attributes are used for inference + # and will be set in on_environment_reset + self._test_recurrent_hidden_states = None + self._prev_actions = None + self._not_done_masks = None + + @abstractmethod + def _create_env_spec(self): + pass + + @abstractmethod + def _get_active_obs_transforms(self): + pass + + @abstractmethod + def _create_agent(self): + pass + + @abstractmethod + def _load_agent_state_dict(self, checkpoint): + pass + + def _load_agent_checkpoint(self): + checkpoint = torch.load( + self._config.habitat_baselines.eval_ckpt_path_dir, + map_location="cpu", + ) + self._load_agent_state_dict(checkpoint) + + def _batch_and_apply_transforms(self, obs): + batch = batch_obs(obs, device=self.device) + batch = apply_obs_transforms_batch(batch, self._obs_transforms) + + return batch + + def on_environment_reset(self): + self._test_recurrent_hidden_states = torch.zeros( + ( + self._num_envs, + *self._agent.actor_critic.hidden_state_shape, + ), + device=self.device, + ) + + self._prev_actions = torch.zeros( + self._num_envs, + *self._action_shape, + device=self.device, + dtype=torch.long if self._discrete_actions else torch.float, + ) + + self._not_done_masks = torch.zeros( + ( + self._num_envs, + *self._agent.masks_shape, + ), + device=self.device, + dtype=torch.bool, + ) + + def act(self, obs, env): + batch = self._batch_and_apply_transforms([obs]) + + with torch.no_grad(): + action_data = self._agent.actor_critic.act( + batch, + self._test_recurrent_hidden_states, + self._prev_actions, + self._not_done_masks, + deterministic=False, + **self._space_lengths, + ) + if action_data.should_inserts is None: + self._test_recurrent_hidden_states = ( + action_data.rnn_hidden_states + ) + self._prev_actions.copy_(action_data.actions) # type: ignore + else: + self._agent.actor_critic.update_hidden_state( + self._test_recurrent_hidden_states, + self._prev_actions, + action_data, + ) + + assert len(action_data.env_actions) == 1 + if is_continuous_action_space(self._env_spec.action_space): + # Clipping actions to the specified limits + action = np.clip( + action_data.env_actions[0].cpu().numpy(), + self._env_spec.action_space.low, + self._env_spec.action_space.high, + ) + else: + action = action_data.env_actions[0].cpu().item() + + # _not_done_masks serves as en indicator of whether the episode is done + # it is reset to False in on_environment_reset + self._not_done_masks.fill_(True) # type: ignore [attr-defined] + + return action + + class SingleAgentBaselinesController(BaselinesController): """Controller for single baseline agent.""" @@ -72,7 +249,7 @@ def __init__( ) def _create_env_spec(self): - # udjust the observation and action space to be agent specific (remove other agents) + # djust the observation and action space to be agent specific (remove other agents) original_action_space = clean_dict( self._gym_habitat_env.original_action_space, self._agent_k ) diff --git a/habitat-hitl/habitat_hitl/environment/controllers/controller_abc.py b/habitat-hitl/habitat_hitl/environment/controllers/controller_abc.py index 7e1d5acf91..1d46b5438e 100644 --- a/habitat-hitl/habitat_hitl/environment/controllers/controller_abc.py +++ b/habitat-hitl/habitat_hitl/environment/controllers/controller_abc.py @@ -5,31 +5,6 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Tuple - -import numpy as np -import torch - -import habitat -from habitat_baselines.common.obs_transformers import ( - apply_obs_transforms_batch, - apply_obs_transforms_obs_space, -) -from habitat_baselines.utils.common import ( - batch_obs, - get_action_space_info, - is_continuous_action_space, -) - -if TYPE_CHECKING: - from omegaconf import DictConfig - - from habitat.core.environments import GymHabitatEnv - from habitat_baselines.common.env_spec import EnvironmentSpec - from habitat_baselines.common.obs_transformers import ( - ObservationTransformer, - ) - from habitat_baselines.rl.ppo.agent_access_mgr import AgentAccessMgr class Controller(ABC): @@ -53,169 +28,3 @@ def __init__(self, agent_idx, is_multi_agent, gui_input): super().__init__(is_multi_agent) self._agent_idx = agent_idx self._gui_input = gui_input - - -class BaselinesController(Controller): - """Abstract controller for baselines agents.""" - - def __init__( - self, - is_multi_agent: bool, - config: "DictConfig", - gym_habitat_env: "GymHabitatEnv", - ): - super().__init__(is_multi_agent) - self._config: DictConfig = config - - self._gym_habitat_env: GymHabitatEnv = gym_habitat_env - self._habitat_env: habitat.Env = gym_habitat_env.unwrapped.habitat_env - self._num_envs: int = 1 - - self.device: torch.device = ( - torch.device("cuda", config.habitat_baselines.torch_gpu_id) - if torch.cuda.is_available() - else torch.device("cpu") - ) - - # create env spec - self._env_spec: EnvironmentSpec = self._create_env_spec() - - # create observations transforms - self._obs_transforms: List[ - "ObservationTransformer" - ] = self._get_active_obs_transforms() - - # apply observations transforms - self._env_spec.observation_space = apply_obs_transforms_obs_space( - self._env_spec.observation_space, self._obs_transforms - ) - - # create agent - self._agent: "AgentAccessMgr" = self._create_agent() - if ( - self._agent.actor_critic.should_load_agent_state - and self._config.habitat_baselines.eval.should_load_ckpt - ): - self._load_agent_checkpoint() - - self._agent.eval() - - self._action_shape: Tuple[int] - self._discrete_actions: bool - self._action_shape, self._discrete_actions = get_action_space_info( - self._agent.actor_critic.policy_action_space - ) - - hidden_state_lens = self._agent.actor_critic.hidden_state_shape_lens - action_space_lens = ( - self._agent.actor_critic.policy_action_space_shape_lens - ) - - self._space_lengths: Dict = {} - n_agents = len(self._config.habitat.simulator.agents) - if n_agents > 1: - self._space_lengths = { - "index_len_recurrent_hidden_states": hidden_state_lens, - "index_len_prev_actions": action_space_lens, - } - - # these attributes are used for inference - # and will be set in on_environment_reset - self._test_recurrent_hidden_states = None - self._prev_actions = None - self._not_done_masks = None - - @abstractmethod - def _create_env_spec(self): - pass - - @abstractmethod - def _get_active_obs_transforms(self): - pass - - @abstractmethod - def _create_agent(self): - pass - - @abstractmethod - def _load_agent_state_dict(self, checkpoint): - pass - - def _load_agent_checkpoint(self): - checkpoint = torch.load( - self._config.habitat_baselines.eval_ckpt_path_dir, - map_location="cpu", - ) - self._load_agent_state_dict(checkpoint) - - def _batch_and_apply_transforms(self, obs): - batch = batch_obs(obs, device=self.device) - batch = apply_obs_transforms_batch(batch, self._obs_transforms) - - return batch - - def on_environment_reset(self): - self._test_recurrent_hidden_states = torch.zeros( - ( - self._num_envs, - *self._agent.actor_critic.hidden_state_shape, - ), - device=self.device, - ) - - self._prev_actions = torch.zeros( - self._num_envs, - *self._action_shape, - device=self.device, - dtype=torch.long if self._discrete_actions else torch.float, - ) - - self._not_done_masks = torch.zeros( - ( - self._num_envs, - *self._agent.masks_shape, - ), - device=self.device, - dtype=torch.bool, - ) - - def act(self, obs, env): - batch = self._batch_and_apply_transforms([obs]) - - with torch.no_grad(): - action_data = self._agent.actor_critic.act( - batch, - self._test_recurrent_hidden_states, - self._prev_actions, - self._not_done_masks, - deterministic=False, - **self._space_lengths, - ) - if action_data.should_inserts is None: - self._test_recurrent_hidden_states = ( - action_data.rnn_hidden_states - ) - self._prev_actions.copy_(action_data.actions) # type: ignore - else: - self._agent.actor_critic.update_hidden_state( - self._test_recurrent_hidden_states, - self._prev_actions, - action_data, - ) - - assert len(action_data.env_actions) == 1 - if is_continuous_action_space(self._env_spec.action_space): - # Clipping actions to the specified limits - action = np.clip( - action_data.env_actions[0].cpu().numpy(), - self._env_spec.action_space.low, - self._env_spec.action_space.high, - ) - else: - action = action_data.env_actions[0].cpu().item() - - # _not_done_masks serves as en indicator of whether the episode is done - # it is reset to False in on_environment_reset - self._not_done_masks.fill_(True) # type: ignore [attr-defined] - - return action From 9759720bf41a9179f852ed077a225fa2a4ed3a9f Mon Sep 17 00:00:00 2001 From: Mikael Dallaire Cote <110583667+0mdc@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:29:58 -0500 Subject: [PATCH 2/3] Fix typo. --- .../environment/controllers/baselines_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py index 1b911af9e5..9dd6851373 100644 --- a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py +++ b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py @@ -249,7 +249,7 @@ def __init__( ) def _create_env_spec(self): - # djust the observation and action space to be agent specific (remove other agents) + # Adjust the observation and action space to be agent specific (remove other agents) original_action_space = clean_dict( self._gym_habitat_env.original_action_space, self._agent_k ) From ddbd084e687018fe9daaea52ddeaf42cd29ce762 Mon Sep 17 00:00:00 2001 From: Mikael Dallaire Cote <110583667+0mdc@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:45:26 -0500 Subject: [PATCH 3/3] Fix import statements. --- .../controllers/baselines_controller.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py index 9dd6851373..7c6d8651c0 100644 --- a/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py +++ b/habitat-hitl/habitat_hitl/environment/controllers/baselines_controller.py @@ -7,38 +7,40 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Dict, List, Tuple +import gym.spaces as spaces +import numpy as np +import torch + +import habitat +import habitat.gym.gym_wrapper as gym_wrapper +from habitat_baselines.common.env_spec import EnvironmentSpec +from habitat_baselines.common.obs_transformers import ( + ObservationTransformer, + apply_obs_transforms_batch, + apply_obs_transforms_obs_space, + get_active_obs_transforms, +) +from habitat_baselines.rl.multi_agent.multi_agent_access_mgr import ( + MultiAgentAccessMgr, +) +from habitat_baselines.rl.multi_agent.utils import ( + update_dict_with_agent_prefix, +) +from habitat_baselines.rl.ppo.agent_access_mgr import AgentAccessMgr +from habitat_baselines.rl.ppo.single_agent_access_mgr import ( + SingleAgentAccessMgr, +) +from habitat_baselines.utils.common import ( + batch_obs, + get_action_space_info, + is_continuous_action_space, +) +from habitat_hitl.environment.controllers.controller_abc import Controller + if TYPE_CHECKING: - import gym.spaces as spaces - import numpy as np - import torch from omegaconf import DictConfig - import habitat - import habitat.gym.gym_wrapper as gym_wrapper from habitat.core.environments import GymHabitatEnv - from habitat_baselines.common.env_spec import EnvironmentSpec - from habitat_baselines.common.obs_transformers import ( - ObservationTransformer, - apply_obs_transforms_batch, - apply_obs_transforms_obs_space, - get_active_obs_transforms, - ) - from habitat_baselines.rl.multi_agent.multi_agent_access_mgr import ( - MultiAgentAccessMgr, - ) - from habitat_baselines.rl.multi_agent.utils import ( - update_dict_with_agent_prefix, - ) - from habitat_baselines.rl.ppo.agent_access_mgr import AgentAccessMgr - from habitat_baselines.rl.ppo.single_agent_access_mgr import ( - SingleAgentAccessMgr, - ) - from habitat_baselines.utils.common import ( - batch_obs, - get_action_space_info, - is_continuous_action_space, - ) - from habitat_hitl.environment.controllers.controller_abc import Controller def clean_dict(d, remove_prefix):