diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index ca80efd57..1830d5975 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -28,26 +28,13 @@ jobs:
MONAI_ZOO_AUTH_TOKEN: ${{ github.token }}
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.9
- uses: actions/setup-python@v5
- with:
- python-version: 3.9
- - name: clean up
- run: |
- sudo rm -rf /usr/share/dotnet
- sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip wheel
- name: Build
run: |
rm -rf /opt/hostedtoolcache
- ./runtests.sh --clean
docker system prune -f
DOCKER_BUILDKIT=1 docker build -t projectmonai/monailabel:${{ github.event.inputs.tag || 'latest' }} -f Dockerfile .
- name: Verify
run: |
- ./runtests.sh --clean
docker run --rm -i --ipc=host --net=host -v $(pwd):/workspace projectmonai/monailabel:${{ github.event.inputs.tag || 'latest' }} /workspace/runtests.sh --net
- name: Publish
run: |
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index c1f45112b..0ffe72520 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -39,6 +39,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
+ - if: runner.os == 'Linux'
+ name: Cleanup (Linux only)
+ run: |
+ rm -rf /opt/hostedtoolcache
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel
@@ -59,6 +63,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
+ rm -rf /opt/hostedtoolcache
sudo apt-get install openslide-tools -y
python -m pip install --upgrade pip wheel
pip install -r requirements-dev.txt
@@ -107,9 +112,10 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
+ rm -rf /opt/hostedtoolcache
sudo apt-get install openslide-tools -y
python -m pip install --user --upgrade pip setuptools wheel
- python -m pip install torch>=1.7 torchvision
+ python -m pip install torch torchvision
- name: Build Package
run: |
./runtests.sh --clean
@@ -144,7 +150,7 @@ jobs:
MONAI_ZOO_AUTH_TOKEN: ${{ github.token }}
strategy:
matrix:
- python-version: ["3.9"]
+ python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
@@ -165,6 +171,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
+ rm -rf /opt/hostedtoolcache
sudo apt-get install openslide-tools -y
python -m pip install --upgrade pip wheel
python -m pip install -r docs/requirements.txt
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 4295b9d50..388b54176 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -49,12 +49,12 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
+ rm -rf /opt/hostedtoolcache
sudo apt-get install openslide-tools -y
python -m pip install --user --upgrade pip setuptools wheel
- python -m pip install torch>=1.7 torchvision
+ python -m pip install torch torchvision
- name: Build Package
run: |
- rm -rf /opt/hostedtoolcache
./runtests.sh --clean
BUILD_OHIF=true python setup.py sdist bdist_wheel --build-number $(date +'%Y%m%d%H%M')
ls -l dist
diff --git a/.readthedocs.yml b/.readthedocs.yml
index af0905941..76aef6d84 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -20,7 +20,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
- python: "3.9"
+ python: "3.10"
# Build documentation in the docs/ directory with Sphinx
sphinx:
diff --git a/README.md b/README.md
index 3265d98ad..3ae5bbc4d 100644
--- a/README.md
+++ b/README.md
@@ -83,6 +83,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
Segmentation
DeepGrow
DeepEdit
+ SAM2 (2D/3D)
@@ -114,6 +115,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
NuClick
Segmentation
Classification
+ SAM2 (2D)
|
@@ -143,6 +145,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
DeepEdit
Tooltracking
InBody/OutBody
+ SAM2 (2D)
|
@@ -164,6 +167,12 @@ In addition, you can find a table of the basic supported fields, modalities, vie
|
+> [**SAM2**](https://github.com/facebookresearch/sam2/)
+>
+> By default, SAM2 is included for all the above Apps only when **_python >= 3.10_**
+> - **sam_2d**: for any organ or tissue and others over a given slice/2D image.
+> - **sam_3d**: to support SAM2 propagation over multiple slices (Radiology/MONAI-Bundle).
+
# Getting Started with MONAI Label
### MONAI Label requires a few steps to get started:
- Step 1: [Install MONAI Label](#step-1-installation)
diff --git a/monailabel/sam2/__init__.py b/monailabel/sam2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/monailabel/sam2/infer.py b/monailabel/sam2/infer.py
new file mode 100644
index 000000000..880e2ae7c
--- /dev/null
+++ b/monailabel/sam2/infer.py
@@ -0,0 +1,454 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import logging
+import os
+import pathlib
+import shutil
+import tempfile
+from datetime import timedelta
+from time import time
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+import pylab
+import schedule
+import torch
+from monai.transforms import KeepLargestConnectedComponent, LoadImaged
+from PIL import Image
+from sam2.build_sam import build_sam2, build_sam2_video_predictor
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from skimage.util import img_as_ubyte
+from timeloop import Timeloop
+from tqdm import tqdm
+
+from monailabel.config import settings
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
+from monailabel.interfaces.utils.transform import run_transforms
+from monailabel.transform.writer import Writer
+from monailabel.utils.others.generic import (
+ device_list,
+ download_file,
+ get_basename_no_ext,
+ md5_digest,
+ name_to_device,
+ remove_file,
+ strtobool,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ImageCache:
+ def __init__(self):
+ cache_path = settings.MONAI_LABEL_DATASTORE_CACHE_PATH
+ self.cache_path = (
+ os.path.join(cache_path, "sam2")
+ if cache_path
+ else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "sam2")
+ )
+ self.cached_dirs = {}
+ self.cache_expiry_sec = 10 * 60
+
+ remove_file(self.cache_path)
+ os.makedirs(self.cache_path, exist_ok=True)
+ logger.info(f"Image Cache Initialized: {self.cache_path}")
+
+ def cleanup(self):
+ ts = time()
+ expired = {k: v for k, v in self.cached_dirs.items() if v < ts}
+ for k, v in expired.items():
+ self.cached_dirs.pop(k)
+ logger.info(f"Remove Expired Image: {k}; ExpiryTs: {v}; CurrentTs: {ts}")
+ remove_file(k)
+
+ def monitor(self):
+ self.cleanup()
+ time_loop = Timeloop()
+ schedule.every(1).minutes.do(self.cleanup)
+
+ @time_loop.job(interval=timedelta(seconds=60))
+ def run_scheduler():
+ schedule.run_pending()
+
+ time_loop.start(block=False)
+
+
+image_cache = ImageCache()
+image_cache.monitor()
+
+
+class Sam2InferTask(InferTask):
+ def __init__(
+ self,
+ model_dir,
+ type=InferType.ANNOTATION,
+ dimension=2,
+ labels=None,
+ additional_info=None,
+ image_loader=LoadImaged(keys="image"),
+ post_trans=None,
+ writer=Writer(ref_image="image"),
+ config=None,
+ ):
+ super().__init__(
+ type=type,
+ dimension=dimension,
+ labels=labels,
+ description="SAM2 (Segment Anything Model)",
+ config={"device": device_list(), "reset_state": False, "largest_cc": False, "pylab": False},
+ )
+ self.additional_info = additional_info
+ self.image_loader = image_loader
+ self.post_trans = post_trans
+ self.writer = writer
+ if config:
+ self._config.update(config)
+
+ # Download PreTrained Model
+ # https://github.com/facebookresearch/sam2?tab=readme-ov-file#model-description
+ pt = "sam2.1_hiera_large.pt"
+ url = f"https://dl.fbaipublicfiles.com/segment_anything_2/092824/{pt}"
+ self.path = os.path.join(model_dir, f"pretrained_{pt}")
+ download_file(url, self.path)
+
+ self.config_path = "configs/sam2.1/sam2.1_hiera_l.yaml"
+ self.predictors = {}
+ self.image_cache = {}
+ self.inference_state = None
+
+ def info(self) -> Dict[str, Any]:
+ d = super().info()
+ if self.additional_info:
+ d.update(self.additional_info)
+ return d
+
+ def is_valid(self) -> bool:
+ return True
+
+ def run2d(self, image_tensor, request, debug=False):
+ device = name_to_device(request.get("device", "cuda"))
+ predictor = self.predictors.get(device)
+ if predictor is None:
+ logger.info(f"Using Device: {device}")
+ device_t = torch.device(device)
+ if device_t.type == "cuda":
+ torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
+ if torch.cuda.get_device_properties(0).major >= 8:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ sam2_model = build_sam2(self.config_path, self.path, device=device)
+ predictor = SAM2ImagePredictor(sam2_model)
+ self.predictors[device] = predictor
+
+ slice_idx = request.get("slice")
+ if slice_idx is None or slice_idx < 0:
+ slices = {p[2] for p in request["foreground"] if len(p) > 2}
+ slices.update({p[2] for p in request["background"] if len(p) > 2})
+ slices = list(slices)
+ slice_idx = slices[0] if len(slices) else -1
+ else:
+ slices = {slice_idx}
+
+ if slice_idx < 0 and len(request["roi"]) == 6:
+ slice_idx = round(request["roi"][4] + (request["roi"][5] - request["roi"][4]) // 2)
+ slices = {slice_idx}
+ logger.info(f"Slices: {slices}; Slice Index: {slice_idx}")
+
+ if slice_idx < 0:
+ slice_np = image_tensor.cpu().numpy()
+ slice_rgb_np = slice_np.astype(np.uint8) if np.max(slice_np) > 1 else img_as_ubyte(slice_np)
+ else:
+ slice_np = image_tensor[:, :, slice_idx].cpu().numpy()
+
+ if strtobool(request.get("pylab")):
+ slice_rgb_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
+ pylab.imsave(slice_rgb_file, slice_np, format="jpg", cmap="Greys_r")
+ slice_rgb_np = np.array(Image.open(slice_rgb_file))
+ remove_file(slice_rgb_file)
+ else:
+ slice_rgb_np = np.array(Image.fromarray(slice_np).convert("RGB"))
+
+ logger.info(f"Slice Index:{slice_idx}; (Image) Slice Shape: {slice_np.shape}")
+ if debug:
+ logger.info(f"Slice {slice_np.shape} Type: {slice_np.dtype}; Max: {np.max(slice_np)}")
+ logger.info(f"Slice RGB {slice_rgb_np.shape} Type: {slice_rgb_np.dtype}; Max: {np.max(slice_rgb_np)}")
+ if slice_idx < 0 and image_tensor.meta.get("filename_or_obj"):
+ shutil.copy(image_tensor.meta["filename_or_obj"], "image.jpg")
+ else:
+ pylab.imsave("image.jpg", slice_np, format="jpg", cmap="Greys_r")
+ Image.fromarray(slice_rgb_np).save("slice.jpg")
+
+ predictor.reset_predictor()
+ predictor.set_image(slice_rgb_np)
+
+ location = request.get("location", (0, 0))
+ tx, ty = location[0], location[1]
+ fp = [[p[0] - tx, p[1] - ty] for p in request["foreground"]]
+ bp = [[p[0] - tx, p[1] - ty] for p in request["background"]]
+ roi = request.get("roi")
+ roi = [roi[0] - tx, roi[1] - ty, roi[2] - tx, roi[3] - ty] if roi else None
+
+ if debug:
+ slice_rgb_np_p = np.copy(slice_rgb_np)
+ if roi:
+ slice_rgb_np_p[roi[0] : roi[2], roi[1] : roi[3], 2] = 255
+ for k, ps in {1: fp, 0: bp}.items():
+ for p in ps:
+ slice_rgb_np_p[p[0] - 2 : p[0] + 2, p[1] - 2 : p[1] + 2, k] = 255
+ Image.fromarray(slice_rgb_np_p).save("slice_p.jpg")
+
+ point_coords = fp + bp
+ point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x
+ box = [roi[1], roi[0], roi[3], roi[2]] if roi else None
+
+ point_labels = [1] * len(fp) + [0] * len(bp)
+ logger.info(f"Coords: {point_coords}; Labels: {point_labels}; Box: {box}")
+
+ masks, scores, _ = predictor.predict(
+ point_coords=np.array(point_coords) if point_coords else None,
+ point_labels=np.array(point_labels) if point_labels else None,
+ multimask_output=False,
+ box=np.array(box) if box else None,
+ )
+ # sorted_ind = np.argsort(scores)[::-1]
+ # masks = masks[sorted_ind]
+ # scores = scores[sorted_ind]
+ if strtobool(request.get("largest_cc", False)):
+ masks = KeepLargestConnectedComponent()(masks).cpu().numpy()
+
+ logger.info(f"Masks Shape: {masks.shape}; Scores: {scores}")
+ if self.post_trans is None:
+ if slice_idx < 0:
+ pred = masks[0]
+ else:
+ pred = np.zeros(tuple(image_tensor.shape))
+ pred[:, :, slice_idx] = masks[0]
+
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": pred, "image": image_tensor})
+ else:
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": masks[0], "image": image_tensor})
+ data = run_transforms(data, self.post_trans, log_prefix="POST", use_compose=False)
+
+ if debug:
+ # pylab.imsave("mask.jpg", masks[0], format="jpg", cmap="Greys_r")
+ Image.fromarray(masks[0] > 0).save("mask.jpg")
+
+ return self.writer(data)
+
+ def run_3d(self, image_tensor, set_image_state, request, debug=False):
+ device = name_to_device(request.get("device", "cuda"))
+ reset_state = strtobool(request.get("reset_state", "false"))
+ predictor = self.predictors.get(device)
+ if predictor is None:
+ logger.info(f"Using Device: {device}")
+ device_t = torch.device(device)
+ if device_t.type == "cuda":
+ torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
+ if torch.cuda.get_device_properties(0).major >= 8:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ predictor = build_sam2_video_predictor(self.config_path, self.path, device=device)
+ self.predictors[device] = predictor
+
+ image_path = request["image"]
+ video_dir = os.path.join(
+ image_cache.cache_path, get_basename_no_ext(image_path) if debug else md5_digest(image_path)
+ )
+ if not os.path.isdir(video_dir):
+ os.makedirs(video_dir, exist_ok=True)
+ for slice_idx in tqdm(range(image_tensor.shape[-1])):
+ slice_np = image_tensor[:, :, slice_idx].numpy()
+ slice_file = os.path.join(video_dir, f"{str(slice_idx).zfill(5)}.jpg")
+
+ if strtobool(request.get("pylab")):
+ pylab.imsave(slice_file, slice_np, format="jpg", cmap="Greys_r")
+ else:
+ Image.fromarray(slice_np).convert("RGB").save(slice_file)
+ logger.info(f"Image (Flattened): {image_tensor.shape[-1]} slices; {video_dir}")
+
+ # Set Expiry Time
+ image_cache.cached_dirs[video_dir] = time() + image_cache.cache_expiry_sec
+
+ if reset_state or set_image_state:
+ if self.inference_state:
+ predictor.reset_state(self.inference_state)
+ self.inference_state = predictor.init_state(video_path=video_dir)
+
+ logger.info(f"Image Shape: {image_tensor.shape}")
+ fps: dict[int, Any] = {}
+ bps: dict[int, Any] = {}
+ sids = set()
+ for key in {"foreground", "background"}:
+ for p in request[key]:
+ sid = p[2]
+ sids.add(sid)
+ kps = fps if key == "foreground" else bps
+ if kps.get(sid):
+ kps[sid].append([p[0], p[1]])
+ else:
+ kps[sid] = [[p[0], p[1]]]
+
+ box = None
+ roi = request.get("roi")
+ if roi:
+ box = [roi[1], roi[0], roi[3], roi[2]]
+ sids.update([i for i in range(roi[4], roi[5])])
+
+ pred = np.zeros(tuple(image_tensor.shape))
+ for sid in sorted(sids):
+ fp = fps.get(sid, [])
+ bp = bps.get(sid, [])
+
+ point_coords = fp + bp
+ point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x
+ point_labels = [1] * len(fp) + [0] * len(bp)
+ # logger.info(f"{sid} - Coords: {point_coords}; Labels: {point_labels}; Box: {box}")
+
+ o_frame_ids, o_obj_ids, o_mask_logits = predictor.add_new_points_or_box(
+ inference_state=self.inference_state,
+ frame_idx=sid,
+ obj_id=1,
+ points=np.array(point_coords) if point_coords else None,
+ labels=np.array(point_labels) if point_labels else None,
+ box=np.array(box) if box else None,
+ )
+
+ # logger.info(f"{sid} - mask_logits: {o_mask_logits.shape}; frame_ids: {o_frame_ids}; obj_ids: {o_obj_ids}")
+ pred[:, :, sid] = (o_mask_logits[0][0] > 0.0).cpu().numpy()
+
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(self.inference_state):
+ # logger.info(f"propagate: {out_frame_idx} - mask_logits: {out_mask_logits.shape}; obj_ids: {out_obj_ids}")
+ pred[:, :, out_frame_idx] = (out_mask_logits[0][0] > 0.0).cpu().numpy()
+
+ writer = Writer(ref_image="image")
+ data = copy.copy(request)
+ data.update({"image_path": request["image"], "pred": pred, "image": image_tensor})
+ return writer(data)
+
+ def __call__(self, request, debug=False) -> Tuple[Union[str, None], Dict]:
+ start_ts = time()
+
+ logger.info(f"Infer Request: {request}")
+ image_path = request["image"]
+ image_tensor = self.image_cache.get(image_path)
+ set_image_state = False
+ cache_image = request.get("cache_image", True)
+
+ if "foreground" not in request:
+ request["foreground"] = []
+ if "background" not in request:
+ request["background"] = []
+ if "roi" not in request:
+ request["roi"] = []
+
+ if not cache_image or image_tensor is None:
+ # TODO:: Fix this to cache more than one image session
+ self.image_cache.clear()
+ image_tensor = self.image_loader(request)["image"]
+ if debug:
+ logger.info(f"Image Meta: {image_tensor.meta}")
+ self.image_cache[image_path] = image_tensor
+ set_image_state = True
+
+ logger.info(f"Image Shape: {image_tensor.shape}; cached: {cache_image}")
+ if self.dimension == 2:
+ mask_file, result_json = self.run2d(image_tensor, request, debug)
+ else:
+ mask_file, result_json = self.run_3d(image_tensor, set_image_state, request)
+
+ logger.info(f"Mask File: {mask_file}; Latency: {round(time() - start_ts, 4)} sec")
+ result_json["latencies"] = {
+ "pre": 0,
+ "infer": 0,
+ "invert": 0,
+ "post": 0,
+ "write": 0,
+ "total": round(time() - start_ts, 2),
+ "transform": None,
+ }
+ return mask_file, result_json
+
+
+"""
+def main():
+ import shutil
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ force=True,
+ )
+
+ app_name = "pathology"
+ app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "sample-apps", app_name))
+ model_dir = os.path.join(app_dir, "model")
+ logger.info(f"Model Dir: {model_dir}")
+ if app_name == "pathology":
+ from lib.transforms import LoadImagePatchd
+
+ from monailabel.transform.post import FindContoursd
+ from monailabel.transform.writer import PolygonWriter
+
+ task = Sam2InferTask(
+ model_dir=model_dir,
+ dimension=2,
+ additional_info={"nuclick": True, "pathology": True},
+ image_loader=LoadImagePatchd(keys="image", padding=False),
+ post_trans=[FindContoursd(keys="pred")],
+ writer=PolygonWriter(),
+ )
+ request = {
+ "device": "cuda:1",
+ "reset_state": False,
+ "model": "sam2",
+ "image": "/home/sachi/Datasets/wsi/JP2K-33003-1.svs",
+ "output": "asap",
+ "level": 0,
+ "location": (2183, 4873),
+ "size": (128, 128),
+ "tile_size": [128, 128],
+ "min_poly_area": 30,
+ "foreground": [[2247, 4937]],
+ "background": [],
+ # "roi": [2220, 4900, 2320, 5000],
+ "max_workers": 1,
+ "id": 0,
+ "logging": "INFO",
+ "result_write_to_file": False,
+ "description": "SAM2 (Segment Anything Model)",
+ "save_label": False,
+ }
+ else:
+ task = Sam2InferTask(model_dir)
+ request = {
+ "image": "/home/sachi/Datasets/SAM2/image.nii.gz",
+ "foreground": [[71, 175, 105]], # [199, 129, 47], [200, 100, 41]],
+ # "background": [[286, 175, 105]],
+ "roi": [44, 110, 113, 239, 72, 178],
+ "largest_cc": True,
+ }
+
+ result = task(request, debug=True)
+ if app_name == "pathology":
+ print(result)
+ else:
+ shutil.move(result[0], "mask.nii.gz")
+
+
+if __name__ == "__main__":
+ main()
+"""
diff --git a/monailabel/sam2/utils.py b/monailabel/sam2/utils.py
new file mode 100644
index 000000000..8abaef244
--- /dev/null
+++ b/monailabel/sam2/utils.py
@@ -0,0 +1,9 @@
+from monai.utils import optional_import
+
+
+def is_sam2_module_available():
+ try:
+ _, flag = optional_import("sam2")
+ return flag
+ except ImportError:
+ return False
diff --git a/plugins/cvat/README.md b/plugins/cvat/README.md
index 84aa63aa9..735ad9ea5 100644
--- a/plugins/cvat/README.md
+++ b/plugins/cvat/README.md
@@ -57,8 +57,9 @@ chmod +x nuctl-$NUCLIO_VERSION-linux-amd64
ln -sf $(pwd)/nuctl-$NUCLIO_VERSION-linux-amd64 /usr/local/bin/nuctl
```
-#### Deployment of Endoscopy Models
+#### Deployment of Endoscopy/SAM2 Models
This step is to deploy MONAI Label plugin with endoscopic models using Nuclio tool.
+> **Prerequisite:** MONAI Label Server is up and running for _**endoscopy**_ app.
```bash
# Run MONAI Label Server (Make sure this Host/IP is accessible inside a docker)
@@ -68,8 +69,13 @@ git clone https://github.com/Project-MONAI/MONAILabel.git
# Deploy all endoscopy models
./plugins/cvat/deploy.sh endoscopy
+
# Or to deploy specific function and model, e.g., tooltracking
./plugins/cvat/deploy.sh endoscopy tooltracking
+
+# Deploy SAM2 Interactor
+./plugins/cvat/deploy.sh sam2 interactor
+
```
After model deployment, you can see the model names in the `Models` page of CVAT.
diff --git a/plugins/cvat/deploy.sh b/plugins/cvat/deploy.sh
index 4435e8bc8..6b9f5c243 100755
--- a/plugins/cvat/deploy.sh
+++ b/plugins/cvat/deploy.sh
@@ -30,9 +30,9 @@ do
echo "Using MONAI Label Server: $MONAI_LABEL_SERVER"
cp $func_config ${func_config}.bak
sed -i "s|http://monailabel.com|$MONAI_LABEL_SERVER|g" $func_config
- mv ${func_config}.bak $func_config
echo "Deploying $func_config..."
nuctl deploy --project-name cvat --path "$func_root" --file "$func_config" --platform local
+ mv ${func_config}.bak $func_config
done
nuctl get function
diff --git a/plugins/cvat/detector.py b/plugins/cvat/detector.py
index 86473febc..e099a31c5 100644
--- a/plugins/cvat/detector.py
+++ b/plugins/cvat/detector.py
@@ -105,6 +105,7 @@ def handler(context, event):
}
)
+ context.logger.info("=============================================================================\n")
return context.Response(
body=json.dumps(results),
headers={},
diff --git a/plugins/cvat/interactor.py b/plugins/cvat/interactor.py
index 185e7f9cc..19ac8d636 100644
--- a/plugins/cvat/interactor.py
+++ b/plugins/cvat/interactor.py
@@ -53,6 +53,7 @@ def handler(context, event):
image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
foreground = data.get("pos_points")
background = data.get("neg_points")
+ roi = data.get("obj_bbox", None)
context.logger.info(f"Image: {image.size}; Foreground: {foreground}; Background: {background}")
image_file = tempfile.NamedTemporaryFile(suffix=".jpg").name
@@ -64,6 +65,11 @@ def handler(context, event):
"background": np.asarray(background, dtype=int).tolist() if background else [],
# "largest_cc": True,
}
+ if roi:
+ roi = np.asarray(roi, dtype=int).flatten().tolist()
+ params["roi"] = roi
+
+ context.logger.info(f"Model:{model}; Params: {params}")
output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params)
if isinstance(output_json, str) or isinstance(output_json, bytes):
output_json = json.loads(output_json)
@@ -76,6 +82,8 @@ def handler(context, event):
resp = {"mask": mask_np.tolist()}
context.logger.info(f"Image: {image.size}; Mask: {mask_im.size} vs {mask_np.shape}; JSON: {output_json}")
+
+ context.logger.info("=============================================================================\n")
return context.Response(
body=json.dumps(resp),
headers={},
diff --git a/plugins/cvat/pathology/deepedit_nuclei.yaml b/plugins/cvat/pathology/deepedit_nuclei.yaml
index 1221bba7e..3084c67e2 100644
--- a/plugins/cvat/pathology/deepedit_nuclei.yaml
+++ b/plugins/cvat/pathology/deepedit_nuclei.yaml
@@ -34,7 +34,7 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_SERVER=http://10.117.9.63:8000
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
value: MONAI_LABEL_MODEL=deepedit_nuclei
diff --git a/plugins/cvat/pathology/nuclick.yaml b/plugins/cvat/pathology/nuclick.yaml
index 2084f94bf..82976e241 100644
--- a/plugins/cvat/pathology/nuclick.yaml
+++ b/plugins/cvat/pathology/nuclick.yaml
@@ -36,7 +36,7 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_SERVER=http://10.117.9.63:8000
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
value: MONAI_LABEL_MODEL=nuclick
diff --git a/plugins/cvat/pathology/segmentation_nuclei.yaml b/plugins/cvat/pathology/segmentation_nuclei.yaml
index 796f29178..f42a8a661 100644
--- a/plugins/cvat/pathology/segmentation_nuclei.yaml
+++ b/plugins/cvat/pathology/segmentation_nuclei.yaml
@@ -38,7 +38,7 @@ spec:
directives:
preCopy:
- kind: ENV
- value: MONAI_LABEL_SERVER=http://10.117.9.63:8000
+ value: MONAI_LABEL_SERVER=http://monailabel.com
- kind: ENV
value: MONAI_LABEL_MODEL=segmentation_nuclei
diff --git a/plugins/cvat/sam2/interactor.yaml b/plugins/cvat/sam2/interactor.yaml
new file mode 100644
index 000000000..6787c300e
--- /dev/null
+++ b/plugins/cvat/sam2/interactor.yaml
@@ -0,0 +1,56 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+metadata:
+ name: monailabel.sam2.interactor
+ namespace: cvat
+ annotations:
+ name: SAM2
+ version: 2
+ type: interactor
+ spec:
+ min_pos_points: 0
+ min_neg_points: 0
+ startswith_box_optional: true
+ help_message: The interactor allows to annotate a Tool using SAM2 model
+
+spec:
+ description: A pre-trained SAM2 model for interactive model
+ runtime: 'python:3.8'
+ handler: interactor:handler
+ eventTimeout: 30s
+
+ build:
+ image: cvat/monailabel.sam2.interactor
+ baseImage: projectmonai/monailabel:latest
+
+ directives:
+ preCopy:
+ - kind: ENV
+ value: MONAI_LABEL_SERVER=http://monailabel.com
+ - kind: ENV
+ value: MONAI_LABEL_MODEL=sam_2d
+
+ triggers:
+ myHttpTrigger:
+ maxWorkers: 1
+ kind: 'http'
+ workerAvailabilityTimeoutMilliseconds: 10000
+ attributes:
+ maxRequestBodySize: 33554432 # 32MB
+
+ platform:
+ attributes:
+ restartPolicy:
+ name: always
+ maximumRetryCount: 1
+ mountMode: volume
+ network: cvat_cvat
diff --git a/plugins/cvat/sam2/tracker.yaml b/plugins/cvat/sam2/tracker.yaml
new file mode 100644
index 000000000..c6842f9bf
--- /dev/null
+++ b/plugins/cvat/sam2/tracker.yaml
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+metadata:
+ name: monailabel.sam2.tracker
+ namespace: cvat
+ annotations:
+ name: SAM2T
+ type: tracker
+ spec:
+
+spec:
+ description: A pre-trained SAM2 model for tracking model
+ runtime: 'python:3.8'
+ handler: tracker:handler
+ eventTimeout: 30s
+
+ build:
+ image: cvat/monailabel.sam2.tracker
+ baseImage: projectmonai/monailabel:latest
+
+ directives:
+ preCopy:
+ - kind: ENV
+ value: MONAI_LABEL_SERVER=http://monailabel.com
+ - kind: ENV
+ value: MONAI_LABEL_MODEL=sam_2d
+
+ triggers:
+ myHttpTrigger:
+ maxWorkers: 1
+ kind: 'http'
+ workerAvailabilityTimeoutMilliseconds: 10000
+ attributes:
+ maxRequestBodySize: 33554432 # 32MB
+
+ platform:
+ attributes:
+ restartPolicy:
+ name: always
+ maximumRetryCount: 1
+ mountMode: volume
+ network: cvat_cvat
diff --git a/plugins/cvat/tracker.py b/plugins/cvat/tracker.py
index 11f5a3ba4..60e75d4c4 100644
--- a/plugins/cvat/tracker.py
+++ b/plugins/cvat/tracker.py
@@ -49,6 +49,7 @@ def handler(context, event):
model: str = context.user_data.model
client: MONAILabelClient = context.user_data.model_handler
context.logger.info(f"Run model: {model}")
+ # TODO:: This is not really a tracker; Need to accumulate previous images + rois and do actual SAM2 Propagation.
data = event.body
image = Image.open(io.BytesIO(base64.b64decode(data["image"])))
@@ -59,47 +60,48 @@ def handler(context, event):
image.save(image_file)
shapes = data.get("shapes")
- context.logger.info(f"Shapes: {shapes}")
-
states = data.get("states")
- context.logger.info(f"States: {states}")
+ context.logger.info(f"Shapes: {shapes}; States: {states}")
- results = {"shapes": [], "states": []}
- bboxes = []
+ rois = []
for i, shape in enumerate(shapes):
- context.logger.info(f"{i} => Shape: {shape}")
+ roi = np.array(shape).astype(int).tolist()
+ context.logger.info(f"{i} => Shape: {shape}; roi: {roi}")
+ rois.append(roi)
- def bounding_box(pts):
- x, y = zip(*pts)
- return [min(x), min(y), max(x), max(y)]
-
- bbox = bounding_box(np.array(shape).astype(int).reshape(-1, 2).tolist())
- context.logger.info(f"bbox: {bbox}")
- bboxes.append(bbox)
+ roi = rois[-1] # Pick the last
+ params = {"output": "json", "roi": roi}
- bbox = bboxes[-1] # Pick the last
- params = {"output": "json", "largest_cc": True, "bbox": bbox}
+ # context.logger.info(f"Model:{model}; Params: {params}")
output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params)
if isinstance(output_json, str) or isinstance(output_json, bytes):
output_json = json.loads(output_json)
- context.logger.info(f"Mask: {output_mask}; Output JSON: {output_json}")
+ # context.logger.info(f"Mask: {output_mask}; Output JSON: {output_json}")
mask_np = np.array(Image.open(output_mask)).astype(np.uint8)
os.remove(output_mask)
os.remove(image_file)
-
context.logger.info(f"Image: {image.size}; Mask: {mask_np.shape}; JSON: {output_json}")
+ results = {"shapes": [], "states": []}
d = FindContoursd(keys="pred")({"pred": mask_np})
annotation = d.get("result", {}).get("annotation")
for element in annotation.get("elements", []):
contours = element["contours"]
+ all_points = []
for contour in contours:
points = np.flip(np.array(contour, int))
- shape = points.flatten().tolist()
- results["shapes"].append(shape)
- break
+ all_points.append(points.flatten().tolist())
+
+ def bounding_box(pts):
+ x, y = zip(*pts)
+ return [min(x), min(y), max(x), max(y)]
+
+ bbox = bounding_box(np.array(all_points).astype(int).reshape(-1, 2).tolist())
+ context.logger.info(f"Input Box: {roi}; Output Box: {bbox}")
+ results["shapes"].append(bbox)
+ context.logger.info("=============================================================================\n")
return context.Response(
body=json.dumps(results),
headers={},
diff --git a/plugins/ohifv3/build.sh b/plugins/ohifv3/build.sh
index 8c4679a6b..5c0e53aec 100755
--- a/plugins/ohifv3/build.sh
+++ b/plugins/ohifv3/build.sh
@@ -50,6 +50,9 @@ APP_CONFIG=config/monai_label.js PUBLIC_URL=/ohif/ QUICK_BUILD=true yarn run bui
rm -rf ${install_dir}
cp -r platform/app/dist/ ${install_dir}
echo "Copied OHIF to ${install_dir}"
-rm -rf ../Viewers
+
+cd ..
+rm -rf Viewers
+find . -type d -name "node_modules" -exec rm -rf "{}" +
cd ${curr_dir}
diff --git a/plugins/slicer/MONAILabel/MONAILabel.py b/plugins/slicer/MONAILabel/MONAILabel.py
index 1b32c91c0..d018631cd 100644
--- a/plugins/slicer/MONAILabel/MONAILabel.py
+++ b/plugins/slicer/MONAILabel/MONAILabel.py
@@ -223,6 +223,7 @@ def __init__(self, parent=None):
self._volumeNode = None
self._segmentNode = None
self._scribblesROINode = None
+ self._sam2ROINode = None
self._volumeNodes = []
self._updatingGUIFromParameterNode = False
@@ -254,6 +255,7 @@ def __init__(self, parent=None):
self.scribblesMode = None
self.ignoreScribblesLabelChangeEvent = False
self.deepedit_multi_label = False
+ self.resetLabelState = False
self.optionsSectionIndex = 0
self.optionsNameIndex = 0
@@ -308,6 +310,13 @@ def setup(self):
self.ui.dgNegativeControlPointPlacementWidget.placeButton().show()
self.ui.dgNegativeControlPointPlacementWidget.deleteButton().show()
+ # ROI placement for SAM2
+ self.ui.sam2PlaceWidget.setMRMLScene(slicer.mrmlScene)
+ self.ui.sam2PlaceWidget.placeButton().toolTip = _("ROI/BBOX Prompt")
+ self.ui.sam2PlaceWidget.buttonsVisible = False
+ self.ui.sam2PlaceWidget.placeButton().show()
+ self.ui.sam2PlaceWidget.deleteButton().show()
+
self.ui.dgUpdateButton.setIcon(self.icon("segment.png"))
# Connections
@@ -325,7 +334,7 @@ def setup(self):
self.ui.labelComboBox.connect("currentIndexChanged(int)", self.onSelectLabel)
self.ui.scribLabelComboBox.connect("currentIndexChanged(int)", self.onSelectScribLabel)
self.ui.dgUpdateButton.connect("clicked(bool)", self.onUpdateDeepgrow)
- self.ui.dgUpdateCheckBox.setStyleSheet("padding-left: 10px;")
+ self.ui.dgUpdateCheckBox.setStyleSheet("padding-left: 10px; padding-top: 10px;")
self.ui.optionsSection.connect("currentIndexChanged(int)", self.onSelectOptionsSection)
self.ui.optionsName.connect("currentIndexChanged(int)", self.onSelectOptionsName)
@@ -420,6 +429,7 @@ def onSceneStartClose(self, caller, event):
self.current_sample = None
self.samples.clear()
self._scribblesROINode = None
+ self._sam2ROINode = None
self.resetPointList(
self.ui.dgPositiveControlPointPlacementWidget,
@@ -434,6 +444,7 @@ def onSceneStartClose(self, caller, event):
)
self.dgNegativePointListNode = None
self.onResetScribbles()
+ self.resetSam2ROI()
def resetPointList(self, markupsPlaceWidget, pointListNode, pointListNodeObservers):
if markupsPlaceWidget.placeModeEnabled:
@@ -712,6 +723,10 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
currentLabelIndex = self.ui.labelComboBox.currentIndex
if currentLabelIndex >= 0:
currentLabel = self.ui.labelComboBox.itemText(currentLabelIndex)
+ oldLabel = self._parameterNode.GetParameter("CurrentLabel")
+ if oldLabel and oldLabel != currentLabel:
+ print(f"Old Label: {oldLabel}; New Label: {currentLabel}")
+ self.resetLabelState = True
self._parameterNode.SetParameter("CurrentLabel", currentLabel)
currentScribLabelIndex = self.ui.scribLabelComboBox.currentIndex
@@ -1237,6 +1252,7 @@ def onNextSampleButton(self):
):
return
self.onResetScribbles()
+ self.resetSam2ROI()
slicer.mrmlScene.Clear(0)
start = time.time()
@@ -1338,6 +1354,9 @@ def initSample(self, sample, autosegment=True):
self.createScribblesROINode()
self.ui.scribblesPlaceWidget.setCurrentNode(self._scribblesROINode)
+ self.createSam2ROINode()
+ self.ui.sam2PlaceWidget.setCurrentNode(self._sam2ROINode)
+
# check if user allows overlapping segments
if slicer.util.settingsValue("MONAILabel/allowOverlappingSegments", False, converter=slicer.util.toBool):
# set segment editor to allow overlaps
@@ -1435,6 +1454,7 @@ def onSaveLabel(self):
labelmapVolumeNode = None
result = None
self.onResetScribbles()
+ self.resetSam2ROI()
if self.current_sample.get("session"):
if not self.onUploadImage(init_sample=False):
@@ -1571,12 +1591,12 @@ def onUpdateDeepgrow(self):
def onClickDeepgrow(self, current_point, skip_infer=False):
model = self.ui.deepgrowModelSelector.currentText
if not model:
- slicer.util.warningDisplay(_("Please select a deepgrow model"))
+ slicer.util.warningDisplay(_("Please select a model"))
return
segment_id, segment = self.currentSegment()
if not segment:
- slicer.util.warningDisplay(_("Please add the required label to run deepgrow"))
+ slicer.util.warningDisplay(_("Please add the required label to run interactive model"))
return
foreground_all = self.getControlPointsXYZ(self.dgPositivePointListNode, "foreground")
@@ -1587,22 +1607,39 @@ def onClickDeepgrow(self, current_point, skip_infer=False):
if skip_infer:
return
- # use model info "deepgrow" to determine
- deepgrow_3d = False if self.models[model].get("dimension", 3) == 2 else True
- print(f"Is DeepGrow 3D: {deepgrow_3d}")
+ # use model info to determine
+ is_3d = False if self.models[model].get("dimension", 3) == 2 else True
+ print(f"Is 3D Model: {is_3d}")
start = time.time()
label = segment.GetName()
- operationDescription = _("Run Deepgrow for segment: {label}; model: {model}; 3d {in3d}").format(
- label=label, model=model, in3d=_("enabled") if deepgrow_3d else _("disabled")
+ operationDescription = _("Run Inference for segment: {label}; model: {model}; 3d {in3d}").format(
+ label=label, model=model, in3d=_("enabled") if is_3d else _("disabled")
)
logging.debug(operationDescription)
+ # try to get roi if placed
+ roiNode = self.ui.sam2PlaceWidget.currentNode()
+ selected_roi = []
+ if roiNode and roiNode.GetControlPointPlacementComplete():
+ selected_roi = self.getROIPointsXYZ(roiNode)
+ print(f"Selected ROI: {selected_roi} => {not selected_roi}")
+
if not current_point:
- if not foreground_all and not deepgrow_3d:
- slicer.util.warningDisplay(operationDescription + " - points not added")
+ if not foreground_all and not is_3d and not selected_roi:
+ slicer.util.warningDisplay(operationDescription + " - points/roi not added")
return
- current_point = foreground_all[-1] if foreground_all else background_all[-1] if background_all else None
+
+ if not is_3d and selected_roi:
+ layoutManager = slicer.app.layoutManager()
+ current_point = [
+ selected_roi[1] + (selected_roi[1] - selected_roi[0]) // 2,
+ selected_roi[3] + (selected_roi[3] - selected_roi[2]) // 2,
+ round(abs(layoutManager.sliceWidget("Red").sliceLogic().GetSliceOffset())),
+ ]
+ else:
+ current_point = foreground_all[-1] if foreground_all else background_all[-1] if background_all else None
+ print(f"(updated) Current Point: {current_point}")
try:
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
@@ -1630,7 +1667,7 @@ def onClickDeepgrow(self, current_point, skip_infer=False):
sliceIndex = current_point[2] if current_point else None
print(f"Slice Index: {sliceIndex}")
- if deepgrow_3d or not sliceIndex:
+ if is_3d or not sliceIndex:
foreground = foreground_all
background = background_all
else:
@@ -1650,11 +1687,19 @@ def onClickDeepgrow(self, current_point, skip_infer=False):
params["label"] = label
params.update(self.getParamsFromConfig("infer", model))
- print(f"Request Params for Deepgrow/Deepedit: {params}")
+ if selected_roi:
+ params["roi"] = selected_roi
+ if not is_3d:
+ params["slice"] = sliceIndex
+ if not params.get("reset_state") and self.resetLabelState:
+ print(f"Override State: {params.get('reset_state')} vs {self.resetLabelState}")
+ params["reset_state"] = True
+ self.resetLabelState = False
+ print(f"Request Params for Inference: {params}")
image_file = self.current_sample["id"]
result_file, params = self.logic.infer(model, image_file, params, session_id=self.getSessionId())
- print(f"Result Params for Deepgrow/Deepedit: {params}")
+ print(f"Result Params for Inference: {params}")
if labels is None:
labels = (
params.get("label_names")
@@ -1665,7 +1710,7 @@ def onClickDeepgrow(self, current_point, skip_infer=False):
labels = [k for k, _ in sorted(labels.items(), key=lambda item: item[1])]
freeze = label if self.ui.freezeUpdateCheckBox.checked else None
- self.updateSegmentationMask(result_file, labels, None if deepgrow_3d else sliceIndex, freeze=freeze)
+ self.updateSegmentationMask(result_file, labels, None if is_3d else sliceIndex, freeze=freeze)
except BaseException as e:
msg = f"Message: {e.msg}" if hasattr(e, "msg") else ""
slicer.util.errorDisplay(
@@ -1703,6 +1748,19 @@ def createScribblesROINode(self):
scribblesROINode.GetDisplayNode().SetActiveColor(1, 1, 1)
self._scribblesROINode = scribblesROINode
+ def createSam2ROINode(self):
+ if self._volumeNode is None:
+ return
+ if self._sam2ROINode is None:
+ sam2ROINode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode")
+ sam2ROINode.SetName("SAM2 ROI")
+ sam2ROINode.CreateDefaultDisplayNodes()
+ sam2ROINode.GetDisplayNode().SetFillOpacity(0.4)
+ sam2ROINode.GetDisplayNode().SetSelectedColor(1, 1, 1)
+ sam2ROINode.GetDisplayNode().SetColor(1, 1, 1)
+ sam2ROINode.GetDisplayNode().SetActiveColor(1, 1, 1)
+ self._sam2ROINode = sam2ROINode
+
def getLabelColor(self, name):
color = GenericAnatomyColors.get(name.lower())
return [c / 255.0 for c in color] if color else None
@@ -2050,6 +2108,10 @@ def resetScribblesROI(self):
if self._scribblesROINode:
self._scribblesROINode.RemoveAllControlPoints()
+ def resetSam2ROI(self):
+ if self._sam2ROINode:
+ self._sam2ROINode.RemoveAllControlPoints()
+
def onClearScribbles(self):
# for clearing scribbles and resetting tools to default
# remove "scribbles" segments from label
diff --git a/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui b/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui
index 64be18207..9d0319936 100644
--- a/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui
+++ b/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui
@@ -386,7 +386,7 @@
-
- SmartEdit / Deepgrow
+ SmartEdit
true
@@ -410,7 +410,7 @@
- -
+
-
false
@@ -476,6 +476,16 @@
+ -
+
+
+ Auto
+
+
+ true
+
+
+
-
@@ -483,7 +493,7 @@
- -
+
-
-
@@ -518,18 +528,28 @@
+ -
+
+
+ ROI:
+
+
+
+ -
+
+
+ qSlicerMarkupsPlaceWidget::ForcePlaceSingleMarkup
+
+
+
+ 0
+ 0
+
+
+
+
- -
-
-
- Auto
-
-
- true
-
-
-
diff --git a/requirements.txt b/requirements.txt
index 040d0005a..3583645c2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -44,6 +44,7 @@ urllib3==2.2.2
scikit-learn
scipy
google-auth==2.29.0
+SAM-2 @ git+https://github.com/facebookresearch/sam2.git@c2ec8e14a185632b0a5d8b161928ceb50197eddc ; python_version >= '3.10'
# scipy and scikit-learn latest packages are missing on python 3.8
diff --git a/runtests.sh b/runtests.sh
index 9d572549c..11667af7d 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -140,6 +140,7 @@ function clean_py() {
find sample-apps -type d -name "__pycache__" -exec rm -rf "{}" +
find monailabel -type d -name "__pycache__" -exec rm -rf "{}" +
+ find plugins -type d -name "node_modules" -exec rm -rf "{}" +
}
function torch_validate() {
diff --git a/sample-apps/README.md b/sample-apps/README.md
index 916c00f55..5b949fb56 100644
--- a/sample-apps/README.md
+++ b/sample-apps/README.md
@@ -54,6 +54,14 @@ The endoscopy template includes example models for interactive and automated too
#### [MONAI Bundle](./monaibundle)
The MONAI Bundle format provides a portable description of deep learning models. This template includes example models for interactive and automated segmentation using MONAI bundles defined in the MONAI Model Zoo. It can pull any bundle defined in the MONAI Model Zoo that is compatible and meets the requirements specified on the [MONAI Bundle Apps page](./monaibundle/).
+----
+> [**SAM2**](https://github.com/facebookresearch/sam2/)
+>
+> By default, SAM2 is included for all the above Apps only when **_python >= 3.10_**
+> - **sam_2d**: for any organ or tissue and others over a given slice/2D image.
+> - **sam_3d**: to support SAM2 propagation over multiple slices (Radiology/MONAI-Bundle).
+
+
### Creating a Custom App
Researchers may want to define and add their own models. Follow the steps below to add a new segmentation model:
diff --git a/sample-apps/endoscopy/main.py b/sample-apps/endoscopy/main.py
index b5ad3191e..cd38984c2 100644
--- a/sample-apps/endoscopy/main.py
+++ b/sample-apps/endoscopy/main.py
@@ -15,7 +15,9 @@
from typing import Dict
import lib.configs
+import numpy as np
import schedule
+from monai.transforms import ToNumpyd
from timeloop import Timeloop
import monailabel
@@ -24,10 +26,11 @@
from monailabel.interfaces.app import MONAILabelApp
from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.datastore import Datastore
-from monailabel.interfaces.tasks.infer_v2 import InferTask
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.tasks.scoring import ScoringMethod
from monailabel.interfaces.tasks.strategy import Strategy
from monailabel.interfaces.tasks.train import TrainTask
+from monailabel.sam2.utils import is_sam2_module_available
from monailabel.tasks.activelearning.random import Random
from monailabel.utils.others.class_utils import get_class_names
from monailabel.utils.others.generic import create_dataset_from_path, strtobool
@@ -82,6 +85,7 @@ def __init__(self, app_dir, studies, conf):
logger.info(f"+++ Using Models: {list(self.models.keys())}")
+ self.sam = strtobool(conf.get("sam", "true"))
super().__init__(
app_dir=app_dir,
studies=studies,
@@ -123,6 +127,21 @@ def init_infers(self) -> Dict[str, InferTask]:
for k, v in c.items():
logger.info(f"+++ Adding Inferer:: {k} => {v}")
infers[k] = v
+
+ #################################################
+ # SAM
+ #################################################
+ if is_sam2_module_available() and self.sam:
+ from monailabel.sam2.infer import Sam2InferTask
+
+ infers["sam_2d"] = Sam2InferTask(
+ model_dir=self.model_dir,
+ type=InferType.ANNOTATION,
+ dimension=2,
+ post_trans=[ToNumpyd(keys="pred", dtype=np.uint8)],
+ config={"cache_image": False, "reset_state": True},
+ )
+
return infers
def init_trainers(self) -> Dict[str, TrainTask]:
diff --git a/sample-apps/monaibundle/main.py b/sample-apps/monaibundle/main.py
index 55c8967d0..c698d402a 100644
--- a/sample-apps/monaibundle/main.py
+++ b/sample-apps/monaibundle/main.py
@@ -17,10 +17,11 @@
import monailabel
from monailabel.interfaces.app import MONAILabelApp
-from monailabel.interfaces.tasks.infer_v2 import InferTask
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.tasks.scoring import ScoringMethod
from monailabel.interfaces.tasks.strategy import Strategy
from monailabel.interfaces.tasks.train import TrainTask
+from monailabel.sam2.utils import is_sam2_module_available
from monailabel.tasks.activelearning.first import First
from monailabel.tasks.activelearning.random import Random
from monailabel.tasks.infer.bundle import BundleInferTask
@@ -33,6 +34,7 @@
class MyApp(MONAILabelApp):
def __init__(self, app_dir, studies, conf):
+ self.model_dir = os.path.join(app_dir, "model")
self.models = get_bundle_models(app_dir, conf)
# Add Epistemic model for scoring
self.epistemic_models = (
@@ -44,6 +46,7 @@ def __init__(self, app_dir, studies, conf):
self.epistemic_simulation_size = int(conf.get("epistemic_simulation_size", "5"))
self.epistemic_dropout = float(conf.get("epistemic_dropout", "0.2"))
+ self.sam = strtobool(conf.get("sam", "true"))
super().__init__(
app_dir=app_dir,
studies=studies,
@@ -74,6 +77,14 @@ def init_infers(self) -> Dict[str, InferTask]:
logger.info(f"+++ Adding Inferer:: {n} => {i}")
infers[n] = i
+ #################################################
+ # SAM
+ #################################################
+ if is_sam2_module_available() and self.sam:
+ from monailabel.sam2.infer import Sam2InferTask
+
+ infers["sam_2d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=2)
+ infers["sam_3d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=3)
return infers
def init_trainers(self) -> Dict[str, TrainTask]:
diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py
index c1a04316c..095791792 100644
--- a/sample-apps/pathology/main.py
+++ b/sample-apps/pathology/main.py
@@ -18,16 +18,20 @@
import lib.configs
from lib.activelearning.random import WSIRandom
from lib.infers import NuClick
+from lib.transforms import LoadImagePatchd
import monailabel
from monailabel.datastore.dsa import DSADatastore
from monailabel.interfaces.app import MONAILabelApp
from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.datastore import Datastore
-from monailabel.interfaces.tasks.infer_v2 import InferTask
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.tasks.strategy import Strategy
from monailabel.interfaces.tasks.train import TrainTask
+from monailabel.sam2.utils import is_sam2_module_available
from monailabel.tasks.infer.basic_infer import BasicInferTask
+from monailabel.transform.post import FindContoursd
+from monailabel.transform.writer import PolygonWriter
from monailabel.utils.others.class_utils import get_class_names
from monailabel.utils.others.generic import strtobool
@@ -86,6 +90,7 @@ def __init__(self, app_dir, studies, conf):
logger.info(f"+++ Using Models: {list(self.models.keys())}")
+ self.sam = strtobool(conf.get("sam", "true"))
super().__init__(
app_dir=app_dir,
studies=studies,
@@ -138,6 +143,23 @@ def init_infers(self) -> Dict[str, InferTask]:
if isinstance(p, NuClick) and isinstance(c, BasicInferTask):
p.init_classification(c)
+ #################################################
+ # SAM
+ #################################################
+ if is_sam2_module_available() and self.sam:
+ from monailabel.sam2.infer import Sam2InferTask
+
+ infers["sam_2d"] = Sam2InferTask(
+ model_dir=self.model_dir,
+ type=InferType.ANNOTATION,
+ dimension=2,
+ additional_info={"nuclick": True, "pathology": True},
+ image_loader=LoadImagePatchd(keys="image", padding=False),
+ post_trans=[FindContoursd(keys="pred")],
+ writer=PolygonWriter(),
+ config={"cache_image": False, "reset_state": True},
+ )
+
return infers
def init_trainers(self) -> Dict[str, TrainTask]:
diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py
index abc3aaaa7..66ef3e4ad 100644
--- a/sample-apps/radiology/main.py
+++ b/sample-apps/radiology/main.py
@@ -23,10 +23,11 @@
from monailabel.interfaces.app import MONAILabelApp
from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.datastore import Datastore
-from monailabel.interfaces.tasks.infer_v2 import InferTask
+from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.tasks.scoring import ScoringMethod
from monailabel.interfaces.tasks.strategy import Strategy
from monailabel.interfaces.tasks.train import TrainTask
+from monailabel.sam2.utils import is_sam2_module_available
from monailabel.tasks.activelearning.first import First
from monailabel.tasks.activelearning.random import Random
@@ -99,6 +100,7 @@ def __init__(self, app_dir, studies, conf):
# Load models from bundle config files, local or released in Model-Zoo, e.g., --conf bundles
self.bundles = get_bundle_models(app_dir, conf, conf_key="bundles") if conf.get("bundles") else None
+ self.sam = strtobool(conf.get("sam", "true"))
super().__init__(
app_dir=app_dir,
studies=studies,
@@ -163,6 +165,15 @@ def init_infers(self) -> Dict[str, InferTask]:
}
)
+ #################################################
+ # SAM
+ #################################################
+ if is_sam2_module_available() and self.sam:
+ from monailabel.sam2.infer import Sam2InferTask
+
+ infers["sam_2d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=2)
+ infers["sam_3d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=3)
+
#################################################
# Pipeline based on existing infers
#################################################
diff --git a/setup.cfg b/setup.cfg
index 54c50b366..15e24022d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -70,6 +70,7 @@ install_requires =
scikit-learn
scipy
google-auth>=2.29.0
+ SAM-2 @ git+https://github.com/facebookresearch/sam2.git@c2ec8e14a185632b0a5d8b161928ceb50197eddc ; python_version >= '3.10'
[flake8]
select = B,C,E,F,N,P,T4,W,B9
diff --git a/tests/unit/endpoints/test_infer_v2.py b/tests/unit/endpoints/test_infer_v2.py
index 5afbec7c8..23239e86a 100644
--- a/tests/unit/endpoints/test_infer_v2.py
+++ b/tests/unit/endpoints/test_infer_v2.py
@@ -15,6 +15,8 @@
import torch
+from monailabel.sam2.utils import is_sam2_module_available
+
from .context import BasicBundleTestSuite, BasicDetectionBundleTestSuite, BasicEndpointV2TestSuite
@@ -42,6 +44,30 @@ def test_deepgrow_pipeline(self):
assert response.status_code == 200
time.sleep(1)
+ def test_sam_2d(self):
+ if not is_sam2_module_available() or not torch.cuda.is_available():
+ return
+
+ model = "sam_2d"
+ image = "spleen_3"
+ params = {"foreground": [[140, 210, 28]], "background": []}
+
+ response = self.client.post(f"/infer/{model}?image={image}", data={"params": json.dumps(params)})
+ assert response.status_code == 200
+ time.sleep(1)
+
+ def test_sam_3d(self):
+ if not is_sam2_module_available or not torch.cuda.is_available():
+ return
+
+ model = "sam_3d"
+ image = "spleen_3"
+ params = {"foreground": [[140, 210, 28]], "background": []}
+
+ response = self.client.post(f"/infer/{model}?image={image}", data={"params": json.dumps(params)})
+ assert response.status_code == 200
+ time.sleep(1)
+
class TestBundleInferTask(BasicBundleTestSuite):
def test_spleen_bundle_infer(self):