diff --git a/db_manipulator.py b/db_manipulator.py index 00fb79c..49a0048 100755 --- a/db_manipulator.py +++ b/db_manipulator.py @@ -127,6 +127,26 @@ def insertmany(conn, table, columns, parameters): conn.commit() +def create_table(conn, schema, table, createQuery): + schema_table, table, _ = check_schema_table(conn, schema, table) + cursor = conn.cursor() + try: + cursor.execute("DROP TABLE IF EXISTS " + schema_table) + cursor.execute(createQuery % (schema_table, table)) + except Exception: + conn.rollback() + schema_table += "_fallback" + table += "_fallback" + cursor.execute("DROP TABLE IF EXISTS " + schema_table) + cursor.execute(createQuery % (schema_table, table)) + if not isinstance(conn, sqlite3.Connection): + cursor.execute("SELECT 1 FROM pg_roles WHERE rolname='tapas_admin_group'") + if len(cursor.fetchall()) > 0: + cursor.execute("GRANT ALL PRIVILEGES ON TABLE %s TO tapas_admin_group" % schema_table) + conn.commit() + return schema_table + + def run_sql(conn, sql): cursor = conn.cursor() command = "" diff --git a/get_trips.py b/get_trips.py index 0420192..13a59c8 100755 --- a/get_trips.py +++ b/get_trips.py @@ -196,11 +196,12 @@ def write_all_pairs(conn, vType, depart, limit, tripfile, params, seed, mode=MOD num_samples = 5 reps = collections.defaultdict(list) cursor = conn.cursor() - command = """SELECT taz_num_id, id, st_X(representative_coordinate) AS X, st_Y(representative_coordinate) AS Y - FROM core.%s r, core.%s t WHERE r.taz_id = t.taz_id""" % ( - params[SP.representatives], params[SP.taz_table]) + x, y = "st_X(representative_coordinate)", "st_Y(representative_coordinate)" + command = "SELECT taz_num_id, id, %s, %s FROM core.%s r, core.%s t WHERE r.taz_id = t.taz_id" % ( + x, y, params[SP.representatives], params[SP.taz_table]) if bbox: - command += " AND X > %s AND Y > %s AND X < %s AND Y < %s" % tuple(bbox.split(",")) + b = bbox.split(",") + command += " AND %s > %s AND %s > %s AND %s < %s AND %s < %s" % (x, b[0], y, b[1], x, b[2], y, b[3]) cursor.execute(command + " ORDER BY taz_num_id, id") for row in cursor: reps[row[0]].append(row[1:]) diff --git a/s2t_miv.py b/s2t_miv.py index 92ae2bb..e4820d8 100755 --- a/s2t_miv.py +++ b/s2t_miv.py @@ -131,11 +131,8 @@ def upload_trip_results(conn, key, params, routes, limit=None): print("Warning! No database connection, writing trip info to file %s.csv." % table) print('\n'.join(map(str, tripstats[:limit])), file=open(table + ".csv", "w")) return - schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', table) - cursor = conn.cursor() - if exists: - cursor.execute("DROP TABLE " + schema_table) - createQuery = """ + if tripstats: + createQuery = """ CREATE TABLE %s ( p_id integer NOT NULL, @@ -146,13 +143,10 @@ def upload_trip_results(conn, key, params, routes, limit=None): distance_real double precision[], CONSTRAINT %s_pkey PRIMARY KEY (p_id, hh_id, start_time_min, clone_id) ) -""" % (schema_table, table) - cursor.execute(createQuery) - if tripstats: - values = [str(tuple([str(e).replace("(", "{").replace(")", "}") for e in t])) for t in tripstats[:limit]] - insertQuery = "INSERT INTO %s (p_id, hh_id, start_time_min, clone_id, travel_time_sec, distance_real) VALUES " - cursor.execute(insertQuery % schema_table + ','.join(values)) - conn.commit() +""" + schema_table = db_manipulator.create_table(conn, 'temp', table, createQuery) + values = [tuple([str(e).replace("(", "{").replace(")", "}") for e in t]) for t in tripstats[:limit]] + db_manipulator.insertmany(conn, schema_table, "p_id, hh_id, start_time_min, clone_id, travel_time_sec, distance_real", values) @benchmark @@ -223,8 +217,6 @@ def upload_all_pairs(conn, tables, start, end, vType, real_routes, rep_routes, n def create_all_pairs(conn, key, params): - cursor = conn.cursor() - schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', '%s_%s' % (params[SP.od_output], key)) createQuery = """ CREATE TABLE %s ( @@ -238,20 +230,7 @@ def create_all_pairs(conn, key, params): CONSTRAINT %s_pkey PRIMARY KEY (taz_id_start, taz_id_end, sumo_type, is_restricted, interval_end) ) """ - try: - cursor.execute("DROP TABLE IF EXISTS " + schema_table) - cursor.execute(createQuery % (schema_table, table)) - except Exception: - conn.rollback() - schema_table += "_fallback" - table += "_fallback" - cursor.execute("DROP TABLE IF EXISTS " + schema_table) - cursor.execute(createQuery % (schema_table, table)) - conn.commit() - - entry_schema_table, table, exists = db_manipulator.check_schema_table(conn, 'temp', '%s_%s' % (params[SP.od_entry], key)) - if exists: - cursor.execute("DROP TABLE " + schema_table) + schema_table = db_manipulator.create_table(conn, 'temp', '%s_%s' % (params[SP.od_output], key), createQuery) createQuery = """ CREATE TABLE %s ( @@ -266,15 +245,7 @@ def create_all_pairs(conn, key, params): CONSTRAINT %s_pkey PRIMARY KEY (entry_id, used_modes) ) """ - try: - cursor.execute(createQuery % (entry_schema_table, table)) - except Exception: - conn.rollback() - entry_schema_table += "_fallback" - table += "_fallback" - cursor.execute("DROP TABLE IF EXISTS " + entry_schema_table) - cursor.execute(createQuery % (entry_schema_table, table)) - conn.commit() + entry_schema_table = db_manipulator.create_table(conn, 'temp', '%s_%s' % (params[SP.od_entry], key), createQuery) return schema_table, entry_schema_table