Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instantiating ScribblePrompt-UNet #9

Open
chenyuanjiao342 opened this issue Aug 13, 2024 · 3 comments
Open

Instantiating ScribblePrompt-UNet #9

chenyuanjiao342 opened this issue Aug 13, 2024 · 3 comments

Comments

@chenyuanjiao342
Copy link

Hi, when I try to instantiate ScribblePrompt-UNet and make predictions, the segmentation results I get are not accurate. Part of my code is as follows. Is there anything wrong? Thank you very much!

def create_scribble_tensor(positive_coords, negative_coords, H, W):
    scribbles = torch.zeros((1, 2, H, W), dtype=torch.float32)
    for coord in positive_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 0, y, x] = 1

    for coord in negative_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 1, y, x] = 1
    return scribbles

def binary_mask_to_polygon(binary_mask, tolerance=0):
    padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
    contours = measure.find_contours(padded_binary_mask, 0.5)
    return contours

def parse_coords(coords_str):
    coords = list(map(int, coords_str.split(',')))
    coords_list = [coords[i:i+2] for i in range(0, len(coords), 2)]
    return coords_list

def main(args: argparse.Namespace) -> None:
    targets = [args.input_dir]
    
    for index, t in enumerate(targets):
        name = os.path.basename(t)
        image = pydicom.dcmread(t)
        image_w = image.Rows
        image_h = image.Columns
        image = image.pixel_array
        image = (image - image.min()) / (image.max() - image.min())
        image = np.clip(image, 0, 1)
        image = torch.tensor(image, dtype=torch.float32)
        image = image.unsqueeze(0).unsqueeze(0)
        image = image.permute(0, 1, 3, 2)
        image = F.interpolate(image, size=(128,128), mode='bilinear')
        
        pos_coords = parse_coords(args.pos_box)
        neg_coords = parse_coords(args.neg_box)
        positive_coords = [tuple(point) for point in pos_coords]
        negative_coords = [tuple(point) for point in neg_coords]
        scribbles = create_scribble_tensor(positive_coords, negative_coords, image_h, image_w)
        scribbles = F.interpolate(scribbles, size=(128,128), mode='bilinear') 

        sp_unet = ScribblePromptUNet()
        mask = sp_unet.predict(image, None, None, scribbles, None, None)
        mask = F.interpolate(mask, size=(image_h, image_w), mode='bilinear').squeeze()   
        mask = mask.cpu().numpy()     

        binary_mask = (mask > 0.5).astype(int)
        contours = binary_mask_to_polygon(binary_mask)
@halleewong
Copy link
Owner

Can you show examples of the input image and coordinates you're using and the predictions you get?

@chenyuanjiao342
Copy link
Author

Can you show examples of the input image and coordinates you're using and the predictions you get?

The first image is the input image and scribbles, and the second image is the predicted result.
image
image

@halleewong
Copy link
Owner

hmm I am not quite sure what's going on. Here are some ideas for troubleshooting:

  • Could xy be switched between your scribble mask and your image? I see you're swapping the last 2 dimensions of your input image. Have you tried visualizing the scribble mask and input image right before inference to check they are aligned?

  • Have you tried drawing a similar scribble on your image in our huggingface demo? You can clone the code for the demo here if you prefer to run locally. If you're able to get better results in the gradio demo, it suggests there's a bug in your code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants