Skip to content

Commit

Permalink
refactoring table creation #42 #39 fix #35
Browse files Browse the repository at this point in the history
  • Loading branch information
behrisch committed Jan 9, 2023
1 parent b4e72b8 commit fb63553
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 41 deletions.
20 changes: 20 additions & 0 deletions db_manipulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def insertmany(conn, table, columns, parameters):
conn.commit()


def create_table(conn, schema, table, createQuery):
schema_table, table, _ = check_schema_table(conn, schema, table)
cursor = conn.cursor()
try:
cursor.execute("DROP TABLE IF EXISTS " + schema_table)
cursor.execute(createQuery % (schema_table, table))
except Exception:
conn.rollback()
schema_table += "_fallback"
table += "_fallback"
cursor.execute("DROP TABLE IF EXISTS " + schema_table)
cursor.execute(createQuery % (schema_table, table))
if not isinstance(conn, sqlite3.Connection):
cursor.execute("SELECT 1 FROM pg_roles WHERE rolname='tapas_admin_group'")
if len(cursor.fetchall()) > 0:
cursor.execute("GRANT ALL PRIVILEGES ON TABLE %s TO tapas_admin_group" % schema_table)
conn.commit()
return schema_table


def run_sql(conn, sql):
cursor = conn.cursor()
command = ""
Expand Down
9 changes: 5 additions & 4 deletions get_trips.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,12 @@ def write_all_pairs(conn, vType, depart, limit, tripfile, params, seed, mode=MOD
num_samples = 5
reps = collections.defaultdict(list)
cursor = conn.cursor()
command = """SELECT taz_num_id, id, st_X(representative_coordinate) AS X, st_Y(representative_coordinate) AS Y
FROM core.%s r, core.%s t WHERE r.taz_id = t.taz_id""" % (
params[SP.representatives], params[SP.taz_table])
x, y = "st_X(representative_coordinate)", "st_Y(representative_coordinate)"
command = "SELECT taz_num_id, id, %s, %s FROM core.%s r, core.%s t WHERE r.taz_id = t.taz_id" % (
x, y, params[SP.representatives], params[SP.taz_table])
if bbox:
command += " AND X > %s AND Y > %s AND X < %s AND Y < %s" % tuple(bbox.split(","))
b = bbox.split(",")
command += " AND %s > %s AND %s > %s AND %s < %s AND %s < %s" % (x, b[0], y, b[1], x, b[2], y, b[3])
cursor.execute(command + " ORDER BY taz_num_id, id")
for row in cursor:
reps[row[0]].append(row[1:])
Expand Down
45 changes: 8 additions & 37 deletions s2t_miv.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,8 @@ def upload_trip_results(conn, key, params, routes, limit=None):
print("Warning! No database connection, writing trip info to file %s.csv." % table)
print('\n'.join(map(str, tripstats[:limit])), file=open(table + ".csv", "w"))
return
schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', table)
cursor = conn.cursor()
if exists:
cursor.execute("DROP TABLE " + schema_table)
createQuery = """
if tripstats:
createQuery = """
CREATE TABLE %s
(
p_id integer NOT NULL,
Expand All @@ -146,13 +143,10 @@ def upload_trip_results(conn, key, params, routes, limit=None):
distance_real double precision[],
CONSTRAINT %s_pkey PRIMARY KEY (p_id, hh_id, start_time_min, clone_id)
)
""" % (schema_table, table)
cursor.execute(createQuery)
if tripstats:
values = [str(tuple([str(e).replace("(", "{").replace(")", "}") for e in t])) for t in tripstats[:limit]]
insertQuery = "INSERT INTO %s (p_id, hh_id, start_time_min, clone_id, travel_time_sec, distance_real) VALUES "
cursor.execute(insertQuery % schema_table + ','.join(values))
conn.commit()
"""
schema_table = db_manipulator.create_table(conn, 'temp', table, createQuery)
values = [tuple([str(e).replace("(", "{").replace(")", "}") for e in t]) for t in tripstats[:limit]]
db_manipulator.insertmany(conn, schema_table, "p_id, hh_id, start_time_min, clone_id, travel_time_sec, distance_real", values)


@benchmark
Expand Down Expand Up @@ -223,8 +217,6 @@ def upload_all_pairs(conn, tables, start, end, vType, real_routes, rep_routes, n


def create_all_pairs(conn, key, params):
cursor = conn.cursor()
schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', '%s_%s' % (params[SP.od_output], key))
createQuery = """
CREATE TABLE %s
(
Expand All @@ -238,20 +230,7 @@ def create_all_pairs(conn, key, params):
CONSTRAINT %s_pkey PRIMARY KEY (taz_id_start, taz_id_end, sumo_type, is_restricted, interval_end)
)
"""
try:
cursor.execute("DROP TABLE IF EXISTS " + schema_table)
cursor.execute(createQuery % (schema_table, table))
except Exception:
conn.rollback()
schema_table += "_fallback"
table += "_fallback"
cursor.execute("DROP TABLE IF EXISTS " + schema_table)
cursor.execute(createQuery % (schema_table, table))
conn.commit()

entry_schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', '%s_%s' % (params[SP.od_entry], key))
if exists:
cursor.execute("DROP TABLE " + schema_table)
schema_table = db_manipulator.create_table(conn, 'temp', '%s_%s' % (params[SP.od_output], key), createQuery)
createQuery = """
CREATE TABLE %s
(
Expand All @@ -266,15 +245,7 @@ def create_all_pairs(conn, key, params):
CONSTRAINT %s_pkey PRIMARY KEY (entry_id, used_modes)
)
"""
try:
cursor.execute(createQuery % (entry_schema_table, table))
except Exception:
conn.rollback()
entry_schema_table += "_fallback"
table += "_fallback"
cursor.execute("DROP TABLE IF EXISTS " + entry_schema_table)
cursor.execute(createQuery % (entry_schema_table, table))
conn.commit()
entry_schema_table = db_manipulator.create_table(conn, 'temp', '%s_%s' % (params[SP.od_entry], key), createQuery)
return schema_table, entry_schema_table


Expand Down

0 comments on commit fb63553

Please sign in to comment.