Skip to content

Commit

Permalink
Update demo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DevLinyan authored Mar 25, 2024
1 parent 714dd64 commit 1553ed5
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions challenge/llama_adapter_v2_multimodal7b/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import argparse
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.multiprocessing import Process, Queue, set_start_method
from threading import Thread
import math

try:
Expand Down Expand Up @@ -50,13 +50,11 @@ def __getitem__(self, idx):

return image, prompt, ids, question, answer

def worker(rank, gpu_id, args, queue):
def worker(rank, gpu_id, args, data_dict):
torch.cuda.set_device(gpu_id)

device = torch.device("cuda")
llama_dir = args.llama_dir

# Choose from BIAS-7B, LORA-BIAS-7B, CAPTION-7B.pth
model, preprocess = llama.load(args.checkpoint, llama_dir, llama_type="7B", device=device)
model.eval()

Expand All @@ -75,19 +73,18 @@ def worker(rank, gpu_id, args, queue):
data_to_process = data_all[start_idx:end_idx]

dataset = LLamaDataset(data_to_process, transform=transform_train)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

data_dict = []
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)

for batch in tqdm(dataloader):
images, prompts, ids, questions, gt_answers = batch
images = images.to(device)
results = model.generate(images, prompts, temperature=0.2, top_p=0.1)

for i, result in enumerate(results):
print(f"Thread {rank}: Result - {result}")
data_dict.append({'id': ids[i], 'question': questions[i], 'gt_answer': gt_answers[i], 'answer': result})

queue.put(data_dict)
print(f"Thread {rank} finished")

# add args
parser = argparse.ArgumentParser(description='LLAMA Adapter')
Expand All @@ -96,31 +93,22 @@ def worker(rank, gpu_id, args, queue):
parser.add_argument('--data', type=str, default="../test_llama.json", help='path to test data')
parser.add_argument('--output', type=str, default="../output.json", help='path to output file')
parser.add_argument('--batch_size', type=int, default=8, help='batch size for parallel processing')
parser.add_argument('--num_processes', type=int, default=8, help='number of processes to use')
parser.add_argument('--num_processes', type=int, default=8, help='number of gpus to use')
args = parser.parse_args()

if __name__ == '__main__':
try:
set_start_method('spawn')
except RuntimeError:
pass

num_gpus = torch.cuda.device_count()
num_gpus = args.num_processes
print(f"Using {num_gpus} GPUs")

queue = Queue()
processes = []
data_dict = []
threads = []
for rank in range(num_gpus):
p = Process(target=worker, args=(rank, rank, args, queue))
p.start()
processes.append(p)

for p in processes:
p.join()
t = Thread(target=worker, args=(rank, rank, args, data_dict))
t.start()
threads.append(t)

data_dict = []
while not queue.empty():
data_dict.extend(queue.get())
for t in threads:
t.join()

with open(args.output, "w") as f:
json.dump(data_dict, f, indent=4)

0 comments on commit 1553ed5

Please sign in to comment.