diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index b74c70411b..7e24935720 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -25,6 +25,9 @@ removed when training with a player. The Editor still requires it to be clamped Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849) #### ml-agents / ml-agents-envs / gym-unity (Python) +- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default + [`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888) +- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888) ### Bug Fixes #### com.unity.ml-agents (C#) diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md index 7919c3deda..c6811d7fd7 100644 --- a/docs/Training-ML-Agents.md +++ b/docs/Training-ML-Agents.md @@ -188,7 +188,8 @@ using the help utility: mlagents-learn --help ``` -These additional CLI arguments are grouped into environment, engine and checkpoint. The available settings and example values are shown below. +These additional CLI arguments are grouped into environment, engine, checkpoint and torch. +The available settings and example values are shown below. #### Environment settings @@ -227,6 +228,13 @@ checkpoint_settings: inference: false ``` +#### Torch settings: + +```yaml +torch_settings: + device: cpu +``` + ### Behavior Configurations The primary section of the trainer config file is a diff --git a/ml-agents/mlagents/torch_utils/__init__.py b/ml-agents/mlagents/torch_utils/__init__.py index 9ba35a3500..0acc96997d 100644 --- a/ml-agents/mlagents/torch_utils/__init__.py +++ b/ml-agents/mlagents/torch_utils/__init__.py @@ -1,3 +1,4 @@ from mlagents.torch_utils.torch import torch as torch # noqa from mlagents.torch_utils.torch import nn # noqa +from mlagents.torch_utils.torch import set_torch_config # noqa from mlagents.torch_utils.torch import default_device # noqa diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index a3cb67ddf5..ddccd7a179 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -3,6 +3,11 @@ from distutils.version import LooseVersion import pkg_resources from mlagents.torch_utils import cpu_utils +from mlagents.trainers.settings import TorchSettings +from mlagents_envs.logging_util import get_logger + + +logger = get_logger(__name__) def assert_torch_installed(): @@ -32,14 +37,32 @@ def assert_torch_installed(): torch.set_num_threads(cpu_utils.get_num_threads_to_use()) os.environ["KMP_BLOCKTIME"] = "0" -if torch.cuda.is_available(): - torch.set_default_tensor_type(torch.cuda.FloatTensor) - device = torch.device("cuda") -else: - torch.set_default_tensor_type(torch.FloatTensor) - device = torch.device("cpu") + +_device = torch.device("cpu") + + +def set_torch_config(torch_settings: TorchSettings) -> None: + global _device + + if torch_settings.device is None: + device_str = "cuda" if torch.cuda.is_available() else "cpu" + else: + device_str = torch_settings.device + + _device = torch.device(device_str) + + if _device.type == "cuda": + torch.set_default_tensor_type(torch.cuda.FloatTensor) + else: + torch.set_default_tensor_type(torch.FloatTensor) + logger.info(f"default Torch device: {_device}") + + +# Initialize to default settings +set_torch_config(TorchSettings(device=None)) + nn = torch.nn def default_device(): - return device + return _device diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py index 5cc9d7c292..6849731600 100644 --- a/ml-agents/mlagents/trainers/cli_utils.py +++ b/ml-agents/mlagents/trainers/cli_utils.py @@ -177,12 +177,6 @@ def _create_parser() -> argparse.ArgumentParser: "passed to the executable.", action=DetectDefault, ) - argparser.add_argument( - "--cpu", - default=False, - action=DetectDefaultStoreTrue, - help="Forces training using CPU only", - ) argparser.add_argument( "--torch", default=False, @@ -252,6 +246,15 @@ def _create_parser() -> argparse.ArgumentParser: help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing " "the graphics driver. Use this only if your agents don't use visual observations.", ) + + torch_conf = argparser.add_argument_group(title="Torch Configuration") + torch_conf.add_argument( + "--torch-device", + default=None, + dest="device", + action=DetectDefault, + help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"', + ) return argparser diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index 82b65f59b4..bdedba2d20 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -62,6 +62,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: :param run_options: Command line arguments for training. """ with hierarchical_timer("run_training.setup"): + torch_utils.set_torch_config(options.torch_settings) checkpoint_settings = options.checkpoint_settings env_settings = options.env_settings engine_settings = options.engine_settings diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 9f47bd567b..02865c96f3 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -733,6 +733,11 @@ class EngineSettings: no_graphics: bool = parser.get_default("no_graphics") +@attr.s(auto_attribs=True) +class TorchSettings: + device: Optional[str] = parser.get_default("torch_device") + + @attr.s(auto_attribs=True) class RunOptions(ExportableSettings): default_settings: Optional[TrainerSettings] = None @@ -743,6 +748,7 @@ class RunOptions(ExportableSettings): engine_settings: EngineSettings = attr.ib(factory=EngineSettings) environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings) + torch_settings: TorchSettings = attr.ib(factory=TorchSettings) # These are options that are relevant to the run itself, and not the engine or environment. # They will be left here. @@ -784,6 +790,7 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": "checkpoint_settings": {}, "env_settings": {}, "engine_settings": {}, + "torch_settings": {}, } if config_path is not None: configured_dict.update(load_config(config_path)) @@ -808,6 +815,8 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": configured_dict["env_settings"][key] = val elif key in attr.fields_dict(EngineSettings): configured_dict["engine_settings"][key] = val + elif key in attr.fields_dict(TorchSettings): + configured_dict["torch_settings"][key] = val else: # Base options configured_dict[key] = val diff --git a/ml-agents/mlagents/trainers/tests/test_torch_utils.py b/ml-agents/mlagents/trainers/tests/test_torch_utils.py new file mode 100644 index 0000000000..7146831319 --- /dev/null +++ b/ml-agents/mlagents/trainers/tests/test_torch_utils.py @@ -0,0 +1,41 @@ +import pytest +from unittest import mock + +import torch # noqa I201 + +from mlagents.torch_utils import set_torch_config, default_device +from mlagents.trainers.settings import TorchSettings + + +@pytest.mark.parametrize( + "device_str, expected_type, expected_index, expected_tensor_type", + [ + ("cpu", "cpu", None, torch.FloatTensor), + ("cuda", "cuda", None, torch.cuda.FloatTensor), + ("cuda:42", "cuda", 42, torch.cuda.FloatTensor), + ("opengl", "opengl", None, torch.FloatTensor), + ], +) +@mock.patch.object(torch, "set_default_tensor_type") +def test_set_torch_device( + mock_set_default_tensor_type, + device_str, + expected_type, + expected_index, + expected_tensor_type, +): + try: + torch_settings = TorchSettings(device=device_str) + set_torch_config(torch_settings) + assert default_device().type == expected_type + if expected_index is None: + assert default_device().index is None + else: + assert default_device().index == expected_index + mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type) + except Exception: + raise + finally: + # restore the defaults + torch_settings = TorchSettings(device=None) + set_torch_config(torch_settings)