You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
getDiscography/python/utils/psqldb.py

411 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional, Tuple
import psycopg2
import psycopg2.extras
# --------------------------------------------------------------------------- #
# Type aliases
# --------------------------------------------------------------------------- #
Row = Dict[str, Any]
ColumnDef = Tuple[str, str, str] # (name, type, constraints)
# --------------------------------------------------------------------------- #
# Helper class
# --------------------------------------------------------------------------- #
class PSQLDB:
"""
A thin wrapper around psycopg2 that provides generic CRUD helpers.
"""
def __init__(
self,
*,
host: str = "localhost",
port: int = 5432,
database: str,
user: str,
password: str,
sslmode: str = "prefer",
autocommit: bool = True,
) -> None:
"""
Open a connection to the PostgreSQL database.
Parameters
----------
host : str
Database host.
port : int
Database port.
database : str
Database name.
user : str
Username.
password : str
Password.
sslmode : str
SSL mode (default: "prefer").
autocommit : bool
If True, each statement is committed automatically.
"""
self.conn = psycopg2.connect(
host=host,
port=port,
dbname=database,
user=user,
password=password,
sslmode=sslmode,
)
self.conn.autocommit = autocommit
self.cursor_factory = psycopg2.extras.RealDictCursor
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
# --------------------------------------------------------------------- #
# Context manager support
# --------------------------------------------------------------------- #
def __enter__(self) -> "PSQLDB":
return self
def __exit__(self, exc_type, exc, tb) -> None:
if exc_type:
self.conn.rollback()
else:
self.conn.commit()
self.cursor.close()
self.conn.close()
def close(self) -> None:
"""Explicitly close the underlying connection."""
self.cursor.close()
self.conn.close()
# --------------------------------------------------------------------- #
# Internal helpers
# --------------------------------------------------------------------- #
def _build_create_sql(self, table: str, columns: Iterable[ColumnDef], primary_key: Optional[str]) -> str:
parts = []
for name, col_type, constraints in columns:
col_part = f"{name} {col_type}"
if constraints:
col_part += f" {constraints}"
parts.append(col_part)
pk_clause = f", PRIMARY KEY ({primary_key})" if primary_key else ""
columns_sql = ", ".join(parts)
return f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql}{pk_clause});"
def _build_set_clause(self, data: Dict[str, Any]) -> Tuple[str, List[Any]]:
"""
Helper for UPDATE and INSERT. Returns a ``SET``style clause
and the corresponding values list.
Example: data={'name':'bob', 'age':30}
→ ("name = %s, age = %s", ['bob', 30])
"""
if not data:
raise ValueError("data dictionary must contain at least one key")
clause = ", ".join(f"{k} = %s" for k in data)
values = list(data.values())
return clause, values
# --------------------------------------------------------------------- #
# Schema helpers
# --------------------------------------------------------------------- #
def create_table(
self,
table: str,
columns: Iterable[ColumnDef],
primary_key: Optional[str] = None,
*,
if_not_exists: bool = True,
# You can add more constraints here e.g. UNIQUE, CHECK, etc.
) -> None:
"""
Create a new table with the supplied column set.
Parameters
----------
table : str
Table name.
columns : iterable of (name, type, constraints)
Constraints are appended verbatim e.g. "NOT NULL", "UNIQUE".
primary_key : str | None
Column name that should become the primary key. If None, no PK is added.
if_not_exists : bool
If True (default) the statement is `CREATE TABLE IF NOT EXISTS`.
"""
sql = self._build_create_sql(table, columns, primary_key)
if not if_not_exists:
sql = sql.replace("IF NOT EXISTS", "")
self.cursor.execute(sql)
# --------------------------------------------------------------------- #
# CRUD helpers
# --------------------------------------------------------------------- #
def insert(self, table: str, data: Dict[str, Any]) -> int:
"""
Insert a single row and return the *generated* primary key
(if the table uses a `SERIAL`/`IDENTITY` PK).
If you supply the PK yourself, the returned value will just be that value.
Parameters
----------
table : str
Table name.
data : dict
Column/value mapping.
Returns
-------
int
The value of the last inserted primary key.
"""
cols = ", ".join(data.keys())
placeholders = ", ".join(f"%s" for _ in data)
values = list(data.values())
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders}) RETURNING *;"
self.cursor.execute(sql, values)
row = self.cursor.fetchone()
# If you want *only* the PK you can change this to RETURNING id, etc.
return row[next(iter(row))] # first column value usually the PK
def insert_many(self, table: str, rows: Iterable[Dict[str, Any]], *, fetch_first: bool = False) -> List[int]:
"""
Bulk insert very fast thanks to execute_values.
Parameters
----------
table : str
Table name.
rows : iterable of dict
Each dict represents one row.
fetch_first : bool
If True, the first rows PK will be returned as a list of one element.
Useful for a quick “INSERT … RETURNING …” when you only care about the first id.
Returns
-------
list[int]
List of primarykey values of the inserted rows.
"""
if not rows:
return []
keys = rows[0].keys()
columns = ", ".join(keys)
values = [tuple(row[k] for k in keys) for row in rows]
placeholder = f"({', '.join(['%s'] * len(keys))})"
sql = f"INSERT INTO {table} ({columns}) VALUES %s RETURNING *;"
psycopg2.extras.execute_values(
self.cursor,
sql,
values,
template=placeholder,
fetch=True,
)
# Return the PK column of every inserted row (first column by convention)
return [r[next(iter(r))] for r in self.cursor.fetchall()]
def get(self, table: str, pk_column: str, pk_value: Any, *, columns: Optional[List[str]] = None) -> Optional[Row]:
"""
Retrieve a single row by its primary key.
Parameters
----------
table : str
Table name.
pk_column : str
Primarykey column name.
pk_value : Any
Value of the primary key.
columns : list[str] | None
If supplied, only those columns will be selected.
Returns
-------
dict | None
The row as a dictionary, or None if not found.
"""
cols = ", ".join(columns) if columns else "*"
sql = f"SELECT {cols} FROM {table} WHERE {pk_column} = %s LIMIT 1;"
self.cursor.execute(sql, (pk_value,))
return self.cursor.fetchone()
def get_all(
self,
table: str,
*,
columns: Optional[List[str]] = None,
where: Optional[str] = None,
where_args: Optional[Iterable[Any]] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[Row]:
"""
Retrieve many rows with optional filtering, sorting and pagination.
Parameters
----------
table : str
Table name.
columns : list[str] | None
List of columns to fetch. If None → SELECT *.
where : str | None
Raw `WHERE` clause *without* the leading `WHERE`. Use `%s` for placeholders.
where_args : iterable | None
Values that match the placeholders in `where`.
order_by : str | None
Raw `ORDER BY` clause (no leading `ORDER BY`).
limit : int | None
Max rows to return.
offset : int | None
Skip the first N rows.
Returns
-------
list[dict]
List of rows as dictionaries.
"""
cols = ", ".join(columns) if columns else "*"
sql = f"SELECT {cols} FROM {table}"
params: List[Any] = []
if where:
sql += f" WHERE {where}"
if where_args:
params.extend(where_args)
if order_by:
sql += f" ORDER BY {order_by}"
if limit is not None:
sql += " LIMIT %s"
params.append(limit)
if offset is not None:
sql += " OFFSET %s"
params.append(offset)
self.cursor.execute(sql, params)
return self.cursor.fetchall()
def update(self, table: str, pk_column: str, pk_value: Any, data: Dict[str, Any]) -> bool:
"""
Update one row by its primary key.
Parameters
----------
table : str
Table name.
pk_column : str
Primarykey column name.
pk_value : Any
Primarykey value.
data : dict
Column/value pairs to update.
Returns
-------
bool
True if a row was modified, False otherwise.
"""
if not data:
return False
set_clause = ", ".join(f"{col} = %s" for col in data)
params = list(data.values()) + [pk_value]
sql = f"UPDATE {table} SET {set_clause} WHERE {pk_column} = %s;"
self.cursor.execute(sql, params)
return self.cursor.rowcount > 0
def delete(self, table: str, pk_column: str, pk_value: Any) -> bool:
"""
Delete a row by its primary key.
Returns
-------
bool
True if a row was deleted.
"""
sql = f"DELETE FROM {table} WHERE {pk_column} = %s;"
self.cursor.execute(sql, (pk_value,))
return self.cursor.rowcount > 0
#
# # ----------------------------------------------------------------------- #
# # Demo usage
# # ----------------------------------------------------------------------- #
# def _demo():
# db = os.getenv("POSTGRES_DB", "testdb")
# user = os.getenv("POSTGRES_USER", "postgres")
# pwd = os.getenv("POSTGRES_PASSWORD", "")
# host = os.getenv("POSTGRES_HOST", "localhost")
# port = int(os.getenv("POSTGRES_PORT", "5432"))
#
# with PostgreSQLDB(
# host=host, database=db, user=user, password=pwd, port=port
# ) as db:
#
# # 1. Create table
# db.create_table(
# "users",
# [
# ("id", "SERIAL", ""), # Identity / autoincrement
# ("username", "TEXT", "NOT NULL"),
# ("email", "TEXT", "UNIQUE"),
# ("age", "INTEGER", "CHECK (age >= 0)"),
# ],
# primary_key="id",
# )
#
# # 2. Insert
# user_id = db.insert(
# "users",
# {"username": "alice", "email": "alice@example.com", "age": 25},
# )
# print(f"Inserted user id={user_id}")
#
# # 3. Bulk insert
# bulk_ids = db.insert_many(
# "users",
# [
# {"username": "bob", "email": "bob@example.com", "age": 30},
# {"username": "carol", "email": "carol@example.com", "age": 22},
# ],
# )
# print(f"Bulk inserted ids: {bulk_ids}")
#
# # 4. Get single
# user = db.get("users", "id", user_id)
# print("Fetched user:", user)
#
# # 5. Get many
# all_users = db.get_all(
# "users",
# where="age >= %s",
# where_args=[20],
# order_by="age DESC",
# limit=10,
# )
# print("All users:", all_users)
#
# # 6. Update
# updated = db.update("users", "id", user_id, {"age": 26, "email": "alice26@example.com"})
# print("Update status:", updated)
#
# # 7. Delete
# deleted = db.delete("users", "id", user_id)
# print("Deleted:", deleted)
#
# # ----------------------------------------------------------------------- #
# # Entry point for `python -m` usage
# # ----------------------------------------------------------------------- #
# if __name__ == "__main__":
# _demo()