首页 > 数据库技术 > 详细

beijing_taxi_2012 to pgdb

时间:2021-05-21 22:04:59      阅读:21      评论:0      收藏:0      [点我收藏+]
import psycopg2
import pandas as pd
from io import StringIO
import csv
from itertools import islice
import os
import numpy as np
from datetime import datetime

def if_contain_symbol(keyword):
    symbols = "?>>~!@#$%^&*()_+-*/<>,.[]\/"
    for symbol in symbols:
        if symbol in keyword:
            return True
    else:
        return False

def table_exist(table_name=None, conn=None, cur=None):
    try:
        cur.execute("select to_regclass(" + "\‘" + table_name + "\‘" + ") is not null")
        rows = cur.fetchall()
    except Exception as e:
        rows = []
        conn.close()
    if rows:
        data = rows
        flag = data[0][0]
        return flag

def txt2csv(txt_dir, csv_save_dir):
    #  TXT ----> CSV
    txt_files = os.listdir(txt_dir)
    for txt in txt_files:
        # print(txt)
        txt_path = os.path.join(txt_dir, txt)
        ts_m = txt.split("_")[-1].split(".")[0]
        ts_m = datetime.strptime(ts_m, %Y%m%d%H%M%S).strftime(%Y-%m-%d %H:%M:%S)
        csv_name = txt.split(".")[0] + ".txt"
        csv_file = os.path.join(csv_save_dir, csv_name)
        with open(txt_path, r) as read_file:
            reader = csv.reader(read_file)
            Trjs = []
            for row in islice(reader, 1, None):  # 跳过第一行
                row[1] = row[1].split("$")[-1]
                row[-1] = row[-1].split("#")[0]
                flag = if_contain_symbol(row[2])  # 乱码判断
                trj_list = []
                if flag==False:
                    # print(row)
                    serial_number = row[0]
                    code_company = row[1]
                    unit_id = row[2]
                    ts_m = ts_m
                    ts_s = row[3]
                    ts_s = datetime.strptime(ts_s, %Y%m%d%H%M%S).strftime(%Y-%m-%d %H:%M:%S)
                    lon = row[4]
                    lat = row[5]
                    speed = row[8]
                    direction = row[9]
                    state = row[10]
                    event = row[11]
                    trj_list.append(serial_number)
                    trj_list.append(code_company)
                    trj_list.append(unit_id)
                    trj_list.append(ts_m)
                    trj_list.append(ts_s)
                    trj_list.append(lon)
                    trj_list.append(lat)
                    trj_list.append(speed)
                    trj_list.append(direction)
                    trj_list.append(state)
                    trj_list.append(event)

                    Trjs.append(trj_list)
            # 列表
            Trjs_name_list = [serial_number, code_company, unit_id, ts_m,
                              ts_s, lon, lat, speed, direction, state, event]
            # list转dataframe
            df = pd.DataFrame(Trjs, columns=Trjs_name_list)
            # 保存到
            df.to_csv(csv_file, header=False, index=False, encoding="utf-8")
    print("txt2csv, down!!!")

def filterByinteral(time_interal, csv_save_dir):
    ymd = time_interal[0].split(" ")[0].split("-")
    ymd = ymd[0] + ymd[1] + ymd[2]
    hms = time_interal[0].split(" ")[-1].split(":")
    hms = hms[0] + hms[1] + hms[2]
    time_interal[0] = ymd + hms

    ymd = time_interal[1].split(" ")[0].split("-")
    ymd = ymd[0] + ymd[1] + ymd[2]
    hms = time_interal[1].split(" ")[-1].split(":")
    hms = hms[0] + hms[1] + hms[2]
    time_interal[1] = ymd + hms

    csv_files = os.listdir(csv_save_dir)
    files_list = []
    for csv in csv_files:
        # print(csv)
        tm_m = float(csv.split("_")[-1].split(".")[0])
        tm_min, tm_max = float(time_interal[0]), float(time_interal[1])
        if tm_m >= tm_min and tm_m <= tm_max:
            files_list.append(csv)
    return files_list

