Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SamAutomaticMaskGenerator after more than 2 in parallel process becomes very slow #791

Open
0930mcx opened this issue Nov 12, 2024 · 0 comments

Comments

@0930mcx
Copy link

0930mcx commented Nov 12, 2024

I'm using torch.multiprocessing for parallel image splitting. I found when I on the number of parallel, after more than two SamAutomaticMaskGenerator generating speed will slow down. It starts out at a normal speed, about two or three seconds. But it slows down after a while, about 60-200 seconds. Does anyone know why that is? The running environment is H800.
the code is as following.

import math
import multiprocessing
import os
import time
import torch
import numpy as np

from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.multiprocessing as mp

multiprocessing.set_start_method('spawn', force=True)

sam_checkpoint =
model_type =

def load_ImageNet(ImageNet_PATH, batch_size=64, workers=3, pin_memory=True, batch_range=None):
traindir = os.path.join(ImageNet_PATH, 'train')
valdir = os.path.join(ImageNet_PATH, 'val')

normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

# 定义训练集数据集
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([transforms.Resize((192, 192)),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalizer])
)

# 定义验证集数据集
val_dataset = datasets.ImageFolder(
    valdir,
    transforms.Compose([transforms.Resize((192, 192)),
                        transforms.ToTensor(),
                        normalizer])
)

# 创建训练集数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=pin_memory
)

# 创建验证集数据加载器
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=pin_memory
)

# 如果指定了 num_batches,截断加载的批次数
if batch_range is not None:
    # 获取训练集和验证集的总样本数
    total_train_samples = len(train_dataset)
    total_val_samples = len(val_dataset)

    # 计算总的批次数
    max_train_batches = total_train_samples // batch_size
    max_val_batches = total_val_samples // batch_size
    print(len(train_loader))
    # 控制批次数,不超过指定的批次数
    train_loader = iter(train_loader)
    val_loader = iter(val_loader)

    # 如果指定了 batch_range,获取特定的批次区间
    if batch_range is not None:
        start_batch, end_batch = batch_range
        # 确保批次范围不超出最大批次数
        start_batch = min(start_batch, max_train_batches)
        end_batch = min(end_batch, max_train_batches)

        # 获取特定范围的批次
        train_loader = [next(train_loader) for _ in range(start_batch, end_batch)]
        val_loader = [next(val_loader) for _ in range(start_batch, end_batch)]

return train_loader, val_loader

Function to create a model and mask generator inside each process

def create_model_and_generator(device):
# Load model in each process to ensure independent models
sam = sam_model_registrymodel_type
sam.to(device)

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           pred_iou_thresh=0.86,
                                           stability_score_thresh=0.92,
                                           crop_n_layers=1,
                                           crop_n_points_downscale_factor=2,
                                           min_mask_region_area=100)
return sam, mask_generator

Function to process each batch with a separate model in each process

def process_batch_parallel(samples, device, idx):
sam, mask_generator = create_model_and_generator(device) # Create model and generator per process
batch_size, _, h, w = samples.shape
batch_masks = []

for i, img in enumerate(samples):
    start = time.time()
    img = img.to(device)  # Move the image to the correct GPU
    img_np = img.permute(1, 2, 0).cpu().numpy()  # CPU conversion for mask generation
    img_np = (img_np * 255).astype(np.uint8)

    # Generate segmentation masks
    masks = mask_generator.generate(img_np)
    combined_mask = np.zeros((h, w), dtype=np.int16)

    # # Sort masks by area in descending order
    # masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for j, mask in enumerate(masks):
        segmentation = mask['segmentation']
        combined_mask[segmentation] = j + 1

    batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu())
    end = time.time()
    print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")

return torch.stack(batch_masks)

def process_batch_parallel2(rank, train_loader, devices, lefts, rights):
device = devices[rank]
left= lefts[rank]
right = rights[rank]
print(f"left to right is {left} to {right}, device is {device}, rank is {rank}")
sam, mask_generator = create_model_and_generator(device) # Create model and generator per process
for idx, (samples, targets) in enumerate(train_loader):
batch_masks = []
if idx > right :
return left, right
if idx < left :
continue
start_time = time.time()
for i, img in enumerate(samples):
batch_size, _, h, w = samples.shape
start = time.time()
img_np = img.permute(1, 2, 0).cpu().numpy()
img_np = (img_np * 255).astype(np.uint8)
masks = mask_generator.generate(img_np)
combined_mask = np.zeros((h, w), dtype=np.int16)
for j, mask in enumerate(masks):
segmentation = mask['segmentation']
combined_mask[segmentation] = j + 1
batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu())
end = time.time()
print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")

    batch = torch.stack(batch_masks)
    torch.save(batch, f"sam_output/sam_batch{idx}.pth")
    end_time = time.time()
    print(f"Processed batch {idx-left+1}/{right - left + 1}, time {end_time - start_time}s")
return left, right

def process_in_parallel(train_loader, max_tasks=4):
length = math.ceil(len(train_loader)/max_tasks)
# Use multiprocessing Pool, but now each process will handle a batch
devices = [f"cuda:{i+4}" for i in range(max_tasks)]
lefts = [(lengthi) for i in range(max_tasks)]
rights = [(length
(i+1)-1) for i in range(max_tasks)]
mp.spawn(process_batch_parallel2, nprocs=max_tasks, args=(train_loader, devices, lefts, rights))

if name == 'main':
max_tasks = 3
# torch.set_num_threads(8)
# Load dataset
train_loader, val_loader = load_ImageNet(
"", 4, 128, True, batch_range=(0, 30))
print(len(train_loader))
print("Start processing batches")
start = time.time()
# Start parallel processing with a limit of max_tasks simultaneous tasks
process_in_parallel(train_loader, max_tasks=max_tasks)
end = time.time()
print(f"Total processing time: {end - start}s")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant