From 1553ed50adc5e2c4f37526347a029279f6408efc Mon Sep 17 00:00:00 2001 From: Linyan Huang <89194485+DevLinyan@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:06:20 +0800 Subject: [PATCH] Update demo.py --- .../llama_adapter_v2_multimodal7b/demo.py | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/challenge/llama_adapter_v2_multimodal7b/demo.py b/challenge/llama_adapter_v2_multimodal7b/demo.py index 63447cd..a7726b4 100644 --- a/challenge/llama_adapter_v2_multimodal7b/demo.py +++ b/challenge/llama_adapter_v2_multimodal7b/demo.py @@ -7,7 +7,7 @@ import argparse import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader -from torch.multiprocessing import Process, Queue, set_start_method +from threading import Thread import math try: @@ -50,13 +50,11 @@ def __getitem__(self, idx): return image, prompt, ids, question, answer -def worker(rank, gpu_id, args, queue): +def worker(rank, gpu_id, args, data_dict): torch.cuda.set_device(gpu_id) - device = torch.device("cuda") llama_dir = args.llama_dir - # Choose from BIAS-7B, LORA-BIAS-7B, CAPTION-7B.pth model, preprocess = llama.load(args.checkpoint, llama_dir, llama_type="7B", device=device) model.eval() @@ -75,9 +73,7 @@ def worker(rank, gpu_id, args, queue): data_to_process = data_all[start_idx:end_idx] dataset = LLamaDataset(data_to_process, transform=transform_train) - dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) - - data_dict = [] + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) for batch in tqdm(dataloader): images, prompts, ids, questions, gt_answers = batch @@ -85,9 +81,10 @@ def worker(rank, gpu_id, args, queue): results = model.generate(images, prompts, temperature=0.2, top_p=0.1) for i, result in enumerate(results): + print(f"Thread {rank}: Result - {result}") data_dict.append({'id': ids[i], 'question': questions[i], 'gt_answer': gt_answers[i], 'answer': result}) - queue.put(data_dict) + print(f"Thread {rank} finished") # add args parser = argparse.ArgumentParser(description='LLAMA Adapter') @@ -96,31 +93,22 @@ def worker(rank, gpu_id, args, queue): parser.add_argument('--data', type=str, default="../test_llama.json", help='path to test data') parser.add_argument('--output', type=str, default="../output.json", help='path to output file') parser.add_argument('--batch_size', type=int, default=8, help='batch size for parallel processing') -parser.add_argument('--num_processes', type=int, default=8, help='number of processes to use') +parser.add_argument('--num_processes', type=int, default=8, help='number of gpus to use') args = parser.parse_args() if __name__ == '__main__': - try: - set_start_method('spawn') - except RuntimeError: - pass - - num_gpus = torch.cuda.device_count() + num_gpus = args.num_processes print(f"Using {num_gpus} GPUs") - queue = Queue() - processes = [] + data_dict = [] + threads = [] for rank in range(num_gpus): - p = Process(target=worker, args=(rank, rank, args, queue)) - p.start() - processes.append(p) - - for p in processes: - p.join() + t = Thread(target=worker, args=(rank, rank, args, data_dict)) + t.start() + threads.append(t) - data_dict = [] - while not queue.empty(): - data_dict.extend(queue.get()) + for t in threads: + t.join() with open(args.output, "w") as f: json.dump(data_dict, f, indent=4)