csv_save_dir = rD:/DataWorkspace/data/20121024_csv
txt_dir = rD:/DataWorkspace/data/20121024
# TXT --> CSV # # # #
txt2csv(txt_dir, csv_save_dir)

# connection the database
conn = psycopg2.connect(database="beijing", user="jiangshan", password="jiangshan", host="localhost", port="5432")
cur = conn.cursor()

# table_name
table_name = "taxi2012_bj"
the_geom_SRID = "4326"

# 时空查询点
point = (116.306251, 39.98070)
point_r = 1000.5#

# 时间片区间限制 该时段范围内的轨迹点
time_interal = [2012-10-24 10:00:00, 2012-10-24 13:00:00]

# CREATE TABLE IF table IS NOT EXIST
# 查询出来的表是否存在的状态,存在则为True,不存在则为False
table_flg = table_exist(table_name, conn, cur)
if table_flg is False:
    sql = "DROP TABLE public.{0} CASCADE".format(table_name)# -- 删除表
    sql = "CREATE TABLE IF NOT EXISTS {0} (serial_number BIGINT, code_company TEXT, unit_id BIGINT, ts_m TIMESTAMP, ts_s TIMESTAMP, lon DOUBLE PRECISION, lat DOUBLE PRECISION, speed FLOAT, direction FLOAT, state INT, event INT)".format(table_name)
    cur.execute(sql)
    conn.commit()
files_list = filterByinteral(time_interal, csv_save_dir)
# 插入数据
print(IMPORT FILES......)# copy_from 不支持 GEOMETRY对象批量导入
header_name_list = [serial_number, code_company, unit_id, ts_m, ts_s, lon, lat, speed, direction, state, event]
dtype_dic = {serial_number: object, code_company: object, unit_id: object, ts_m: object, ts_s: object, lon: object, lat: object, speed: float, direction: float, state: object, event: object}
id_list = []
for csv in files_list:
    print(csv)
    csv_path = os.path.join(csv_save_dir, csv)
    data = pd.read_csv(csv_path, header=None, names=header_name_list, dtype=dtype_dic)

    u_id = np.unique(data.unit_id.values).tolist()
    id_list += u_id

    # dataframe类型转换为IO缓冲区中的str类型
    output = StringIO()
    data.to_csv(output, sep=\t, index=False, header=False)
    output = output.getvalue()
    # print(output)
    cur.copy_from(StringIO(output), table_name)
    conn.commit()
id_list = list(set(id_list))
print(IMPORT FILES OK!!)# copy_from 不支持 GEOMETRY对象批量导入

print(ADD A GEOMETRY COLUMN......)
# ADD A GEOMETRY COLUMN
cur.execute("alter table " + table_name + " add the_geom GEOMETRY")
conn.commit()

print(UPDATE THE GEOMETRY.....)
# UPDATE THE GEOMETRY
# for id in id_list:
#     sql = "UPDATE " + table_name + " set the_geom=st_geomfromtext(\‘POINT(\‘|| lon ||‘ ‘|| lat ||\‘)\‘,\‘"+the_geom_SRID+"\‘) where unit_id = {0}".format(id)
#     sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\‘POINT(\‘ || lon || \‘ \‘ || lat || \‘)\‘, {1})".format(table_name, the_geom_SRID)
#     cur.execute(sql)
#     conn.commit()
sql = "UPDATE public.{0} SET the_geom=st_geomfromtext(\‘POINT(\‘ || lon || \‘ \‘ || lat || \‘)\‘, {1})".format(table_name, the_geom_SRID)
cur.execute(sql)
conn.commit()

cur.close()
conn.close()
print(done)

 

beijing_taxi_2012 to pgdb

原文:https://www.cnblogs.com/jeshy/p/14797234.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!