From 427ad8b7f4451b538f8ddb028129c3b385c02a87 Mon Sep 17 00:00:00 2001 From: Pramit Choudhary Date: Wed, 30 Aug 2023 12:14:59 -0700 Subject: [PATCH] Support 4bit quantized model format as default (#28) * Change default loading in 4bit quantized format * Handle edge case failures better * Add trigger for regeneration * Upgrade version to 0.0.6 --- app.toml | 2 +- sidekick/configs/env.toml | 3 +- sidekick/prompter.py | 274 ++++++++++++++++++++------------------ sidekick/query.py | 41 +++--- sidekick/utils.py | 77 +++++++---- ui/app.py | 103 ++++++++------ 6 files changed, 284 insertions(+), 216 deletions(-) diff --git a/app.toml b/app.toml index efd0b20..b336d39 100644 --- a/app.toml +++ b/app.toml @@ -4,7 +4,7 @@ title = "SQL-Sidekick" description = "QnA with tabular data using NLI" LongDescription = "about.md" Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"] -Version = "0.0.5" +Version = "0.0.6" [Runtime] MemoryLimit = "64Gi" diff --git a/sidekick/configs/env.toml b/sidekick/configs/env.toml index a4019c8..2289673 100644 --- a/sidekick/configs/env.toml +++ b/sidekick/configs/env.toml @@ -1,6 +1,7 @@ [MODEL_INFO] OPENAI_API_KEY = "" # Needed only for openAI models MODEL_NAME = "h2ogpt-sql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003 +QUANT_TYPE = '4bit' [LOCAL_DB_CONFIG] HOST_NAME = "localhost" @@ -10,7 +11,7 @@ DB_NAME = "querydb" PORT = "5432" [LOGGING] -LOG-LEVEL = "INFO" +LOG-LEVEL = "DEBUG" [DB-DIALECT] DB_TYPE = "sqlite" diff --git a/sidekick/prompter.py b/sidekick/prompter.py index e926897..93dcc8b 100644 --- a/sidekick/prompter.py +++ b/sidekick/prompter.py @@ -1,3 +1,4 @@ +import gc import json import os from pathlib import Path @@ -5,6 +6,7 @@ import click import openai import toml +import torch from colorama import Back as B from colorama import Fore as F from colorama import Style @@ -156,8 +158,7 @@ def db_setup_api( """Creates context for the new Database""" click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}") try: - res = None - err = None + res = err = None env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] = hostname env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] = user_name env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] = password @@ -373,140 +374,153 @@ def query_api( f.close() openai.api_key = api_key - # Set context - logger.info("Setting context...") - logger.info(f"Question: {question}") - # Get updated info from env.toml - host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] - user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] - passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] - db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] - - if db_dialect == "sqlite": - db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db" - else: - db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format( - user_name, passwd, host_name, db_name - ) - - if table_info_path is None: - table_info_path = _get_table_info(path) + try: + # Set context + logger.info("Setting context...") + logger.info(f"Question: {question}") + # Get updated info from env.toml + host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + + if db_dialect == "sqlite": + db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db" + else: + db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format( + user_name, passwd, host_name, db_name + ) - sql_g = SQLGenerator( - db_url, - api_key, - job_path=base_path, - data_input_path=table_info_path, - sample_queries_path=sample_queries_path, - regenerate=is_regenerate, - ) - if "h2ogpt-sql" not in model_name: - sql_g._tasks = sql_g.generate_tasks(table_names, question) - results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"]) - click.echo(sql_g._tasks) - - updated_tasks = None - if sql_g._tasks is not None and is_command: - edit_val = click.prompt("Would you like to edit the tasks? (y/n)") - if edit_val.lower() == "y": - updated_tasks = click.edit(sql_g._tasks) - click.echo(f"Tasks:\n {updated_tasks}") - else: - click.echo("Skipping edit...") - if updated_tasks is not None: - sql_g._tasks = updated_tasks + if table_info_path is None: + table_info_path = _get_table_info(path) - res, alt_res = sql_g.generate_sql( - table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate - ) - logger.info(f"Input query: {question}") - logger.info(f"Generated response:\n\n{res}") - - if res is not None: - updated_sql = None - res_val = "e" - if is_command: - while res_val.lower() in ["e", "edit", "r", "regenerate"]: - res_val = click.prompt( - "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. " - "To skip, enter 's' or 'skip'" - ) - if res_val.lower() == "e" or res_val.lower() == "edit": - updated_sql = click.edit(res) - click.echo(f"Updated SQL:\n {updated_sql}") - elif res_val.lower() == "r" or res_val.lower() == "regenerate": - click.echo("Attempting to regenerate...") - res, alt_res = sql_g.generate_sql( - table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate + sql_g = SQLGenerator( + db_url, + api_key, + job_path=base_path, + data_input_path=table_info_path, + sample_queries_path=sample_queries_path, + is_regenerate = is_regenerate + ) + if "h2ogpt-sql" not in model_name: + sql_g._tasks = sql_g.generate_tasks(table_names, question) + results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"]) + click.echo(sql_g._tasks) + + updated_tasks = None + if sql_g._tasks is not None and is_command: + edit_val = click.prompt("Would you like to edit the tasks? (y/n)") + if edit_val.lower() == "y": + updated_tasks = click.edit(sql_g._tasks) + click.echo(f"Tasks:\n {updated_tasks}") + else: + click.echo("Skipping edit...") + if updated_tasks is not None: + sql_g._tasks = updated_tasks + alt_res = None + res, alt_res = sql_g.generate_sql( + table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate + ) + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + + if res is not None: + updated_sql = None + res_val = "e" + if is_command: + while res_val.lower() in ["e", "edit", "r", "regenerate"]: + res_val = click.prompt( + "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. " + "To skip, enter 's' or 'skip'" + ) + if res_val.lower() == "e" or res_val.lower() == "edit": + updated_sql = click.edit(res) + click.echo(f"Updated SQL:\n {updated_sql}") + elif res_val.lower() == "r" or res_val.lower() == "regenerate": + click.echo("Attempting to regenerate...") + res, alt_res = sql_g.generate_sql( + table_names, + question, + model_name=model_name, + _dialect=db_dialect, + is_regenerate=is_regenerate, + ) + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + + results.extend(["**Generated Query:**\n", res, "\n"]) + logger.info(f"Alternate responses:\n\n{alt_res}") + + exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y" + if exe_sql.lower() == "y" or exe_sql.lower() == "yes": + # For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later. + q_res = None + option = "DB" # or DB + _val = updated_sql if updated_sql else res + if option == "DB": + hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + port = env_settings["LOCAL_DB_CONFIG"]["PORT"] + db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + + db_obj = DBConfig( + db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect ) - logger.info(f"Input query: {question}") - logger.info(f"Generated response:\n\n{res}") - - results.extend(["**Generated Query:**\n", res, "\n"]) - logger.info(f"Alternate responses:\n\n{alt_res}") - - exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y" - if exe_sql.lower() == "y" or exe_sql.lower() == "yes": - # For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later. - q_res = None - option = "DB" # or DB - _val = updated_sql if updated_sql else res - if option == "DB": - hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] - user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] - password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] - port = env_settings["LOCAL_DB_CONFIG"]["PORT"] - db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] - - db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect) - - q_res, err = db_obj.execute_query_db(query=_val) - - elif option == "pandas": - tables = extract_table_names(_val) - tables_path = dict() - if Path(f"{path}/table_context.json").exists(): - f = open(f"{path}/table_context.json", "r") - table_metadata = json.load(f) - for table in tables: - # Check if the local table_path exists in the cache - if table not in table_metadata["data_table_map"].keys(): - val = enter_file_path(table) - if not os.path.isfile(val): - click.echo("In-correct Path. Please enter again! Yes(y) or no(n)") + + q_res, err = db_obj.execute_query_db(query=_val) + + elif option == "pandas": + tables = extract_table_names(_val) + tables_path = dict() + if Path(f"{path}/table_context.json").exists(): + f = open(f"{path}/table_context.json", "r") + table_metadata = json.load(f) + for table in tables: + # Check if the local table_path exists in the cache + if table not in table_metadata["data_table_map"].keys(): + val = enter_file_path(table) + if not os.path.isfile(val): + click.echo("In-correct Path. Please enter again! Yes(y) or no(n)") + else: + tables_path[table] = val + table_metadata["data_table_map"][table] = val + break else: - tables_path[table] = val - table_metadata["data_table_map"][table] = val - break - else: - tables_path[table] = table_metadata["data_table_map"][table] - assert len(tables) == len(tables_path) - with open(f"{path}/table_context.json", "w") as outfile: - json.dump(table_metadata, outfile, indent=4, sort_keys=False) - try: - q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100) + tables_path[table] = table_metadata["data_table_map"][table] + assert len(tables) == len(tables_path) + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + try: + q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100) + click.echo(f"The query results are:\n {q_res}") + except sqldf.PandaSQLException as e: + logger.error(f"Error in executing the query: {e}") + click.echo("Error in executing the query. Validate generated SQL and try again.") + click.echo("No result to display.") + + results.append("**Query Results:** \n") + if q_res: click.echo(f"The query results are:\n {q_res}") - except sqldf.PandaSQLException as e: - logger.error(f"Error in executing the query: {e}") - click.echo("Error in executing the query. Validate generated SQL and try again.") - click.echo("No result to display.") - - results.append("**Query Results:** \n") - if q_res: - click.echo(f"The query results are:\n {q_res}") - results.extend([str(q_res), "\n"]) + results.extend([str(q_res), "\n"]) + else: + click.echo(f"While executing query:\n {err}") + results.extend([str(err), "\n"]) + + save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n" + if save_sql.lower() == "y" or save_sql.lower() == "yes": + # Persist for future use + _val = updated_sql if updated_sql else res + save_query(base_path, query=question, response=_val) else: - click.echo(f"While executing query:\n {err}") - results.extend([str(err), "\n"]) - - save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n" - if save_sql.lower() == "y" or save_sql.lower() == "yes": - # Persist for future use - _val = updated_sql if updated_sql else res - save_query(base_path, query=question, response=_val) - else: - click.echo("Exiting...") - + click.echo("Exiting...") + except (MemoryError, RuntimeError) as e: + logger.error(f"Something went wrong while generating response: {e}") + del sql_g + gc.collect() + torch.cuda.empty_cache() + alt_res, err = None, None + results = ["Something went wrong while generating response. Please try again."] return results, alt_res, err diff --git a/sidekick/query.py b/sidekick/query.py index 41f3582..ebd6c1b 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -1,5 +1,6 @@ import json import os +import gc import random import sys from pathlib import Path @@ -21,7 +22,7 @@ load_embedding_model, read_sample_pairs, remove_duplicates, - offload_state, + is_resource_low, ) from sqlalchemy import create_engine from transformers import AutoModelForCausalLM, AutoTokenizer @@ -39,16 +40,23 @@ def __new__( sample_queries_path: str = "./samples.csv", job_path: str = "./", device: str = "auto", - regenerate: bool = False + is_regenerate: bool = False, ): - offloading = offload_state() - if offloading and regenerate: + offloading = is_resource_low() + if offloading and is_regenerate: + del cls._instance cls._instance = None - logger.info(f"Offloading state : {offloading}") + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Low memory: {offloading}/ Model re-initialization: True") if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.model, cls._instance.tokenizer = load_causal_lm_model( - model_name, cache_path=f"{job_path}/models/", device=device, off_load=offloading + model_name, + cache_path=f"{job_path}/models/", + device=device, + off_load=offloading, + re_generate=is_regenerate, ) model_embed_path = f"{job_path}/models/sentence_transformers" device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device @@ -64,7 +72,7 @@ def __init__( sample_queries_path: str = "./samples.csv", job_path: str = "./", device: str = "cpu", - regenerate: bool = False + is_regenerate: bool = False, ): self.db_url = db_url self.engine = create_engine(db_url) @@ -80,6 +88,7 @@ def __init__( self.content_queries = None def clear(self): + del SQLGenerator._instance SQLGenerator._instance = None def load_column_samples(self, tables: list): @@ -301,17 +310,6 @@ def generate_sql( logger.info("We did the best we could, there might be still be some error:\n") logger.info(f"Realized query so far:\n {res}") else: - if self.model is None: - # Load h2oGPT.NSQL if not initialized self.model is None - # https://github.com/pytorch/pytorch/issues/52291 - offloading = offload_state() - if offloading: - self.clear() - logger.info(f"Offloading state: {offloading}") - self.model, self.tokenizer = load_causal_lm_model( - self.model_name, cache_path=f"{self.path}/models/", device="auto", off_load=offloading - ) - # TODO Update needed for multiple tables columns_w_type = ( self.context_builder.full_context_dict[table_names[0]].split(":")[2].split("and")[0].strip() @@ -439,12 +437,13 @@ def generate_sql( max_new_tokens=300, temperature=0.5, output_scores=True, + do_sample=True, return_dict_in_generate=True, ) generated_tokens = output.sequences[:, input_length:][0] else: - self.model.eval() + logger.info("Regeneration requested on previous query ...") random_seed = random.randint(0, 50) torch.manual_seed(random_seed) random_temperature = round(random.uniform(0.5, 0.75), 2) @@ -452,7 +451,7 @@ def generate_sql( **inputs.to(device_type), max_new_tokens=300, temperature=random_temperature, - top_k=10, + top_k=5, top_p=0.95, num_beams=5, num_beam_groups=5, @@ -489,7 +488,7 @@ def generate_sql( res = "SELECT " + result.strip() + " LIMIT 100;" else: res = "SELECT " + result.strip() + ";" - alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[idx]})\n{res}" + alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[idx]})\n{res}\n" alternate_queries.append(alt_res) logger.info(alt_res) diff --git a/sidekick/utils.py b/sidekick/utils.py index 145d501..04dd903 100644 --- a/sidekick/utils.py +++ b/sidekick/utils.py @@ -15,6 +15,7 @@ from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from accelerate import init_empty_weights, infer_auto_device_map +from transformers import BitsAndBytesConfig def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None): @@ -242,63 +243,89 @@ def get_table_keys(file_path: str, table_key: str): return res, data -def offload_state(): +def is_resource_low(): free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3) - logger.info(f"Total Memory: {total_memory}") + logger.info(f"Total Memory: {total_memory}GB") logger.info(f"Free GPU memory: {free_in_GB}GB") off_load = True - if int(free_in_GB) >= int(0.45 * total_memory): + if (int(free_in_GB) - 2) >= int(0.5 * total_memory): off_load = False return off_load def load_causal_lm_model( - model_name: str, cache_path: str, device: str, load_in_8bit: bool = True, off_load: bool = False + model_name: str, + cache_path: str, + device: str, + load_in_8bit: bool = False, + load_in_4bit=True, + off_load: bool = False, + re_generate: bool = False, ): try: # Load h2oGPT.NSQL model device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device - free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3) + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + logger.info(f"Free GPU memory: {free_in_GB}GB") n_gpus = torch.cuda.device_count() + _load_in_8bit = load_in_8bit # 22GB (Least requirement on GPU) is a magic number for the current model size. - if off_load and total_memory < 22: + if off_load and re_generate and total_memory < 22: + # To prevent the system from crashing in-case memory runs low. # TODO: Performance when offloading to CPU. max_memory = f"{4}GB" max_memory = {i: max_memory for i in range(n_gpus)} - logger.info(f"Max Memory: {max_memory}") - config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path) - - model = AutoModelForCausalLM.from_config(config) - device = infer_auto_device_map(model, max_memory=max_memory) - device["lm_head"] = 0 + logger.info(f"Max Memory: {max_memory}, offloading to CPU") + with init_empty_weights(): + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path) + # A blank model with desired config. + model = AutoModelForCausalLM.from_config(config) + device = infer_auto_device_map(model, max_memory=max_memory) + device["lm_head"] = 0 _offload_state_dict = True _llm_int8_enable_fp32_cpu_offload = True + _load_in_8bit = True + load_in_4bit = False else: max_memory = f"{int(free_in_GB)-2}GB" max_memory = {i: max_memory for i in range(n_gpus)} _offload_state_dict = False _llm_int8_enable_fp32_cpu_offload = False - if load_in_8bit: + if _load_in_8bit and _offload_state_dict and not load_in_4bit: _load_in_8bit = False if "cpu" in device else True + logger.debug( + f"Loading in 8 bit mode: {_load_in_8bit} with offloading state: {_llm_int8_enable_fp32_cpu_offload}" + ) + model = AutoModelForCausalLM.from_pretrained( + model_name, + cache_dir=cache_path, + device_map=device, + load_in_8bit=_load_in_8bit, + llm_int8_enable_fp32_cpu_offload=_llm_int8_enable_fp32_cpu_offload, + offload_state_dict=_offload_state_dict, + max_memory=max_memory, + offload_folder=f"{cache_path}/weights/", + ) + else: - _load_in_8bit = False - logger.debug(f"Current device config: {device}") + logger.debug(f"Loading in 4 bit mode: {load_in_4bit} with device {device}") + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_name, cache_dir=cache_path, device_map=device, quantization_config=nf4_config + ) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path, device_map=device) - model = AutoModelForCausalLM.from_pretrained( - model_name, - cache_dir=cache_path, - device_map=device, - load_in_8bit=_load_in_8bit, - llm_int8_enable_fp32_cpu_offload=_offload_state_dict, - offload_state_dict=_llm_int8_enable_fp32_cpu_offload, - max_memory=max_memory, - offload_folder=f"{cache_path}/weights/", - ) + return model, tokenizer except Exception as e: logger.info(f"An error occurred while loading the model: {e}") diff --git a/ui/app.py b/ui/app.py index 24cbac5..f62867e 100644 --- a/ui/app.py +++ b/ui/app.py @@ -1,3 +1,4 @@ +import gc import json import logging from pathlib import Path @@ -5,7 +6,9 @@ import openai import toml +import torch from h2o_wave import Q, app, data, handle_on, main, on, ui +from h2o_wave.core import expando_to_dict from sidekick.prompter import db_setup_api, query_api from sidekick.utils import get_table_keys, setup_dir, update_tables @@ -97,6 +100,7 @@ async def chat(q: Q): box=ui.box("vertical", height="500px"), name="chatbot", data=data(fields="content from_user", t="list", size=-50), + commands=[ui.command(name=f"regenerate_event", icon="RepeatAll", caption="Regenerate", label="Regenerate")], ), ) @@ -117,46 +121,53 @@ async def chatbot(q: Q): question = f"{q.args.chatbot}" logging.info(f"Question: {question}") - if q.args.chatbot.lower() == "db setup": - llm_response, err = db_setup_api( - db_name=q.user.db_name, - hostname=q.user.host_name, - user_name=q.user.user_name, - password=q.user.password, - port=q.user.port, - table_info_path=q.user.table_info_path, - table_samples_path=q.user.table_samples_path, - table_name=q.user.table_name, - ) - elif q.args.chatbot.lower() == "regenerate": - if q.client.query is not None and q.client.query.strip() != "": - llm_response, alt_response, err = query_api( - question=q.client.query, - sample_queries_path=q.user.sample_qna_path, + try: + if q.args.chatbot.lower() == "db setup": + llm_response, err = db_setup_api( + db_name=q.user.db_name, + hostname=q.user.host_name, + user_name=q.user.user_name, + password=q.user.password, + port=q.user.port, table_info_path=q.user.table_info_path, + table_samples_path=q.user.table_samples_path, table_name=q.user.table_name, - is_regenerate=True, ) - response = "\n".join(llm_response) - if alt_response: - llm_response = response + "\n\n" + "**Alternate options:**\n" + "\n".join(alt_response) - logging.info(f"Regenerate response: {llm_response}") + elif q.args.chatbot.lower() == "regenerate" or q.args.regenerate_event: + # Attempts to regenerate response on the last supplie query + if q.client.query is not None and q.client.query.strip() != "": + llm_response, alt_response, err = query_api( + question=q.client.query, + sample_queries_path=q.user.sample_qna_path, + table_info_path=q.user.table_info_path, + table_name=q.user.table_name, + is_regenerate=True, + ) + response = "\n".join(llm_response) + if alt_response: + llm_response = response + "\n\n" + "**Alternate options:**\n" + "\n".join(alt_response) + logging.info(f"Regenerate response: {llm_response}") + else: + llm_response = response else: - llm_response = response + llm_response, err = ( + "Sure, I can generate a new response for you. However, in order to assist you " + "effectively could you please provide me with your question?" + ), None else: - llm_response, err = ( - "Sure, I can generate a new response for you. However, in order to assist you " - "effectively could you please provide me with your question?" - ), None - else: - q.client.query = question - llm_response, alt_response, err = query_api( - question=question, - sample_queries_path=q.user.sample_qna_path, - table_info_path=q.user.table_info_path, - table_name=q.user.table_name, - ) - llm_response = "\n".join(llm_response) + q.client.query = question + llm_response, alt_response, err = query_api( + question=question, + sample_queries_path=q.user.sample_qna_path, + table_info_path=q.user.table_info_path, + table_name=q.user.table_name, + ) + llm_response = "\n".join(llm_response) + except (MemoryError, RuntimeError) as e: + logging.error(f"Something went wrong while generating response: {e}") + gc.collect() + torch.cuda.empty_cache() + llm_response = "Something went wrong, try executing the query again!" q.page["chat_card"].data += [llm_response, False] @@ -393,7 +404,22 @@ async def init(q: Q) -> None: def on_shutdown(): - logging.debug("App stopped. Goodbye!") + logging.info("App stopped. Goodbye!") + + +async def on_event(q: Q): + event_handled = False + logging.info(f"Event handled ... ") + args_dict = expando_to_dict(q.args) + logging.debug(f"Args dict {args_dict}") + if q.args.regenerate_event: + q.args.chatbot = "regenerate" + await chatbot(q) + event_handled = True + else: # default chatbot event + await handle_on(q) + event_handled = True + return event_handled @app("/", on_shutdown=on_shutdown) @@ -406,5 +432,6 @@ async def serve(q: Q): q.client.initialized = True # Handle routing. - await handle_on(q) - await q.page.save() + if await on_event(q): + await q.page.save() + return