Skip to content

Commit

Permalink
Add prompt_for_slabs_to_keep filter
Browse files Browse the repository at this point in the history
Changes the default behavior of find_adsorbate_binding_sites() to prompt
users for the set of slabs that they want to submit.
  • Loading branch information
kjmichel committed Oct 12, 2023
1 parent a9f13fc commit bae23bb
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 32 deletions.
48 changes: 29 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ results = await find_adsorbate_binding_sites(
)
```

Users will be prompted to select one or more surfaces that should be relaxed.

Input to this function includes:

* The SMILES string of the adsorbate to place
Expand All @@ -57,7 +59,7 @@ In addition, this handles:
* Retrying failed calls to the Open Catalyst Demo API
* Retrying submission of relaxations when they are rate limited

This should take 5-10 minutes to finish while hundreds of individual adsorbate placements are relaxed over six unique surfaces of Pt. Each of the objects in the returned list includes (among other details):
This should take 2-10 minutes to finish while tens to hundreds (depending on the number of surfaces that are selected) of individual adsorbate placements are relaxed on unique surfaces of Pt. Each of the objects in the returned list includes (among other details):

* Information about the surface being searched, including its structure and Miller indices
* The initial positions of the adsorbate before relaxation
Expand All @@ -66,24 +68,6 @@ This should take 5-10 minutes to finish while hundreds of individual adsorbate p
* The predicted force on each atom in the final structure


### Search over a subset of Miller indices

```python
from ocpapi import (
find_adsorbate_binding_sites,
keep_slabs_with_miller_indices,
)

results = await find_adsorbate_binding_sites(
adsorbate="*OH",
bulk="mp-126",
adslab_filter=keep_slabs_with_miller_indices([(1, 1, 0), (1, 1, 1)])
)
```

This example adds the `adslab_filter` field, which takes a function that selects out generated surfaces that meet some criteria; in this case, keeping only the surfaces that have Miller indices of (1, 1, 0) or (1, 1, 1).


### Persisting results

**Results should be saved whenever possible in order to avoid expensive recomputation.**
Expand Down Expand Up @@ -135,6 +119,32 @@ results = await find_adsorbate_binding_sites(
)
```

### Skip relaxation approval prompts

Calls to `find_adsorbate_binding_sites()` will, by default, show the user all pending relaxations and ask for approval before they are submitted. In order to run the relaxations automatically without manual approval, `adslab_filter` can be set to a function that automatically approves any or all adslabs.

Run relaxations for all slabs that are generated:
```python
from ocpapi import find_adsorbate_binding_sites, keep_all_slabs

results = await find_adsorbate_binding_sites(
adsorbate="*OH",
bulk="mp-126",
adslab_filter=keep_all_slabs(),
)
```

Run relaxations only for slabs with Miller Indices in the input set:
```python
from ocpapi import find_adsorbate_binding_sites, keep_slabs_with_miller_indices

results = await find_adsorbate_binding_sites(
adsorbate="*OH",
bulk="mp-126",
adslab_filter=keep_slabs_with_miller_indices([(1, 0, 0), (1, 1, 1)]),
)
```

### Converting to [ase.Atoms](https://wiki.fysik.dtu.dk/ase/ase/atoms.html) objects

**Important! The `to_ase_atoms()` method described below will fail with an import error if [ase](https://wiki.fysik.dtu.dk/ase) is not installed.**
Expand Down
6 changes: 5 additions & 1 deletion ocpapi/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
get_adsorbate_slab_relaxation_results,
wait_for_adsorbate_slab_relaxations,
)
from .filter import keep_all_slabs, keep_slabs_with_miller_indices # noqa
from .filter import ( # noqa
keep_all_slabs,
keep_slabs_with_miller_indices,
prompt_for_slabs_to_keep,
)
from .retry import ( # noqa
NO_LIMIT,
NoLimitType,
Expand Down
11 changes: 5 additions & 6 deletions ocpapi/workflows/adsorbates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)

from .context import set_context_var
from .filter import keep_all_slabs
from .filter import prompt_for_slabs_to_keep
from .log import log
from .retry import NO_LIMIT, RateLimitLogging, retry_api_calls

