Skip to content

Commit

Permalink
feat(ml-models): download models from neptune.ai
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasschaub committed Nov 14, 2023
1 parent 8ff28ff commit b21affc
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 1 deletion.
4 changes: 4 additions & 0 deletions config/sample.config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ wms-url = "https://maps.heigit.org/osm-carto/service?SERVICE=WMS&VERSION=1.1.1"
wms-layers = "heigit:osm-carto@2xx"
wms-read-timeout = 600
max-nr-simultaneous-uploads = 100
neptune_api_token = "h0dHBzOi8aHR06E0Z...jMifQ"
neptune_project = "HeiGIT/SketchMapTool"
neptune_model_id_yolo = "SMT-OSM-1"
neptune_model_id_sam = "SMT-SAM-1"
4 changes: 4 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ To create a new configuration file simply copy the sample configuration file and
```
cp sample.config.toml config.toml
```

## Required Configuration

Except of the API token (`SMT-NEPTUNE-API-TOKEN`) for neptune.ai all configuration values come with defaults for development purposes. Please make sure to configure the API token for your environment.
6 changes: 6 additions & 0 deletions docs/development-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ npm run build # build/bundle JS and CSS

Please refer to the [configuration documentation](/docs/configuration.md).

> TL;DR: Except of the API token (`SMT-NEPTUNE-API-TOKEN`) for neptune.ai all configuration values come with defaults for development purposes. Please make sure to configure the API token for your environment.
## Usage

### 1. Start Celery (Task Queue)
Expand Down Expand Up @@ -104,3 +106,7 @@ If you setup sketch-map-tool in an IDE like PyCharm please make sure that your I
Go thought the setup steps above in the terminal and change interpreter settings in the IDE to point to the mamba/conda environment.

Also make sure the environment variable `PROJ_LIB` to point to the `proj` directory of the mamba/conda environment.

## Troubleshooting

Make sure that Poetry does not try to manage the virtual environment. Check with `poetry env list`. If any environment are listed remove them: `poetry env remove ...`
10 changes: 9 additions & 1 deletion sketch_map_tool/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_config_path() -> str:
return os.getenv("SMT-CONFIG", default=default)


def load_config_default() -> Dict[str, str]:
def load_config_default() -> Dict[str, str | int | float]:
return {
"data-dir": get_default_data_dir(),
"user-agent": "sketch-map-tool",
Expand All @@ -29,6 +29,10 @@ def load_config_default() -> Dict[str, str]:
"wms-read-timeout": 600,
"max-nr-simultaneous-uploads": 100,
"max_pixel_per_image": 10e8, # 10.000*10.000
"neptune_project": "HeiGIT/SketchMapTool",
"neptune_api_token": "",
"neptune_model_id_yolo": "SMT-OSM-1",
"neptune_model_id_sam": "SMT-SAM-1",
}


Expand All @@ -53,6 +57,10 @@ def load_config_from_env() -> Dict[str, str]:
"wms-read-timeout": os.getenv("SMT-WMS-READ-TIMEOUT"),
"max-nr-simultaneous-uploads": os.getenv("SMT-MAX-NR-SIM-UPLOADS"),
"max_pixel_per_image": os.getenv("MAX-PIXEL-PER-IMAGE"),
"neptune_project": os.getenv("SMT-NEPTUNE-PROJECT"),
"neptune_api_token": os.getenv("SMT-NEPTUNE-API-TOKEN"),
"neptune_model_id_yolo": os.getenv("SMT-NEPTUNE-MODEL-ID-YOLO"),
"neptune_model_id_sam": os.getenv("SMT-NEPTUNE-MODEL-ID-SAM"),
}
return {k: v for k, v in cfg.items() if v is not None}

Expand Down
54 changes: 54 additions & 0 deletions sketch_map_tool/upload_processing/ml_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
from pathlib import Path

import neptune

from sketch_map_tool.config import get_config_value

PROJECT = get_config_value("neptune_project")
API_TOKEN = get_config_value("neptune_api_token")


def init_model(id: str) -> Path:
"""Initilaze model. Download model to data dir if not present."""
# TODO:
# _check_id(id)

data_dir = Path(get_config_value("data-dir"))
model = neptune.init_model_version(
with_id=id,
project=PROJECT,
api_token=API_TOKEN,
mode="read-only",
)

raw = data_dir / id
path = raw.with_suffix(_get_file_suffix(id))
if not path.is_file():
logging.info(f"Downloading model {id} from neptune.ai to {path}.")
model["model"].download(str(path))

# TODO: check if model is valid/working
logging.info("Model available model from neptune.ai: " + id)
return path


def _check_id(id: str):
# TODO:
project = neptune.init_project(
project=PROJECT,
api_token=API_TOKEN,
mode="read-only",
)

if not project.exists("models/" + id):
raise ValueError("Invalid model ID: " + id)


def _get_file_suffix(id: str) -> str:
if "SAM" in id:
return ".pth"
elif "OSM" in id:
return ".pt"
else:
raise ValueError("Unexpected model ID: " + id)
4 changes: 4 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def config_keys():
"wms-read-timeout",
"max-nr-simultaneous-uploads",
"max_pixel_per_image",
"neptune_project",
"neptune_api_token",
"neptune_model_id_yolo",
"neptune_model_id_sam",
)


Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_ml_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from hypothesis import example, given
from hypothesis.strategies import text

from sketch_map_tool.config import get_config_value
from sketch_map_tool.upload_processing import ml_models
from tests import vcr_app as vcr


@pytest.mark.parametrize(
"id",
(
get_config_value("neptune_model_id_yolo"),
get_config_value("neptune_model_id_sam"),
),
)
@pytest.mark.skip("longrunning tests. downloads ml-models from neptunge.ai")
def test_init_model(id, monkeypatch, tmpdir):
monkeypatch.setenv("SMT-DATA-DIR", tmpdir)
path = ml_models.init_model(id)
assert path.is_file()


@given(text())
@example("")
@pytest.mark.skip(reason="not implemented yet")
@vcr.use_cassette
def test_init_model_unexpected_id(id):
with pytest.raises(ValueError):
ml_models.init_model(id)

0 comments on commit b21affc

Please sign in to comment.