Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About evaluation and training codes #10

Open
yan9qu opened this issue Jan 14, 2024 · 3 comments
Open

About evaluation and training codes #10

yan9qu opened this issue Jan 14, 2024 · 3 comments

Comments

@yan9qu
Copy link

yan9qu commented Jan 14, 2024

well done! I wonder when will the training and inference related code be provided?

@csuhan
Copy link
Owner

csuhan commented Jan 16, 2024

For inference you can refer to our demo code:

def model_worker(
rank: int, args: argparse.Namespace, barrier: mp.Barrier,
request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
) -> None:
"""
The worker function that manipulates the GPU to run the inference.
Exact n_gpu workers are started, with each one operating on a separate GPU.
Args:
rank (int): Distributed rank of the worker.
args (argparse.Namespace): All command line arguments.
barrier (multiprocessing.Barrier): A barrier used to delay the start
of Web UI to be after the start of the model.
"""
world_size = len(args.gpu_ids)
gpu_id = args.gpu_ids[rank]
dist.init_process_group(
backend="nccl", rank=rank, world_size=world_size,
init_method=f"tcp://{args.master_addr}:{args.master_port}",
)
print(f"| distributed init on worker {rank}/{world_size}. "
f"using gpu: {gpu_id}")
fs_init.initialize_model_parallel(world_size)
torch.cuda.set_device(gpu_id)
torch.manual_seed(1)
np.random.seed(1)
# set the print behavior.
setup_for_distributed(rank == 0)
target_dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16
}[args.dtype]
with default_tensor_type(dtype=target_dtype, device="cuda"):
model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
print("Loading pretrained weights ...")
checkpoint = torch.load(args.pretrained_path, map_location='cpu')
msg = model.load_state_dict(checkpoint, strict=False)
print("load result:\n", msg)
model.cuda()
model.eval()
print(f"Model = {str(model)}")
barrier.wait()
while True:
img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
if 'image' in modality and img_path is not None:
image = Image.open(img_path).convert('RGB')
inputs = T_random_resized_crop(image)
elif 'video' in modality and video_path is not None:
inputs = load_video(video_path)
elif 'audio' in modality and audio_path is not None:
inputs = load_audio(audio_path)
else:
inputs = None
if inputs is not None:
inputs = inputs[None].cuda().to(target_dtype)
conv = conv_templates["v1"].copy()
for user, bot in chatbot:
conv.append_message(conv.roles[0], user)
conv.append_message(conv.roles[1], bot)
with torch.cuda.amp.autocast(dtype=target_dtype):
print(conv.get_prompt())
for stream_response in model.stream_generate(
conv.get_prompt(), inputs,
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
modal = modality
):
conv_sep = (
conv.sep
if conv.sep_style == SeparatorStyle.SINGLE
else conv.sep2
)
end_pos = stream_response["text"].find(conv_sep)
if end_pos != -1:
stream_response["text"] = (
stream_response['text'][:end_pos].rstrip() + "\n"
)
stream_response["end_of_content"] = True
# keep a few characters if not end_of_content to avoid sending
# part of conv_sep before all of it is generated.
if not stream_response["end_of_content"]:
if len(stream_response["text"]) < len(conv_sep):
continue
stream_response["text"] = (
stream_response["text"][:-len(conv_sep)]
)
if response_queue is not None:
response_queue.put(stream_response)
if stream_response["end_of_content"]:
break

For training and data, we plan to release in the near 1-2 months.

@csuhan
Copy link
Owner

csuhan commented Mar 8, 2024

Hi @yan9qu , we have just released the training code. Feel free to tell us if you need any help.

@yan9qu
Copy link
Author

yan9qu commented Mar 8, 2024

That's great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants