diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index ca80efd57..1830d5975 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -28,26 +28,13 @@ jobs: MONAI_ZOO_AUTH_TOKEN: ${{ github.token }} steps: - uses: actions/checkout@v4 - - name: Set up Python 3.9 - uses: actions/setup-python@v5 - with: - python-version: 3.9 - - name: clean up - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - - name: Install dependencies - run: | - python -m pip install --upgrade pip wheel - name: Build run: | rm -rf /opt/hostedtoolcache - ./runtests.sh --clean docker system prune -f DOCKER_BUILDKIT=1 docker build -t projectmonai/monailabel:${{ github.event.inputs.tag || 'latest' }} -f Dockerfile . - name: Verify run: | - ./runtests.sh --clean docker run --rm -i --ipc=host --net=host -v $(pwd):/workspace projectmonai/monailabel:${{ github.event.inputs.tag || 'latest' }} /workspace/runtests.sh --net - name: Publish run: | diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index c1f45112b..0ffe72520 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -39,6 +39,10 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - if: runner.os == 'Linux' + name: Cleanup (Linux only) + run: | + rm -rf /opt/hostedtoolcache - name: Install dependencies run: | python -m pip install --upgrade pip wheel @@ -59,6 +63,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + rm -rf /opt/hostedtoolcache sudo apt-get install openslide-tools -y python -m pip install --upgrade pip wheel pip install -r requirements-dev.txt @@ -107,9 +112,10 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + rm -rf /opt/hostedtoolcache sudo apt-get install openslide-tools -y python -m pip install --user --upgrade pip setuptools wheel - python -m pip install torch>=1.7 torchvision + python -m pip install torch torchvision - name: Build Package run: | ./runtests.sh --clean @@ -144,7 +150,7 @@ jobs: MONAI_ZOO_AUTH_TOKEN: ${{ github.token }} strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: - uses: actions/checkout@v4 - name: Set up Python @@ -165,6 +171,7 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + rm -rf /opt/hostedtoolcache sudo apt-get install openslide-tools -y python -m pip install --upgrade pip wheel python -m pip install -r docs/requirements.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4295b9d50..388b54176 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -49,12 +49,12 @@ jobs: key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | + rm -rf /opt/hostedtoolcache sudo apt-get install openslide-tools -y python -m pip install --user --upgrade pip setuptools wheel - python -m pip install torch>=1.7 torchvision + python -m pip install torch torchvision - name: Build Package run: | - rm -rf /opt/hostedtoolcache ./runtests.sh --clean BUILD_OHIF=true python setup.py sdist bdist_wheel --build-number $(date +'%Y%m%d%H%M') ls -l dist diff --git a/.readthedocs.yml b/.readthedocs.yml index af0905941..76aef6d84 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,7 +20,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/README.md b/README.md index 3265d98ad..3ae5bbc4d 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
  • Segmentation
  • DeepGrow
  • DeepEdit
  • +
  • SAM2 (2D/3D)
  • @@ -114,6 +115,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
  • NuClick
  • Segmentation
  • Classification
  • +
  • SAM2 (2D)
  • @@ -143,6 +145,7 @@ In addition, you can find a table of the basic supported fields, modalities, vie
  • DeepEdit
  • Tooltracking
  • InBody/OutBody
  • +
  • SAM2 (2D)
  • @@ -164,6 +167,12 @@ In addition, you can find a table of the basic supported fields, modalities, vie +> [**SAM2**](https://github.com/facebookresearch/sam2/) +> +> By default, SAM2 is included for all the above Apps only when **_python >= 3.10_** +> - **sam_2d**: for any organ or tissue and others over a given slice/2D image. +> - **sam_3d**: to support SAM2 propagation over multiple slices (Radiology/MONAI-Bundle). + # Getting Started with MONAI Label ### MONAI Label requires a few steps to get started: - Step 1: [Install MONAI Label](#step-1-installation) diff --git a/monailabel/sam2/__init__.py b/monailabel/sam2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/monailabel/sam2/infer.py b/monailabel/sam2/infer.py new file mode 100644 index 000000000..880e2ae7c --- /dev/null +++ b/monailabel/sam2/infer.py @@ -0,0 +1,454 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import os +import pathlib +import shutil +import tempfile +from datetime import timedelta +from time import time +from typing import Any, Dict, Tuple, Union + +import numpy as np +import pylab +import schedule +import torch +from monai.transforms import KeepLargestConnectedComponent, LoadImaged +from PIL import Image +from sam2.build_sam import build_sam2, build_sam2_video_predictor +from sam2.sam2_image_predictor import SAM2ImagePredictor +from skimage.util import img_as_ubyte +from timeloop import Timeloop +from tqdm import tqdm + +from monailabel.config import settings +from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType +from monailabel.interfaces.utils.transform import run_transforms +from monailabel.transform.writer import Writer +from monailabel.utils.others.generic import ( + device_list, + download_file, + get_basename_no_ext, + md5_digest, + name_to_device, + remove_file, + strtobool, +) + +logger = logging.getLogger(__name__) + + +class ImageCache: + def __init__(self): + cache_path = settings.MONAI_LABEL_DATASTORE_CACHE_PATH + self.cache_path = ( + os.path.join(cache_path, "sam2") + if cache_path + else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "sam2") + ) + self.cached_dirs = {} + self.cache_expiry_sec = 10 * 60 + + remove_file(self.cache_path) + os.makedirs(self.cache_path, exist_ok=True) + logger.info(f"Image Cache Initialized: {self.cache_path}") + + def cleanup(self): + ts = time() + expired = {k: v for k, v in self.cached_dirs.items() if v < ts} + for k, v in expired.items(): + self.cached_dirs.pop(k) + logger.info(f"Remove Expired Image: {k}; ExpiryTs: {v}; CurrentTs: {ts}") + remove_file(k) + + def monitor(self): + self.cleanup() + time_loop = Timeloop() + schedule.every(1).minutes.do(self.cleanup) + + @time_loop.job(interval=timedelta(seconds=60)) + def run_scheduler(): + schedule.run_pending() + + time_loop.start(block=False) + + +image_cache = ImageCache() +image_cache.monitor() + + +class Sam2InferTask(InferTask): + def __init__( + self, + model_dir, + type=InferType.ANNOTATION, + dimension=2, + labels=None, + additional_info=None, + image_loader=LoadImaged(keys="image"), + post_trans=None, + writer=Writer(ref_image="image"), + config=None, + ): + super().__init__( + type=type, + dimension=dimension, + labels=labels, + description="SAM2 (Segment Anything Model)", + config={"device": device_list(), "reset_state": False, "largest_cc": False, "pylab": False}, + ) + self.additional_info = additional_info + self.image_loader = image_loader + self.post_trans = post_trans + self.writer = writer + if config: + self._config.update(config) + + # Download PreTrained Model + # https://github.com/facebookresearch/sam2?tab=readme-ov-file#model-description + pt = "sam2.1_hiera_large.pt" + url = f"https://dl.fbaipublicfiles.com/segment_anything_2/092824/{pt}" + self.path = os.path.join(model_dir, f"pretrained_{pt}") + download_file(url, self.path) + + self.config_path = "configs/sam2.1/sam2.1_hiera_l.yaml" + self.predictors = {} + self.image_cache = {} + self.inference_state = None + + def info(self) -> Dict[str, Any]: + d = super().info() + if self.additional_info: + d.update(self.additional_info) + return d + + def is_valid(self) -> bool: + return True + + def run2d(self, image_tensor, request, debug=False): + device = name_to_device(request.get("device", "cuda")) + predictor = self.predictors.get(device) + if predictor is None: + logger.info(f"Using Device: {device}") + device_t = torch.device(device) + if device_t.type == "cuda": + torch.autocast("cuda", dtype=torch.bfloat16).__enter__() + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + sam2_model = build_sam2(self.config_path, self.path, device=device) + predictor = SAM2ImagePredictor(sam2_model) + self.predictors[device] = predictor + + slice_idx = request.get("slice") + if slice_idx is None or slice_idx < 0: + slices = {p[2] for p in request["foreground"] if len(p) > 2} + slices.update({p[2] for p in request["background"] if len(p) > 2}) + slices = list(slices) + slice_idx = slices[0] if len(slices) else -1 + else: + slices = {slice_idx} + + if slice_idx < 0 and len(request["roi"]) == 6: + slice_idx = round(request["roi"][4] + (request["roi"][5] - request["roi"][4]) // 2) + slices = {slice_idx} + logger.info(f"Slices: {slices}; Slice Index: {slice_idx}") + + if slice_idx < 0: + slice_np = image_tensor.cpu().numpy() + slice_rgb_np = slice_np.astype(np.uint8) if np.max(slice_np) > 1 else img_as_ubyte(slice_np) + else: + slice_np = image_tensor[:, :, slice_idx].cpu().numpy() + + if strtobool(request.get("pylab")): + slice_rgb_file = tempfile.NamedTemporaryFile(suffix=".jpg").name + pylab.imsave(slice_rgb_file, slice_np, format="jpg", cmap="Greys_r") + slice_rgb_np = np.array(Image.open(slice_rgb_file)) + remove_file(slice_rgb_file) + else: + slice_rgb_np = np.array(Image.fromarray(slice_np).convert("RGB")) + + logger.info(f"Slice Index:{slice_idx}; (Image) Slice Shape: {slice_np.shape}") + if debug: + logger.info(f"Slice {slice_np.shape} Type: {slice_np.dtype}; Max: {np.max(slice_np)}") + logger.info(f"Slice RGB {slice_rgb_np.shape} Type: {slice_rgb_np.dtype}; Max: {np.max(slice_rgb_np)}") + if slice_idx < 0 and image_tensor.meta.get("filename_or_obj"): + shutil.copy(image_tensor.meta["filename_or_obj"], "image.jpg") + else: + pylab.imsave("image.jpg", slice_np, format="jpg", cmap="Greys_r") + Image.fromarray(slice_rgb_np).save("slice.jpg") + + predictor.reset_predictor() + predictor.set_image(slice_rgb_np) + + location = request.get("location", (0, 0)) + tx, ty = location[0], location[1] + fp = [[p[0] - tx, p[1] - ty] for p in request["foreground"]] + bp = [[p[0] - tx, p[1] - ty] for p in request["background"]] + roi = request.get("roi") + roi = [roi[0] - tx, roi[1] - ty, roi[2] - tx, roi[3] - ty] if roi else None + + if debug: + slice_rgb_np_p = np.copy(slice_rgb_np) + if roi: + slice_rgb_np_p[roi[0] : roi[2], roi[1] : roi[3], 2] = 255 + for k, ps in {1: fp, 0: bp}.items(): + for p in ps: + slice_rgb_np_p[p[0] - 2 : p[0] + 2, p[1] - 2 : p[1] + 2, k] = 255 + Image.fromarray(slice_rgb_np_p).save("slice_p.jpg") + + point_coords = fp + bp + point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x + box = [roi[1], roi[0], roi[3], roi[2]] if roi else None + + point_labels = [1] * len(fp) + [0] * len(bp) + logger.info(f"Coords: {point_coords}; Labels: {point_labels}; Box: {box}") + + masks, scores, _ = predictor.predict( + point_coords=np.array(point_coords) if point_coords else None, + point_labels=np.array(point_labels) if point_labels else None, + multimask_output=False, + box=np.array(box) if box else None, + ) + # sorted_ind = np.argsort(scores)[::-1] + # masks = masks[sorted_ind] + # scores = scores[sorted_ind] + if strtobool(request.get("largest_cc", False)): + masks = KeepLargestConnectedComponent()(masks).cpu().numpy() + + logger.info(f"Masks Shape: {masks.shape}; Scores: {scores}") + if self.post_trans is None: + if slice_idx < 0: + pred = masks[0] + else: + pred = np.zeros(tuple(image_tensor.shape)) + pred[:, :, slice_idx] = masks[0] + + data = copy.copy(request) + data.update({"image_path": request["image"], "pred": pred, "image": image_tensor}) + else: + data = copy.copy(request) + data.update({"image_path": request["image"], "pred": masks[0], "image": image_tensor}) + data = run_transforms(data, self.post_trans, log_prefix="POST", use_compose=False) + + if debug: + # pylab.imsave("mask.jpg", masks[0], format="jpg", cmap="Greys_r") + Image.fromarray(masks[0] > 0).save("mask.jpg") + + return self.writer(data) + + def run_3d(self, image_tensor, set_image_state, request, debug=False): + device = name_to_device(request.get("device", "cuda")) + reset_state = strtobool(request.get("reset_state", "false")) + predictor = self.predictors.get(device) + if predictor is None: + logger.info(f"Using Device: {device}") + device_t = torch.device(device) + if device_t.type == "cuda": + torch.autocast("cuda", dtype=torch.bfloat16).__enter__() + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + predictor = build_sam2_video_predictor(self.config_path, self.path, device=device) + self.predictors[device] = predictor + + image_path = request["image"] + video_dir = os.path.join( + image_cache.cache_path, get_basename_no_ext(image_path) if debug else md5_digest(image_path) + ) + if not os.path.isdir(video_dir): + os.makedirs(video_dir, exist_ok=True) + for slice_idx in tqdm(range(image_tensor.shape[-1])): + slice_np = image_tensor[:, :, slice_idx].numpy() + slice_file = os.path.join(video_dir, f"{str(slice_idx).zfill(5)}.jpg") + + if strtobool(request.get("pylab")): + pylab.imsave(slice_file, slice_np, format="jpg", cmap="Greys_r") + else: + Image.fromarray(slice_np).convert("RGB").save(slice_file) + logger.info(f"Image (Flattened): {image_tensor.shape[-1]} slices; {video_dir}") + + # Set Expiry Time + image_cache.cached_dirs[video_dir] = time() + image_cache.cache_expiry_sec + + if reset_state or set_image_state: + if self.inference_state: + predictor.reset_state(self.inference_state) + self.inference_state = predictor.init_state(video_path=video_dir) + + logger.info(f"Image Shape: {image_tensor.shape}") + fps: dict[int, Any] = {} + bps: dict[int, Any] = {} + sids = set() + for key in {"foreground", "background"}: + for p in request[key]: + sid = p[2] + sids.add(sid) + kps = fps if key == "foreground" else bps + if kps.get(sid): + kps[sid].append([p[0], p[1]]) + else: + kps[sid] = [[p[0], p[1]]] + + box = None + roi = request.get("roi") + if roi: + box = [roi[1], roi[0], roi[3], roi[2]] + sids.update([i for i in range(roi[4], roi[5])]) + + pred = np.zeros(tuple(image_tensor.shape)) + for sid in sorted(sids): + fp = fps.get(sid, []) + bp = bps.get(sid, []) + + point_coords = fp + bp + point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x + point_labels = [1] * len(fp) + [0] * len(bp) + # logger.info(f"{sid} - Coords: {point_coords}; Labels: {point_labels}; Box: {box}") + + o_frame_ids, o_obj_ids, o_mask_logits = predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=sid, + obj_id=1, + points=np.array(point_coords) if point_coords else None, + labels=np.array(point_labels) if point_labels else None, + box=np.array(box) if box else None, + ) + + # logger.info(f"{sid} - mask_logits: {o_mask_logits.shape}; frame_ids: {o_frame_ids}; obj_ids: {o_obj_ids}") + pred[:, :, sid] = (o_mask_logits[0][0] > 0.0).cpu().numpy() + + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(self.inference_state): + # logger.info(f"propagate: {out_frame_idx} - mask_logits: {out_mask_logits.shape}; obj_ids: {out_obj_ids}") + pred[:, :, out_frame_idx] = (out_mask_logits[0][0] > 0.0).cpu().numpy() + + writer = Writer(ref_image="image") + data = copy.copy(request) + data.update({"image_path": request["image"], "pred": pred, "image": image_tensor}) + return writer(data) + + def __call__(self, request, debug=False) -> Tuple[Union[str, None], Dict]: + start_ts = time() + + logger.info(f"Infer Request: {request}") + image_path = request["image"] + image_tensor = self.image_cache.get(image_path) + set_image_state = False + cache_image = request.get("cache_image", True) + + if "foreground" not in request: + request["foreground"] = [] + if "background" not in request: + request["background"] = [] + if "roi" not in request: + request["roi"] = [] + + if not cache_image or image_tensor is None: + # TODO:: Fix this to cache more than one image session + self.image_cache.clear() + image_tensor = self.image_loader(request)["image"] + if debug: + logger.info(f"Image Meta: {image_tensor.meta}") + self.image_cache[image_path] = image_tensor + set_image_state = True + + logger.info(f"Image Shape: {image_tensor.shape}; cached: {cache_image}") + if self.dimension == 2: + mask_file, result_json = self.run2d(image_tensor, request, debug) + else: + mask_file, result_json = self.run_3d(image_tensor, set_image_state, request) + + logger.info(f"Mask File: {mask_file}; Latency: {round(time() - start_ts, 4)} sec") + result_json["latencies"] = { + "pre": 0, + "infer": 0, + "invert": 0, + "post": 0, + "write": 0, + "total": round(time() - start_ts, 2), + "transform": None, + } + return mask_file, result_json + + +""" +def main(): + import shutil + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + app_name = "pathology" + app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "sample-apps", app_name)) + model_dir = os.path.join(app_dir, "model") + logger.info(f"Model Dir: {model_dir}") + if app_name == "pathology": + from lib.transforms import LoadImagePatchd + + from monailabel.transform.post import FindContoursd + from monailabel.transform.writer import PolygonWriter + + task = Sam2InferTask( + model_dir=model_dir, + dimension=2, + additional_info={"nuclick": True, "pathology": True}, + image_loader=LoadImagePatchd(keys="image", padding=False), + post_trans=[FindContoursd(keys="pred")], + writer=PolygonWriter(), + ) + request = { + "device": "cuda:1", + "reset_state": False, + "model": "sam2", + "image": "/home/sachi/Datasets/wsi/JP2K-33003-1.svs", + "output": "asap", + "level": 0, + "location": (2183, 4873), + "size": (128, 128), + "tile_size": [128, 128], + "min_poly_area": 30, + "foreground": [[2247, 4937]], + "background": [], + # "roi": [2220, 4900, 2320, 5000], + "max_workers": 1, + "id": 0, + "logging": "INFO", + "result_write_to_file": False, + "description": "SAM2 (Segment Anything Model)", + "save_label": False, + } + else: + task = Sam2InferTask(model_dir) + request = { + "image": "/home/sachi/Datasets/SAM2/image.nii.gz", + "foreground": [[71, 175, 105]], # [199, 129, 47], [200, 100, 41]], + # "background": [[286, 175, 105]], + "roi": [44, 110, 113, 239, 72, 178], + "largest_cc": True, + } + + result = task(request, debug=True) + if app_name == "pathology": + print(result) + else: + shutil.move(result[0], "mask.nii.gz") + + +if __name__ == "__main__": + main() +""" diff --git a/monailabel/sam2/utils.py b/monailabel/sam2/utils.py new file mode 100644 index 000000000..8abaef244 --- /dev/null +++ b/monailabel/sam2/utils.py @@ -0,0 +1,9 @@ +from monai.utils import optional_import + + +def is_sam2_module_available(): + try: + _, flag = optional_import("sam2") + return flag + except ImportError: + return False diff --git a/plugins/cvat/README.md b/plugins/cvat/README.md index 84aa63aa9..735ad9ea5 100644 --- a/plugins/cvat/README.md +++ b/plugins/cvat/README.md @@ -57,8 +57,9 @@ chmod +x nuctl-$NUCLIO_VERSION-linux-amd64 ln -sf $(pwd)/nuctl-$NUCLIO_VERSION-linux-amd64 /usr/local/bin/nuctl ``` -#### Deployment of Endoscopy Models +#### Deployment of Endoscopy/SAM2 Models This step is to deploy MONAI Label plugin with endoscopic models using Nuclio tool. +> **Prerequisite:** MONAI Label Server is up and running for _**endoscopy**_ app. ```bash # Run MONAI Label Server (Make sure this Host/IP is accessible inside a docker) @@ -68,8 +69,13 @@ git clone https://github.com/Project-MONAI/MONAILabel.git # Deploy all endoscopy models ./plugins/cvat/deploy.sh endoscopy + # Or to deploy specific function and model, e.g., tooltracking ./plugins/cvat/deploy.sh endoscopy tooltracking + +# Deploy SAM2 Interactor +./plugins/cvat/deploy.sh sam2 interactor + ``` After model deployment, you can see the model names in the `Models` page of CVAT. diff --git a/plugins/cvat/deploy.sh b/plugins/cvat/deploy.sh index 4435e8bc8..6b9f5c243 100755 --- a/plugins/cvat/deploy.sh +++ b/plugins/cvat/deploy.sh @@ -30,9 +30,9 @@ do echo "Using MONAI Label Server: $MONAI_LABEL_SERVER" cp $func_config ${func_config}.bak sed -i "s|http://monailabel.com|$MONAI_LABEL_SERVER|g" $func_config - mv ${func_config}.bak $func_config echo "Deploying $func_config..." nuctl deploy --project-name cvat --path "$func_root" --file "$func_config" --platform local + mv ${func_config}.bak $func_config done nuctl get function diff --git a/plugins/cvat/detector.py b/plugins/cvat/detector.py index 86473febc..e099a31c5 100644 --- a/plugins/cvat/detector.py +++ b/plugins/cvat/detector.py @@ -105,6 +105,7 @@ def handler(context, event): } ) + context.logger.info("=============================================================================\n") return context.Response( body=json.dumps(results), headers={}, diff --git a/plugins/cvat/interactor.py b/plugins/cvat/interactor.py index 185e7f9cc..19ac8d636 100644 --- a/plugins/cvat/interactor.py +++ b/plugins/cvat/interactor.py @@ -53,6 +53,7 @@ def handler(context, event): image = Image.open(io.BytesIO(base64.b64decode(data["image"]))) foreground = data.get("pos_points") background = data.get("neg_points") + roi = data.get("obj_bbox", None) context.logger.info(f"Image: {image.size}; Foreground: {foreground}; Background: {background}") image_file = tempfile.NamedTemporaryFile(suffix=".jpg").name @@ -64,6 +65,11 @@ def handler(context, event): "background": np.asarray(background, dtype=int).tolist() if background else [], # "largest_cc": True, } + if roi: + roi = np.asarray(roi, dtype=int).flatten().tolist() + params["roi"] = roi + + context.logger.info(f"Model:{model}; Params: {params}") output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params) if isinstance(output_json, str) or isinstance(output_json, bytes): output_json = json.loads(output_json) @@ -76,6 +82,8 @@ def handler(context, event): resp = {"mask": mask_np.tolist()} context.logger.info(f"Image: {image.size}; Mask: {mask_im.size} vs {mask_np.shape}; JSON: {output_json}") + + context.logger.info("=============================================================================\n") return context.Response( body=json.dumps(resp), headers={}, diff --git a/plugins/cvat/pathology/deepedit_nuclei.yaml b/plugins/cvat/pathology/deepedit_nuclei.yaml index 1221bba7e..3084c67e2 100644 --- a/plugins/cvat/pathology/deepedit_nuclei.yaml +++ b/plugins/cvat/pathology/deepedit_nuclei.yaml @@ -34,7 +34,7 @@ spec: directives: preCopy: - kind: ENV - value: MONAI_LABEL_SERVER=http://10.117.9.63:8000 + value: MONAI_LABEL_SERVER=http://monailabel.com - kind: ENV value: MONAI_LABEL_MODEL=deepedit_nuclei diff --git a/plugins/cvat/pathology/nuclick.yaml b/plugins/cvat/pathology/nuclick.yaml index 2084f94bf..82976e241 100644 --- a/plugins/cvat/pathology/nuclick.yaml +++ b/plugins/cvat/pathology/nuclick.yaml @@ -36,7 +36,7 @@ spec: directives: preCopy: - kind: ENV - value: MONAI_LABEL_SERVER=http://10.117.9.63:8000 + value: MONAI_LABEL_SERVER=http://monailabel.com - kind: ENV value: MONAI_LABEL_MODEL=nuclick diff --git a/plugins/cvat/pathology/segmentation_nuclei.yaml b/plugins/cvat/pathology/segmentation_nuclei.yaml index 796f29178..f42a8a661 100644 --- a/plugins/cvat/pathology/segmentation_nuclei.yaml +++ b/plugins/cvat/pathology/segmentation_nuclei.yaml @@ -38,7 +38,7 @@ spec: directives: preCopy: - kind: ENV - value: MONAI_LABEL_SERVER=http://10.117.9.63:8000 + value: MONAI_LABEL_SERVER=http://monailabel.com - kind: ENV value: MONAI_LABEL_MODEL=segmentation_nuclei diff --git a/plugins/cvat/sam2/interactor.yaml b/plugins/cvat/sam2/interactor.yaml new file mode 100644 index 000000000..6787c300e --- /dev/null +++ b/plugins/cvat/sam2/interactor.yaml @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + name: monailabel.sam2.interactor + namespace: cvat + annotations: + name: SAM2 + version: 2 + type: interactor + spec: + min_pos_points: 0 + min_neg_points: 0 + startswith_box_optional: true + help_message: The interactor allows to annotate a Tool using SAM2 model + +spec: + description: A pre-trained SAM2 model for interactive model + runtime: 'python:3.8' + handler: interactor:handler + eventTimeout: 30s + + build: + image: cvat/monailabel.sam2.interactor + baseImage: projectmonai/monailabel:latest + + directives: + preCopy: + - kind: ENV + value: MONAI_LABEL_SERVER=http://monailabel.com + - kind: ENV + value: MONAI_LABEL_MODEL=sam_2d + + triggers: + myHttpTrigger: + maxWorkers: 1 + kind: 'http' + workerAvailabilityTimeoutMilliseconds: 10000 + attributes: + maxRequestBodySize: 33554432 # 32MB + + platform: + attributes: + restartPolicy: + name: always + maximumRetryCount: 1 + mountMode: volume + network: cvat_cvat diff --git a/plugins/cvat/sam2/tracker.yaml b/plugins/cvat/sam2/tracker.yaml new file mode 100644 index 000000000..c6842f9bf --- /dev/null +++ b/plugins/cvat/sam2/tracker.yaml @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + name: monailabel.sam2.tracker + namespace: cvat + annotations: + name: SAM2T + type: tracker + spec: + +spec: + description: A pre-trained SAM2 model for tracking model + runtime: 'python:3.8' + handler: tracker:handler + eventTimeout: 30s + + build: + image: cvat/monailabel.sam2.tracker + baseImage: projectmonai/monailabel:latest + + directives: + preCopy: + - kind: ENV + value: MONAI_LABEL_SERVER=http://monailabel.com + - kind: ENV + value: MONAI_LABEL_MODEL=sam_2d + + triggers: + myHttpTrigger: + maxWorkers: 1 + kind: 'http' + workerAvailabilityTimeoutMilliseconds: 10000 + attributes: + maxRequestBodySize: 33554432 # 32MB + + platform: + attributes: + restartPolicy: + name: always + maximumRetryCount: 1 + mountMode: volume + network: cvat_cvat diff --git a/plugins/cvat/tracker.py b/plugins/cvat/tracker.py index 11f5a3ba4..60e75d4c4 100644 --- a/plugins/cvat/tracker.py +++ b/plugins/cvat/tracker.py @@ -49,6 +49,7 @@ def handler(context, event): model: str = context.user_data.model client: MONAILabelClient = context.user_data.model_handler context.logger.info(f"Run model: {model}") + # TODO:: This is not really a tracker; Need to accumulate previous images + rois and do actual SAM2 Propagation. data = event.body image = Image.open(io.BytesIO(base64.b64decode(data["image"]))) @@ -59,47 +60,48 @@ def handler(context, event): image.save(image_file) shapes = data.get("shapes") - context.logger.info(f"Shapes: {shapes}") - states = data.get("states") - context.logger.info(f"States: {states}") + context.logger.info(f"Shapes: {shapes}; States: {states}") - results = {"shapes": [], "states": []} - bboxes = [] + rois = [] for i, shape in enumerate(shapes): - context.logger.info(f"{i} => Shape: {shape}") + roi = np.array(shape).astype(int).tolist() + context.logger.info(f"{i} => Shape: {shape}; roi: {roi}") + rois.append(roi) - def bounding_box(pts): - x, y = zip(*pts) - return [min(x), min(y), max(x), max(y)] - - bbox = bounding_box(np.array(shape).astype(int).reshape(-1, 2).tolist()) - context.logger.info(f"bbox: {bbox}") - bboxes.append(bbox) + roi = rois[-1] # Pick the last + params = {"output": "json", "roi": roi} - bbox = bboxes[-1] # Pick the last - params = {"output": "json", "largest_cc": True, "bbox": bbox} + # context.logger.info(f"Model:{model}; Params: {params}") output_mask, output_json = client.infer(model=model, image_id="", file=image_file, params=params) if isinstance(output_json, str) or isinstance(output_json, bytes): output_json = json.loads(output_json) - context.logger.info(f"Mask: {output_mask}; Output JSON: {output_json}") + # context.logger.info(f"Mask: {output_mask}; Output JSON: {output_json}") mask_np = np.array(Image.open(output_mask)).astype(np.uint8) os.remove(output_mask) os.remove(image_file) - context.logger.info(f"Image: {image.size}; Mask: {mask_np.shape}; JSON: {output_json}") + results = {"shapes": [], "states": []} d = FindContoursd(keys="pred")({"pred": mask_np}) annotation = d.get("result", {}).get("annotation") for element in annotation.get("elements", []): contours = element["contours"] + all_points = [] for contour in contours: points = np.flip(np.array(contour, int)) - shape = points.flatten().tolist() - results["shapes"].append(shape) - break + all_points.append(points.flatten().tolist()) + + def bounding_box(pts): + x, y = zip(*pts) + return [min(x), min(y), max(x), max(y)] + + bbox = bounding_box(np.array(all_points).astype(int).reshape(-1, 2).tolist()) + context.logger.info(f"Input Box: {roi}; Output Box: {bbox}") + results["shapes"].append(bbox) + context.logger.info("=============================================================================\n") return context.Response( body=json.dumps(results), headers={}, diff --git a/plugins/ohifv3/build.sh b/plugins/ohifv3/build.sh index 8c4679a6b..5c0e53aec 100755 --- a/plugins/ohifv3/build.sh +++ b/plugins/ohifv3/build.sh @@ -50,6 +50,9 @@ APP_CONFIG=config/monai_label.js PUBLIC_URL=/ohif/ QUICK_BUILD=true yarn run bui rm -rf ${install_dir} cp -r platform/app/dist/ ${install_dir} echo "Copied OHIF to ${install_dir}" -rm -rf ../Viewers + +cd .. +rm -rf Viewers +find . -type d -name "node_modules" -exec rm -rf "{}" + cd ${curr_dir} diff --git a/plugins/slicer/MONAILabel/MONAILabel.py b/plugins/slicer/MONAILabel/MONAILabel.py index 1b32c91c0..d018631cd 100644 --- a/plugins/slicer/MONAILabel/MONAILabel.py +++ b/plugins/slicer/MONAILabel/MONAILabel.py @@ -223,6 +223,7 @@ def __init__(self, parent=None): self._volumeNode = None self._segmentNode = None self._scribblesROINode = None + self._sam2ROINode = None self._volumeNodes = [] self._updatingGUIFromParameterNode = False @@ -254,6 +255,7 @@ def __init__(self, parent=None): self.scribblesMode = None self.ignoreScribblesLabelChangeEvent = False self.deepedit_multi_label = False + self.resetLabelState = False self.optionsSectionIndex = 0 self.optionsNameIndex = 0 @@ -308,6 +310,13 @@ def setup(self): self.ui.dgNegativeControlPointPlacementWidget.placeButton().show() self.ui.dgNegativeControlPointPlacementWidget.deleteButton().show() + # ROI placement for SAM2 + self.ui.sam2PlaceWidget.setMRMLScene(slicer.mrmlScene) + self.ui.sam2PlaceWidget.placeButton().toolTip = _("ROI/BBOX Prompt") + self.ui.sam2PlaceWidget.buttonsVisible = False + self.ui.sam2PlaceWidget.placeButton().show() + self.ui.sam2PlaceWidget.deleteButton().show() + self.ui.dgUpdateButton.setIcon(self.icon("segment.png")) # Connections @@ -325,7 +334,7 @@ def setup(self): self.ui.labelComboBox.connect("currentIndexChanged(int)", self.onSelectLabel) self.ui.scribLabelComboBox.connect("currentIndexChanged(int)", self.onSelectScribLabel) self.ui.dgUpdateButton.connect("clicked(bool)", self.onUpdateDeepgrow) - self.ui.dgUpdateCheckBox.setStyleSheet("padding-left: 10px;") + self.ui.dgUpdateCheckBox.setStyleSheet("padding-left: 10px; padding-top: 10px;") self.ui.optionsSection.connect("currentIndexChanged(int)", self.onSelectOptionsSection) self.ui.optionsName.connect("currentIndexChanged(int)", self.onSelectOptionsName) @@ -420,6 +429,7 @@ def onSceneStartClose(self, caller, event): self.current_sample = None self.samples.clear() self._scribblesROINode = None + self._sam2ROINode = None self.resetPointList( self.ui.dgPositiveControlPointPlacementWidget, @@ -434,6 +444,7 @@ def onSceneStartClose(self, caller, event): ) self.dgNegativePointListNode = None self.onResetScribbles() + self.resetSam2ROI() def resetPointList(self, markupsPlaceWidget, pointListNode, pointListNodeObservers): if markupsPlaceWidget.placeModeEnabled: @@ -712,6 +723,10 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): currentLabelIndex = self.ui.labelComboBox.currentIndex if currentLabelIndex >= 0: currentLabel = self.ui.labelComboBox.itemText(currentLabelIndex) + oldLabel = self._parameterNode.GetParameter("CurrentLabel") + if oldLabel and oldLabel != currentLabel: + print(f"Old Label: {oldLabel}; New Label: {currentLabel}") + self.resetLabelState = True self._parameterNode.SetParameter("CurrentLabel", currentLabel) currentScribLabelIndex = self.ui.scribLabelComboBox.currentIndex @@ -1237,6 +1252,7 @@ def onNextSampleButton(self): ): return self.onResetScribbles() + self.resetSam2ROI() slicer.mrmlScene.Clear(0) start = time.time() @@ -1338,6 +1354,9 @@ def initSample(self, sample, autosegment=True): self.createScribblesROINode() self.ui.scribblesPlaceWidget.setCurrentNode(self._scribblesROINode) + self.createSam2ROINode() + self.ui.sam2PlaceWidget.setCurrentNode(self._sam2ROINode) + # check if user allows overlapping segments if slicer.util.settingsValue("MONAILabel/allowOverlappingSegments", False, converter=slicer.util.toBool): # set segment editor to allow overlaps @@ -1435,6 +1454,7 @@ def onSaveLabel(self): labelmapVolumeNode = None result = None self.onResetScribbles() + self.resetSam2ROI() if self.current_sample.get("session"): if not self.onUploadImage(init_sample=False): @@ -1571,12 +1591,12 @@ def onUpdateDeepgrow(self): def onClickDeepgrow(self, current_point, skip_infer=False): model = self.ui.deepgrowModelSelector.currentText if not model: - slicer.util.warningDisplay(_("Please select a deepgrow model")) + slicer.util.warningDisplay(_("Please select a model")) return segment_id, segment = self.currentSegment() if not segment: - slicer.util.warningDisplay(_("Please add the required label to run deepgrow")) + slicer.util.warningDisplay(_("Please add the required label to run interactive model")) return foreground_all = self.getControlPointsXYZ(self.dgPositivePointListNode, "foreground") @@ -1587,22 +1607,39 @@ def onClickDeepgrow(self, current_point, skip_infer=False): if skip_infer: return - # use model info "deepgrow" to determine - deepgrow_3d = False if self.models[model].get("dimension", 3) == 2 else True - print(f"Is DeepGrow 3D: {deepgrow_3d}") + # use model info to determine + is_3d = False if self.models[model].get("dimension", 3) == 2 else True + print(f"Is 3D Model: {is_3d}") start = time.time() label = segment.GetName() - operationDescription = _("Run Deepgrow for segment: {label}; model: {model}; 3d {in3d}").format( - label=label, model=model, in3d=_("enabled") if deepgrow_3d else _("disabled") + operationDescription = _("Run Inference for segment: {label}; model: {model}; 3d {in3d}").format( + label=label, model=model, in3d=_("enabled") if is_3d else _("disabled") ) logging.debug(operationDescription) + # try to get roi if placed + roiNode = self.ui.sam2PlaceWidget.currentNode() + selected_roi = [] + if roiNode and roiNode.GetControlPointPlacementComplete(): + selected_roi = self.getROIPointsXYZ(roiNode) + print(f"Selected ROI: {selected_roi} => {not selected_roi}") + if not current_point: - if not foreground_all and not deepgrow_3d: - slicer.util.warningDisplay(operationDescription + " - points not added") + if not foreground_all and not is_3d and not selected_roi: + slicer.util.warningDisplay(operationDescription + " - points/roi not added") return - current_point = foreground_all[-1] if foreground_all else background_all[-1] if background_all else None + + if not is_3d and selected_roi: + layoutManager = slicer.app.layoutManager() + current_point = [ + selected_roi[1] + (selected_roi[1] - selected_roi[0]) // 2, + selected_roi[3] + (selected_roi[3] - selected_roi[2]) // 2, + round(abs(layoutManager.sliceWidget("Red").sliceLogic().GetSliceOffset())), + ] + else: + current_point = foreground_all[-1] if foreground_all else background_all[-1] if background_all else None + print(f"(updated) Current Point: {current_point}") try: qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor) @@ -1630,7 +1667,7 @@ def onClickDeepgrow(self, current_point, skip_infer=False): sliceIndex = current_point[2] if current_point else None print(f"Slice Index: {sliceIndex}") - if deepgrow_3d or not sliceIndex: + if is_3d or not sliceIndex: foreground = foreground_all background = background_all else: @@ -1650,11 +1687,19 @@ def onClickDeepgrow(self, current_point, skip_infer=False): params["label"] = label params.update(self.getParamsFromConfig("infer", model)) - print(f"Request Params for Deepgrow/Deepedit: {params}") + if selected_roi: + params["roi"] = selected_roi + if not is_3d: + params["slice"] = sliceIndex + if not params.get("reset_state") and self.resetLabelState: + print(f"Override State: {params.get('reset_state')} vs {self.resetLabelState}") + params["reset_state"] = True + self.resetLabelState = False + print(f"Request Params for Inference: {params}") image_file = self.current_sample["id"] result_file, params = self.logic.infer(model, image_file, params, session_id=self.getSessionId()) - print(f"Result Params for Deepgrow/Deepedit: {params}") + print(f"Result Params for Inference: {params}") if labels is None: labels = ( params.get("label_names") @@ -1665,7 +1710,7 @@ def onClickDeepgrow(self, current_point, skip_infer=False): labels = [k for k, _ in sorted(labels.items(), key=lambda item: item[1])] freeze = label if self.ui.freezeUpdateCheckBox.checked else None - self.updateSegmentationMask(result_file, labels, None if deepgrow_3d else sliceIndex, freeze=freeze) + self.updateSegmentationMask(result_file, labels, None if is_3d else sliceIndex, freeze=freeze) except BaseException as e: msg = f"Message: {e.msg}" if hasattr(e, "msg") else "" slicer.util.errorDisplay( @@ -1703,6 +1748,19 @@ def createScribblesROINode(self): scribblesROINode.GetDisplayNode().SetActiveColor(1, 1, 1) self._scribblesROINode = scribblesROINode + def createSam2ROINode(self): + if self._volumeNode is None: + return + if self._sam2ROINode is None: + sam2ROINode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsROINode") + sam2ROINode.SetName("SAM2 ROI") + sam2ROINode.CreateDefaultDisplayNodes() + sam2ROINode.GetDisplayNode().SetFillOpacity(0.4) + sam2ROINode.GetDisplayNode().SetSelectedColor(1, 1, 1) + sam2ROINode.GetDisplayNode().SetColor(1, 1, 1) + sam2ROINode.GetDisplayNode().SetActiveColor(1, 1, 1) + self._sam2ROINode = sam2ROINode + def getLabelColor(self, name): color = GenericAnatomyColors.get(name.lower()) return [c / 255.0 for c in color] if color else None @@ -2050,6 +2108,10 @@ def resetScribblesROI(self): if self._scribblesROINode: self._scribblesROINode.RemoveAllControlPoints() + def resetSam2ROI(self): + if self._sam2ROINode: + self._sam2ROINode.RemoveAllControlPoints() + def onClearScribbles(self): # for clearing scribbles and resetting tools to default # remove "scribbles" segments from label diff --git a/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui b/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui index 64be18207..9d0319936 100644 --- a/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui +++ b/plugins/slicer/MONAILabel/Resources/UI/MONAILabel.ui @@ -386,7 +386,7 @@ - SmartEdit / Deepgrow + SmartEdit true @@ -410,7 +410,7 @@ - + false @@ -476,6 +476,16 @@ + + + + Auto + + + true + + + @@ -483,7 +493,7 @@ - + @@ -518,18 +528,28 @@ + + + + ROI: + + + + + + + qSlicerMarkupsPlaceWidget::ForcePlaceSingleMarkup + + + + 0 + 0 + + + + - - - - Auto - - - true - - - diff --git a/requirements.txt b/requirements.txt index 040d0005a..3583645c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,6 +44,7 @@ urllib3==2.2.2 scikit-learn scipy google-auth==2.29.0 +SAM-2 @ git+https://github.com/facebookresearch/sam2.git@c2ec8e14a185632b0a5d8b161928ceb50197eddc ; python_version >= '3.10' # scipy and scikit-learn latest packages are missing on python 3.8 diff --git a/runtests.sh b/runtests.sh index 9d572549c..11667af7d 100755 --- a/runtests.sh +++ b/runtests.sh @@ -140,6 +140,7 @@ function clean_py() { find sample-apps -type d -name "__pycache__" -exec rm -rf "{}" + find monailabel -type d -name "__pycache__" -exec rm -rf "{}" + + find plugins -type d -name "node_modules" -exec rm -rf "{}" + } function torch_validate() { diff --git a/sample-apps/README.md b/sample-apps/README.md index 916c00f55..5b949fb56 100644 --- a/sample-apps/README.md +++ b/sample-apps/README.md @@ -54,6 +54,14 @@ The endoscopy template includes example models for interactive and automated too #### [MONAI Bundle](./monaibundle) The MONAI Bundle format provides a portable description of deep learning models. This template includes example models for interactive and automated segmentation using MONAI bundles defined in the MONAI Model Zoo. It can pull any bundle defined in the MONAI Model Zoo that is compatible and meets the requirements specified on the [MONAI Bundle Apps page](./monaibundle/). +---- +> [**SAM2**](https://github.com/facebookresearch/sam2/) +> +> By default, SAM2 is included for all the above Apps only when **_python >= 3.10_** +> - **sam_2d**: for any organ or tissue and others over a given slice/2D image. +> - **sam_3d**: to support SAM2 propagation over multiple slices (Radiology/MONAI-Bundle). + + ### Creating a Custom App Researchers may want to define and add their own models. Follow the steps below to add a new segmentation model: diff --git a/sample-apps/endoscopy/main.py b/sample-apps/endoscopy/main.py index b5ad3191e..cd38984c2 100644 --- a/sample-apps/endoscopy/main.py +++ b/sample-apps/endoscopy/main.py @@ -15,7 +15,9 @@ from typing import Dict import lib.configs +import numpy as np import schedule +from monai.transforms import ToNumpyd from timeloop import Timeloop import monailabel @@ -24,10 +26,11 @@ from monailabel.interfaces.app import MONAILabelApp from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.datastore import Datastore -from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType from monailabel.interfaces.tasks.scoring import ScoringMethod from monailabel.interfaces.tasks.strategy import Strategy from monailabel.interfaces.tasks.train import TrainTask +from monailabel.sam2.utils import is_sam2_module_available from monailabel.tasks.activelearning.random import Random from monailabel.utils.others.class_utils import get_class_names from monailabel.utils.others.generic import create_dataset_from_path, strtobool @@ -82,6 +85,7 @@ def __init__(self, app_dir, studies, conf): logger.info(f"+++ Using Models: {list(self.models.keys())}") + self.sam = strtobool(conf.get("sam", "true")) super().__init__( app_dir=app_dir, studies=studies, @@ -123,6 +127,21 @@ def init_infers(self) -> Dict[str, InferTask]: for k, v in c.items(): logger.info(f"+++ Adding Inferer:: {k} => {v}") infers[k] = v + + ################################################# + # SAM + ################################################# + if is_sam2_module_available() and self.sam: + from monailabel.sam2.infer import Sam2InferTask + + infers["sam_2d"] = Sam2InferTask( + model_dir=self.model_dir, + type=InferType.ANNOTATION, + dimension=2, + post_trans=[ToNumpyd(keys="pred", dtype=np.uint8)], + config={"cache_image": False, "reset_state": True}, + ) + return infers def init_trainers(self) -> Dict[str, TrainTask]: diff --git a/sample-apps/monaibundle/main.py b/sample-apps/monaibundle/main.py index 55c8967d0..c698d402a 100644 --- a/sample-apps/monaibundle/main.py +++ b/sample-apps/monaibundle/main.py @@ -17,10 +17,11 @@ import monailabel from monailabel.interfaces.app import MONAILabelApp -from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType from monailabel.interfaces.tasks.scoring import ScoringMethod from monailabel.interfaces.tasks.strategy import Strategy from monailabel.interfaces.tasks.train import TrainTask +from monailabel.sam2.utils import is_sam2_module_available from monailabel.tasks.activelearning.first import First from monailabel.tasks.activelearning.random import Random from monailabel.tasks.infer.bundle import BundleInferTask @@ -33,6 +34,7 @@ class MyApp(MONAILabelApp): def __init__(self, app_dir, studies, conf): + self.model_dir = os.path.join(app_dir, "model") self.models = get_bundle_models(app_dir, conf) # Add Epistemic model for scoring self.epistemic_models = ( @@ -44,6 +46,7 @@ def __init__(self, app_dir, studies, conf): self.epistemic_simulation_size = int(conf.get("epistemic_simulation_size", "5")) self.epistemic_dropout = float(conf.get("epistemic_dropout", "0.2")) + self.sam = strtobool(conf.get("sam", "true")) super().__init__( app_dir=app_dir, studies=studies, @@ -74,6 +77,14 @@ def init_infers(self) -> Dict[str, InferTask]: logger.info(f"+++ Adding Inferer:: {n} => {i}") infers[n] = i + ################################################# + # SAM + ################################################# + if is_sam2_module_available() and self.sam: + from monailabel.sam2.infer import Sam2InferTask + + infers["sam_2d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=2) + infers["sam_3d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=3) return infers def init_trainers(self) -> Dict[str, TrainTask]: diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index c1a04316c..095791792 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -18,16 +18,20 @@ import lib.configs from lib.activelearning.random import WSIRandom from lib.infers import NuClick +from lib.transforms import LoadImagePatchd import monailabel from monailabel.datastore.dsa import DSADatastore from monailabel.interfaces.app import MONAILabelApp from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.datastore import Datastore -from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType from monailabel.interfaces.tasks.strategy import Strategy from monailabel.interfaces.tasks.train import TrainTask +from monailabel.sam2.utils import is_sam2_module_available from monailabel.tasks.infer.basic_infer import BasicInferTask +from monailabel.transform.post import FindContoursd +from monailabel.transform.writer import PolygonWriter from monailabel.utils.others.class_utils import get_class_names from monailabel.utils.others.generic import strtobool @@ -86,6 +90,7 @@ def __init__(self, app_dir, studies, conf): logger.info(f"+++ Using Models: {list(self.models.keys())}") + self.sam = strtobool(conf.get("sam", "true")) super().__init__( app_dir=app_dir, studies=studies, @@ -138,6 +143,23 @@ def init_infers(self) -> Dict[str, InferTask]: if isinstance(p, NuClick) and isinstance(c, BasicInferTask): p.init_classification(c) + ################################################# + # SAM + ################################################# + if is_sam2_module_available() and self.sam: + from monailabel.sam2.infer import Sam2InferTask + + infers["sam_2d"] = Sam2InferTask( + model_dir=self.model_dir, + type=InferType.ANNOTATION, + dimension=2, + additional_info={"nuclick": True, "pathology": True}, + image_loader=LoadImagePatchd(keys="image", padding=False), + post_trans=[FindContoursd(keys="pred")], + writer=PolygonWriter(), + config={"cache_image": False, "reset_state": True}, + ) + return infers def init_trainers(self) -> Dict[str, TrainTask]: diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py index abc3aaaa7..66ef3e4ad 100644 --- a/sample-apps/radiology/main.py +++ b/sample-apps/radiology/main.py @@ -23,10 +23,11 @@ from monailabel.interfaces.app import MONAILabelApp from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.datastore import Datastore -from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType from monailabel.interfaces.tasks.scoring import ScoringMethod from monailabel.interfaces.tasks.strategy import Strategy from monailabel.interfaces.tasks.train import TrainTask +from monailabel.sam2.utils import is_sam2_module_available from monailabel.tasks.activelearning.first import First from monailabel.tasks.activelearning.random import Random @@ -99,6 +100,7 @@ def __init__(self, app_dir, studies, conf): # Load models from bundle config files, local or released in Model-Zoo, e.g., --conf bundles self.bundles = get_bundle_models(app_dir, conf, conf_key="bundles") if conf.get("bundles") else None + self.sam = strtobool(conf.get("sam", "true")) super().__init__( app_dir=app_dir, studies=studies, @@ -163,6 +165,15 @@ def init_infers(self) -> Dict[str, InferTask]: } ) + ################################################# + # SAM + ################################################# + if is_sam2_module_available() and self.sam: + from monailabel.sam2.infer import Sam2InferTask + + infers["sam_2d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=2) + infers["sam_3d"] = Sam2InferTask(model_dir=self.model_dir, type=InferType.DEEPGROW, dimension=3) + ################################################# # Pipeline based on existing infers ################################################# diff --git a/setup.cfg b/setup.cfg index 54c50b366..15e24022d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,6 +70,7 @@ install_requires = scikit-learn scipy google-auth>=2.29.0 + SAM-2 @ git+https://github.com/facebookresearch/sam2.git@c2ec8e14a185632b0a5d8b161928ceb50197eddc ; python_version >= '3.10' [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/unit/endpoints/test_infer_v2.py b/tests/unit/endpoints/test_infer_v2.py index 5afbec7c8..23239e86a 100644 --- a/tests/unit/endpoints/test_infer_v2.py +++ b/tests/unit/endpoints/test_infer_v2.py @@ -15,6 +15,8 @@ import torch +from monailabel.sam2.utils import is_sam2_module_available + from .context import BasicBundleTestSuite, BasicDetectionBundleTestSuite, BasicEndpointV2TestSuite @@ -42,6 +44,30 @@ def test_deepgrow_pipeline(self): assert response.status_code == 200 time.sleep(1) + def test_sam_2d(self): + if not is_sam2_module_available() or not torch.cuda.is_available(): + return + + model = "sam_2d" + image = "spleen_3" + params = {"foreground": [[140, 210, 28]], "background": []} + + response = self.client.post(f"/infer/{model}?image={image}", data={"params": json.dumps(params)}) + assert response.status_code == 200 + time.sleep(1) + + def test_sam_3d(self): + if not is_sam2_module_available or not torch.cuda.is_available(): + return + + model = "sam_3d" + image = "spleen_3" + params = {"foreground": [[140, 210, 28]], "background": []} + + response = self.client.post(f"/infer/{model}?image={image}", data={"params": json.dumps(params)}) + assert response.status_code == 200 + time.sleep(1) + class TestBundleInferTask(BasicBundleTestSuite): def test_spleen_bundle_infer(self):