Source code for metatrain.utils.output_gradient
import warnings
from typing import List, Optional
import torch
[docs]
def compute_gradient(
target: torch.Tensor, inputs: List[torch.Tensor], is_training: bool
) -> List[torch.Tensor]:
"""
Calculates the gradient of a target tensor with respect to a list of input tensors.
``target`` must be a single torch.Tensor object. If target contains multiple values,
the gradient will be calculated with respect to the sum of all values.
"""
grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)]
try:
gradient = torch.autograd.grad(
outputs=[target],
inputs=inputs,
grad_outputs=grad_outputs,
retain_graph=is_training,
create_graph=is_training,
)
except RuntimeError as e:
# Torch raises an error if the target tensor does not require grad,
# but this could just mean that the target is a constant tensor, like in
# the case of composition models. In this case, we can safely ignore the error
# and we raise a warning instead. The warning can be caught and silenced in the
# appropriate places.
if (
"element 0 of tensors does not require grad and does not have a grad_fn"
in str(e)
):
warnings.warn(f"GRADIENT WARNING: {e}", RuntimeWarning, stacklevel=2)
gradient = [torch.zeros_like(i) for i in inputs]
else:
# Re-raise the error if it's not the one above
raise
if gradient is None:
raise ValueError(
"Unexpected None value for computed gradient. "
"One or more operations inside the model might "
"not have a gradient implementation."
)
else:
return gradient