|
|
from __future__ import annotations
|
|
|
|
|
|
import sqlite3
|
|
|
from contextlib import contextmanager
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
# Variables
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
DB_PATH = {'memory': ":memory:", "file": ""}
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
# Helper type aliases
|
|
|
# --------------------------------------------------------------------------- #
|
|
|
Row = Dict[str, Any]
|
|
|
# A column description is: (name, type, optional constraints)
|
|
|
ColumnDef = Tuple[str, str, str] # name, type, constraints (empty if none)
|
|
|
|
|
|
|
|
|
class SqlDb:
|
|
|
"""
|
|
|
Thin wrapper around sqlite3 that exposes CRUD helpers for arbitrary tables.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, db_path: str = ":memory:") -> None:
|
|
|
"""
|
|
|
Open a connection to the SQLite database.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
db_path : str
|
|
|
Path to the database file. Use ':memory:' for an in‑memory DB.
|
|
|
"""
|
|
|
self.db_path = db_path
|
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
|
# Use sqlite.Row so that we can return rows as dictionaries
|
|
|
self.conn.row_factory = sqlite3.Row
|
|
|
self.cursor = self.conn.cursor()
|
|
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
# Context‑manager support
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
def __enter__(self) -> "SqlDb":
|
|
|
return self
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
|
self.close()
|
|
|
|
|
|
def close(self) -> None:
|
|
|
"""Close the database connection."""
|
|
|
self.cursor.close()
|
|
|
self.conn.close()
|
|
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
# Table creation
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
def create_table(
|
|
|
self,
|
|
|
table: str,
|
|
|
columns: Iterable[ColumnDef],
|
|
|
*,
|
|
|
primary_key: Optional[str] = None,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Create a table with the given columns.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
columns : iterable of ColumnDef
|
|
|
Each item is a tuple: (column_name, sql_type, constraints)
|
|
|
Example: ("name", "TEXT", "NOT NULL")
|
|
|
primary_key : str | None
|
|
|
Name of the column that should be the PRIMARY KEY.
|
|
|
If omitted, *no* primary key is created automatically.
|
|
|
"""
|
|
|
col_defs: List[str] = []
|
|
|
for name, col_type, constraints in columns:
|
|
|
if constraints:
|
|
|
col_defs.append(f"{name} {col_type} {constraints}")
|
|
|
else:
|
|
|
col_defs.append(f"{name} {col_type}")
|
|
|
|
|
|
if primary_key:
|
|
|
# SQLite allows only one PRIMARY KEY per table, so we make sure
|
|
|
# that the PK column is part of the definition.
|
|
|
# We will just add "PRIMARY KEY" to the column definition.
|
|
|
pk_index = next(
|
|
|
(i for i, col in enumerate(col_defs) if col.startswith(primary_key)),
|
|
|
None,
|
|
|
)
|
|
|
if pk_index is not None:
|
|
|
col_defs[pk_index] = f"{col_defs[pk_index]} PRIMARY KEY"
|
|
|
else:
|
|
|
# If the PK column wasn't part of the supplied list,
|
|
|
# we add a separate primary key constraint.
|
|
|
col_defs.append(f"PRIMARY KEY ({primary_key})")
|
|
|
|
|
|
sql = f"CREATE TABLE IF NOT EXISTS {table} ({', '.join(col_defs)});"
|
|
|
self.cursor.execute(sql)
|
|
|
self.conn.commit()
|
|
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
# Generic CRUD helpers
|
|
|
# ----------------------------------------------------------------------- #
|
|
|
def _dict_to_columns_and_placeholders(
|
|
|
self, data: Dict[str, Any]
|
|
|
) -> Tuple[str, str, Tuple[Any, ...]]:
|
|
|
"""
|
|
|
Convert a dict into a comma‑separated list of columns,
|
|
|
a list of '?' placeholders and a tuple of values.
|
|
|
"""
|
|
|
columns = ", ".join(data.keys())
|
|
|
placeholders = ", ".join("?" for _ in data)
|
|
|
values = tuple(data.values())
|
|
|
return columns, placeholders, values
|
|
|
|
|
|
def insert(
|
|
|
self, table: str, data: Dict[str, Any]
|
|
|
) -> int:
|
|
|
"""
|
|
|
Insert a single row.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
data : dict
|
|
|
Column/value mapping.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
int
|
|
|
The lastrowid of the inserted row.
|
|
|
"""
|
|
|
cols, phs, vals = self._dict_to_columns_and_placeholders(data)
|
|
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({phs})"
|
|
|
self.cursor.execute(sql, vals)
|
|
|
self.conn.commit()
|
|
|
return self.cursor.lastrowid
|
|
|
|
|
|
def insert_many(
|
|
|
self, table: str, rows: Iterable[Dict[str, Any]]
|
|
|
) -> List[int]:
|
|
|
"""
|
|
|
Insert many rows at once.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
rows : iterable of dict
|
|
|
Each dict represents a row to insert.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
list of int
|
|
|
List of lastrowid values for each inserted row.
|
|
|
"""
|
|
|
# All rows must share the same keys
|
|
|
rows = list(rows)
|
|
|
if not rows:
|
|
|
return []
|
|
|
|
|
|
cols, phs, _ = self._dict_to_columns_and_placeholders(rows[0])
|
|
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({phs})"
|
|
|
|
|
|
# Build a list of tuples for the values
|
|
|
values = [tuple(row[col] for col in cols.split(", ")) for row in rows]
|
|
|
self.cursor.executemany(sql, values)
|
|
|
self.conn.commit()
|
|
|
|
|
|
# SQLite gives only the last row id; we approximate
|
|
|
last_id = self.cursor.lastrowid
|
|
|
# For deterministic behaviour we will simply return the last id for each row.
|
|
|
return [last_id - len(rows) + i + 1 for i in range(len(rows))]
|
|
|
|
|
|
def get(
|
|
|
self,
|
|
|
table: str,
|
|
|
pk: str,
|
|
|
pk_value: Any,
|
|
|
*,
|
|
|
columns: Optional[List[str]] = None,
|
|
|
) -> Optional[Row]:
|
|
|
"""
|
|
|
Retrieve a single row by primary key.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
pk : str
|
|
|
Name of the primary‑key column.
|
|
|
pk_value : any
|
|
|
Value of the primary key.
|
|
|
columns : list | None
|
|
|
Columns to retrieve. If None, all columns are returned.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
dict or None
|
|
|
Row as a dict, or None if not found.
|
|
|
"""
|
|
|
cols = ", ".join(columns) if columns else "*"
|
|
|
sql = f"SELECT {cols} FROM {table} WHERE {pk} = ?"
|
|
|
self.cursor.execute(sql, (pk_value,))
|
|
|
row = self.cursor.fetchone()
|
|
|
return dict(row) if row else None
|
|
|
|
|
|
def get_all(
|
|
|
self,
|
|
|
table: str,
|
|
|
*,
|
|
|
columns: Optional[List[str]] = None,
|
|
|
where: Optional[str] = None,
|
|
|
where_args: Optional[Iterable[Any]] = None,
|
|
|
order_by: Optional[str] = None,
|
|
|
limit: Optional[int] = None,
|
|
|
) -> List[Row]:
|
|
|
"""
|
|
|
Retrieve all rows, optionally filtered/sorted/paginated.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
columns : list | None
|
|
|
Columns to fetch. If None, all columns are returned.
|
|
|
where : str | None
|
|
|
Optional WHERE clause (without the word WHERE).
|
|
|
where_args : iterable | None
|
|
|
Values for the WHERE clause placeholders.
|
|
|
order_by : str | None
|
|
|
ORDER BY clause (without the word ORDER BY).
|
|
|
limit : int | None
|
|
|
LIMIT clause.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
list of dict
|
|
|
List of rows as dictionaries.
|
|
|
"""
|
|
|
cols = ", ".join(columns) if columns else "*"
|
|
|
sql = f"SELECT {cols} FROM {table}"
|
|
|
args: List[Any] = []
|
|
|
|
|
|
if where:
|
|
|
sql += f" WHERE {where}"
|
|
|
if where_args:
|
|
|
args.extend(where_args)
|
|
|
|
|
|
if order_by:
|
|
|
sql += f" ORDER BY {order_by}"
|
|
|
|
|
|
if limit is not None:
|
|
|
sql += " LIMIT ?"
|
|
|
args.append(limit)
|
|
|
|
|
|
self.cursor.execute(sql, tuple(args))
|
|
|
rows = self.cursor.fetchall()
|
|
|
return [dict(row) for row in rows]
|
|
|
|
|
|
def update(
|
|
|
self,
|
|
|
table: str,
|
|
|
pk: str,
|
|
|
pk_value: Any,
|
|
|
data: Dict[str, Any],
|
|
|
) -> bool:
|
|
|
"""
|
|
|
Update a row identified by its primary key.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
pk : str
|
|
|
Primary‑key column name.
|
|
|
pk_value : any
|
|
|
Value of the primary key.
|
|
|
data : dict
|
|
|
Column/value pairs to update.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
bool
|
|
|
True if a row was updated, False otherwise.
|
|
|
"""
|
|
|
if not data:
|
|
|
return False
|
|
|
|
|
|
set_clause = ", ".join(f"{col} = ?" for col in data)
|
|
|
vals = tuple(data.values()) + (pk_value,)
|
|
|
sql = f"UPDATE {table} SET {set_clause} WHERE {pk} = ?"
|
|
|
self.cursor.execute(sql, vals)
|
|
|
self.conn.commit()
|
|
|
return self.cursor.rowcount > 0
|
|
|
|
|
|
def delete(self, table: str, pk: str, pk_value: Any) -> bool:
|
|
|
"""
|
|
|
Delete a row by primary key.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
table : str
|
|
|
Table name.
|
|
|
pk : str
|
|
|
Primary‑key column name.
|
|
|
pk_value : any
|
|
|
Value of the primary key.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
bool
|
|
|
True if a row was deleted, False otherwise.
|
|
|
"""
|
|
|
sql = f"DELETE FROM {table} WHERE {pk} = ?"
|
|
|
self.cursor.execute(sql, (pk_value,))
|
|
|
self.conn.commit()
|
|
|
return self.cursor.rowcount > 0
|
|
|
#
|
|
|
# # --------------------------------------------------------------------------- #
|
|
|
# # Example usage
|
|
|
# # --------------------------------------------------------------------------- #
|
|
|
# if __name__ == "__main__":
|
|
|
# # Using an in‑memory database for demonstration
|
|
|
# with GenericSQLiteDB(":memory:") as db:
|
|
|
# # Define two tables of different shapes
|
|
|
# user_columns = [
|
|
|
# ("id", "INTEGER", ""), # will become PRIMARY KEY via parameter
|
|
|
# ("name", "TEXT", "NOT NULL"),
|
|
|
# ("email", "TEXT", "UNIQUE NOT NULL"),
|
|
|
# ("created_at", "TEXT", ""),
|
|
|
# ]
|
|
|
#
|
|
|
# product_columns = [
|
|
|
# ("product_id", "INTEGER", ""),
|
|
|
# ("name", "TEXT", "NOT NULL"),
|
|
|
# ("price", "REAL", "NOT NULL"),
|
|
|
# ("stock", "INTEGER", ""),
|
|
|
# ]
|
|
|
#
|
|
|
# # Create tables
|
|
|
# db.create_table("users", user_columns, primary_key="id")
|
|
|
# db.create_table("products", product_columns, primary_key="product_id")
|
|
|
#
|
|
|
# # Insert a user
|
|
|
# user_id = db.insert(
|
|
|
# "users",
|
|
|
# {"id": 1, "name": "Alice", "email": "alice@example.com", "created_at": "2024‑01‑01"},
|
|
|
# )
|
|
|
# print(f"Inserted user id={user_id}")
|
|
|
#
|
|
|
# # Insert multiple products
|
|
|
# prod_ids = db.insert_many(
|
|
|
# "products",
|
|
|
# [
|
|
|
# {"product_id": 101, "name": "Gizmo", "price": 19.99, "stock": 50},
|
|
|
# {"product_id": 102, "name": "Widget", "price": 9.49, "stock": 200},
|
|
|
# ],
|
|
|
# )
|
|
|
# print(f"Inserted product ids={prod_ids}")
|
|
|
#
|
|
|
# # Retrieve a single user
|
|
|
# user = db.get("users", "id", user_id)
|
|
|
# print("Fetched user:", user)
|
|
|
#
|
|
|
# # Update the user's email
|
|
|
# db.update("users", "id", user_id, {"email": "alice@newdomain.com"})
|
|
|
# print("Updated user:", db.get("users", "id", user_id))
|
|
|
#
|
|
|
# # List all products that are in stock > 10
|
|
|
# print(
|
|
|
# "Products in stock > 10:",
|
|
|
# db.get_all("products", where="stock > ?", where_args=(10,)),
|
|
|
# )
|
|
|
#
|
|
|
# # Delete the user
|
|
|
# db.delete("users", "id", user_id)
|
|
|
# print("User after delete:", db.get("users", "id", user_id))
|