PostgreSQL 数据库操作类优化

以下是对你的 GPDB 类的优化建议,包括性能改进、错误处理和代码结构优化:

import pandas as pd
import psycopg2
import psycopg2.extras
from io import StringIO
import contextlib
from typing import Optional, List, Dict, Any, Union

class GPDB:
    def __init__(self, dbname: str, user: str, password: str, host: str, port: str):
        """
        初始化数据库连接参数
        
        参数:
            dbname: 数据库名
            user: 用户名
            password: 密码
            host: 主机地址
            port: 端口号
        """
        self.dbname = dbname
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self._connection_pool = None  # 可以扩展为连接池

    @contextlib.contextmanager
    def _get_cursor(self, cursor_factory=None):
        """
        上下文管理器,自动处理连接和游标的创建与关闭
        
        参数:
            cursor_factory: 游标工厂,默认为DictCursor
        """
        conn = None
        cursor = None
        try:
            conn = self.gp_connect()
            cursor = conn.cursor(cursor_factory=cursor_factory or psycopg2.extras.DictCursor)
            yield cursor
            conn.commit()
        except Exception as e:
            if conn:
                conn.rollback()
            raise e
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()

    def gp_connect(self):
        """建立数据库连接"""
        try:
            return psycopg2.connect(
                dbname=self.dbname,
                user=self.user,
                password=self.password,
                host=self.host,
                port=self.port,
                connect_timeout=10  # 添加连接超时
            )
        except psycopg2.Error as e:
            raise ConnectionError(f"无法连接到Greenplum服务器: {e}")

    def select_data(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
        """
        执行查询并返回结果列表
        
        参数:
            sql: SQL查询语句
            params: SQL参数
            
        返回:
            包含查询结果的字典列表
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            return cur.fetchall()

    def execute_sql(self, sql: str, params: Optional[tuple] = None) -> int:
        """
        执行SQL语句(INSERT, UPDATE, DELETE等)
        
        参数:
            sql: SQL语句
            params: SQL参数
            
        返回:
            影响的行数
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            return cur.rowcount

    def truncate_table(self, table_name: str, cascade: bool = False) -> None:
        """
        清空表数据
        
        参数:
            table_name: 表名
            cascade: 是否级联清空相关表
        """
        sql = f"TRUNCATE TABLE {table_name}"
        if cascade:
            sql += " CASCADE"
        self.execute_sql(sql)

    def insert_df(self, table_name: str, df: pd.DataFrame, batch_size: int = 10000) -> int:
        """
        使用批量插入方式将DataFrame数据写入数据库
        
        参数:
            table_name: 目标表名
            df: 要插入的DataFrame
            batch_size: 每批插入的行数
            
        返回:
            插入的总行数
        """
        if df.empty:
            return 0

        columns = ', '.join(df.columns)
        placeholders = ', '.join(['%s'] * len(df.columns))
        sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"

        total_rows = 0
        with self._get_cursor() as cur:
            # 分批插入数据
            for i in range(0, len(df), batch_size):
                batch = df.iloc[i:i + batch_size]
                psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
                total_rows += len(batch)
        
        return total_rows

    def read_df(self, sql: str, params: Optional[tuple] = None) -> pd.DataFrame:
        """
        执行SQL查询并返回DataFrame
        
        参数:
            sql: SQL查询语句
            params: SQL参数
            
        返回:
            包含查询结果的DataFrame
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            columns = [desc[0] for desc in cur.description]
            data = cur.fetchall()
            return pd.DataFrame(data, columns=columns)

    def copy_from_df(self, table_name: str, df: pd.DataFrame, sep: str = '\t', null: str = '\\N') -> None:
        """
        使用COPY命令高效导入数据
        
        参数:
            table_name: 目标表名
            df: 要导入的DataFrame
            sep: 分隔符
            null: NULL值的表示方式
        """
        if df.empty:
            return

        with StringIO() as buffer:
            df.to_csv(buffer, sep=sep, index=False, header=False, na_rep=null)
            buffer.seek(0)
            
            with self._get_cursor() as cur:
                cur.copy_from(buffer, table_name, sep=sep, columns=df.columns.tolist(), null=null)

    def copy_from_file(self, table_name: str, file_path: str, sep: str = '\t', columns: Optional[List[str]] = None) -> None:
        """
        从文件导入数据到数据库表
        
        参数:
            table_name: 目标表名
            file_path: 文件路径
            sep: 分隔符
            columns: 要导入的列名列表
        """
        with open(file_path, 'r') as f:
            with self._get_cursor() as cur:
                cur.copy_from(f, table_name, sep=sep, columns=columns)

    def upsert_df(self, table_name: str, df: pd.DataFrame, conflict_columns: List[str], update_columns: List[str]) -> int:
        """
        执行UPSERT操作(存在则更新,不存在则插入)
        
        参数:
            table_name: 目标表名
            df: 要插入/更新的DataFrame
            conflict_columns: 冲突检测列
            update_columns: 需要更新的列
            
        返回:
            影响的总行数
        """
        if df.empty:
            return 0

        columns = ', '.join(df.columns)
        placeholders = ', '.join(['%s'] * len(df.columns))
        update_set = ', '.join([f"{col} = EXCLUDED.{col}" for col in update_columns])
        
        sql = f"""
        INSERT INTO {table_name} ({columns}) 
        VALUES ({placeholders})
        ON CONFLICT ({', '.join(conflict_columns)}) 
        DO UPDATE SET {update_set}
        """

        total_rows = 0
        with self._get_cursor() as cur:
            # 分批执行UPSERT
            for i in range(0, len(df), 10000):
                batch = df.iloc[i:i + 10000]
                psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
                total_rows += len(batch)
        
        return total_rows

优化说明

  1. 类型提示:添加了类型提示,提高代码可读性和IDE支持

  2. 上下文管理器:使用contextlib.contextmanager创建上下文管理器,自动处理连接和事务

  3. 批量操作

    • 添加了批量插入和批量更新功能
    • 默认分批处理大数据量,避免内存问题
  4. 错误处理

    • 更完善的错误处理和事务回滚
    • 连接超时设置
  5. 新增功能

    • 添加了upsert_df方法实现存在则更新,不存在则插入
    • 添加了通用execute_sql方法
  6. 性能优化

    • 使用execute_batch替代executemany提高批量插入性能
    • 改进了COPY命令的实现
  7. 代码结构

    • 更清晰的文档字符串
    • 更合理的参数命名
    • 分离不同功能的方法

使用示例

# 初始化
db = GPDB(dbname="mydb", user="user", password="pass", host="localhost", port="5432")

# 查询数据
results = db.select_data("SELECT * FROM users WHERE age > %s", (30,))

# 读取为DataFrame
df = db.read_df("SELECT * FROM products")

# 插入DataFrame
db.insert_df("products", df)

# 高效导入大数据
db.copy_from_df("large_table", large_df)

# UPSERT操作
db.upsert_df("users", user_df, conflict_columns=["id"], update_columns=["name", "email"])

这个优化版本提供了更好的性能、更强的健壮性和更清晰的接口设计。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