From a155abb393145f526175d024fe80e7e6e0067483 Mon Sep 17 00:00:00 2001 From: David Date: Tue, 12 May 2026 10:27:57 -0500 Subject: [PATCH 1/2] Add ER diagram feature with table picker and relationship visualization Adds the ability to generate Unicode ER diagrams showing tables, columns (with PK/FK markers), and foreign key relationship lines. Accessible via S key on explorer tree nodes and space+d leader menu for a table picker modal. Includes FK introspection for PostgreSQL and SQLite providers. --- sqlit/core/keymap.py | 2 + .../connections/providers/adapters/base.py | 22 ++ sqlit/domains/connections/providers/model.py | 4 + .../connections/providers/postgresql/base.py | 37 ++ .../connections/providers/sqlite/adapter.py | 22 ++ sqlit/domains/diagram/__init__.py | 0 sqlit/domains/diagram/app/__init__.py | 0 sqlit/domains/diagram/app/diagram_service.py | 52 +++ sqlit/domains/diagram/app/layout.py | 342 ++++++++++++++++++ sqlit/domains/diagram/ui/__init__.py | 0 sqlit/domains/diagram/ui/diagram_screen.py | 104 ++++++ sqlit/domains/diagram/ui/table_picker.py | 176 +++++++++ sqlit/domains/explorer/app/schema_service.py | 13 +- .../explorer/state/tree_on_database.py | 10 +- .../domains/explorer/state/tree_on_folder.py | 12 + sqlit/domains/explorer/state/tree_on_table.py | 9 + sqlit/domains/explorer/ui/mixins/tree.py | 105 ++++++ sqlit/domains/shell/ui/mixins/ui_leader.py | 3 + tests/unit/test_diagram_layout.py | 110 ++++++ 19 files changed, 1021 insertions(+), 2 deletions(-) create mode 100644 sqlit/domains/diagram/__init__.py create mode 100644 sqlit/domains/diagram/app/__init__.py create mode 100644 sqlit/domains/diagram/app/diagram_service.py create mode 100644 sqlit/domains/diagram/app/layout.py create mode 100644 sqlit/domains/diagram/ui/__init__.py create mode 100644 sqlit/domains/diagram/ui/diagram_screen.py create mode 100644 sqlit/domains/diagram/ui/table_picker.py create mode 100644 tests/unit/test_diagram_layout.py diff --git a/sqlit/core/keymap.py b/sqlit/core/keymap.py index bcd4aed3..c7d7bcef 100644 --- a/sqlit/core/keymap.py +++ b/sqlit/core/keymap.py @@ -196,6 +196,7 @@ def _build_leader_commands(self) -> list[LeaderCommandDef]: # Actions LeaderCommandDef("z", "cancel_operation", "Cancel", "Actions", guard="query_executing"), LeaderCommandDef("t", "change_theme", "Change Theme", "Actions"), + LeaderCommandDef("d", "show_diagram_picker", "ER Diagram", "Actions", guard="has_connection"), LeaderCommandDef("h", "show_help", "Help", "Actions"), LeaderCommandDef("space", "telescope", "Telescope", "Actions"), LeaderCommandDef("slash", "telescope_filter", "Telescope Search", "Actions"), @@ -309,6 +310,7 @@ def _build_action_keys(self) -> list[ActionKeyDef]: ActionKeyDef("v", "exit_tree_visual_mode", "tree_visual", primary=False), ActionKeyDef("escape", "clear_connection_selection", "tree"), ActionKeyDef("s", "select_table", "tree"), + ActionKeyDef("S", "show_diagram", "tree"), ActionKeyDef("f", "refresh_tree", "tree"), ActionKeyDef("R", "refresh_tree", "tree", primary=False), ActionKeyDef("e", "edit_connection", "tree"), diff --git a/sqlit/domains/connections/providers/adapters/base.py b/sqlit/domains/connections/providers/adapters/base.py index 109ea1da..ff1842f5 100644 --- a/sqlit/domains/connections/providers/adapters/base.py +++ b/sqlit/domains/connections/providers/adapters/base.py @@ -69,6 +69,19 @@ class SequenceInfo: name: str +@dataclass +class ForeignKeyInfo: + """Information about a foreign key relationship.""" + + constraint_name: str + source_table: str + source_column: str + target_table: str + target_column: str + source_schema: str = "" + target_schema: str = "" + + # Type alias for table/view info: (schema, name) TableInfo = tuple[str, str] @@ -343,6 +356,14 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence """ pass + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get list of foreign keys in the database. + + Returns: + List of ForeignKeyInfo objects describing FK relationships. + """ + return [] + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: @@ -518,6 +539,7 @@ def execute_non_query(self, conn: Any, query: str) -> int: __all__ = [ "ColumnInfo", "DatabaseAdapter", + "ForeignKeyInfo", "IndexInfo", "SequenceInfo", "TableInfo", diff --git a/sqlit/domains/connections/providers/model.py b/sqlit/domains/connections/providers/model.py index b104db3c..50ffe85d 100644 --- a/sqlit/domains/connections/providers/model.py +++ b/sqlit/domains/connections/providers/model.py @@ -116,6 +116,10 @@ def get_sequence_definition(self, conn: Any, sequence_name: str, database: str | @runtime_checkable +class ForeignKeyInspector(Protocol): + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[Any]: ... + + @runtime_checkable class ProcedureInspector(Protocol): def get_procedures(self, conn: Any, database: str | None = None) -> list[Any]: ... diff --git a/sqlit/domains/connections/providers/postgresql/base.py b/sqlit/domains/connections/providers/postgresql/base.py index c5140eca..034a7fe2 100644 --- a/sqlit/domains/connections/providers/postgresql/base.py +++ b/sqlit/domains/connections/providers/postgresql/base.py @@ -8,6 +8,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -174,6 +175,42 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence ) return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from PostgreSQL.""" + cursor = conn.cursor() + cursor.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " ccu.table_schema AS target_schema, " + " ccu.table_name AS target_table, " + " ccu.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.constraint_column_usage ccu " + " ON ccu.constraint_name = tc.constraint_name " + " AND ccu.table_schema = tc.table_schema " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: diff --git a/sqlit/domains/connections/providers/sqlite/adapter.py b/sqlit/domains/connections/providers/sqlite/adapter.py index 2b7315f4..98464367 100644 --- a/sqlit/domains/connections/providers/sqlite/adapter.py +++ b/sqlit/domains/connections/providers/sqlite/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -121,6 +122,27 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence """SQLite doesn't support sequences - return empty list.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SQLite using PRAGMA foreign_key_list for each table.""" + tables = self.get_tables(conn, database) + results: list[ForeignKeyInfo] = [] + for _schema, table_name in tables: + cursor = conn.cursor() + quoted = self.quote_identifier(table_name) + cursor.execute(f"PRAGMA foreign_key_list({quoted})") + for row in cursor.fetchall(): + # row: id, seq, table, from, to, on_update, on_delete, match + results.append( + ForeignKeyInfo( + constraint_name=f"fk_{table_name}_{row[3]}", + source_table=table_name, + source_column=row[3], + target_table=row[2], + target_column=row[4], + ) + ) + return results + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: diff --git a/sqlit/domains/diagram/__init__.py b/sqlit/domains/diagram/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/app/__init__.py b/sqlit/domains/diagram/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/app/diagram_service.py b/sqlit/domains/diagram/app/diagram_service.py new file mode 100644 index 00000000..bc0fc5d2 --- /dev/null +++ b/sqlit/domains/diagram/app/diagram_service.py @@ -0,0 +1,52 @@ +"""Service for building ER diagrams from schema data.""" + +from __future__ import annotations + +from typing import Any + +from sqlit.domains.connections.providers.adapters.base import ColumnInfo, ForeignKeyInfo +from sqlit.domains.diagram.app.layout import build_layout, render_diagram + + +def build_diagram_text( + table_names: list[str], + columns_by_table: dict[str, list[ColumnInfo]], + foreign_keys: list[ForeignKeyInfo], + schemas: dict[str, str] | None = None, +) -> str: + """Build diagram text from schema data.""" + tables: dict[str, list[tuple[str, str, bool]]] = {} + for name in table_names: + cols = columns_by_table.get(name, []) + tables[name] = [(c.name, c.data_type, c.is_primary_key) for c in cols] + + table_set = set(table_names) + fk_tuples: list[tuple[str, str, str, str]] = [] + for fk in foreign_keys: + if fk.source_table in table_set: + fk_tuples.append((fk.source_table, fk.source_column, fk.target_table, fk.target_column)) + + layout = build_layout(tables, fk_tuples, schemas) + lines = render_diagram(layout) + return "\n".join(lines) + + +def fetch_diagram_data( + schema_service: Any, + database: str | None, + table_names: list[str], +) -> tuple[dict[str, list[ColumnInfo]], list[ForeignKeyInfo]]: + """Fetch columns and FK data for a set of tables.""" + columns_by_table: dict[str, list[ColumnInfo]] = {} + for name in table_names: + try: + columns_by_table[name] = schema_service.list_columns(database, None, name) + except Exception: + columns_by_table[name] = [] + + try: + foreign_keys: list[ForeignKeyInfo] = schema_service.list_foreign_keys(database) + except Exception: + foreign_keys = [] + + return columns_by_table, foreign_keys diff --git a/sqlit/domains/diagram/app/layout.py b/sqlit/domains/diagram/app/layout.py new file mode 100644 index 00000000..ba6e0cb7 --- /dev/null +++ b/sqlit/domains/diagram/app/layout.py @@ -0,0 +1,342 @@ +"""ER diagram layout engine for terminal rendering.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class TableBox: + """A rendered table box with position and content.""" + + name: str + schema: str + columns: list[tuple[str, str, bool, bool]] # (name, type, is_pk, is_fk) + x: int = 0 + y: int = 0 + width: int = 0 + height: int = 0 + + def compute_dimensions(self) -> None: + name_len = len(self.name) + 2 + col_widths = [len(f" {c[0]} {c[1]} ") + (3 if c[2] or c[3] else 0) for c in self.columns] + content_width = max([name_len, *col_widths]) if col_widths else name_len + self.width = content_width + 2 # borders + self.height = len(self.columns) + 3 # top border + title + separator + columns + bottom border + if not self.columns: + self.height = 3 + + +@dataclass +class Relationship: + """A foreign key relationship between two tables.""" + + source_table: str + source_column: str + target_table: str + target_column: str + + +@dataclass +class DiagramLayout: + """Computed diagram layout ready for rendering.""" + + tables: dict[str, TableBox] = field(default_factory=dict) + relationships: list[Relationship] = field(default_factory=list) + width: int = 0 + height: int = 0 + + +def build_layout( + tables: dict[str, list[tuple[str, str, bool]]], # name -> [(col_name, col_type, is_pk)] + foreign_keys: list[tuple[str, str, str, str]], # [(src_table, src_col, tgt_table, tgt_col)] + schemas: dict[str, str] | None = None, +) -> DiagramLayout: + """Build a diagram layout from table and FK data.""" + layout = DiagramLayout() + + fk_lookup: set[tuple[str, str]] = set() + for src_table, src_col, _tgt_table, _tgt_col in foreign_keys: + fk_lookup.add((src_table, src_col)) + + for table_name, columns in tables.items(): + box = TableBox( + name=table_name, + schema=(schemas or {}).get(table_name, ""), + columns=[ + (col_name, col_type, is_pk, (table_name, col_name) in fk_lookup) + for col_name, col_type, is_pk in columns + ], + ) + box.compute_dimensions() + layout.tables[table_name] = box + + for src_table, src_col, tgt_table, tgt_col in foreign_keys: + if src_table in layout.tables and tgt_table in layout.tables: + layout.relationships.append( + Relationship(src_table, src_col, tgt_table, tgt_col) + ) + + _position_tables(layout) + return layout + + +def _position_tables(layout: DiagramLayout) -> None: + """Position tables in a grid layout, placing related tables near each other.""" + if not layout.tables: + return + + adjacency: dict[str, set[str]] = {t: set() for t in layout.tables} + for rel in layout.relationships: + adjacency[rel.source_table].add(rel.target_table) + adjacency[rel.target_table].add(rel.source_table) + + ordered = _topo_order(layout.tables.keys(), adjacency) + + h_gap = 6 + v_gap = 2 + max_cols = _pick_grid_cols(len(ordered)) + + row_x = 0 + row_y = 0 + col_idx = 0 + row_max_h = 0 + + for table_name in ordered: + box = layout.tables[table_name] + box.x = row_x + box.y = row_y + row_x += box.width + h_gap + row_max_h = max(row_max_h, box.height) + col_idx += 1 + if col_idx >= max_cols: + col_idx = 0 + row_x = 0 + row_y += row_max_h + v_gap + row_max_h = 0 + + max_x = max(b.x + b.width for b in layout.tables.values()) + max_y = max(b.y + b.height for b in layout.tables.values()) + layout.width = max_x + 1 + layout.height = max_y + 1 + + +def _topo_order( + tables: object, + adjacency: dict[str, set[str]], +) -> list[str]: + """Order tables so that related tables appear near each other (BFS from most-connected).""" + table_list = list(tables) # type: ignore[arg-type] + if not table_list: + return [] + + table_list.sort(key=lambda t: len(adjacency.get(t, set())), reverse=True) + + visited: set[str] = set() + result: list[str] = [] + queue: list[str] = [] + + for start in table_list: + if start in visited: + continue + queue.append(start) + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + result.append(current) + neighbors = sorted(adjacency.get(current, set()), key=lambda t: len(adjacency.get(t, set())), reverse=True) + for neighbor in neighbors: + if neighbor not in visited: + queue.append(neighbor) + + return result + + +def _pick_grid_cols(count: int) -> int: + if count <= 2: + return count + if count <= 4: + return 2 + if count <= 9: + return 3 + return 4 + + +def render_diagram(layout: DiagramLayout) -> list[str]: + """Render the diagram layout to a list of text lines.""" + if not layout.tables: + return ["(no tables)"] + + canvas_h = layout.height + 4 + canvas_w = layout.width + 10 + canvas: list[list[str]] = [[" "] * canvas_w for _ in range(canvas_h)] + + for box in layout.tables.values(): + _draw_table_box(canvas, box) + + for ri, rel in enumerate(layout.relationships): + _draw_relationship(canvas, layout, rel, ri) + + return ["".join(row).rstrip() for row in canvas] + + +def _draw_table_box(canvas: list[list[str]], box: TableBox) -> None: + """Draw a single table box onto the canvas.""" + x, y, w = box.x, box.y, box.width + + _put(canvas, y, x, "╭") + _put(canvas, y, x + w - 1, "╮") + for i in range(1, w - 1): + _put(canvas, y, x + i, "─") + + title = f" {box.name} " + title_start = x + 1 + for i, ch in enumerate(title[:w - 2]): + _put(canvas, y, title_start + i, ch) + + sep_y = y + 1 + _put(canvas, sep_y, x, "├") + _put(canvas, sep_y, x + w - 1, "┤") + for i in range(1, w - 1): + _put(canvas, sep_y, x + i, "─") + + for ci, (col_name, col_type, is_pk, is_fk) in enumerate(box.columns): + row_y = y + 2 + ci + _put(canvas, row_y, x, "│") + _put(canvas, row_y, x + w - 1, "│") + + tag = "" + if is_pk and is_fk: + tag = "◆ " # PK+FK + elif is_pk: + tag = "● " # PK + elif is_fk: + tag = "○ " # FK + + cell = f" {tag}{col_name} {col_type}" + cell = cell[:w - 2].ljust(w - 2) + for i, ch in enumerate(cell): + _put(canvas, row_y, x + 1 + i, ch) + + bottom_y = y + box.height - 1 + _put(canvas, bottom_y, x, "╰") + _put(canvas, bottom_y, x + w - 1, "╯") + for i in range(1, w - 1): + _put(canvas, bottom_y, x + i, "─") + + +def _draw_relationship( + canvas: list[list[str]], + layout: DiagramLayout, + rel: Relationship, + index: int = 0, +) -> None: + """Draw a relationship line between two table boxes.""" + src_box = layout.tables.get(rel.source_table) + tgt_box = layout.tables.get(rel.target_table) + if not src_box or not tgt_box: + return + + src_col_idx = _find_column_index(src_box, rel.source_column) + tgt_col_idx = _find_column_index(tgt_box, rel.target_column) + if src_col_idx < 0 or tgt_col_idx < 0: + return + + src_row = src_box.y + 2 + src_col_idx + tgt_row = tgt_box.y + 2 + tgt_col_idx + offset = index % 3 + + if src_box.x + src_box.width <= tgt_box.x: + src_x = src_box.x + src_box.width + tgt_x = tgt_box.x - 1 + _draw_line_between(canvas, src_x, src_row, tgt_x, tgt_row, offset) + elif tgt_box.x + tgt_box.width <= src_box.x: + src_x = src_box.x - 1 + tgt_x = tgt_box.x + tgt_box.width + _draw_line_between(canvas, src_x, src_row, tgt_x, tgt_row, offset) + else: + sx = src_box.x + src_box.width + tx = tgt_box.x + tgt_box.width + mid_x = max(sx, tx) + 1 + offset + _draw_line_between(canvas, sx, src_row, tx, tgt_row, offset, mid_x) + + +def _draw_line_between( + canvas: list[list[str]], + src_x: int, + src_y: int, + tgt_x: int, + tgt_y: int, + offset: int = 0, + force_mid_x: int | None = None, +) -> None: + """Draw a line from (src_x, src_y) to (tgt_x, tgt_y) with a vertical segment.""" + if src_y == tgt_y: + lo, hi = min(src_x, tgt_x), max(src_x, tgt_x) + for x in range(lo, hi + 1): + _put_line(canvas, src_y, x, "─") + return + + if force_mid_x is not None: + mid_x = force_mid_x + else: + mid_x = (src_x + tgt_x) // 2 + offset + + h_char = "─" + + if src_x <= mid_x: + for x in range(src_x, mid_x): + _put_line(canvas, src_y, x, h_char) + else: + for x in range(mid_x + 1, src_x + 1): + _put_line(canvas, src_y, x, h_char) + + if src_y < tgt_y: + corner_src = "╮" if src_x <= mid_x else "╭" + corner_tgt = "╰" if tgt_x >= mid_x else "╯" + _put_line(canvas, src_y, mid_x, corner_src) + for y in range(src_y + 1, tgt_y): + _put_line(canvas, y, mid_x, "│") + _put_line(canvas, tgt_y, mid_x, corner_tgt) + else: + corner_src = "╯" if src_x <= mid_x else "╰" + corner_tgt = "╭" if tgt_x >= mid_x else "╮" + _put_line(canvas, src_y, mid_x, corner_src) + for y in range(tgt_y + 1, src_y): + _put_line(canvas, y, mid_x, "│") + _put_line(canvas, tgt_y, mid_x, corner_tgt) + + if tgt_x >= mid_x: + for x in range(mid_x + 1, tgt_x + 1): + _put_line(canvas, tgt_y, x, h_char) + else: + for x in range(tgt_x, mid_x): + _put_line(canvas, tgt_y, x, h_char) + + +def _find_column_index(box: TableBox, column_name: str) -> int: + for i, (name, _type, _pk, _fk) in enumerate(box.columns): + if name == column_name: + return i + return -1 + + +def _put(canvas: list[list[str]], row: int, col: int, ch: str) -> None: + if 0 <= row < len(canvas) and 0 <= col < len(canvas[row]): + canvas[row][col] = ch + + +_BOX_CHARS = frozenset("╭╮╰╯├┤│─●○◆") + + +def _put_line(canvas: list[list[str]], row: int, col: int, ch: str) -> None: + if 0 <= row < len(canvas) and 0 <= col < len(canvas[row]): + existing = canvas[row][col] + if existing == " ": + canvas[row][col] = ch + elif existing in _BOX_CHARS: + pass + elif existing.isalnum() or existing in ("_", "(", ")", ","): + pass diff --git a/sqlit/domains/diagram/ui/__init__.py b/sqlit/domains/diagram/ui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/ui/diagram_screen.py b/sqlit/domains/diagram/ui/diagram_screen.py new file mode 100644 index 00000000..7fe2aa03 --- /dev/null +++ b/sqlit/domains/diagram/ui/diagram_screen.py @@ -0,0 +1,104 @@ +"""ER diagram modal screen.""" + +from __future__ import annotations + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import VerticalScroll +from textual.screen import ModalScreen +from textual.widgets import Static + +from sqlit.shared.ui.widgets import Dialog + + +class DiagramScreen(ModalScreen): + """Full-screen modal displaying an ER diagram.""" + + BINDINGS = [ + Binding("escape", "close", "Close", show=False, priority=True), + Binding("q", "close", "Close", show=False), + Binding("j", "scroll_down", "Down", show=False), + Binding("k", "scroll_up", "Up", show=False), + Binding("h", "scroll_left", "Left", show=False), + Binding("l", "scroll_right", "Right", show=False), + Binding("down", "scroll_down", "Down", show=False), + Binding("up", "scroll_up", "Up", show=False), + Binding("left", "scroll_left", "Left", show=False), + Binding("right", "scroll_right", "Right", show=False), + Binding("ctrl+d", "page_down", "Page Down", show=False), + Binding("ctrl+u", "page_up", "Page Up", show=False), + Binding("g", "scroll_home", "Top", show=False), + Binding("G", "scroll_end", "Bottom", show=False), + ] + + CSS = """ + DiagramScreen { + align: center middle; + background: $background 80%; + } + + #diagram-dialog { + width: 95%; + height: 90%; + max-width: 200; + } + + #diagram-scroll { + height: 1fr; + overflow: auto; + } + + #diagram-content { + width: auto; + min-width: 100%; + color: $text; + } + """ + + def __init__(self, diagram_text: str, title: str = "ER Diagram") -> None: + super().__init__() + self._diagram_text = diagram_text + self._title = title + + def compose(self) -> ComposeResult: + shortcuts = [("Close", "q"), ("Scroll", "hjkl")] + with Dialog(id="diagram-dialog", title=self._title, shortcuts=shortcuts): + with VerticalScroll(id="diagram-scroll"): + yield Static(self._diagram_text, id="diagram-content") + + def on_mount(self) -> None: + self.query_one("#diagram-scroll").focus() + + def action_close(self) -> None: + self.dismiss() + + def action_scroll_down(self) -> None: + self.query_one("#diagram-scroll").scroll_down() + + def action_scroll_up(self) -> None: + self.query_one("#diagram-scroll").scroll_up() + + def action_scroll_left(self) -> None: + self.query_one("#diagram-scroll").scroll_left() + + def action_scroll_right(self) -> None: + self.query_one("#diagram-scroll").scroll_right() + + def action_page_down(self) -> None: + scroll = self.query_one("#diagram-scroll") + scroll.scroll_page_down() + + def action_page_up(self) -> None: + scroll = self.query_one("#diagram-scroll") + scroll.scroll_page_up() + + def action_scroll_home(self) -> None: + self.query_one("#diagram-scroll").scroll_home() + + def action_scroll_end(self) -> None: + self.query_one("#diagram-scroll").scroll_end() + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/domains/diagram/ui/table_picker.py b/sqlit/domains/diagram/ui/table_picker.py new file mode 100644 index 00000000..fcf8f6d5 --- /dev/null +++ b/sqlit/domains/diagram/ui/table_picker.py @@ -0,0 +1,176 @@ +"""Table picker modal for ER diagram generation.""" + +from __future__ import annotations + +from typing import Any + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Vertical +from textual.screen import ModalScreen +from textual.widgets import Input, OptionList, Static +from textual.widgets.option_list import Option + +from sqlit.shared.ui.widgets import Dialog + + +class DiagramTablePicker(ModalScreen[list[str] | None]): + """Modal to select tables for ER diagram generation.""" + + BINDINGS = [ + Binding("escape", "cancel", "Cancel", show=False, priority=True), + ] + + CSS = """ + DiagramTablePicker { + align: center middle; + background: transparent; + } + + #diagram-picker-dialog { + width: 60; + max-width: 80%; + max-height: 80%; + } + + #diagram-picker-body { + height: auto; + max-height: 100%; + } + + #diagram-picker-input { + margin-bottom: 1; + } + + #diagram-picker-list { + height: auto; + max-height: 30; + border: none; + } + + #diagram-picker-list > .option-list--option { + padding: 0; + } + + #diagram-picker-status { + height: 1; + margin-top: 1; + color: $text-muted; + text-align: right; + } + """ + + def __init__(self, tables: list[str]) -> None: + super().__init__() + self._tables = sorted(tables) + self._selected: set[str] = set() + self._filter_text = "" + + def compose(self) -> ComposeResult: + shortcuts = [("Toggle", ""), ("Toggle All", "a"), ("OK", "")] + with Dialog(id="diagram-picker-dialog", title="Select Tables for Diagram", shortcuts=shortcuts): + with Vertical(id="diagram-picker-body"): + yield Input(placeholder="Filter tables...", id="diagram-picker-input") + yield OptionList(*self._build_options(), id="diagram-picker-list") + yield Static(self._status_text(), id="diagram-picker-status") + + def _build_options(self) -> list[Option]: + options: list[Option] = [] + for table in self._tables: + if self._filter_text and self._filter_text.lower() not in table.lower(): + continue + check = "[*]" if table in self._selected else "[ ]" + options.append(Option(f"{check} {table}", id=table)) + return options + + def _status_text(self) -> str: + return f"{len(self._selected)}/{len(self._tables)} selected" + + def _rebuild(self) -> None: + option_list = self.query_one("#diagram-picker-list", OptionList) + highlighted = option_list.highlighted + option_list.clear_options() + for opt in self._build_options(): + option_list.add_option(opt) + if highlighted is not None and highlighted < option_list.option_count: + option_list.highlighted = highlighted + self.query_one("#diagram-picker-status", Static).update(self._status_text()) + + def on_mount(self) -> None: + self.query_one("#diagram-picker-list", OptionList).focus() + + def on_input_changed(self, event: Input.Changed) -> None: + if event.input.id == "diagram-picker-input": + self._filter_text = event.value + self._rebuild() + + def on_input_submitted(self, event: Input.Submitted) -> None: + if event.input.id == "diagram-picker-input": + option_list = self.query_one("#diagram-picker-list", OptionList) + option_list.focus() + + def _toggle_highlighted(self) -> None: + option_list = self.query_one("#diagram-picker-list", OptionList) + if option_list.highlighted is not None: + opt = option_list.get_option_at_index(option_list.highlighted) + table_name = opt.id or "" + if table_name: + if table_name in self._selected: + self._selected.discard(table_name) + else: + self._selected.add(table_name) + self._rebuild() + + def on_key(self, event: Any) -> None: + key = getattr(event, "key", "") + focused = self.app.focused + input_widget = self.query_one("#diagram-picker-input", Input) + option_list = self.query_one("#diagram-picker-list", OptionList) + in_input = focused is input_widget + + if key == "enter" and not in_input: + self._confirm() + event.prevent_default() + event.stop() + elif key == "tab" and in_input: + option_list.focus() + event.prevent_default() + event.stop() + elif key == "slash" and not in_input: + input_widget.focus() + event.prevent_default() + event.stop() + elif key == "space" and not in_input: + self._toggle_highlighted() + event.prevent_default() + event.stop() + elif key == "j" and not in_input: + option_list.action_cursor_down() + event.prevent_default() + event.stop() + elif key == "k" and not in_input: + option_list.action_cursor_up() + event.prevent_default() + event.stop() + elif key == "a" and not in_input: + if len(self._selected) == len(self._tables): + self._selected.clear() + else: + self._selected = set(self._tables) + self._rebuild() + event.prevent_default() + event.stop() + + def _confirm(self) -> None: + if self._selected: + self.dismiss(sorted(self._selected)) + else: + self.dismiss(None) + + def action_cancel(self) -> None: + self.dismiss(None) + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/domains/explorer/app/schema_service.py b/sqlit/domains/explorer/app/schema_service.py index 1252b774..a588b5f1 100644 --- a/sqlit/domains/explorer/app/schema_service.py +++ b/sqlit/domains/explorer/app/schema_service.py @@ -7,8 +7,9 @@ from typing import Any, TypeVar from sqlit.domains.connections.app.session import ConnectionSession -from sqlit.domains.connections.providers.adapters.base import ColumnInfo +from sqlit.domains.connections.providers.adapters.base import ColumnInfo, ForeignKeyInfo from sqlit.domains.connections.providers.model import ( + ForeignKeyInspector, IndexInspector, ProcedureInspector, SequenceInspector, @@ -206,3 +207,13 @@ def get_sequence_definition(self, database: str | None, name: str) -> dict[str, lambda: inspector.get_sequence_definition(self.session.connection, name, db_arg), database, ) + + def list_foreign_keys(self, database: str | None) -> list[ForeignKeyInfo]: + inspector = self.session.provider.schema_inspector + if not isinstance(inspector, ForeignKeyInspector): + return [] + db_arg = self._resolve_db_arg(database) + return self._run_with_retry( + lambda: inspector.get_foreign_keys(self.session.connection, db_arg), + database, + ) diff --git a/sqlit/domains/explorer/state/tree_on_database.py b/sqlit/domains/explorer/state/tree_on_database.py index 8a2336eb..19868227 100644 --- a/sqlit/domains/explorer/state/tree_on_database.py +++ b/sqlit/domains/explorer/state/tree_on_database.py @@ -12,7 +12,7 @@ class TreeOnDatabaseState(State): help_category = "Explorer" def _setup_actions(self) -> None: - pass # Expanding a database now sets it as active automatically + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -20,6 +20,14 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], left.append(DisplayBinding(key="enter", label="Use database", action="toggle_node")) seen.add("toggle_node") + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/state/tree_on_folder.py b/sqlit/domains/explorer/state/tree_on_folder.py index 68c19394..e56f222c 100644 --- a/sqlit/domains/explorer/state/tree_on_folder.py +++ b/sqlit/domains/explorer/state/tree_on_folder.py @@ -25,6 +25,7 @@ def is_connection_folder(app: InputContext) -> bool: label="Delete", help="Delete connection folder", ) + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -32,6 +33,17 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], left.append(DisplayBinding(key="enter", label="Expand", action="toggle_node")) seen.add("toggle_node") + + if app.tree_node_kind == "folder": + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") + left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/state/tree_on_table.py b/sqlit/domains/explorer/state/tree_on_table.py index ee092789..76fe9113 100644 --- a/sqlit/domains/explorer/state/tree_on_table.py +++ b/sqlit/domains/explorer/state/tree_on_table.py @@ -13,6 +13,7 @@ class TreeOnTableState(State): def _setup_actions(self) -> None: self.allows("select_table", label="Select TOP 100", help="Select TOP 100 (table/view)") + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -28,6 +29,14 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], ) ) seen.add("select_table") + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 26cf34b3..223f55ef 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -357,6 +357,111 @@ def action_select_table(self: TreeMixinHost) -> None: tree_object_info.show_sequence_info(self, data) return + def action_show_diagram(self: TreeMixinHost) -> None: + """Show ER diagram for selected table (S key) — single table only.""" + if not self.current_provider or not self._session: + return + + node = self.object_tree.cursor_node + if not node or not node.data: + return + + data = node.data + kind = self._get_node_kind(node) + + if kind in ("table", "view"): + self._show_diagram_for_tables(data.database, [data.name]) + elif kind == "database": + self._show_diagram_for_database(data.name) + elif kind == "folder" and data.folder_type == "tables": + self._show_diagram_for_database(data.database) + else: + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + self._show_diagram_for_database(db) + + def action_show_diagram_picker(self: TreeMixinHost) -> None: + """Show table picker then ER diagram (space+d leader) — always opens picker.""" + if not self.current_provider or not self._session: + return + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + self._show_diagram_for_database(db) + + def _show_diagram_for_tables(self: TreeMixinHost, database: str | None, table_names: list[str]) -> None: + """Show diagram for specific tables only (no auto-expansion).""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + self.notify("Not connected", severity="error") + return + + async def work() -> None: + from sqlit.domains.diagram.app.diagram_service import build_diagram_text, fetch_diagram_data + + cols, fks = await asyncio.to_thread(fetch_diagram_data, schema_service, database, table_names) + text = build_diagram_text(table_names, cols, fks) + label = ", ".join(table_names[:3]) + if len(table_names) > 3: + label += f" +{len(table_names) - 3}" + + from sqlit.domains.diagram.ui.diagram_screen import DiagramScreen + + self.app.push_screen(DiagramScreen(text, title=f"ER Diagram: {label}")) + + self.run_worker(work(), name="diagram-load", exclusive=False) + + def _show_diagram_for_database(self: TreeMixinHost, database: str | None) -> None: + """Show table picker then diagram for a database.""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + self.notify("Not connected", severity="error") + return + + async def work() -> None: + tables = await asyncio.to_thread(schema_service.list_folder_items, "tables", database) + table_names = [name for _, _schema, name in tables] + + if not table_names: + self.notify("No tables found", severity="warning") + return + + from sqlit.domains.diagram.ui.table_picker import DiagramTablePicker + + def on_selected(selected: list[str] | None) -> None: + if selected: + self._show_diagram_for_selected(database, selected) + + self.app.push_screen(DiagramTablePicker(table_names), on_selected) + + self.run_worker(work(), name="diagram-tables", exclusive=False) + + def _show_diagram_for_selected(self: TreeMixinHost, database: str | None, table_names: list[str]) -> None: + """Generate and show diagram for selected tables.""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + return + + async def work() -> None: + from sqlit.domains.diagram.app.diagram_service import build_diagram_text, fetch_diagram_data + + cols, fks = await asyncio.to_thread(fetch_diagram_data, schema_service, database, table_names) + text = build_diagram_text(table_names, cols, fks) + title = f"ER Diagram ({len(table_names)} tables)" + + from sqlit.domains.diagram.ui.diagram_screen import DiagramScreen + + self.app.push_screen(DiagramScreen(text, title=title)) + + self.run_worker(work(), name="diagram-render", exclusive=False) + def action_use_database(self: TreeMixinHost) -> None: """Toggle the selected database as the default for the current connection.""" node = self.object_tree.cursor_node diff --git a/sqlit/domains/shell/ui/mixins/ui_leader.py b/sqlit/domains/shell/ui/mixins/ui_leader.py index 55a225bc..3a04c120 100644 --- a/sqlit/domains/shell/ui/mixins/ui_leader.py +++ b/sqlit/domains/shell/ui/mixins/ui_leader.py @@ -118,5 +118,8 @@ def action_leader_telescope(self: UINavigationMixinHost) -> None: def action_leader_telescope_filter(self: UINavigationMixinHost) -> None: self._execute_leader_command("telescope_filter") + def action_leader_show_diagram_picker(self: UINavigationMixinHost) -> None: + self._execute_leader_command("show_diagram_picker") + def action_leader_quit(self: UINavigationMixinHost) -> None: self._execute_leader_command("quit") diff --git a/tests/unit/test_diagram_layout.py b/tests/unit/test_diagram_layout.py new file mode 100644 index 00000000..bcc89909 --- /dev/null +++ b/tests/unit/test_diagram_layout.py @@ -0,0 +1,110 @@ +"""Tests for ER diagram layout engine.""" + +from __future__ import annotations + +from sqlit.domains.diagram.app.layout import ( + DiagramLayout, + TableBox, + build_layout, + render_diagram, +) + + +class TestTableBox: + def test_compute_dimensions_basic(self) -> None: + box = TableBox( + name="users", + schema="", + columns=[("id", "int", True, False), ("name", "varchar", False, False)], + ) + box.compute_dimensions() + assert box.width > 0 + assert box.height == 5 # top + title_sep + 2 cols + bottom + + def test_compute_dimensions_empty_table(self) -> None: + box = TableBox(name="empty", schema="", columns=[]) + box.compute_dimensions() + assert box.width > 0 + assert box.height == 3 # top + title_sep + bottom + + +class TestBuildLayout: + def test_single_table(self) -> None: + tables = {"users": [("id", "int", True), ("name", "varchar", False)]} + layout = build_layout(tables, []) + assert "users" in layout.tables + assert layout.tables["users"].x == 0 + assert layout.tables["users"].y == 0 + + def test_two_related_tables(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True), ("name", "varchar", False)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + assert len(layout.tables) == 2 + assert len(layout.relationships) == 1 + + def test_fk_marks_column_as_fk(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + order_cols = layout.tables["orders"].columns + assert order_cols[1][3] is True # user_id is_fk + + def test_ignores_fks_for_missing_tables(self) -> None: + tables = {"users": [("id", "int", True)]} + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + assert len(layout.relationships) == 0 + + def test_grid_layout_positions(self) -> None: + tables = {f"t{i}": [("id", "int", True)] for i in range(6)} + layout = build_layout(tables, []) + positions = {name: (box.x, box.y) for name, box in layout.tables.items()} + # All positions should be unique + assert len(set(positions.values())) == 6 + + +class TestRenderDiagram: + def test_empty_tables(self) -> None: + layout = DiagramLayout() + lines = render_diagram(layout) + assert lines == ["(no tables)"] + + def test_renders_table_box(self) -> None: + tables = {"users": [("id", "int", True), ("name", "varchar", False)]} + layout = build_layout(tables, []) + lines = render_diagram(layout) + text = "\n".join(lines) + assert "users" in text + assert "●" in text # PK indicator + assert "id" in text + assert "name" in text + + def test_renders_fk_indicator(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + lines = render_diagram(layout) + text = "\n".join(lines) + assert "○" in text # FK indicator + + def test_renders_relationship_lines(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + lines = render_diagram(layout) + text = "\n".join(lines) + # Should contain relationship line characters + assert any(ch in text for ch in "─│╮╯╰╭◀") From ceb0288fe3483001334ffe6df1e1e4999719d077 Mon Sep 17 00:00:00 2001 From: David Date: Tue, 12 May 2026 15:00:57 -0500 Subject: [PATCH 2/2] Add get_foreign_keys() to all database adapters Implements FK introspection for MySQL, MariaDB, MSSQL, Oracle, Firebird, DuckDB, Turso, Snowflake, Spanner, HANA, DB2, Teradata, Redshift, and MotherDuck. Each adapter queries the database's native metadata catalog to return ForeignKeyInfo objects used by the ER diagram feature. Includes 42 unit tests with mocked cursors and integration test coverage via BaseDatabaseTests. All 9 test fixtures updated with test_orders FK table. --- .../connections/providers/db2/adapter.py | 29 + .../connections/providers/duckdb/adapter.py | 40 ++ .../connections/providers/firebird/adapter.py | 26 + .../connections/providers/hana/adapter.py | 25 + .../connections/providers/mariadb/adapter.py | 32 + .../providers/motherduck/adapter.py | 47 +- .../connections/providers/mssql/adapter.py | 36 ++ .../connections/providers/mysql/base.py | 32 + .../connections/providers/oracle/adapter.py | 29 + .../connections/providers/redshift/adapter.py | 37 ++ .../providers/snowflake/adapter.py | 42 +- .../connections/providers/spanner/adapter.py | 31 + .../connections/providers/teradata/adapter.py | 35 + .../connections/providers/turso/adapter.py | 20 + tests/fixtures/duckdb.py | 9 + tests/fixtures/firebird.py | 10 + tests/fixtures/mariadb.py | 11 + tests/fixtures/mssql.py | 9 + tests/fixtures/mysql.py | 11 + tests/fixtures/oracle.py | 10 + tests/fixtures/postgres.py | 11 + tests/fixtures/sqlite.py | 9 + tests/fixtures/turso.py | 11 + tests/test_database_base.py | 42 ++ tests/unit/test_foreign_keys.py | 607 ++++++++++++++++++ 25 files changed, 1196 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_foreign_keys.py diff --git a/sqlit/domains/connections/providers/db2/adapter.py b/sqlit/domains/connections/providers/db2/adapter.py index 7aafbd82..c0555583 100644 --- a/sqlit/domains/connections/providers/db2/adapter.py +++ b/sqlit/domains/connections/providers/db2/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -151,6 +152,34 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from DB2.""" + cursor = conn.cursor() + cursor.execute( + "SELECT r.constname, r.tabschema, r.tabname, fk.colname, " + "r.reftabschema, r.reftabname, pk.colname " + "FROM syscat.references r " + "JOIN syscat.keycoluse fk ON r.constname = fk.constname " + "AND r.tabschema = fk.tabschema AND r.tabname = fk.tabname " + "JOIN syscat.keycoluse pk ON r.refkeyname = pk.constname " + "AND r.reftabschema = pk.tabschema AND r.reftabname = pk.tabname " + "AND fk.colseq = pk.colseq " + "WHERE r.tabschema NOT LIKE 'SYS%' " + "ORDER BY r.tabname, r.constname" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() cursor.execute( diff --git a/sqlit/domains/connections/providers/duckdb/adapter.py b/sqlit/domains/connections/providers/duckdb/adapter.py index e70fbaa7..021620a5 100644 --- a/sqlit/domains/connections/providers/duckdb/adapter.py +++ b/sqlit/domains/connections/providers/duckdb/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -155,6 +156,45 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] for row in result.fetchall() ] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from DuckDB.""" + result = conn.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " kcu2.table_schema AS target_schema, " + " kcu2.table_name AS target_table, " + " kcu2.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in result.fetchall() + ] + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: """DuckDB doesn't support triggers - return empty list.""" return [] diff --git a/sqlit/domains/connections/providers/firebird/adapter.py b/sqlit/domains/connections/providers/firebird/adapter.py index f424cb0b..8541b254 100644 --- a/sqlit/domains/connections/providers/firebird/adapter.py +++ b/sqlit/domains/connections/providers/firebird/adapter.py @@ -5,6 +5,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -166,6 +167,31 @@ def get_index_definition( "definition": " ".join(definition_parts), } + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Firebird.""" + cursor = conn.cursor() + cursor.execute( + "SELECT rc.rdb$constraint_name, rc.rdb$relation_name, sg.rdb$field_name, " + "rfc.rdb$relation_name, sg2.rdb$field_name " + "FROM rdb$relation_constraints rc " + "JOIN rdb$ref_constraints ref ON rc.rdb$constraint_name = ref.rdb$constraint_name " + "JOIN rdb$relation_constraints rfc ON ref.rdb$const_name_uq = rfc.rdb$constraint_name " + "JOIN rdb$index_segments sg ON rc.rdb$index_name = sg.rdb$index_name " + "JOIN rdb$index_segments sg2 ON rfc.rdb$index_name = sg2.rdb$index_name " + "AND sg.rdb$field_position = sg2.rdb$field_position " + "WHERE rc.rdb$constraint_type = 'FOREIGN KEY'" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0].rstrip(), + source_table=row[1].rstrip(), + source_column=row[2].rstrip(), + target_table=row[3].rstrip(), + target_column=row[4].rstrip(), + ) + for row in cursor.fetchall() + ] + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: cursor = conn.cursor() cursor.execute("SELECT rdb$generator_name FROM rdb$generators WHERE rdb$system_flag = 0") diff --git a/sqlit/domains/connections/providers/hana/adapter.py b/sqlit/domains/connections/providers/hana/adapter.py index 12afe5d3..d55801e6 100644 --- a/sqlit/domains/connections/providers/hana/adapter.py +++ b/sqlit/domains/connections/providers/hana/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -143,6 +144,30 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SAP HANA.""" + cursor = conn.cursor() + cursor.execute( + "SELECT constraint_name, schema_name, table_name, column_name, " + "referenced_schema_name, referenced_table_name, referenced_column_name " + "FROM sys.referential_constraints " + "WHERE schema_name NOT LIKE '_SYS%' " + "AND schema_name NOT IN ('SYS', 'SYSTEM') " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() cursor.execute( diff --git a/sqlit/domains/connections/providers/mariadb/adapter.py b/sqlit/domains/connections/providers/mariadb/adapter.py index ba9459d2..f073c5b2 100644 --- a/sqlit/domains/connections/providers/mariadb/adapter.py +++ b/sqlit/domains/connections/providers/mariadb/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -218,6 +219,37 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from MariaDB (uses ? placeholders).""" + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = ? AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name", + (database,), + ) + else: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from MariaDB (uses ? placeholders).""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/motherduck/adapter.py b/sqlit/domains/connections/providers/motherduck/adapter.py index 72eec7dc..c1888aa8 100644 --- a/sqlit/domains/connections/providers/motherduck/adapter.py +++ b/sqlit/domains/connections/providers/motherduck/adapter.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any -from sqlit.domains.connections.providers.adapters.base import TableInfo +from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo, TableInfo from sqlit.domains.connections.providers.duckdb.adapter import DuckDBAdapter if TYPE_CHECKING: @@ -115,6 +115,51 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) return [(row[0], row[1]) for row in result.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from a specific MotherDuck database.""" + query = ( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " kcu2.table_schema AS target_schema, " + " kcu2.table_name AS target_table, " + " kcu2.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + ) + if database: + query += "AND tc.table_catalog = ? " + query += "ORDER BY tc.table_name, tc.constraint_name" + result = conn.execute(query, (database,)) + else: + query += "ORDER BY tc.table_name, tc.constraint_name" + result = conn.execute(query) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in result.fetchall() + ] + def build_select_query( self, table: str, limit: int, database: str | None = None, schema: str | None = None ) -> str: diff --git a/sqlit/domains/connections/providers/mssql/adapter.py b/sqlit/domains/connections/providers/mssql/adapter.py index 2b734707..60190d1c 100644 --- a/sqlit/domains/connections/providers/mssql/adapter.py +++ b/sqlit/domains/connections/providers/mssql/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -256,6 +257,41 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SQL Server.""" + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT " + " fk.name, " + " sch1.name, " + " tp.name, " + " cp.name, " + " sch2.name, " + " tr.name, " + " cr.name " + "FROM sys.foreign_keys fk " + "JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id " + "JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id " + "JOIN sys.columns cp ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id " + "JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id " + "JOIN sys.columns cr ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id " + "JOIN sys.schemas sch1 ON tp.schema_id = sch1.schema_id " + "JOIN sys.schemas sch2 ON tr.schema_id = sch2.schema_id " + "ORDER BY tp.name, fk.name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from SQL Server.""" cursor = self._get_cursor_for_database(conn, database) diff --git a/sqlit/domains/connections/providers/mysql/base.py b/sqlit/domains/connections/providers/mysql/base.py index d65fcb90..ebd0670b 100644 --- a/sqlit/domains/connections/providers/mysql/base.py +++ b/sqlit/domains/connections/providers/mysql/base.py @@ -10,6 +10,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -148,6 +149,37 @@ def build_select_query(self, table: str, limit: int, database: str | None = None return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}" return f"SELECT * FROM `{table}` LIMIT {limit}" + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from MySQL.""" + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = %s AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name", + (database,), + ) + else: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from MySQL/MariaDB.""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py index 76edc895..08a916b8 100644 --- a/sqlit/domains/connections/providers/oracle/adapter.py +++ b/sqlit/domains/connections/providers/oracle/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -167,6 +168,34 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: finally: cursor.close() + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Oracle.""" + cursor = conn.cursor() + try: + cursor.execute( + "SELECT c.constraint_name, c.table_name, cc.column_name, " + "r.table_name, rc.column_name " + "FROM user_constraints c " + "JOIN user_cons_columns cc ON c.constraint_name = cc.constraint_name " + "JOIN user_constraints r ON c.r_constraint_name = r.constraint_name " + "JOIN user_cons_columns rc ON r.constraint_name = rc.constraint_name " + "AND cc.position = rc.position " + "WHERE c.constraint_type = 'R' " + "ORDER BY c.table_name, c.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + finally: + cursor.close() + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from Oracle.""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/redshift/adapter.py b/sqlit/domains/connections/providers/redshift/adapter.py index 6e29ef71..713af060 100644 --- a/sqlit/domains/connections/providers/redshift/adapter.py +++ b/sqlit/domains/connections/providers/redshift/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -205,6 +206,42 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Redshift.""" + cursor = conn.cursor() + cursor.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " ccu.table_schema AS target_schema, " + " ccu.table_name AS target_table, " + " ccu.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.constraint_column_usage ccu " + " ON ccu.constraint_name = tc.constraint_name " + " AND ccu.table_schema = tc.table_schema " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_internal') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Redshift doesn't have traditional indexes.""" return [] diff --git a/sqlit/domains/connections/providers/snowflake/adapter.py b/sqlit/domains/connections/providers/snowflake/adapter.py index 13e9882d..95a7b84e 100644 --- a/sqlit/domains/connections/providers/snowflake/adapter.py +++ b/sqlit/domains/connections/providers/snowflake/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -211,12 +212,45 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: # deduplicate return sorted(list({row[0] for row in cursor.fetchall()})) + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Snowflake.""" + cursor = conn.cursor() + db_prefix = f"{self.quote_identifier(database)}." if database else "" + sql = ( + "SELECT tc.constraint_name, tc.table_schema, tc.table_name, " + "kcu.column_name, rc.unique_constraint_schema, " + "kcu2.table_name, kcu2.column_name " + f"FROM {db_prefix}information_schema.table_constraints tc " + f"JOIN {db_prefix}information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + f"JOIN {db_prefix}information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + f"JOIN {db_prefix}information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema != 'INFORMATION_SCHEMA' " + "ORDER BY tc.table_name, tc.constraint_name" + ) + cursor.execute(sql) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes.""" - # Snowflake doesn't really have traditional indexes like Postgres/MySQL. - # It has clustering keys, search optimization service, etc. - # But 'SHOW PRIMARY KEYS' or similar might work. - # For now, return empty list as Snowflake is mostly auto-managed. return [] def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: diff --git a/sqlit/domains/connections/providers/spanner/adapter.py b/sqlit/domains/connections/providers/spanner/adapter.py index f5ed784b..c3327d6a 100644 --- a/sqlit/domains/connections/providers/spanner/adapter.py +++ b/sqlit/domains/connections/providers/spanner/adapter.py @@ -24,6 +24,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -261,6 +262,36 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Spanner doesn't support stored procedures.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Spanner INFORMATION_SCHEMA.""" + query = """ + SELECT tc.CONSTRAINT_NAME, tc.TABLE_NAME, + kcu.COLUMN_NAME, + ccu.TABLE_NAME, ccu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc + ON tc.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ccu + ON rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME + AND kcu.ORDINAL_POSITION = ccu.ORDINAL_POSITION + WHERE tc.CONSTRAINT_TYPE = 'FOREIGN KEY' + AND tc.TABLE_SCHEMA = '' + ORDER BY tc.TABLE_NAME, tc.CONSTRAINT_NAME + """ + rows = self._execute_readonly(conn, query) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in rows + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get list of indexes from INFORMATION_SCHEMA.""" query = """ diff --git a/sqlit/domains/connections/providers/teradata/adapter.py b/sqlit/domains/connections/providers/teradata/adapter.py index aa8e0201..c98d0642 100644 --- a/sqlit/domains/connections/providers/teradata/adapter.py +++ b/sqlit/domains/connections/providers/teradata/adapter.py @@ -8,6 +8,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, TableInfo, TriggerInfo, @@ -209,6 +210,40 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Teradata.""" + cursor = conn.cursor() + if database: + cursor.execute( + "lock row for access " + "SELECT IndexName, ChildTable, ChildKeyColumn, " + "ParentTable, ParentKeyColumn, ChildDB, ParentDB " + "FROM DBC.All_RI_ChildrenV " + "WHERE ChildDB = ? " + "ORDER BY ChildTable, IndexName", + (database,), + ) + else: + cursor.execute( + "lock row for access " + "SELECT IndexName, ChildTable, ChildKeyColumn, " + "ParentTable, ParentKeyColumn, ChildDB, ParentDB " + "FROM DBC.All_RI_ChildrenV " + "ORDER BY ChildTable, IndexName" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + source_schema=row[5], + target_schema=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() if database: diff --git a/sqlit/domains/connections/providers/turso/adapter.py b/sqlit/domains/connections/providers/turso/adapter.py index 1e65d546..d626ce23 100644 --- a/sqlit/domains/connections/providers/turso/adapter.py +++ b/sqlit/domains/connections/providers/turso/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -132,6 +133,25 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Turso doesn't support stored procedures - return empty list.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Turso using PRAGMA foreign_key_list.""" + tables = self.get_tables(conn, database) + results: list[ForeignKeyInfo] = [] + for _schema, table_name in tables: + quoted = self.quote_identifier(table_name) + rows = conn.execute(f"PRAGMA foreign_key_list({quoted})").fetchall() + for row in rows: + results.append( + ForeignKeyInfo( + constraint_name=f"fk_{table_name}_{row[3]}", + source_table=table_name, + source_column=row[3], + target_table=row[2], + target_column=row[4], + ) + ) + return results + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from Turso (SQLite-compatible).""" rows = conn.execute( diff --git a/tests/fixtures/duckdb.py b/tests/fixtures/duckdb.py index 9a658d53..3327092b 100644 --- a/tests/fixtures/duckdb.py +++ b/tests/fixtures/duckdb.py @@ -48,6 +48,15 @@ def duckdb_db(duckdb_db_path: Path) -> Path: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + conn.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests conn.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/firebird.py b/tests/fixtures/firebird.py index 18812af9..50676ae8 100644 --- a/tests/fixtures/firebird.py +++ b/tests/fixtures/firebird.py @@ -86,6 +86,15 @@ def firebird_db(firebird_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + RECREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") cursor.execute(""" @@ -135,6 +144,7 @@ def firebird_db(firebird_server_ready: bool) -> str: try: for cleanup in [ "DROP VIEW test_user_emails", + "DROP TABLE test_orders", "DROP TABLE test_users", "DROP TABLE test_products", "DROP TRIGGER trg_test_users_audit", diff --git a/tests/fixtures/mariadb.py b/tests/fixtures/mariadb.py index 7046abbd..73419160 100644 --- a/tests/fixtures/mariadb.py +++ b/tests/fixtures/mariadb.py @@ -54,6 +54,7 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -81,6 +82,15 @@ def mariadb_db(mariadb_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -130,6 +140,7 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) cursor = conn.cursor() cursor.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/mssql.py b/tests/fixtures/mssql.py index 09f8d6de..0b4395a1 100644 --- a/tests/fixtures/mssql.py +++ b/tests/fixtures/mssql.py @@ -104,6 +104,15 @@ def mssql_db(mssql_server_ready: bool) -> str: END """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/mysql.py b/tests/fixtures/mysql.py index 8ee19774..93723d8a 100644 --- a/tests/fixtures/mysql.py +++ b/tests/fixtures/mysql.py @@ -54,6 +54,7 @@ def mysql_db(mysql_server_ready: bool) -> str: ) cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -80,6 +81,15 @@ def mysql_db(mysql_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -126,6 +136,7 @@ def mysql_db(mysql_server_ready: bool) -> str: ) cursor = conn.cursor() cursor.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/oracle.py b/tests/fixtures/oracle.py index 312e12de..4584ddea 100644 --- a/tests/fixtures/oracle.py +++ b/tests/fixtures/oracle.py @@ -86,6 +86,15 @@ def oracle_db(oracle_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id NUMBER PRIMARY KEY, + user_id NUMBER NOT NULL, + amount NUMBER(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") cursor.execute(""" @@ -141,6 +150,7 @@ def oracle_db(oracle_server_ready: bool) -> str: for stmt in [ "DROP VIEW test_user_emails", + "DROP TABLE test_orders CASCADE CONSTRAINTS", "DROP TABLE test_users CASCADE CONSTRAINTS", "DROP TABLE test_products CASCADE CONSTRAINTS", "DROP SEQUENCE test_sequence", diff --git a/tests/fixtures/postgres.py b/tests/fixtures/postgres.py index 0fa9404a..f2f09d3e 100644 --- a/tests/fixtures/postgres.py +++ b/tests/fixtures/postgres.py @@ -54,6 +54,7 @@ def postgres_db(postgres_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders CASCADE") cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -80,6 +81,15 @@ def postgres_db(postgres_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -132,6 +142,7 @@ def postgres_db(postgres_server_ready: bool) -> str: ) conn.autocommit = True cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders CASCADE") cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/sqlite.py b/tests/fixtures/sqlite.py index ac8f6a23..6b833e67 100644 --- a/tests/fixtures/sqlite.py +++ b/tests/fixtures/sqlite.py @@ -45,6 +45,15 @@ def sqlite_db(sqlite_db_path: Path) -> Path: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount REAL NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/turso.py b/tests/fixtures/turso.py index be369088..47bb5428 100644 --- a/tests/fixtures/turso.py +++ b/tests/fixtures/turso.py @@ -86,6 +86,7 @@ def _setup_turso_test_tables(client) -> None: client.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") client.execute("DROP INDEX IF EXISTS idx_test_users_email") client.execute("DROP VIEW IF EXISTS test_user_emails") + client.execute("DROP TABLE IF EXISTS test_orders") client.execute("DROP TABLE IF EXISTS test_users") client.execute("DROP TABLE IF EXISTS test_products") @@ -107,6 +108,15 @@ def _setup_turso_test_tables(client) -> None: ) """) + client.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount REAL NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create view client.execute(""" CREATE VIEW test_user_emails AS @@ -149,6 +159,7 @@ def _cleanup_turso_test_tables(client) -> None: client.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") client.execute("DROP INDEX IF EXISTS idx_test_users_email") client.execute("DROP VIEW IF EXISTS test_user_emails") + client.execute("DROP TABLE IF EXISTS test_orders") client.execute("DROP TABLE IF EXISTS test_users") client.execute("DROP TABLE IF EXISTS test_products") client.commit() diff --git a/tests/test_database_base.py b/tests/test_database_base.py index 9030386d..40bd66c7 100644 --- a/tests/test_database_base.py +++ b/tests/test_database_base.py @@ -440,6 +440,48 @@ def test_get_indexes(self, request): f"Index should be on test_users table, got {test_index.table_name}" ) + def test_get_foreign_keys(self, request): + """Test that adapter correctly retrieves foreign keys. + + This tests that get_foreign_keys returns ForeignKeyInfo objects for + foreign keys created on the test tables. The test fixture should create + a test_orders table with a FK from user_id to test_users(id). + """ + from sqlit.domains.connections.app.session import ConnectionSession + from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo + from sqlit.domains.connections.providers.registry import get_adapter + from sqlit.domains.connections.store.connections import load_connections + + connection_name = request.getfixturevalue(self.config.connection_fixture) + connections = load_connections() + config = next((c for c in connections if c.name == connection_name), None) + assert config is not None, f"Connection {connection_name} not found" + + with ConnectionSession.create(config, get_adapter) as session: + foreign_keys = session.adapter.get_foreign_keys( + session.connection, + database=config.database if session.adapter.supports_multiple_databases else None, + ) + + assert isinstance(foreign_keys, list), "get_foreign_keys should return a list" + for fk in foreign_keys: + assert isinstance(fk, ForeignKeyInfo), f"Expected ForeignKeyInfo, got {type(fk)}" + + test_fk = next( + (fk for fk in foreign_keys if "user_id" in fk.source_column.lower()), + None, + ) + assert test_fk is not None, ( + f"FK on user_id not found. " + f"Found FKs: {[(fk.source_table, fk.source_column) for fk in foreign_keys]}" + ) + assert "test_orders" in test_fk.source_table.lower(), ( + f"FK source should be test_orders, got {test_fk.source_table}" + ) + assert "test_users" in test_fk.target_table.lower(), ( + f"FK target should be test_users, got {test_fk.target_table}" + ) + def test_get_triggers(self, request): """Test that adapter correctly retrieves triggers. diff --git a/tests/unit/test_foreign_keys.py b/tests/unit/test_foreign_keys.py new file mode 100644 index 00000000..9eb89ef1 --- /dev/null +++ b/tests/unit/test_foreign_keys.py @@ -0,0 +1,607 @@ +"""Unit tests for get_foreign_keys() across all database adapters.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo + + +def _assert_fk_list(result: list, expected_count: int = 1) -> None: + assert isinstance(result, list) + assert len(result) == expected_count + for fk in result: + assert isinstance(fk, ForeignKeyInfo) + assert fk.constraint_name + assert fk.source_table + assert fk.source_column + assert fk.target_table + assert fk.target_column + + +# -- PostgreSQL ---------------------------------------------------------------- + +class TestPostgreSQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.postgresql.adapter import PostgreSQLAdapter + return PostgreSQLAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "public", "orders", "user_id", "public", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + assert result[0].source_schema == "public" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_contains_foreign_key_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "FOREIGN KEY" in sql + assert "table_constraints" in sql + + +# -- SQLite -------------------------------------------------------------------- + +class TestSQLiteForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.sqlite.adapter import SQLiteAdapter + return SQLiteAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + # get_tables calls sqlite_master -> returns raw name rows + cursor.fetchall.side_effect = [ + [("orders",)], # get_tables -> sqlite_master + [(0, 0, "users", "user_id", "id", "", "", "")], # PRAGMA foreign_key_list + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].source_column == "user_id" + assert result[0].target_table == "users" + assert result[0].target_column == "id" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.side_effect = [ + [("orders",)], # get_tables + [], # PRAGMA foreign_key_list returns nothing + ] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- DuckDB -------------------------------------------------------------------- + +class TestDuckDBForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.duckdb.adapter import DuckDBAdapter + return DuckDBAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [ + ("fk_orders_user_id", "main", "orders", "user_id", "main", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].source_schema == "main" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_contains_foreign_key_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = mock_conn.execute.call_args[0][0] + assert "FOREIGN KEY" in sql + assert "referential_constraints" in sql + + +# -- MySQL --------------------------------------------------------------------- + +class TestMySQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mysql.adapter import MySQLAdapter + return MySQLAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "orders", "user_id", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_with_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="mydb") + sql = cursor.execute.call_args[0][0] + assert "table_schema = %s" in sql + + def test_without_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "DATABASE()" in sql + + +# -- MariaDB ------------------------------------------------------------------- + +class TestMariaDBForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mariadb.adapter import MariaDBAdapter + return MariaDBAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "orders", "user_id", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + + def test_with_database_uses_question_mark(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="mydb") + sql = cursor.execute.call_args[0][0] + assert "table_schema = ?" in sql + assert "%s" not in sql + + +# -- SQL Server ---------------------------------------------------------------- + +class TestMSSQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter + return SQLServerAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_orders_users", "dbo", "orders", "user_id", "dbo", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "dbo" + assert result[0].target_schema == "dbo" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_uses_sys_foreign_keys(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "sys.foreign_keys" in sql + + +# -- Oracle -------------------------------------------------------------------- + +class TestOracleForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + return OracleAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "ORDERS", "USER_ID", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "ORDERS" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_cursor_is_closed(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + cursor.close.assert_called_once() + + +# -- Firebird ------------------------------------------------------------------ + +class TestFirebirdForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.firebird.adapter import FirebirdAdapter + return FirebirdAdapter() + + def test_returns_fk_info_with_stripped_names(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS ", "ORDERS ", "USER_ID ", "USERS ", "ID "), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].constraint_name == "FK_ORDERS" + assert result[0].source_table == "ORDERS" + assert result[0].source_column == "USER_ID" + assert result[0].target_table == "USERS" + assert result[0].target_column == "ID" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- Turso --------------------------------------------------------------------- + +class TestTursoForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.turso.adapter import TursoAdapter + return TursoAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + # get_tables uses conn.execute().fetchall() + tables_result = MagicMock() + fk_result = MagicMock() + mock_conn.execute.side_effect = [tables_result, fk_result] + tables_result.fetchall.return_value = [("orders",)] + fk_result.fetchall.return_value = [ + (0, 0, "users", "user_id", "id", "", "", ""), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + tables_result = MagicMock() + fk_result = MagicMock() + mock_conn.execute.side_effect = [tables_result, fk_result] + tables_result.fetchall.return_value = [("orders",)] + fk_result.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- Snowflake ----------------------------------------------------------------- + +class TestSnowflakeForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.snowflake.adapter import SnowflakeAdapter + return SnowflakeAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "PUBLIC", "ORDERS", "USER_ID", "PUBLIC", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "PUBLIC" + + def test_with_database_uses_prefix(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="MYDB") + sql = cursor.execute.call_args[0][0] + assert '"MYDB".' in sql + + +# -- Spanner ------------------------------------------------------------------- + +class TestSpannerForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.spanner.adapter import SpannerAdapter + return SpannerAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.autocommit = False + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_Orders_Users", "Orders", "UserId", "Users", "Id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.autocommit = False + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- HANA ---------------------------------------------------------------------- + +class TestHanaForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.hana.adapter import HanaAdapter + return HanaAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "MYSCHEMA", "ORDERS", "USER_ID", "MYSCHEMA", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYSCHEMA" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- DB2 ----------------------------------------------------------------------- + +class TestDb2ForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.db2.adapter import Db2Adapter + return Db2Adapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS", "MYSCHEMA", "ORDERS", "USER_ID", "MYSCHEMA", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYSCHEMA" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_uses_syscat(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "syscat.references" in sql + + +# -- Teradata ------------------------------------------------------------------ + +class TestTeradataForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.teradata.adapter import TeradataAdapter + return TeradataAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS", "ORDERS", "USER_ID", "USERS", "ID", "MYDB", "MYDB"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYDB" + + def test_with_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="MYDB") + sql = cursor.execute.call_args[0][0] + assert "ChildDB = ?" in sql + + def test_sql_uses_lock_row(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "lock row for access" in sql + + +# -- Redshift ------------------------------------------------------------------ + +class TestRedshiftForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.redshift.adapter import RedshiftAdapter + return RedshiftAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "public", "orders", "user_id", "public", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "public" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_excludes_system_schemas(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "pg_internal" in sql + + +# -- MotherDuck ---------------------------------------------------------------- + +class TestMotherDuckForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.motherduck.adapter import MotherDuckAdapter + return MotherDuckAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [ + ("fk_orders_user_id", "main", "orders", "user_id", "main", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "main" + + def test_with_database_filters_by_catalog(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="my_db") + sql = mock_conn.execute.call_args[0][0] + assert "table_catalog" in sql + assert mock_conn.execute.call_args[0][1] == ("my_db",) + + def test_without_database_no_catalog_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = mock_conn.execute.call_args[0][0] + assert "table_catalog" not in sql