Source code for metatrain.utils.data.writers
from pathlib import Path
from typing import Dict, Optional, Protocol, Type, Union
from metatomic.torch import ModelCapabilities
from .ase import ASEWriter
from .diskdataset import DiskDatasetWriter
from .metatensor import MetatensorWriter
from .writers import (
Writer,
)
from .writers import (
_split_tensormaps as _split_tensormaps,
)
class WriterFactory(Protocol):
def __call__(
self,
filename: Union[str, Path],
capabilities: Optional[ModelCapabilities] = None,
append: Optional[bool] = None,
) -> Writer: ...
def _make_factory(
cls: Type[Writer],
) -> WriterFactory:
def factory(
filename: Union[str, Path],
capabilities: Optional[ModelCapabilities] = None,
append: Optional[bool] = None,
) -> Writer:
return cls(filename, capabilities, append)
return factory
PREDICTIONS_WRITERS: Dict[str, WriterFactory] = {
".xyz": _make_factory(ASEWriter),
".mts": _make_factory(MetatensorWriter),
".zip": _make_factory(DiskDatasetWriter),
}
""":py:class:`dict`: dictionary mapping file suffixes to a prediction writer"""
DEFAULT_WRITER: WriterFactory = _make_factory(ASEWriter)
[docs]
def get_writer(
filename: Union[str, Path],
capabilities: Optional[ModelCapabilities] = None,
append: Optional[bool] = None,
fileformat: Optional[str] = None,
) -> Writer:
"""Selects the appropriate writer based on the file extension.
For certain file suffixes, the systems will also be written (i.e ``xyz``).
The capabilities of the model are used to infer the type (physical quantity) of
the predictions. In this way, for example, position gradients of energies can be
saved as forces.
For the moment, strain gradients of the energy are saved as stresses
(and not as virials).
:param filename: name of the file to write
:param capabilities: capabilities of the model
:param append: if :py:obj:`True`, the data will be appended to the file, if it
exists. If :py:obj:`False`, the file will be overwritten. If :py:obj:`None`,
the default behavior of the writer is used.
:param fileformat: format of the target value file. If :py:obj:`None` the format is
determined from the file extension.
"""
if fileformat is None:
fileformat = Path(filename).suffix
try:
writer_factory = PREDICTIONS_WRITERS[fileformat]
except KeyError:
raise ValueError(f"fileformat '{fileformat}' is not supported")
return writer_factory(Path(filename).stem + fileformat, capabilities, append)