Skip to content

Commit

Permalink
fix(stablelm): running on GPU
Browse files Browse the repository at this point in the history
Signed-off-by: aarnphm-ec2-dev <[email protected]>
  • Loading branch information
aarnphm committed Jun 11, 2023
1 parent 8762a56 commit a5efb7f
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/openllm/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]),
}

if torch.cuda.is_available():
self.model.cuda()

inputs = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt")).to(self.device)
tokens = self.model.generate(**inputs, **generation_kwargs)
return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)]

0 comments on commit a5efb7f

Please sign in to comment.