from typing import Callable, Dict, List, Tuple
import ase.neighborlist
import numpy as np
import torch
import vesin
from metatensor.torch import Labels, TensorBlock
from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors
from .data.system_to_ase import system_to_ase
[docs]
def get_requested_neighbor_lists(
module: torch.nn.Module,
) -> List[NeighborListOptions]:
"""Get the neighbor lists requested by a module and its children.
:param module: The module for which to get the requested neighbor lists.
:return: A list of `NeighborListOptions` objects requested by the module.
"""
requested: List[NeighborListOptions] = []
_get_requested_neighbor_lists_in_place(
module=module,
module_name="",
requested=requested,
)
return requested
def _get_requested_neighbor_lists_in_place(
module: torch.nn.Module,
module_name: str,
requested: List[NeighborListOptions],
) -> None:
# copied from
# metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py
# and just removed the length units
if hasattr(module, "requested_neighbor_lists"):
for new_options in module.requested_neighbor_lists():
new_options.add_requestor(module_name)
already_requested = False
for existing in requested:
if existing == new_options:
already_requested = True
for requestor in new_options.requestors():
existing.add_requestor(requestor)
if not already_requested:
requested.append(new_options)
for child_name, child in module.named_children():
_get_requested_neighbor_lists_in_place(
module=child,
module_name=module_name + "." + child_name,
requested=requested,
)
[docs]
def get_system_with_neighbor_lists(
system: System, neighbor_lists: List[NeighborListOptions]
) -> System:
"""Attaches neighbor lists to a `System` object.
:param system: The system for which to calculate neighbor lists.
:param neighbor_lists: A list of `NeighborListOptions` objects,
each of which specifies the parameters for a neighbor list.
:return: The `System` object with the neighbor lists added.
"""
# Convert the system to an ASE atoms object
atoms = system_to_ase(system)
# Compute the neighbor lists
for options in neighbor_lists:
if options not in system.known_neighbor_lists():
neighbors = _compute_single_neighbor_list(atoms, options).to(
device=system.device, dtype=system.dtype
)
register_autograd_neighbors(system, neighbors)
system.add_neighbor_list(options, neighbors)
return system
def _compute_single_neighbor_list(
atoms: ase.Atoms, options: NeighborListOptions
) -> TensorBlock:
# Computes a single neighbor list for an ASE atoms object (as in metatomic.torch)
if np.all(atoms.pbc) or np.all(~atoms.pbc):
nl_i, nl_j, nl_S, nl_D = vesin.ase_neighbor_list(
"ijSD",
atoms,
cutoff=options.cutoff,
)
else:
# this is not implemented in vesin, so we use ASE
nl_i, nl_j, nl_S, nl_D = ase.neighborlist.neighbor_list(
"ijSD",
atoms,
cutoff=options.cutoff,
)
if not options.full_list:
# The pair selection code here below avoids a relatively slow loop over
# all pairs to improve performance
reject_condition = (
# we want a half neighbor list, so drop all duplicated neighbors
(nl_j < nl_i)
| (
(nl_i == nl_j)
& (
# only create pairs with the same atom twice if the pair spans more
# than one unit cell
((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0))
|
# The code generates multiple redundant pairs for multiple impages
# (e.g. with shifts 0 1 1 and 0 -1 -1) and we want to only keep one:
# We keep the pair in the positive half plane of shifts.
(
(nl_S.sum(axis=1) < 0)
| (
(nl_S.sum(axis=1) == 0)
& (
(nl_S[:, 2] < 0)
| ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0))
)
)
)
)
)
)
selected = np.logical_not(reject_condition)
nl_i = nl_i[selected]
nl_j = nl_j[selected]
nl_S = nl_S[selected]
nl_D = nl_D[selected]
samples = np.concatenate(
[nl_i[:, None], nl_j[:, None], nl_S], axis=-1, dtype=np.int32
)
samples = torch.from_numpy(samples)
distances = torch.from_numpy(nl_D)
return TensorBlock(
values=distances.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=samples,
assume_unique=True,
),
components=[Labels.range("xyz", 3)],
properties=Labels.range("distance", 1),
)