Source code for metatrain.utils.transfer

from typing import Dict, List, Optional

import torch
from metatensor.torch import TensorMap
from metatomic.torch import System

from . import torch_jit_script_unless_coverage


[docs] @torch_jit_script_unless_coverage def batch_to( systems: List[System], targets: Dict[str, TensorMap], extra_data: Optional[Dict[str, TensorMap]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): """ Changes the systems and targets to the specified floating point data type. :param systems: List of systems. :param targets: Dictionary of targets. :param dtype: Desired floating point data type. """ systems = [system.to(dtype=dtype, device=device) for system in systems] targets = { key: value.to(dtype=dtype, device=device) for key, value in targets.items() } if extra_data is not None: new_dtypes: List[Optional[int]] = [] for key in extra_data.keys(): if key.endswith("_mask"): # masks should always be boolean new_dtypes.append(torch.bool) else: new_dtypes.append(dtype) extra_data = { key: value.to(dtype=_dtype, device=device) for (key, value), _dtype in zip(extra_data.items(), new_dtypes) } return systems, targets, extra_data