|
|
from __future__ import annotations
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
import psycopg2
|
|
|
import psycopg2.extras
|
|
|
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
# Type aliases
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
Row = Dict[str, Any]
|
|
|
ColumnDef = Tuple[str, str, str] # (name, type, constraints)
|
|
|
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
# Helper class
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
class PSQLDB:
|
|
|
"""
|
|
|
A thin wrapper around psycopg2 that provides generic CRUD helpers.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
*,
|
|
|
host: str = "localhost",
|
|
|
port: int = 5432,
|
|
|
database: str,
|
|
|
user: str,
|
|
|
password: str,
|
|
|
sslmode: str = "prefer",
|
|
|
autocommit: bool = True,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Open a connection to the PostgreSQL database.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
host : str
|
|
|
Database host.
|
|
|
port : int
|
|
|
Database port.
|
|
|
database : str
|
|
|
Database name.
|
|
|
user : str
|
|
|
Username.
|
|
|
password : str
|
|
|
Password.
|
|
|
sslmode : str
|
|
|
SSL mode (default: "prefer").
|
|
|
autocommit : bool
|
|
|
If True, each statement is committed automatically.
|
|
|
"""
|
|
|
self.conn = psycopg2.connect(
|
|
|
host=host,
|
|
|
port=port,
|
|
|
dbname=database,
|
|
|
user=user,
|
|
|
password=password,
|
|
|
sslmode=sslmode,
|
|
|
)
|
|
|
self.conn.autocommit = autocommit
|
|
|
self.cursor_factory = psycopg2.extras.RealDictCursor
|
|
|
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
|
|
|
|
|
|
# --------------------------------------------------------------------- #
|
|
|
# Context manager support
|
|
|
# --------------------------------------------------------------------- #
|
|
|
def __enter__(self) -> "PSQLDB":
|
|
|
return self
|
|
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None:
|
|
|
if exc_type:
|
|
|
self.conn.rollback()
|
|
|
else:
|
|
|
self.conn.commit()
|
|
|
self.cursor.close()
|
|
|
self.conn.close()
|
|
|
|
|
|
def close(self) -> None:
|
|
|
"""Explicitly close the underlying connection."""
|
|
|
self.cursor.close()
|
|
|
self.conn.close()
|
|
|
|
|
|
# --------------------------------------------------------------------- #
|
|
|
# Internal helpers
|
|
|
# --------------------------------------------------------------------- #
|
|
|
def _build_create_sql(self, table: str, columns: Iterable[ColumnDef], primary_key: Optional[str]) -> str:
|
|
|
parts = []
|
|
|
for name, col_type, constraints in columns:
|
|
|
col_part = f"{name} {col_type}"
|
|
|
if constraints:
|
|
|
col_part += f" {constraints}"
|
|
|
parts.append(col_part)
|
|
|
|
|
|
pk_clause = f", PRIMARY KEY ({primary_key})" if primary_key else ""
|
|
|
columns_sql = ", ".join(parts)
|
|
|
return f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql}{pk_clause});"
|
|
|
|
|
|
def _build_set_clause(self, data: Dict[str, Any]) -> Tuple[str, List[Any]]:
|
|
|
"""
|
|
|
Helper for UPDATE and INSERT. Returns a ``SET``‑style clause
|
|
|
and the corresponding values list.
|
|
|
|
|
|
Example: data={'name':'bob', 'age':30}
|
|
|
→ ("name = %s, age = %s", ['bob', 30])
|
|
|
"""
|
|
|
if not data:
|
|
|
raise ValueError("data dictionary must contain at least one key")
|
|
|
|
|
|
clause = ", ".join(f"{k} = %s" for k in data)
|
|
|
values = list(data.values())
|
|
|
return clause, values
|
|
|
|
|
|
# --------------------------------------------------------------------- #
|
|
|
# Schema helpers
|
|
|
# --------------------------------------------------------------------- #
|
|
|
def create_table(
|
|
|
self,
|
|
|
table: str,
|
|
|
columns: Iterable[ColumnDef],
|
|
|
primary_key: Optional[str] = None,
|
|
|
*,
|
|
|
if_not_exists: bool = True,
|
|
|
# You can add more constraints here – e.g. UNIQUE, CHECK, etc.
|
|
|
) -> None:
|
|
|
"""
|
|
|
Create a new table with the supplied column set.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
columns : iterable of (name, type, constraints)
|
|
|
Constraints are appended verbatim – e.g. "NOT NULL", "UNIQUE".
|
|
|
primary_key : str | None
|
|
|
Column name that should become the primary key. If None, no PK is added.
|
|
|
if_not_exists : bool
|
|
|
If True (default) the statement is `CREATE TABLE IF NOT EXISTS`.
|
|
|
"""
|
|
|
sql = self._build_create_sql(table, columns, primary_key)
|
|
|
if not if_not_exists:
|
|
|
sql = sql.replace("IF NOT EXISTS", "")
|
|
|
self.cursor.execute(sql)
|
|
|
|
|
|
# --------------------------------------------------------------------- #
|
|
|
# CRUD helpers
|
|
|
# --------------------------------------------------------------------- #
|
|
|
def insert(self, table: str, data: Dict[str, Any]) -> int:
|
|
|
"""
|
|
|
Insert a single row and return the *generated* primary key
|
|
|
(if the table uses a `SERIAL`/`IDENTITY` PK).
|
|
|
If you supply the PK yourself, the returned value will just be that value.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
data : dict
|
|
|
Column/value mapping.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
int
|
|
|
The value of the last inserted primary key.
|
|
|
"""
|
|
|
cols = ", ".join(data.keys())
|
|
|
placeholders = ", ".join(f"%s" for _ in data)
|
|
|
values = list(data.values())
|
|
|
|
|
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders}) RETURNING *;"
|
|
|
self.cursor.execute(sql, values)
|
|
|
row = self.cursor.fetchone()
|
|
|
# If you want *only* the PK you can change this to RETURNING id, etc.
|
|
|
return row[next(iter(row))] # first column value – usually the PK
|
|
|
|
|
|
def insert_many(self, table: str, rows: Iterable[Dict[str, Any]], *, fetch_first: bool = False) -> List[int]:
|
|
|
"""
|
|
|
Bulk insert – very fast thanks to execute_values.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
rows : iterable of dict
|
|
|
Each dict represents one row.
|
|
|
fetch_first : bool
|
|
|
If True, the first row’s PK will be returned as a list of one element.
|
|
|
Useful for a quick “INSERT … RETURNING …” when you only care about the first id.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
list[int]
|
|
|
List of primary‑key values of the inserted rows.
|
|
|
"""
|
|
|
if not rows:
|
|
|
return []
|
|
|
|
|
|
keys = rows[0].keys()
|
|
|
columns = ", ".join(keys)
|
|
|
values = [tuple(row[k] for k in keys) for row in rows]
|
|
|
placeholder = f"({', '.join(['%s'] * len(keys))})"
|
|
|
|
|
|
sql = f"INSERT INTO {table} ({columns}) VALUES %s RETURNING *;"
|
|
|
psycopg2.extras.execute_values(
|
|
|
self.cursor,
|
|
|
sql,
|
|
|
values,
|
|
|
template=placeholder,
|
|
|
fetch=True,
|
|
|
)
|
|
|
# Return the PK column of every inserted row (first column by convention)
|
|
|
return [r[next(iter(r))] for r in self.cursor.fetchall()]
|
|
|
|
|
|
def get(self, table: str, pk_column: str, pk_value: Any, *, columns: Optional[List[str]] = None) -> Optional[Row]:
|
|
|
"""
|
|
|
Retrieve a single row by its primary key.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
pk_column : str
|
|
|
Primary‑key column name.
|
|
|
pk_value : Any
|
|
|
Value of the primary key.
|
|
|
columns : list[str] | None
|
|
|
If supplied, only those columns will be selected.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
dict | None
|
|
|
The row as a dictionary, or None if not found.
|
|
|
"""
|
|
|
cols = ", ".join(columns) if columns else "*"
|
|
|
sql = f"SELECT {cols} FROM {table} WHERE {pk_column} = %s LIMIT 1;"
|
|
|
self.cursor.execute(sql, (pk_value,))
|
|
|
return self.cursor.fetchone()
|
|
|
|
|
|
def get_all(
|
|
|
self,
|
|
|
table: str,
|
|
|
*,
|
|
|
columns: Optional[List[str]] = None,
|
|
|
where: Optional[str] = None,
|
|
|
where_args: Optional[Iterable[Any]] = None,
|
|
|
order_by: Optional[str] = None,
|
|
|
limit: Optional[int] = None,
|
|
|
offset: Optional[int] = None,
|
|
|
) -> List[Row]:
|
|
|
"""
|
|
|
Retrieve many rows with optional filtering, sorting and pagination.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
columns : list[str] | None
|
|
|
List of columns to fetch. If None → SELECT *.
|
|
|
where : str | None
|
|
|
Raw `WHERE` clause *without* the leading `WHERE`. Use `%s` for placeholders.
|
|
|
where_args : iterable | None
|
|
|
Values that match the placeholders in `where`.
|
|
|
order_by : str | None
|
|
|
Raw `ORDER BY` clause (no leading `ORDER BY`).
|
|
|
limit : int | None
|
|
|
Max rows to return.
|
|
|
offset : int | None
|
|
|
Skip the first N rows.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
list[dict]
|
|
|
List of rows as dictionaries.
|
|
|
"""
|
|
|
cols = ", ".join(columns) if columns else "*"
|
|
|
sql = f"SELECT {cols} FROM {table}"
|
|
|
params: List[Any] = []
|
|
|
|
|
|
if where:
|
|
|
sql += f" WHERE {where}"
|
|
|
if where_args:
|
|
|
params.extend(where_args)
|
|
|
|
|
|
if order_by:
|
|
|
sql += f" ORDER BY {order_by}"
|
|
|
|
|
|
if limit is not None:
|
|
|
sql += " LIMIT %s"
|
|
|
params.append(limit)
|
|
|
|
|
|
if offset is not None:
|
|
|
sql += " OFFSET %s"
|
|
|
params.append(offset)
|
|
|
|
|
|
self.cursor.execute(sql, params)
|
|
|
return self.cursor.fetchall()
|
|
|
|
|
|
def update(self, table: str, pk_column: str, pk_value: Any, data: Dict[str, Any]) -> bool:
|
|
|
"""
|
|
|
Update one row by its primary key.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
pk_column : str
|
|
|
Primary‑key column name.
|
|
|
pk_value : Any
|
|
|
Primary‑key value.
|
|
|
data : dict
|
|
|
Column/value pairs to update.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
bool
|
|
|
True if a row was modified, False otherwise.
|
|
|
"""
|
|
|
if not data:
|
|
|
return False
|
|
|
|
|
|
set_clause = ", ".join(f"{col} = %s" for col in data)
|
|
|
params = list(data.values()) + [pk_value]
|
|
|
sql = f"UPDATE {table} SET {set_clause} WHERE {pk_column} = %s;"
|
|
|
self.cursor.execute(sql, params)
|
|
|
return self.cursor.rowcount > 0
|
|
|
|
|
|
def delete(self, table: str, pk_column: str, pk_value: Any) -> bool:
|
|
|
"""
|
|
|
Delete a row by its primary key.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
bool
|
|
|
True if a row was deleted.
|
|
|
"""
|
|
|
sql = f"DELETE FROM {table} WHERE {pk_column} = %s;"
|
|
|
self.cursor.execute(sql, (pk_value,))
|
|
|
return self.cursor.rowcount > 0
|
|
|
|
|
|
#
|
|
|
# # ----------------------------------------------------------------------- #
|
|
|
# # Demo usage
|
|
|
# # ----------------------------------------------------------------------- #
|
|
|
# def _demo():
|
|
|
# db = os.getenv("POSTGRES_DB", "testdb")
|
|
|
# user = os.getenv("POSTGRES_USER", "postgres")
|
|
|
# pwd = os.getenv("POSTGRES_PASSWORD", "")
|
|
|
# host = os.getenv("POSTGRES_HOST", "localhost")
|
|
|
# port = int(os.getenv("POSTGRES_PORT", "5432"))
|
|
|
#
|
|
|
# with PostgreSQLDB(
|
|
|
# host=host, database=db, user=user, password=pwd, port=port
|
|
|
# ) as db:
|
|
|
#
|
|
|
# # 1. Create table
|
|
|
# db.create_table(
|
|
|
# "users",
|
|
|
# [
|
|
|
# ("id", "SERIAL", ""), # Identity / auto‑increment
|
|
|
# ("username", "TEXT", "NOT NULL"),
|
|
|
# ("email", "TEXT", "UNIQUE"),
|
|
|
# ("age", "INTEGER", "CHECK (age >= 0)"),
|
|
|
# ],
|
|
|
# primary_key="id",
|
|
|
# )
|
|
|
#
|
|
|
# # 2. Insert
|
|
|
# user_id = db.insert(
|
|
|
# "users",
|
|
|
# {"username": "alice", "email": "alice@example.com", "age": 25},
|
|
|
# )
|
|
|
# print(f"Inserted user id={user_id}")
|
|
|
#
|
|
|
# # 3. Bulk insert
|
|
|
# bulk_ids = db.insert_many(
|
|
|
# "users",
|
|
|
# [
|
|
|
# {"username": "bob", "email": "bob@example.com", "age": 30},
|
|
|
# {"username": "carol", "email": "carol@example.com", "age": 22},
|
|
|
# ],
|
|
|
# )
|
|
|
# print(f"Bulk inserted ids: {bulk_ids}")
|
|
|
#
|
|
|
# # 4. Get single
|
|
|
# user = db.get("users", "id", user_id)
|
|
|
# print("Fetched user:", user)
|
|
|
#
|
|
|
# # 5. Get many
|
|
|
# all_users = db.get_all(
|
|
|
# "users",
|
|
|
# where="age >= %s",
|
|
|
# where_args=[20],
|
|
|
# order_by="age DESC",
|
|
|
# limit=10,
|
|
|
# )
|
|
|
# print("All users:", all_users)
|
|
|
#
|
|
|
# # 6. Update
|
|
|
# updated = db.update("users", "id", user_id, {"age": 26, "email": "alice26@example.com"})
|
|
|
# print("Update status:", updated)
|
|
|
#
|
|
|
# # 7. Delete
|
|
|
# deleted = db.delete("users", "id", user_id)
|
|
|
# print("Deleted:", deleted)
|
|
|
#
|
|
|
# # ----------------------------------------------------------------------- #
|
|
|
# # Entry point for `python -m` usage
|
|
|
# # ----------------------------------------------------------------------- #
|
|
|
# if __name__ == "__main__":
|
|
|
# _demo() |