from __future__ import annotations import os import sys from typing import Any, Dict, Iterable, List, Optional, Tuple import psycopg2 import psycopg2.extras # --------------------------------------------------------------------------- # # Type aliases # --------------------------------------------------------------------------- # Row = Dict[str, Any] ColumnDef = Tuple[str, str, str] # (name, type, constraints) # --------------------------------------------------------------------------- # # Helper class # --------------------------------------------------------------------------- # class PSQLDB: """ A thin wrapper around psycopg2 that provides generic CRUD helpers. """ def __init__( self, *, host: str = "localhost", port: int = 5432, database: str, user: str, password: str, sslmode: str = "prefer", autocommit: bool = True, ) -> None: """ Open a connection to the PostgreSQL database. Parameters ---------- host : str Database host. port : int Database port. database : str Database name. user : str Username. password : str Password. sslmode : str SSL mode (default: "prefer"). autocommit : bool If True, each statement is committed automatically. """ self.conn = psycopg2.connect( host=host, port=port, dbname=database, user=user, password=password, sslmode=sslmode, ) self.conn.autocommit = autocommit self.cursor_factory = psycopg2.extras.RealDictCursor self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory) # --------------------------------------------------------------------- # # Context manager support # --------------------------------------------------------------------- # def __enter__(self) -> "PSQLDB": return self def __exit__(self, exc_type, exc, tb) -> None: if exc_type: self.conn.rollback() else: self.conn.commit() self.cursor.close() self.conn.close() def close(self) -> None: """Explicitly close the underlying connection.""" self.cursor.close() self.conn.close() # --------------------------------------------------------------------- # # Internal helpers # --------------------------------------------------------------------- # def _build_create_sql(self, table: str, columns: Iterable[ColumnDef], primary_key: Optional[str]) -> str: parts = [] for name, col_type, constraints in columns: col_part = f"{name} {col_type}" if constraints: col_part += f" {constraints}" parts.append(col_part) pk_clause = f", PRIMARY KEY ({primary_key})" if primary_key else "" columns_sql = ", ".join(parts) return f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql}{pk_clause});" def _build_set_clause(self, data: Dict[str, Any]) -> Tuple[str, List[Any]]: """ Helper for UPDATE and INSERT. Returns a ``SET``‑style clause and the corresponding values list. Example: data={'name':'bob', 'age':30} → ("name = %s, age = %s", ['bob', 30]) """ if not data: raise ValueError("data dictionary must contain at least one key") clause = ", ".join(f"{k} = %s" for k in data) values = list(data.values()) return clause, values # --------------------------------------------------------------------- # # Schema helpers # --------------------------------------------------------------------- # def create_table( self, table: str, columns: Iterable[ColumnDef], primary_key: Optional[str] = None, *, if_not_exists: bool = True, # You can add more constraints here – e.g. UNIQUE, CHECK, etc. ) -> None: """ Create a new table with the supplied column set. Parameters ---------- table : str Table name. columns : iterable of (name, type, constraints) Constraints are appended verbatim – e.g. "NOT NULL", "UNIQUE". primary_key : str | None Column name that should become the primary key. If None, no PK is added. if_not_exists : bool If True (default) the statement is `CREATE TABLE IF NOT EXISTS`. """ sql = self._build_create_sql(table, columns, primary_key) if not if_not_exists: sql = sql.replace("IF NOT EXISTS", "") self.cursor.execute(sql) # --------------------------------------------------------------------- # # CRUD helpers # --------------------------------------------------------------------- # def insert(self, table: str, data: Dict[str, Any]) -> int: """ Insert a single row and return the *generated* primary key (if the table uses a `SERIAL`/`IDENTITY` PK). If you supply the PK yourself, the returned value will just be that value. Parameters ---------- table : str Table name. data : dict Column/value mapping. Returns ------- int The value of the last inserted primary key. """ cols = ", ".join(data.keys()) placeholders = ", ".join(f"%s" for _ in data) values = list(data.values()) sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders}) RETURNING *;" self.cursor.execute(sql, values) row = self.cursor.fetchone() # If you want *only* the PK you can change this to RETURNING id, etc. return row[next(iter(row))] # first column value – usually the PK def insert_many(self, table: str, rows: Iterable[Dict[str, Any]], *, fetch_first: bool = False) -> List[int]: """ Bulk insert – very fast thanks to execute_values. Parameters ---------- table : str Table name. rows : iterable of dict Each dict represents one row. fetch_first : bool If True, the first row’s PK will be returned as a list of one element. Useful for a quick “INSERT … RETURNING …” when you only care about the first id. Returns ------- list[int] List of primary‑key values of the inserted rows. """ if not rows: return [] keys = rows[0].keys() columns = ", ".join(keys) values = [tuple(row[k] for k in keys) for row in rows] placeholder = f"({', '.join(['%s'] * len(keys))})" sql = f"INSERT INTO {table} ({columns}) VALUES %s RETURNING *;" psycopg2.extras.execute_values( self.cursor, sql, values, template=placeholder, fetch=True, ) # Return the PK column of every inserted row (first column by convention) return [r[next(iter(r))] for r in self.cursor.fetchall()] def get(self, table: str, pk_column: str, pk_value: Any, *, columns: Optional[List[str]] = None) -> Optional[Row]: """ Retrieve a single row by its primary key. Parameters ---------- table : str Table name. pk_column : str Primary‑key column name. pk_value : Any Value of the primary key. columns : list[str] | None If supplied, only those columns will be selected. Returns ------- dict | None The row as a dictionary, or None if not found. """ cols = ", ".join(columns) if columns else "*" sql = f"SELECT {cols} FROM {table} WHERE {pk_column} = %s LIMIT 1;" self.cursor.execute(sql, (pk_value,)) return self.cursor.fetchone() def get_all( self, table: str, *, columns: Optional[List[str]] = None, where: Optional[str] = None, where_args: Optional[Iterable[Any]] = None, order_by: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[Row]: """ Retrieve many rows with optional filtering, sorting and pagination. Parameters ---------- table : str Table name. columns : list[str] | None List of columns to fetch. If None → SELECT *. where : str | None Raw `WHERE` clause *without* the leading `WHERE`. Use `%s` for placeholders. where_args : iterable | None Values that match the placeholders in `where`. order_by : str | None Raw `ORDER BY` clause (no leading `ORDER BY`). limit : int | None Max rows to return. offset : int | None Skip the first N rows. Returns ------- list[dict] List of rows as dictionaries. """ cols = ", ".join(columns) if columns else "*" sql = f"SELECT {cols} FROM {table}" params: List[Any] = [] if where: sql += f" WHERE {where}" if where_args: params.extend(where_args) if order_by: sql += f" ORDER BY {order_by}" if limit is not None: sql += " LIMIT %s" params.append(limit) if offset is not None: sql += " OFFSET %s" params.append(offset) self.cursor.execute(sql, params) return self.cursor.fetchall() def update(self, table: str, pk_column: str, pk_value: Any, data: Dict[str, Any]) -> bool: """ Update one row by its primary key. Parameters ---------- table : str Table name. pk_column : str Primary‑key column name. pk_value : Any Primary‑key value. data : dict Column/value pairs to update. Returns ------- bool True if a row was modified, False otherwise. """ if not data: return False set_clause = ", ".join(f"{col} = %s" for col in data) params = list(data.values()) + [pk_value] sql = f"UPDATE {table} SET {set_clause} WHERE {pk_column} = %s;" self.cursor.execute(sql, params) return self.cursor.rowcount > 0 def delete(self, table: str, pk_column: str, pk_value: Any) -> bool: """ Delete a row by its primary key. Returns ------- bool True if a row was deleted. """ sql = f"DELETE FROM {table} WHERE {pk_column} = %s;" self.cursor.execute(sql, (pk_value,)) return self.cursor.rowcount > 0 # # # ----------------------------------------------------------------------- # # # Demo usage # # ----------------------------------------------------------------------- # # def _demo(): # db = os.getenv("POSTGRES_DB", "testdb") # user = os.getenv("POSTGRES_USER", "postgres") # pwd = os.getenv("POSTGRES_PASSWORD", "") # host = os.getenv("POSTGRES_HOST", "localhost") # port = int(os.getenv("POSTGRES_PORT", "5432")) # # with PostgreSQLDB( # host=host, database=db, user=user, password=pwd, port=port # ) as db: # # # 1. Create table # db.create_table( # "users", # [ # ("id", "SERIAL", ""), # Identity / auto‑increment # ("username", "TEXT", "NOT NULL"), # ("email", "TEXT", "UNIQUE"), # ("age", "INTEGER", "CHECK (age >= 0)"), # ], # primary_key="id", # ) # # # 2. Insert # user_id = db.insert( # "users", # {"username": "alice", "email": "alice@example.com", "age": 25}, # ) # print(f"Inserted user id={user_id}") # # # 3. Bulk insert # bulk_ids = db.insert_many( # "users", # [ # {"username": "bob", "email": "bob@example.com", "age": 30}, # {"username": "carol", "email": "carol@example.com", "age": 22}, # ], # ) # print(f"Bulk inserted ids: {bulk_ids}") # # # 4. Get single # user = db.get("users", "id", user_id) # print("Fetched user:", user) # # # 5. Get many # all_users = db.get_all( # "users", # where="age >= %s", # where_args=[20], # order_by="age DESC", # limit=10, # ) # print("All users:", all_users) # # # 6. Update # updated = db.update("users", "id", user_id, {"age": 26, "email": "alice26@example.com"}) # print("Update status:", updated) # # # 7. Delete # deleted = db.delete("users", "id", user_id) # print("Deleted:", deleted) # # # ----------------------------------------------------------------------- # # # Entry point for `python -m` usage # # ----------------------------------------------------------------------- # # if __name__ == "__main__": # _demo()