Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement local caching for WMTS requests #2316

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions lib/cartopy/io/ogc_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import warnings
import weakref
from xml.etree import ElementTree
import os

import numpy as np
from PIL import Image
from pathlib import Path
import shapely.geometry as sgeom

import cartopy


try:
import owslib.util
Expand Down Expand Up @@ -357,7 +361,7 @@ class WMTSRasterSource(RasterSource):

"""

def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
def __init__(self, wmts, layer_name, gettile_extra_kwargs=None, cache=False):
dnowacki-usgs marked this conversation as resolved.
Show resolved Hide resolved
"""
Parameters
----------
Expand All @@ -368,6 +372,9 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
gettile_extra_kwargs: dict, optional
Extra keywords (e.g. time) to pass through to the
service's gettile method.
cache : bool or str, optional
If True, the default cache directory is used. If False, no cache is
used. If a string, the string is used as the path to the cache.

"""
if WebMapService is None:
Expand Down Expand Up @@ -397,6 +404,18 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):

self._matrix_set_name_map = {}

# Enable a cache mechanism when cache is equal to True or to a path.
self._default_cache = False
if cache is True:
self._default_cache = True
self.cache_path = Path(cartopy.config["cache_dir"])
elif cache is False:
self.cache_path = None
else:
self.cache_path = Path(cache)
self.cache = set({})
self._load_cache()

def _matrix_set_name(self, target_projection):
key = id(target_projection)
matrix_set_name = self._matrix_set_name_map.get(key)
Expand Down Expand Up @@ -510,6 +529,23 @@ def fetch_raster(self, projection, extent, target_resolution):

return located_images

@property
def _cache_dir(self):
"""Return the name of the cache directory"""
return self.cache_path / self.__class__.__name__

def _load_cache(self):
"""Load the cache"""
if self.cache_path is not None:
cache_dir = self._cache_dir
if not cache_dir.exists():
os.makedirs(cache_dir)
if self._default_cache:
warnings.warn(
'Cartopy created the following directory to cache '
'WMTSRasterSource tiles: {}'.format(cache_dir))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'WMTSRasterSource tiles: {}'.format(cache_dir))
f'WMTSRasterSource tiles: {cache_dir}')

self.cache = self.cache.union(set(os.listdir(cache_dir)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.cache = self.cache.union(set(os.listdir(cache_dir)))
self.cache = self.cache.union(set(cache_dir.iterdir()))


def _choose_matrix(self, tile_matrices, meters_per_unit, max_pixel_span):
# Get the tile matrices in order of increasing resolution.
tile_matrices = sorted(tile_matrices,
Expand Down Expand Up @@ -642,21 +678,41 @@ def _wmts_images(self, wmts, layer, matrix_set_name, extent,
# Get the tile's Image from the cache if possible.
img_key = (row, col)
img = image_cache.get(img_key)

if img is None:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# Try it from disk cache
if self.cache_path is not None:
filename = f"{img_key[0]}_{img_key[1]}.npy"
cached_file = os.path.join(
self._cache_dir,
filename
)
Comment on lines +686 to +689
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cached_file = os.path.join(
self._cache_dir,
filename
)
cached_file = self._cache_dir / filename

else:
filename = None
cached_file = None

if filename in self.cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is some bouncing back and forth between pathlib and str-like paths. So if you change the previous around, I think this would also need to be the Path version.

Suggested change
if filename in self.cache:
if cached_file in self.cache:

img = Image.fromarray(np.load(cached_file, allow_pickle=False))
else:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# save image to local cache
if self.cache_path is not None:
np.save(cached_file, img, allow_pickle=False)
self.cache.add(filename)

if big_img is None:
size = (img.size[0] * n_cols, img.size[1] * n_rows)
big_img = Image.new('RGBA', size, (255, 255, 255, 255))
Expand Down
4 changes: 2 additions & 2 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,7 +2224,7 @@ def streamplot(self, x, y, u, v, **kwargs):
sp = super().streamplot(x, y, u, v, **kwargs)
return sp

def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
def add_wmts(self, wmts, layer_name, wmts_kwargs=None, cache=False, **kwargs):
"""
Add the specified WMTS layer to the axes.

Expand All @@ -2249,7 +2249,7 @@ def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
"""
from cartopy.io.ogc_clients import WMTSRasterSource
wmts = WMTSRasterSource(wmts, layer_name,
gettile_extra_kwargs=wmts_kwargs)
gettile_extra_kwargs=wmts_kwargs, cache=cache)
return self.add_raster(wmts, **kwargs)

def add_wms(self, wms, layers, wms_kwargs=None, **kwargs):
Expand Down