Source code for metatrain.utils.data.writers.ase
from pathlib import Path
from typing import Dict, List, Optional, Union
import ase
import ase.io
import torch
from metatensor.torch import TensorMap
from metatomic.torch import ModelCapabilities, System
from metatrain.utils.external_naming import to_external_name
from .writers import Writer, _split_tensormaps
[docs]
class ASEWriter(Writer):
"""Write systems and predictions to an ASE-compatible XYZ file."""
def __init__(
self,
filename: Union[str, Path],
capabilities: Optional[
ModelCapabilities
] = None, # unused, but matches base signature
append: Optional[bool] = False, # unused, but matches base signature
):
super().__init__(filename, capabilities, append)
self._first = True
self._systems: List[System] = []
self._preds: List[Dict[str, TensorMap]] = []
[docs]
def write(self, systems: List[System], predictions: Dict[str, TensorMap]):
"""
Accumulate systems and predictions to write them all at once in ``finish``.
"""
self._systems.extend([system.to("cpu").to(torch.float64) for system in systems])
self._preds.extend(_split_tensormaps(systems, predictions))
[docs]
def finish(self):
"""
Write all accumulated systems and predictions to the XYZ file.
"""
if not self._systems:
return
systems = self._systems
predictions_by_structure = self._preds
frames = []
for system, system_predictions in zip(systems, predictions_by_structure):
info = {}
arrays = {}
for target_name, target_map in system_predictions.items():
if len(target_map.keys) != 1:
raise ValueError(
"Only single-block `TensorMap`s can be "
"written to xyz files for the moment."
)
block = target_map.block()
if "atom" in block.samples.names:
# save inside arrays
values = block.values.detach().cpu().numpy()
arrays[target_name] = values.reshape(values.shape[0], -1)
# reshaping reshaping because `arrays` only accepts 2D arrays
else:
# save inside info
if block.values.numel() == 1:
info[target_name] = block.values.item()
else:
info[target_name] = (
block.values.detach().cpu().numpy().squeeze(0)
)
# squeeze the sample dimension, which corresponds to the system
for gradient_name, gradient_block in block.gradients():
# we assume that gradients are always an array, never a scalar
internal_name = f"{target_name}_{gradient_name}_gradients"
external_name = to_external_name(
internal_name, self.capabilities.outputs
)
if "forces" in external_name:
arrays[external_name] = (
# squeeze the property dimension
-gradient_block.values.detach().cpu().squeeze(-1).numpy()
)
elif "virial" in external_name:
# in this case, we write both the virial and the stress
external_name_virial = external_name
external_name_stress = external_name.replace("virial", "stress")
strain_derivatives = (
# squeeze the property dimension
gradient_block.values.detach().cpu().squeeze(-1).numpy()
)
if not torch.any(system.cell != 0):
raise ValueError(
"stresses cannot be written for non-periodic systems."
)
cell_volume = torch.det(system.cell).item()
if cell_volume == 0:
raise ValueError(
(
"stresses cannot be written for "
"systems with zero volume."
)
)
info[external_name_virial] = -strain_derivatives
info[external_name_stress] = strain_derivatives / cell_volume
else:
info[external_name] = (
# squeeze the property dimension
gradient_block.values.detach().cpu().squeeze(-1).numpy()
)
atoms = ase.Atoms(
symbols=system.types.numpy(),
positions=system.positions.detach().numpy(),
info=info,
)
# assign cell and pbcs
if torch.any(system.cell != 0):
atoms.pbc = True
atoms.cell = system.cell.detach().cpu().numpy()
# assign arrays
for array_name, array in arrays.items():
atoms.arrays[array_name] = array
frames.append(atoms)
ase.io.write(self.filename, frames)