diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 48414bb36..ddd1c2798 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -174,6 +174,7 @@ def __init__(self, url, auth: AuthBase = None, session: requests.Session = None) """ super().__init__(root_url=url, auth=auth, session=session) self._cached_capabilities = None + self._process_registry = None # Initial API version check. if self._api_version.below(self._MINIMUM_API_VERSION): @@ -267,6 +268,15 @@ def capabilities(self) -> 'Capabilities': return self._cached_capabilities + def process_registry(self) -> 'ProcessRegistry': + """ + Load all processes supported by the backend (lazy/cached) + :return: ProcessRegistry + """ + if self._process_registry is None: + self._process_registry = ProcessRegistry.from_connection(connection=self) + return self._process_registry + @deprecated("Use 'list_output_formats' instead") def list_file_types(self) -> dict: return self.list_output_formats() @@ -554,3 +564,23 @@ def session(userid=None, endpoint: str = "https://openeo.org/openeo") -> Connect """ return connect(url=endpoint) + +class ProcessRegistry: + """ + Registry of process specs (e.g. the processes supported by a backend) + """ + def __init__(self, processes: dict): + self._reg = processes + + @classmethod + def from_connection(cls, connection=Connection): + """Factory to load process registry from given backend connection.""" + # Get as list from API + processes = connection.get('/processes').json()['processes'] + # Make it a dictionary for more efficient retrieval + processes = {p['id']: p for p in processes} + return cls(processes=processes) + + def get_parameters(self, process_id: str) -> List[dict]: + """Get parameters for given process_id.""" + return self._reg[process_id]["parameters"] diff --git a/openeo/rest/imagecollectionclient.py b/openeo/rest/imagecollectionclient.py index 2bbc375d7..381b6f79a 100644 --- a/openeo/rest/imagecollectionclient.py +++ b/openeo/rest/imagecollectionclient.py @@ -31,6 +31,7 @@ def __init__(self, node_id: str, builder: GraphBuilder, session: 'Connection', m self.session = session self.graph = builder.processes self.metadata = metadata + self.dynamic = DynamicCubeMethodDelegator(cube=self) def __str__(self): return "ImageCollection: %s" % self.node_id @@ -1070,3 +1071,59 @@ def to_graphviz(self): # TODO: add subgraph for "callback" arguments? return graph + + +class DynamicProcessException(Exception): + pass + + +class _DynamicCubeMethod: + """ + A dynamically detected process bound to a raster cube. + The process should have a single "raster-cube" parameter. + """ + + def __init__(self, cube: ImageCollectionClient, process_id: str, parameters: List[dict]): + self.cube = cube + self.process_id = process_id + self.parameters = parameters + + # Find raster-cube parameter. + expected_schema = {"type": "object", "subtype": "raster-cube"} + names = [p["name"] for p in self.parameters if p["schema"] == expected_schema] + if len(names) != 1: + raise DynamicProcessException("Need one raster-cube parameter but found {c}".format(c=len(names))) + self.cube_parameter = names[0] + + def __call__(self, *args, **kwargs): + """Call the "cube method": pass cube and other arguments to the process.""" + arguments = { + self.cube_parameter: {"from_node": self.cube.node_id} + } + # TODO: more advanced parameter checking (required vs optional), normalization based on type, ... + for i, arg in enumerate(args): + arguments[self.parameters[i]["name"]] = arg + for key, value in kwargs.items(): + assert any(p["name"] == key for p in self.parameters) + assert key not in arguments + arguments[key] = value + + return self.cube.graph_add_process( + process_id=self.process_id, + args=arguments, + ) + + +class DynamicCubeMethodDelegator: + """ + Wrapper for a DataCube to group and delegate to dynamically detected processes + (depending on a particular backend or API spec) + """ + + def __init__(self, cube: ImageCollectionClient): + self.cube = cube + + def __getattr__(self, process_id): + self.process_registry = self.cube.session.process_registry() + parameters = self.process_registry.get_parameters(process_id) + return _DynamicCubeMethod(self.cube, process_id=process_id, parameters=parameters) diff --git a/tests/rest/test_imagecollectionclient.py b/tests/rest/test_imagecollectionclient.py index 3281b71d9..25fd7b8d9 100644 --- a/tests/rest/test_imagecollectionclient.py +++ b/tests/rest/test_imagecollectionclient.py @@ -14,6 +14,8 @@ def session040(requests_mock): requests_mock.get(API_URL + "/", json={"api_version": "0.4.0"}) session = openeo.connect(API_URL) + # Reset graph builder + GraphBuilder.id_counter = {} return session @@ -82,3 +84,29 @@ def result_callback(request, context): path = tmpdir.join("tmp.tiff") session040.load_collection("SENTINEL2").download(str(path), format="GTIFF") assert path.read() == "tiffdata" + + +def test_dynamic_cube_method(session040, requests_mock): + processes = [ + { + "id": "make_larger", + "description": "multiply a raster cube with a factor", + "parameters": [ + {"name": "data", "schema": {"type": "object", "subtype": "raster-cube"}}, + {"name": "factor", "schema": {"type": "float"}}, + ]} + ] + requests_mock.get(API_URL + '/processes', json={"processes": processes}) + requests_mock.get(API_URL + "/collections/SENTINEL2", json={"foo": "bar"}) + + cube = session040.load_collection("SENTINEL2") + evi = cube.dynamic.make_larger(factor=42) + assert set(evi.graph.keys()) == {"loadcollection1", "makelarger1"} + assert evi.graph["makelarger1"] == { + "process_id": "make_larger", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "factor": 42, + }, + "result": False + }