Source code for metatrain.utils.architectures
import difflib
import importlib
import json
import logging
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, List, Union
from omegaconf import OmegaConf
from .. import PACKAGE_ROOT
from .jsonschema import validate
[docs]
def check_architecture_name(name: str) -> None:
"""Check if the requested architecture is available.
If the architecture is not found an :func:`ValueError` is raised. If an architecture
with the same name as an experimental or deprecated architecture exist, this
architecture is suggested. If no architecture exist the closest architecture is
given to help debugging typos.
:param name: name of the architecture
:raises ValueError: if the architecture is not found
"""
try:
if name == "llpr":
return
if find_spec(f"metatrain.{name}") is not None:
return
elif find_spec(f"metatrain.experimental.{name}") is not None:
msg = (
f"Architecture {name!r} is not a stable architecture. An "
"experimental architecture with the same name was found. Set "
f"`name: experimental.{name}` in your options file to use this "
"experimental architecture."
)
elif find_spec(f"metatrain.deprecated.{name}") is not None:
msg = (
f"Architecture {name!r} is not a stable architecture. A "
"deprecated architecture with the same name was found. Set "
f"`name: deprecated.{name}` in your options file to use this "
"deprecated architecture."
)
else: # not found anywhere, just raise the following except block
raise ModuleNotFoundError
except ModuleNotFoundError:
msg = f"Architecture {name!r} is not a valid architecture."
closest_match = difflib.get_close_matches(
word=name, possibilities=find_all_architectures()
)
if closest_match:
msg += f" Do you mean '{closest_match[0]}'?"
raise ValueError(msg)
[docs]
def check_architecture_options(
name: str,
options: Dict,
) -> None:
"""Verifies that an options instance only contains valid keys
If the architecture developer does not provide a validation scheme the ``options``
will not checked.
:param name: name of the architecture
:param options: architecture options to check
"""
schema_path = get_architecture_path(name) / "schema-hypers.json"
if schema_path.exists():
with open(schema_path, "r") as f:
schema = json.load(f)
validate(instance=options, schema=schema)
else:
logging.debug("No schema found for {name!r} architecture. Skipping validation.")
[docs]
def get_architecture_name(path: Union[str, Path]) -> str:
"""Name of an architecture based on path to pointing inside an architecture.
The function should be used to determine the ``ARCHITECTURE_NAME`` based on the name
of the folder.
:param absolute_architecture_path: absolute path of the architecture directory
:returns: architecture name
:raises ValueError: if ``absolute_architecture_path`` does not point to a valid
architecture directory.
.. seealso::
:py:func:`get_architecture_path` to get the relative path within the metatrain
project of an architecture name.
"""
path = Path(path)
if path.is_dir():
directory = path
elif path.is_file():
directory = path.parent
else:
raise ValueError(f"`path` {str(path)!r} does not exist")
architecture_path = directory.relative_to(PACKAGE_ROOT)
name = str(architecture_path).replace("/", ".")
try:
check_architecture_name(name)
except ValueError as err:
raise ValueError(
f"`path` {str(path)!r} does not point to a valid architecture folder"
) from err
return name
[docs]
def import_architecture(name: str):
"""Import an architecture.
:param name: name of the architecture
:raises ImportError: if the architecture dependencies are not met
"""
check_architecture_name(name)
try:
if name == "llpr":
return importlib.import_module("metatrain.utils.llpr")
else:
return importlib.import_module(f"metatrain.{name}")
except ImportError as err:
# consistent name with pyproject.toml's `optional-dependencies` section
name_for_deps = name
if "experimental." in name or "deprecated." in name:
name_for_deps = ".".join(name.split(".")[1:])
name_for_deps = name_for_deps.replace("_", "-")
raise ImportError(
f"Trying to import '{name}' but architecture dependencies "
f"seem not be installed. \n"
f"Try to install them with `pip install metatrain[{name_for_deps}]`"
) from err
[docs]
def get_architecture_path(name: str) -> Path:
"""Return the relative path to the architecture directory.
Path based on the ``name`` within the metatrain project directory.
:param name: name of the architecture
:returns: path to the architecture directory
.. seealso::
:py:func:`get_architecture_name` to get the name based on an absolute path of an
architecture.
"""
check_architecture_name(name)
return PACKAGE_ROOT / Path(name.replace(".", "/"))
[docs]
def find_all_architectures() -> List[str]:
"""Find all currently available architectures.
To find the architectures the function searches for the mandatory
``default-hypers.yaml`` file in each architecture directory.
:returns: List of architectures names
"""
options_files_path = PACKAGE_ROOT.rglob("default-hypers.yaml")
architecture_names = []
for option_file_path in options_files_path:
architecture_names.append(get_architecture_name(option_file_path))
architecture_names.append("llpr")
return architecture_names
[docs]
def get_default_hypers(name: str) -> Dict:
"""Dictionary of the default architecture hyperparameters.
:param: name of the architecture
:returns: default hyper parameters of the architectures
"""
check_architecture_name(name)
default_hypers = OmegaConf.load(get_architecture_path(name) / "default-hypers.yaml")
# We present the `default-hypers.yaml` file inside the documentation. For a better
# user experience we store these yaml files with an additional level of indentation
# (`"architecture"`), which we have to remove here to get the raw default hypers.
return OmegaConf.to_container(default_hypers)["architecture"]