Expand Down Expand Up @@ -783,7 +783,7 @@ async def _relax_binding_sites_on_slabs(

_DEFAULT_ADSLAB_FILTER: Callable[
[List[AdsorbateSlabConfigs]], Awaitable[List[AdsorbateSlabConfigs]]
] = keep_all_slabs()
] = prompt_for_slabs_to_keep()


async def find_adsorbate_binding_sites(
Expand All @@ -802,12 +802,11 @@ async def find_adsorbate_binding_sites(
1. Ensure that both the adsorbate and bulk are supported in the
OCP API.
2. Enumerate unique surfaces from the bulk material. If a
slab_filter function is provided, only those surfaces for
which the filter returns True will be kept.
2. Enumerate unique surfaces from the bulk material.
3. Enumerate likely binding sites for the input adsorbate on each
of the generated surfaces.
4. Relax each generated surface+adsorbate structure by refining
4. Filter the list of generated adslabs using the input adslab_filter.
5. Relax each generated surface+adsorbate structure by refining
atomic positions to minimize forces generated by the input model.
Args:
Expand Down
65 changes: 64 additions & 1 deletion ocpapi/workflows/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterable, List, Set, Tuple

from ocpapi.client import AdsorbateSlabConfigs
from ocpapi.client import AdsorbateSlabConfigs, SlabMetadata


class keep_all_slabs:
Expand Down Expand Up @@ -39,3 +39,66 @@ async def __call__(
for adslab in adslabs
if adslab.slab.metadata.millers in self._unique_millers
]


class prompt_for_slabs_to_keep:
"""
Adslab filter than presents the user with an interactive prompt to choose
which of the input slabs to keep.
"""

@staticmethod
def _sort_key(
adslab: AdsorbateSlabConfigs,
) -> Tuple[Tuple[int, int, int], float, str]:
"""
Generates a sort key from the input adslab. Returns the miller indices,
shift, and top/bottom label so that they will be sorted by those values
in that order.
"""
metadata: SlabMetadata = adslab.slab.metadata
return (metadata.millers, metadata.shift, metadata.top)

async def __call__(
self,
adslabs: List[AdsorbateSlabConfigs],
) -> List[AdsorbateSlabConfigs]:
from inquirer import Checkbox, prompt

# Break early if no adslabs were provided
if not adslabs:
return adslabs

# Sort the input list so the options are grouped in a sensible way
adslabs = sorted(adslabs, key=self._sort_key)

# List of options to present to the user. The first item in each tuple
# will be presented to the user in the prompt. The second item in each
# tuple (indices from the input list of adslabs) will be returned from
# the prompt.
choices: List[Tuple[str, int]] = [
(
(
f"{adslab.slab.metadata.millers} "
f"{'top' if adslab.slab.metadata.top else 'bottom'} "
"surface shifted by "
f"{round(adslab.slab.metadata.shift, 3)}; "
f"{len(adslab.adsorbate_configs)} unique adsorbate "
"placements to relax"
),
idx,
)
for idx, adslab in enumerate(adslabs)
]
checkbox: Checkbox = Checkbox(
"adslabs",
message=(
"Choose surfaces to relax (up/down arrows to move, "
"space to select, enter when finished)"
),
choices=choices,
)
selected_indices: List[int] = prompt([checkbox])["adslabs"]

# Return the adslabs that were chosen
return [adslabs[i] for i in selected_indices]
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ install_requires =
responses == 0.23.2
tenacity == 8.2.3
tqdm == 4.66.1
inquirer == 3.1.3
dataclasses-json == 0.6.0

[options.extras_require]
dev = ase == 3.22.1
dev =
ase == 3.22.1
readchar == 4.0.5

[options.packages.find]
exclude =
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/workflows/test_adsorbates.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
UnsupportedModelException,
find_adsorbate_binding_sites,
get_adsorbate_slab_relaxation_results,
keep_all_slabs,
keep_slabs_with_miller_indices,
wait_for_adsorbate_slab_relaxations,
)
Expand Down Expand Up @@ -1040,8 +1041,14 @@ class TestCase:

# Coroutine that will fetch results
other = case.non_default_args if case.non_default_args else {}
if "adslab_filter" not in other:
# Override default that will prompt for input
other["adslab_filter"] = keep_all_slabs()
coro = find_adsorbate_binding_sites(
adsorbate=case.adsorbate, bulk=case.bulk, client=client, **other
adsorbate=case.adsorbate,
bulk=case.bulk,
client=client,
**other,
)

# Make sure an exception is raised if expected
Expand Down
112 changes: 109 additions & 3 deletions tests/unit/workflows/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import functools
import sys
from contextlib import ExitStack
from dataclasses import dataclass
from typing import List, Optional, Tuple
from unittest import IsolatedAsyncioTestCase
from io import StringIO
from typing import Any, List, Optional, Tuple
from unittest import IsolatedAsyncioTestCase, mock

from inquirer import prompt
from inquirer.events import KeyEventGenerator
from inquirer.render import ConsoleRender
from readchar import key

from ocpapi.client import AdsorbateSlabConfigs, Atoms, Slab, SlabMetadata
from ocpapi.workflows import keep_all_slabs, keep_slabs_with_miller_indices
from ocpapi.workflows import (
keep_all_slabs,
keep_slabs_with_miller_indices,
prompt_for_slabs_to_keep,
)


# Function used to generate a new adslab instance. This filles the minimum
Expand Down Expand Up @@ -122,3 +135,96 @@ class TestCase:
with self.subTest(msg=case.message):
actual = await case.adslab_filter(case.input)
self.assertEqual(case.expected, actual)

async def test_prompt_for_slabs_to_keep(self) -> None:
@dataclass
class TestCase:
message: str
input: List[AdsorbateSlabConfigs]
key_events: List[Any]
expected: List[AdsorbateSlabConfigs]

test_cases: List[TestCase] = [
# If no adslabs are provided then none should be returned
TestCase(
message="no slabs provided",
input=[],
key_events=[],
expected=[],
),
# If adslabs are provided but none are selected then none
# should be returned
TestCase(
message="no slabs selected",
input=[
_new_adslab(miller_indices=(1, 0, 0)),
_new_adslab(miller_indices=(2, 0, 0)),
_new_adslab(miller_indices=(3, 0, 0)),
],
key_events=[key.ENTER],
expected=[],
),
# If adslabs are provided and some are selected then those
# should be returned
TestCase(
message="some slabs selected",
input=[
_new_adslab(miller_indices=(1, 0, 0)),
_new_adslab(miller_indices=(2, 0, 0)),
_new_adslab(miller_indices=(3, 0, 0)),
],
key_events=[
key.SPACE, # Select first slab
key.DOWN, # Move to second slab
key.DOWN, # Move to third slab
key.SPACE, # Select third slab
key.ENTER, # Finish
],
expected=[
_new_adslab(miller_indices=(1, 0, 0)),
_new_adslab(miller_indices=(3, 0, 0)),
],
),
]

for case in test_cases:
with ExitStack() as es:
es.enter_context(self.subTest(msg=case.message))

# prompt_for_slabs_to_keep() creates an interactive prompt
# that the user can select from. Here we inject key presses
# to simulate a user interacting with the prompt. First we
# need to direct stdin and stdout to our own io objects.
orig_stdin = sys.stdin
orig_stdout = sys.stdout
try:
sys.stdin = StringIO()
sys.stdout = StringIO()

# Now we create a inquirer.ConsoleRender instance that
# uses the key_events (key presses) in the current test
# case.
it = iter(case.key_events)
renderer = ConsoleRender(
event_generator=KeyEventGenerator(lambda: next(it))
)

# Now inject our renderer into the prompt
es.enter_context(
mock.patch(
"inquirer.prompt",
side_effect=functools.partial(
prompt,
render=renderer,
),
)
)

# Finally run the filter
adslab_filter = prompt_for_slabs_to_keep()
actual = await adslab_filter(case.input)
self.assertEqual(case.expected, actual)

finally:
sys.stdin = orig_stdin
sys.stdout = orig_stdout

0 comments on commit bae23bb

Please sign in to comment.