Skip to content

Commit

Permalink
Support 4bit quantized model format as default (#28)
Browse files Browse the repository at this point in the history
* Change default loading in 4bit quantized format

* Handle edge case failures better

* Add trigger for regeneration

* Upgrade version to 0.0.6
  • Loading branch information
pramitchoudhary authored Aug 30, 2023
1 parent 80cb1f4 commit 427ad8b
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 216 deletions.
2 changes: 1 addition & 1 deletion app.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ title = "SQL-Sidekick"
description = "QnA with tabular data using NLI"
LongDescription = "about.md"
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
Version = "0.0.5"
Version = "0.0.6"

[Runtime]
MemoryLimit = "64Gi"
Expand Down
3 changes: 2 additions & 1 deletion sidekick/configs/env.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[MODEL_INFO]
OPENAI_API_KEY = "" # Needed only for openAI models
MODEL_NAME = "h2ogpt-sql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
QUANT_TYPE = '4bit'

[LOCAL_DB_CONFIG]
HOST_NAME = "localhost"
Expand All @@ -10,7 +11,7 @@ DB_NAME = "querydb"
PORT = "5432"

[LOGGING]
LOG-LEVEL = "INFO"
LOG-LEVEL = "DEBUG"

[DB-DIALECT]
DB_TYPE = "sqlite"
Expand Down
274 changes: 144 additions & 130 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import gc
import json
import os
from pathlib import Path

import click
import openai
import toml
import torch
from colorama import Back as B
from colorama import Fore as F
from colorama import Style
Expand Down Expand Up @@ -156,8 +158,7 @@ def db_setup_api(
"""Creates context for the new Database"""
click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}")
try:
res = None
err = None
res = err = None
env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] = hostname
env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] = user_name
env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] = password
Expand Down Expand Up @@ -373,140 +374,153 @@ def query_api(
f.close()
openai.api_key = api_key

# Set context
logger.info("Setting context...")
logger.info(f"Question: {question}")
# Get updated info from env.toml
host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"]
passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]

if db_dialect == "sqlite":
db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
else:
db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
user_name, passwd, host_name, db_name
)

if table_info_path is None:
table_info_path = _get_table_info(path)
try:
# Set context
logger.info("Setting context...")
logger.info(f"Question: {question}")
# Get updated info from env.toml
host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"]
passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]

if db_dialect == "sqlite":
db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
else:
db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
user_name, passwd, host_name, db_name
)

sql_g = SQLGenerator(
db_url,
api_key,
job_path=base_path,
data_input_path=table_info_path,
sample_queries_path=sample_queries_path,
regenerate=is_regenerate,
)
if "h2ogpt-sql" not in model_name:
sql_g._tasks = sql_g.generate_tasks(table_names, question)
results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"])
click.echo(sql_g._tasks)

updated_tasks = None
if sql_g._tasks is not None and is_command:
edit_val = click.prompt("Would you like to edit the tasks? (y/n)")
if edit_val.lower() == "y":
updated_tasks = click.edit(sql_g._tasks)
click.echo(f"Tasks:\n {updated_tasks}")
else:
click.echo("Skipping edit...")
if updated_tasks is not None:
sql_g._tasks = updated_tasks
if table_info_path is None:
table_info_path = _get_table_info(path)

res, alt_res = sql_g.generate_sql(
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
)
logger.info(f"Input query: {question}")
logger.info(f"Generated response:\n\n{res}")

if res is not None:
updated_sql = None
res_val = "e"
if is_command:
while res_val.lower() in ["e", "edit", "r", "regenerate"]:
res_val = click.prompt(
"Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. "
"To skip, enter 's' or 'skip'"
)
if res_val.lower() == "e" or res_val.lower() == "edit":
updated_sql = click.edit(res)
click.echo(f"Updated SQL:\n {updated_sql}")
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
click.echo("Attempting to regenerate...")
res, alt_res = sql_g.generate_sql(
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
sql_g = SQLGenerator(
db_url,
api_key,
job_path=base_path,
data_input_path=table_info_path,
sample_queries_path=sample_queries_path,
is_regenerate = is_regenerate
)
if "h2ogpt-sql" not in model_name:
sql_g._tasks = sql_g.generate_tasks(table_names, question)
results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"])
click.echo(sql_g._tasks)

updated_tasks = None
if sql_g._tasks is not None and is_command:
edit_val = click.prompt("Would you like to edit the tasks? (y/n)")
if edit_val.lower() == "y":
updated_tasks = click.edit(sql_g._tasks)
click.echo(f"Tasks:\n {updated_tasks}")
else:
click.echo("Skipping edit...")
if updated_tasks is not None:
sql_g._tasks = updated_tasks
alt_res = None
res, alt_res = sql_g.generate_sql(
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
)
logger.info(f"Input query: {question}")
logger.info(f"Generated response:\n\n{res}")

if res is not None:
updated_sql = None
res_val = "e"
if is_command:
while res_val.lower() in ["e", "edit", "r", "regenerate"]:
res_val = click.prompt(
"Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. "
"To skip, enter 's' or 'skip'"
)
if res_val.lower() == "e" or res_val.lower() == "edit":
updated_sql = click.edit(res)
click.echo(f"Updated SQL:\n {updated_sql}")
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
click.echo("Attempting to regenerate...")
res, alt_res = sql_g.generate_sql(
table_names,
question,
model_name=model_name,
_dialect=db_dialect,
is_regenerate=is_regenerate,
)
logger.info(f"Input query: {question}")
logger.info(f"Generated response:\n\n{res}")

results.extend(["**Generated Query:**\n", res, "\n"])
logger.info(f"Alternate responses:\n\n{alt_res}")

exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y"
if exe_sql.lower() == "y" or exe_sql.lower() == "yes":
# For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later.
q_res = None
option = "DB" # or DB
_val = updated_sql if updated_sql else res
if option == "DB":
hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"]
password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
port = env_settings["LOCAL_DB_CONFIG"]["PORT"]
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]

db_obj = DBConfig(
db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect
)
logger.info(f"Input query: {question}")
logger.info(f"Generated response:\n\n{res}")

results.extend(["**Generated Query:**\n", res, "\n"])
logger.info(f"Alternate responses:\n\n{alt_res}")

exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y"
if exe_sql.lower() == "y" or exe_sql.lower() == "yes":
# For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later.
q_res = None
option = "DB" # or DB
_val = updated_sql if updated_sql else res
if option == "DB":
hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"]
password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
port = env_settings["LOCAL_DB_CONFIG"]["PORT"]
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]

db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect)

q_res, err = db_obj.execute_query_db(query=_val)

elif option == "pandas":
tables = extract_table_names(_val)
tables_path = dict()
if Path(f"{path}/table_context.json").exists():
f = open(f"{path}/table_context.json", "r")
table_metadata = json.load(f)
for table in tables:
# Check if the local table_path exists in the cache
if table not in table_metadata["data_table_map"].keys():
val = enter_file_path(table)
if not os.path.isfile(val):
click.echo("In-correct Path. Please enter again! Yes(y) or no(n)")

q_res, err = db_obj.execute_query_db(query=_val)

elif option == "pandas":
tables = extract_table_names(_val)
tables_path = dict()
if Path(f"{path}/table_context.json").exists():
f = open(f"{path}/table_context.json", "r")
table_metadata = json.load(f)
for table in tables:
# Check if the local table_path exists in the cache
if table not in table_metadata["data_table_map"].keys():
val = enter_file_path(table)
if not os.path.isfile(val):
click.echo("In-correct Path. Please enter again! Yes(y) or no(n)")
else:
tables_path[table] = val
table_metadata["data_table_map"][table] = val
break
else:
tables_path[table] = val
table_metadata["data_table_map"][table] = val
break
else:
tables_path[table] = table_metadata["data_table_map"][table]
assert len(tables) == len(tables_path)
with open(f"{path}/table_context.json", "w") as outfile:
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
try:
q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100)
tables_path[table] = table_metadata["data_table_map"][table]
assert len(tables) == len(tables_path)
with open(f"{path}/table_context.json", "w") as outfile:
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
try:
q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100)
click.echo(f"The query results are:\n {q_res}")
except sqldf.PandaSQLException as e:
logger.error(f"Error in executing the query: {e}")
click.echo("Error in executing the query. Validate generated SQL and try again.")
click.echo("No result to display.")

results.append("**Query Results:** \n")
if q_res:
click.echo(f"The query results are:\n {q_res}")
except sqldf.PandaSQLException as e:
logger.error(f"Error in executing the query: {e}")
click.echo("Error in executing the query. Validate generated SQL and try again.")
click.echo("No result to display.")

results.append("**Query Results:** \n")
if q_res:
click.echo(f"The query results are:\n {q_res}")
results.extend([str(q_res), "\n"])
results.extend([str(q_res), "\n"])
else:
click.echo(f"While executing query:\n {err}")
results.extend([str(err), "\n"])

save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n"
if save_sql.lower() == "y" or save_sql.lower() == "yes":
# Persist for future use
_val = updated_sql if updated_sql else res
save_query(base_path, query=question, response=_val)
else:
click.echo(f"While executing query:\n {err}")
results.extend([str(err), "\n"])

save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n"
if save_sql.lower() == "y" or save_sql.lower() == "yes":
# Persist for future use
_val = updated_sql if updated_sql else res
save_query(base_path, query=question, response=_val)
else:
click.echo("Exiting...")

click.echo("Exiting...")
except (MemoryError, RuntimeError) as e:
logger.error(f"Something went wrong while generating response: {e}")
del sql_g
gc.collect()
torch.cuda.empty_cache()
alt_res, err = None, None
results = ["Something went wrong while generating response. Please try again."]
return results, alt_res, err


Expand Down
Loading

0 comments on commit 427ad8b

Please sign in to comment.