from __future__ import annotations import sqlite3 from contextlib import contextmanager from typing import Any, Dict, Iterable, List, Optional, Tuple, Union # --------------------------------------------------------------------------- # # Variables # --------------------------------------------------------------------------- # DB_PATH = {'memory': ":memory:", "file": ""} # --------------------------------------------------------------------------- # # Helper type aliases # --------------------------------------------------------------------------- # Row = Dict[str, Any] # A column description is: (name, type, optional constraints) ColumnDef = Tuple[str, str, str] # name, type, constraints (empty if none) class SqlDb: """ Thin wrapper around sqlite3 that exposes CRUD helpers for arbitrary tables. """ def __init__(self, db_path: str = ":memory:") -> None: """ Open a connection to the SQLite database. Parameters ---------- db_path : str Path to the database file. Use ':memory:' for an in‑memory DB. """ self.db_path = db_path self.conn = sqlite3.connect(self.db_path) # Use sqlite.Row so that we can return rows as dictionaries self.conn.row_factory = sqlite3.Row self.cursor = self.conn.cursor() # ----------------------------------------------------------------------- # # Context‑manager support # ----------------------------------------------------------------------- # def __enter__(self) -> "SqlDb": return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() def close(self) -> None: """Close the database connection.""" self.cursor.close() self.conn.close() # ----------------------------------------------------------------------- # # Table creation # ----------------------------------------------------------------------- # def create_table( self, table: str, columns: Iterable[ColumnDef], *, primary_key: Optional[str] = None, ) -> None: """ Create a table with the given columns. Parameters ---------- table : str Table name. columns : iterable of ColumnDef Each item is a tuple: (column_name, sql_type, constraints) Example: ("name", "TEXT", "NOT NULL") primary_key : str | None Name of the column that should be the PRIMARY KEY. If omitted, *no* primary key is created automatically. """ col_defs: List[str] = [] for name, col_type, constraints in columns: if constraints: col_defs.append(f"{name} {col_type} {constraints}") else: col_defs.append(f"{name} {col_type}") if primary_key: # SQLite allows only one PRIMARY KEY per table, so we make sure # that the PK column is part of the definition. # We will just add "PRIMARY KEY" to the column definition. pk_index = next( (i for i, col in enumerate(col_defs) if col.startswith(primary_key)), None, ) if pk_index is not None: col_defs[pk_index] = f"{col_defs[pk_index]} PRIMARY KEY" else: # If the PK column wasn't part of the supplied list, # we add a separate primary key constraint. col_defs.append(f"PRIMARY KEY ({primary_key})") sql = f"CREATE TABLE IF NOT EXISTS {table} ({', '.join(col_defs)});" self.cursor.execute(sql) self.conn.commit() # ----------------------------------------------------------------------- # # Generic CRUD helpers # ----------------------------------------------------------------------- # def _dict_to_columns_and_placeholders( self, data: Dict[str, Any] ) -> Tuple[str, str, Tuple[Any, ...]]: """ Convert a dict into a comma‑separated list of columns, a list of '?' placeholders and a tuple of values. """ columns = ", ".join(data.keys()) placeholders = ", ".join("?" for _ in data) values = tuple(data.values()) return columns, placeholders, values def insert( self, table: str, data: Dict[str, Any] ) -> int: """ Insert a single row. Parameters ---------- table : str Table name. data : dict Column/value mapping. Returns ------- int The lastrowid of the inserted row. """ cols, phs, vals = self._dict_to_columns_and_placeholders(data) sql = f"INSERT INTO {table} ({cols}) VALUES ({phs})" self.cursor.execute(sql, vals) self.conn.commit() return self.cursor.lastrowid def insert_many( self, table: str, rows: Iterable[Dict[str, Any]] ) -> List[int]: """ Insert many rows at once. Parameters ---------- table : str Table name. rows : iterable of dict Each dict represents a row to insert. Returns ------- list of int List of lastrowid values for each inserted row. """ # All rows must share the same keys rows = list(rows) if not rows: return [] cols, phs, _ = self._dict_to_columns_and_placeholders(rows[0]) sql = f"INSERT INTO {table} ({cols}) VALUES ({phs})" # Build a list of tuples for the values values = [tuple(row[col] for col in cols.split(", ")) for row in rows] self.cursor.executemany(sql, values) self.conn.commit() # SQLite gives only the last row id; we approximate last_id = self.cursor.lastrowid # For deterministic behaviour we will simply return the last id for each row. return [last_id - len(rows) + i + 1 for i in range(len(rows))] def get( self, table: str, pk: str, pk_value: Any, *, columns: Optional[List[str]] = None, ) -> Optional[Row]: """ Retrieve a single row by primary key. Parameters ---------- table : str Table name. pk : str Name of the primary‑key column. pk_value : any Value of the primary key. columns : list | None Columns to retrieve. If None, all columns are returned. Returns ------- dict or None Row as a dict, or None if not found. """ cols = ", ".join(columns) if columns else "*" sql = f"SELECT {cols} FROM {table} WHERE {pk} = ?" self.cursor.execute(sql, (pk_value,)) row = self.cursor.fetchone() return dict(row) if row else None def get_all( self, table: str, *, columns: Optional[List[str]] = None, where: Optional[str] = None, where_args: Optional[Iterable[Any]] = None, order_by: Optional[str] = None, limit: Optional[int] = None, ) -> List[Row]: """ Retrieve all rows, optionally filtered/sorted/paginated. Parameters ---------- table : str Table name. columns : list | None Columns to fetch. If None, all columns are returned. where : str | None Optional WHERE clause (without the word WHERE). where_args : iterable | None Values for the WHERE clause placeholders. order_by : str | None ORDER BY clause (without the word ORDER BY). limit : int | None LIMIT clause. Returns ------- list of dict List of rows as dictionaries. """ cols = ", ".join(columns) if columns else "*" sql = f"SELECT {cols} FROM {table}" args: List[Any] = [] if where: sql += f" WHERE {where}" if where_args: args.extend(where_args) if order_by: sql += f" ORDER BY {order_by}" if limit is not None: sql += " LIMIT ?" args.append(limit) self.cursor.execute(sql, tuple(args)) rows = self.cursor.fetchall() return [dict(row) for row in rows] def update( self, table: str, pk: str, pk_value: Any, data: Dict[str, Any], ) -> bool: """ Update a row identified by its primary key. Parameters ---------- table : str Table name. pk : str Primary‑key column name. pk_value : any Value of the primary key. data : dict Column/value pairs to update. Returns ------- bool True if a row was updated, False otherwise. """ if not data: return False set_clause = ", ".join(f"{col} = ?" for col in data) vals = tuple(data.values()) + (pk_value,) sql = f"UPDATE {table} SET {set_clause} WHERE {pk} = ?" self.cursor.execute(sql, vals) self.conn.commit() return self.cursor.rowcount > 0 def delete(self, table: str, pk: str, pk_value: Any) -> bool: """ Delete a row by primary key. Parameters ---------- table : str Table name. pk : str Primary‑key column name. pk_value : any Value of the primary key. Returns ------- bool True if a row was deleted, False otherwise. """ sql = f"DELETE FROM {table} WHERE {pk} = ?" self.cursor.execute(sql, (pk_value,)) self.conn.commit() return self.cursor.rowcount > 0 # # # --------------------------------------------------------------------------- # # # Example usage # # --------------------------------------------------------------------------- # # if __name__ == "__main__": # # Using an in‑memory database for demonstration # with GenericSQLiteDB(":memory:") as db: # # Define two tables of different shapes # user_columns = [ # ("id", "INTEGER", ""), # will become PRIMARY KEY via parameter # ("name", "TEXT", "NOT NULL"), # ("email", "TEXT", "UNIQUE NOT NULL"), # ("created_at", "TEXT", ""), # ] # # product_columns = [ # ("product_id", "INTEGER", ""), # ("name", "TEXT", "NOT NULL"), # ("price", "REAL", "NOT NULL"), # ("stock", "INTEGER", ""), # ] # # # Create tables # db.create_table("users", user_columns, primary_key="id") # db.create_table("products", product_columns, primary_key="product_id") # # # Insert a user # user_id = db.insert( # "users", # {"id": 1, "name": "Alice", "email": "alice@example.com", "created_at": "2024‑01‑01"}, # ) # print(f"Inserted user id={user_id}") # # # Insert multiple products # prod_ids = db.insert_many( # "products", # [ # {"product_id": 101, "name": "Gizmo", "price": 19.99, "stock": 50}, # {"product_id": 102, "name": "Widget", "price": 9.49, "stock": 200}, # ], # ) # print(f"Inserted product ids={prod_ids}") # # # Retrieve a single user # user = db.get("users", "id", user_id) # print("Fetched user:", user) # # # Update the user's email # db.update("users", "id", user_id, {"email": "alice@newdomain.com"}) # print("Updated user:", db.get("users", "id", user_id)) # # # List all products that are in stock > 10 # print( # "Products in stock > 10:", # db.get_all("products", where="stock > ?", where_args=(10,)), # ) # # # Delete the user # db.delete("users", "id", user_id) # print("User after delete:", db.get("users", "id", user_id))