Skip to content

Commit

Permalink
new authorization method
Browse files Browse the repository at this point in the history
  • Loading branch information
seldereyy committed Nov 22, 2024
1 parent 96442f6 commit b5237db
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions lm_eval/models/gigachat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,42 +43,44 @@ def _create_payload(
**kwargs,
) -> dict:
if generate:
max_tokens = gen_kwargs.pop("max_tokens", None)
temperature = gen_kwargs.pop("temperature", None)
profanity_check = gen_kwargs.pop("profanity_check", True)
do_sample = gen_kwargs.pop("do_sample", None)

if (
"do_sample" in gen_kwargs
): # GigaChat API does not have do sample option.
do_sample = gen_kwargs.pop("do_sample")
if do_sample is not None: # GigaChat API does not have do sample option.
if not do_sample: # Ensure greedy decoding if do_sample=False
gen_kwargs["repetition_penalty"] = 1.0
gen_kwargs["top_p"] = 0.0
elif temperature == 0:
elif temperature == 0.0:
eval_logger.warning(
"You cannot set do_sample=True and temperature=0. Automatically setting temperature=1."
)
temperature = 1.0
if (
temperature == 0
temperature == 0.0
): # Ensure greedy decoding by setting top_p=0 and repetition_penalty = 1
temperature = (
1.0 # temperature cannot be set to zero. Use top_p instead
)
gen_kwargs["repetition_penalty"] = 1
gen_kwargs["top_p"] = 0
gen_kwargs["repetition_penalty"] = 1.0
gen_kwargs["top_p"] = 0.0
print(
{
"messages": messages,
"model": self.model,
"temperature": temperature,
**gen_kwargs,
}
)
return {
"messages": messages,
"model": self.model,
"max_tokens": max_tokens,
"temperature": temperature,
"profanity_check": profanity_check,
**gen_kwargs,
}
else:
return None

@property # Don't use cached_property as we need to check that the acess_token has not expired.
@property # Don't use cached_property as we need to check that the access_token has not expired.
def header(self) -> dict:
"""Override this property to return the headers for the API request."""

Expand All @@ -90,6 +92,11 @@ def header(self) -> dict:

@property # Don't use cached_property as we need to check that the acess_token has not expired.
def api_key(self):
self.key = os.environ.get(
"GIGACHAT_CREDENTIALS", None
) # GigaChat access token.
if self.key:
return self.key # If access token is available, return access token.
RqUID = os.environ.get(
"GIGACHAT_RQUID", None
) # Unique identification request. Complies with uuid4 format. Value must match regular expression (([0-9a-fA-F-])36)
Expand Down

0 comments on commit b5237db

Please sign in to comment.