diff --git a/sqlit/core/keymap.py b/sqlit/core/keymap.py index bcd4aed3..c7d7bcef 100644 --- a/sqlit/core/keymap.py +++ b/sqlit/core/keymap.py @@ -196,6 +196,7 @@ def _build_leader_commands(self) -> list[LeaderCommandDef]: # Actions LeaderCommandDef("z", "cancel_operation", "Cancel", "Actions", guard="query_executing"), LeaderCommandDef("t", "change_theme", "Change Theme", "Actions"), + LeaderCommandDef("d", "show_diagram_picker", "ER Diagram", "Actions", guard="has_connection"), LeaderCommandDef("h", "show_help", "Help", "Actions"), LeaderCommandDef("space", "telescope", "Telescope", "Actions"), LeaderCommandDef("slash", "telescope_filter", "Telescope Search", "Actions"), @@ -309,6 +310,7 @@ def _build_action_keys(self) -> list[ActionKeyDef]: ActionKeyDef("v", "exit_tree_visual_mode", "tree_visual", primary=False), ActionKeyDef("escape", "clear_connection_selection", "tree"), ActionKeyDef("s", "select_table", "tree"), + ActionKeyDef("S", "show_diagram", "tree"), ActionKeyDef("f", "refresh_tree", "tree"), ActionKeyDef("R", "refresh_tree", "tree", primary=False), ActionKeyDef("e", "edit_connection", "tree"), diff --git a/sqlit/domains/connections/providers/adapters/base.py b/sqlit/domains/connections/providers/adapters/base.py index 109ea1da..ff1842f5 100644 --- a/sqlit/domains/connections/providers/adapters/base.py +++ b/sqlit/domains/connections/providers/adapters/base.py @@ -69,6 +69,19 @@ class SequenceInfo: name: str +@dataclass +class ForeignKeyInfo: + """Information about a foreign key relationship.""" + + constraint_name: str + source_table: str + source_column: str + target_table: str + target_column: str + source_schema: str = "" + target_schema: str = "" + + # Type alias for table/view info: (schema, name) TableInfo = tuple[str, str] @@ -343,6 +356,14 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence """ pass + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get list of foreign keys in the database. + + Returns: + List of ForeignKeyInfo objects describing FK relationships. + """ + return [] + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: @@ -518,6 +539,7 @@ def execute_non_query(self, conn: Any, query: str) -> int: __all__ = [ "ColumnInfo", "DatabaseAdapter", + "ForeignKeyInfo", "IndexInfo", "SequenceInfo", "TableInfo", diff --git a/sqlit/domains/connections/providers/db2/adapter.py b/sqlit/domains/connections/providers/db2/adapter.py index 7aafbd82..c0555583 100644 --- a/sqlit/domains/connections/providers/db2/adapter.py +++ b/sqlit/domains/connections/providers/db2/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -151,6 +152,34 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from DB2.""" + cursor = conn.cursor() + cursor.execute( + "SELECT r.constname, r.tabschema, r.tabname, fk.colname, " + "r.reftabschema, r.reftabname, pk.colname " + "FROM syscat.references r " + "JOIN syscat.keycoluse fk ON r.constname = fk.constname " + "AND r.tabschema = fk.tabschema AND r.tabname = fk.tabname " + "JOIN syscat.keycoluse pk ON r.refkeyname = pk.constname " + "AND r.reftabschema = pk.tabschema AND r.reftabname = pk.tabname " + "AND fk.colseq = pk.colseq " + "WHERE r.tabschema NOT LIKE 'SYS%' " + "ORDER BY r.tabname, r.constname" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() cursor.execute( diff --git a/sqlit/domains/connections/providers/duckdb/adapter.py b/sqlit/domains/connections/providers/duckdb/adapter.py index e70fbaa7..021620a5 100644 --- a/sqlit/domains/connections/providers/duckdb/adapter.py +++ b/sqlit/domains/connections/providers/duckdb/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -155,6 +156,45 @@ def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo] for row in result.fetchall() ] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from DuckDB.""" + result = conn.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " kcu2.table_schema AS target_schema, " + " kcu2.table_name AS target_table, " + " kcu2.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in result.fetchall() + ] + def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: """DuckDB doesn't support triggers - return empty list.""" return [] diff --git a/sqlit/domains/connections/providers/firebird/adapter.py b/sqlit/domains/connections/providers/firebird/adapter.py index f424cb0b..8541b254 100644 --- a/sqlit/domains/connections/providers/firebird/adapter.py +++ b/sqlit/domains/connections/providers/firebird/adapter.py @@ -5,6 +5,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -166,6 +167,31 @@ def get_index_definition( "definition": " ".join(definition_parts), } + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Firebird.""" + cursor = conn.cursor() + cursor.execute( + "SELECT rc.rdb$constraint_name, rc.rdb$relation_name, sg.rdb$field_name, " + "rfc.rdb$relation_name, sg2.rdb$field_name " + "FROM rdb$relation_constraints rc " + "JOIN rdb$ref_constraints ref ON rc.rdb$constraint_name = ref.rdb$constraint_name " + "JOIN rdb$relation_constraints rfc ON ref.rdb$const_name_uq = rfc.rdb$constraint_name " + "JOIN rdb$index_segments sg ON rc.rdb$index_name = sg.rdb$index_name " + "JOIN rdb$index_segments sg2 ON rfc.rdb$index_name = sg2.rdb$index_name " + "AND sg.rdb$field_position = sg2.rdb$field_position " + "WHERE rc.rdb$constraint_type = 'FOREIGN KEY'" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0].rstrip(), + source_table=row[1].rstrip(), + source_column=row[2].rstrip(), + target_table=row[3].rstrip(), + target_column=row[4].rstrip(), + ) + for row in cursor.fetchall() + ] + def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: cursor = conn.cursor() cursor.execute("SELECT rdb$generator_name FROM rdb$generators WHERE rdb$system_flag = 0") diff --git a/sqlit/domains/connections/providers/hana/adapter.py b/sqlit/domains/connections/providers/hana/adapter.py index 12afe5d3..d55801e6 100644 --- a/sqlit/domains/connections/providers/hana/adapter.py +++ b/sqlit/domains/connections/providers/hana/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -143,6 +144,30 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SAP HANA.""" + cursor = conn.cursor() + cursor.execute( + "SELECT constraint_name, schema_name, table_name, column_name, " + "referenced_schema_name, referenced_table_name, referenced_column_name " + "FROM sys.referential_constraints " + "WHERE schema_name NOT LIKE '_SYS%' " + "AND schema_name NOT IN ('SYS', 'SYSTEM') " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() cursor.execute( diff --git a/sqlit/domains/connections/providers/mariadb/adapter.py b/sqlit/domains/connections/providers/mariadb/adapter.py index ba9459d2..f073c5b2 100644 --- a/sqlit/domains/connections/providers/mariadb/adapter.py +++ b/sqlit/domains/connections/providers/mariadb/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -218,6 +219,37 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from MariaDB (uses ? placeholders).""" + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = ? AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name", + (database,), + ) + else: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from MariaDB (uses ? placeholders).""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/model.py b/sqlit/domains/connections/providers/model.py index b104db3c..50ffe85d 100644 --- a/sqlit/domains/connections/providers/model.py +++ b/sqlit/domains/connections/providers/model.py @@ -116,6 +116,10 @@ def get_sequence_definition(self, conn: Any, sequence_name: str, database: str | @runtime_checkable +class ForeignKeyInspector(Protocol): + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[Any]: ... + + @runtime_checkable class ProcedureInspector(Protocol): def get_procedures(self, conn: Any, database: str | None = None) -> list[Any]: ... diff --git a/sqlit/domains/connections/providers/motherduck/adapter.py b/sqlit/domains/connections/providers/motherduck/adapter.py index 72eec7dc..c1888aa8 100644 --- a/sqlit/domains/connections/providers/motherduck/adapter.py +++ b/sqlit/domains/connections/providers/motherduck/adapter.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any -from sqlit.domains.connections.providers.adapters.base import TableInfo +from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo, TableInfo from sqlit.domains.connections.providers.duckdb.adapter import DuckDBAdapter if TYPE_CHECKING: @@ -115,6 +115,51 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) return [(row[0], row[1]) for row in result.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from a specific MotherDuck database.""" + query = ( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " kcu2.table_schema AS target_schema, " + " kcu2.table_name AS target_table, " + " kcu2.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + ) + if database: + query += "AND tc.table_catalog = ? " + query += "ORDER BY tc.table_name, tc.constraint_name" + result = conn.execute(query, (database,)) + else: + query += "ORDER BY tc.table_name, tc.constraint_name" + result = conn.execute(query) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in result.fetchall() + ] + def build_select_query( self, table: str, limit: int, database: str | None = None, schema: str | None = None ) -> str: diff --git a/sqlit/domains/connections/providers/mssql/adapter.py b/sqlit/domains/connections/providers/mssql/adapter.py index 2b734707..60190d1c 100644 --- a/sqlit/domains/connections/providers/mssql/adapter.py +++ b/sqlit/domains/connections/providers/mssql/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -256,6 +257,41 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SQL Server.""" + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT " + " fk.name, " + " sch1.name, " + " tp.name, " + " cp.name, " + " sch2.name, " + " tr.name, " + " cr.name " + "FROM sys.foreign_keys fk " + "JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id " + "JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id " + "JOIN sys.columns cp ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id " + "JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id " + "JOIN sys.columns cr ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id " + "JOIN sys.schemas sch1 ON tp.schema_id = sch1.schema_id " + "JOIN sys.schemas sch2 ON tr.schema_id = sch2.schema_id " + "ORDER BY tp.name, fk.name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from SQL Server.""" cursor = self._get_cursor_for_database(conn, database) diff --git a/sqlit/domains/connections/providers/mysql/base.py b/sqlit/domains/connections/providers/mysql/base.py index d65fcb90..ebd0670b 100644 --- a/sqlit/domains/connections/providers/mysql/base.py +++ b/sqlit/domains/connections/providers/mysql/base.py @@ -10,6 +10,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -148,6 +149,37 @@ def build_select_query(self, table: str, limit: int, database: str | None = None return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}" return f"SELECT * FROM `{table}` LIMIT {limit}" + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from MySQL.""" + cursor = conn.cursor() + if database: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = %s AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name", + (database,), + ) + else: + cursor.execute( + "SELECT constraint_name, table_name, column_name, " + "referenced_table_name, referenced_column_name " + "FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND referenced_table_name IS NOT NULL " + "ORDER BY table_name, constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from MySQL/MariaDB.""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py index 76edc895..08a916b8 100644 --- a/sqlit/domains/connections/providers/oracle/adapter.py +++ b/sqlit/domains/connections/providers/oracle/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -167,6 +168,34 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: finally: cursor.close() + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Oracle.""" + cursor = conn.cursor() + try: + cursor.execute( + "SELECT c.constraint_name, c.table_name, cc.column_name, " + "r.table_name, rc.column_name " + "FROM user_constraints c " + "JOIN user_cons_columns cc ON c.constraint_name = cc.constraint_name " + "JOIN user_constraints r ON c.r_constraint_name = r.constraint_name " + "JOIN user_cons_columns rc ON r.constraint_name = rc.constraint_name " + "AND cc.position = rc.position " + "WHERE c.constraint_type = 'R' " + "ORDER BY c.table_name, c.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in cursor.fetchall() + ] + finally: + cursor.close() + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from Oracle.""" cursor = conn.cursor() diff --git a/sqlit/domains/connections/providers/postgresql/base.py b/sqlit/domains/connections/providers/postgresql/base.py index c5140eca..034a7fe2 100644 --- a/sqlit/domains/connections/providers/postgresql/base.py +++ b/sqlit/domains/connections/providers/postgresql/base.py @@ -8,6 +8,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -174,6 +175,42 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence ) return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from PostgreSQL.""" + cursor = conn.cursor() + cursor.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " ccu.table_schema AS target_schema, " + " ccu.table_name AS target_table, " + " ccu.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.constraint_column_usage ccu " + " ON ccu.constraint_name = tc.constraint_name " + " AND ccu.table_schema = tc.table_schema " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: diff --git a/sqlit/domains/connections/providers/redshift/adapter.py b/sqlit/domains/connections/providers/redshift/adapter.py index 6e29ef71..713af060 100644 --- a/sqlit/domains/connections/providers/redshift/adapter.py +++ b/sqlit/domains/connections/providers/redshift/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -205,6 +206,42 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Redshift.""" + cursor = conn.cursor() + cursor.execute( + "SELECT " + " tc.constraint_name, " + " tc.table_schema AS source_schema, " + " tc.table_name AS source_table, " + " kcu.column_name AS source_column, " + " ccu.table_schema AS target_schema, " + " ccu.table_name AS target_table, " + " ccu.column_name AS target_column " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "JOIN information_schema.constraint_column_usage ccu " + " ON ccu.constraint_name = tc.constraint_name " + " AND ccu.table_schema = tc.table_schema " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_internal') " + "ORDER BY tc.table_name, tc.constraint_name" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Redshift doesn't have traditional indexes.""" return [] diff --git a/sqlit/domains/connections/providers/snowflake/adapter.py b/sqlit/domains/connections/providers/snowflake/adapter.py index 13e9882d..95a7b84e 100644 --- a/sqlit/domains/connections/providers/snowflake/adapter.py +++ b/sqlit/domains/connections/providers/snowflake/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -211,12 +212,45 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: # deduplicate return sorted(list({row[0] for row in cursor.fetchall()})) + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Snowflake.""" + cursor = conn.cursor() + db_prefix = f"{self.quote_identifier(database)}." if database else "" + sql = ( + "SELECT tc.constraint_name, tc.table_schema, tc.table_name, " + "kcu.column_name, rc.unique_constraint_schema, " + "kcu2.table_name, kcu2.column_name " + f"FROM {db_prefix}information_schema.table_constraints tc " + f"JOIN {db_prefix}information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + f"JOIN {db_prefix}information_schema.referential_constraints rc " + " ON tc.constraint_name = rc.constraint_name " + " AND tc.constraint_schema = rc.constraint_schema " + f"JOIN {db_prefix}information_schema.key_column_usage kcu2 " + " ON rc.unique_constraint_name = kcu2.constraint_name " + " AND rc.unique_constraint_schema = kcu2.constraint_schema " + " AND kcu.ordinal_position = kcu2.ordinal_position " + "WHERE tc.constraint_type = 'FOREIGN KEY' " + "AND tc.table_schema != 'INFORMATION_SCHEMA' " + "ORDER BY tc.table_name, tc.constraint_name" + ) + cursor.execute(sql) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_schema=row[1], + source_table=row[2], + source_column=row[3], + target_schema=row[4], + target_table=row[5], + target_column=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes.""" - # Snowflake doesn't really have traditional indexes like Postgres/MySQL. - # It has clustering keys, search optimization service, etc. - # But 'SHOW PRIMARY KEYS' or similar might work. - # For now, return empty list as Snowflake is mostly auto-managed. return [] def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: diff --git a/sqlit/domains/connections/providers/spanner/adapter.py b/sqlit/domains/connections/providers/spanner/adapter.py index f5ed784b..c3327d6a 100644 --- a/sqlit/domains/connections/providers/spanner/adapter.py +++ b/sqlit/domains/connections/providers/spanner/adapter.py @@ -24,6 +24,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -261,6 +262,36 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Spanner doesn't support stored procedures.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Spanner INFORMATION_SCHEMA.""" + query = """ + SELECT tc.CONSTRAINT_NAME, tc.TABLE_NAME, + kcu.COLUMN_NAME, + ccu.TABLE_NAME, ccu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc + ON tc.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ccu + ON rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME + AND kcu.ORDINAL_POSITION = ccu.ORDINAL_POSITION + WHERE tc.CONSTRAINT_TYPE = 'FOREIGN KEY' + AND tc.TABLE_SCHEMA = '' + ORDER BY tc.TABLE_NAME, tc.CONSTRAINT_NAME + """ + rows = self._execute_readonly(conn, query) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + ) + for row in rows + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get list of indexes from INFORMATION_SCHEMA.""" query = """ diff --git a/sqlit/domains/connections/providers/sqlite/adapter.py b/sqlit/domains/connections/providers/sqlite/adapter.py index 2b7315f4..98464367 100644 --- a/sqlit/domains/connections/providers/sqlite/adapter.py +++ b/sqlit/domains/connections/providers/sqlite/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -121,6 +122,27 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence """SQLite doesn't support sequences - return empty list.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from SQLite using PRAGMA foreign_key_list for each table.""" + tables = self.get_tables(conn, database) + results: list[ForeignKeyInfo] = [] + for _schema, table_name in tables: + cursor = conn.cursor() + quoted = self.quote_identifier(table_name) + cursor.execute(f"PRAGMA foreign_key_list({quoted})") + for row in cursor.fetchall(): + # row: id, seq, table, from, to, on_update, on_delete, match + results.append( + ForeignKeyInfo( + constraint_name=f"fk_{table_name}_{row[3]}", + source_table=table_name, + source_column=row[3], + target_table=row[2], + target_column=row[4], + ) + ) + return results + def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: diff --git a/sqlit/domains/connections/providers/teradata/adapter.py b/sqlit/domains/connections/providers/teradata/adapter.py index aa8e0201..c98d0642 100644 --- a/sqlit/domains/connections/providers/teradata/adapter.py +++ b/sqlit/domains/connections/providers/teradata/adapter.py @@ -8,6 +8,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, CursorBasedAdapter, + ForeignKeyInfo, IndexInfo, TableInfo, TriggerInfo, @@ -209,6 +210,40 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: ) return [row[0] for row in cursor.fetchall()] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Teradata.""" + cursor = conn.cursor() + if database: + cursor.execute( + "lock row for access " + "SELECT IndexName, ChildTable, ChildKeyColumn, " + "ParentTable, ParentKeyColumn, ChildDB, ParentDB " + "FROM DBC.All_RI_ChildrenV " + "WHERE ChildDB = ? " + "ORDER BY ChildTable, IndexName", + (database,), + ) + else: + cursor.execute( + "lock row for access " + "SELECT IndexName, ChildTable, ChildKeyColumn, " + "ParentTable, ParentKeyColumn, ChildDB, ParentDB " + "FROM DBC.All_RI_ChildrenV " + "ORDER BY ChildTable, IndexName" + ) + return [ + ForeignKeyInfo( + constraint_name=row[0], + source_table=row[1], + source_column=row[2], + target_table=row[3], + target_column=row[4], + source_schema=row[5], + target_schema=row[6], + ) + for row in cursor.fetchall() + ] + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: cursor = conn.cursor() if database: diff --git a/sqlit/domains/connections/providers/turso/adapter.py b/sqlit/domains/connections/providers/turso/adapter.py index 1e65d546..d626ce23 100644 --- a/sqlit/domains/connections/providers/turso/adapter.py +++ b/sqlit/domains/connections/providers/turso/adapter.py @@ -7,6 +7,7 @@ from sqlit.domains.connections.providers.adapters.base import ( ColumnInfo, DatabaseAdapter, + ForeignKeyInfo, IndexInfo, SequenceInfo, TableInfo, @@ -132,6 +133,25 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Turso doesn't support stored procedures - return empty list.""" return [] + def get_foreign_keys(self, conn: Any, database: str | None = None) -> list[ForeignKeyInfo]: + """Get foreign keys from Turso using PRAGMA foreign_key_list.""" + tables = self.get_tables(conn, database) + results: list[ForeignKeyInfo] = [] + for _schema, table_name in tables: + quoted = self.quote_identifier(table_name) + rows = conn.execute(f"PRAGMA foreign_key_list({quoted})").fetchall() + for row in rows: + results.append( + ForeignKeyInfo( + constraint_name=f"fk_{table_name}_{row[3]}", + source_table=table_name, + source_column=row[3], + target_table=row[2], + target_column=row[4], + ) + ) + return results + def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from Turso (SQLite-compatible).""" rows = conn.execute( diff --git a/sqlit/domains/diagram/__init__.py b/sqlit/domains/diagram/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/app/__init__.py b/sqlit/domains/diagram/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/app/diagram_service.py b/sqlit/domains/diagram/app/diagram_service.py new file mode 100644 index 00000000..bc0fc5d2 --- /dev/null +++ b/sqlit/domains/diagram/app/diagram_service.py @@ -0,0 +1,52 @@ +"""Service for building ER diagrams from schema data.""" + +from __future__ import annotations + +from typing import Any + +from sqlit.domains.connections.providers.adapters.base import ColumnInfo, ForeignKeyInfo +from sqlit.domains.diagram.app.layout import build_layout, render_diagram + + +def build_diagram_text( + table_names: list[str], + columns_by_table: dict[str, list[ColumnInfo]], + foreign_keys: list[ForeignKeyInfo], + schemas: dict[str, str] | None = None, +) -> str: + """Build diagram text from schema data.""" + tables: dict[str, list[tuple[str, str, bool]]] = {} + for name in table_names: + cols = columns_by_table.get(name, []) + tables[name] = [(c.name, c.data_type, c.is_primary_key) for c in cols] + + table_set = set(table_names) + fk_tuples: list[tuple[str, str, str, str]] = [] + for fk in foreign_keys: + if fk.source_table in table_set: + fk_tuples.append((fk.source_table, fk.source_column, fk.target_table, fk.target_column)) + + layout = build_layout(tables, fk_tuples, schemas) + lines = render_diagram(layout) + return "\n".join(lines) + + +def fetch_diagram_data( + schema_service: Any, + database: str | None, + table_names: list[str], +) -> tuple[dict[str, list[ColumnInfo]], list[ForeignKeyInfo]]: + """Fetch columns and FK data for a set of tables.""" + columns_by_table: dict[str, list[ColumnInfo]] = {} + for name in table_names: + try: + columns_by_table[name] = schema_service.list_columns(database, None, name) + except Exception: + columns_by_table[name] = [] + + try: + foreign_keys: list[ForeignKeyInfo] = schema_service.list_foreign_keys(database) + except Exception: + foreign_keys = [] + + return columns_by_table, foreign_keys diff --git a/sqlit/domains/diagram/app/layout.py b/sqlit/domains/diagram/app/layout.py new file mode 100644 index 00000000..ba6e0cb7 --- /dev/null +++ b/sqlit/domains/diagram/app/layout.py @@ -0,0 +1,342 @@ +"""ER diagram layout engine for terminal rendering.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class TableBox: + """A rendered table box with position and content.""" + + name: str + schema: str + columns: list[tuple[str, str, bool, bool]] # (name, type, is_pk, is_fk) + x: int = 0 + y: int = 0 + width: int = 0 + height: int = 0 + + def compute_dimensions(self) -> None: + name_len = len(self.name) + 2 + col_widths = [len(f" {c[0]} {c[1]} ") + (3 if c[2] or c[3] else 0) for c in self.columns] + content_width = max([name_len, *col_widths]) if col_widths else name_len + self.width = content_width + 2 # borders + self.height = len(self.columns) + 3 # top border + title + separator + columns + bottom border + if not self.columns: + self.height = 3 + + +@dataclass +class Relationship: + """A foreign key relationship between two tables.""" + + source_table: str + source_column: str + target_table: str + target_column: str + + +@dataclass +class DiagramLayout: + """Computed diagram layout ready for rendering.""" + + tables: dict[str, TableBox] = field(default_factory=dict) + relationships: list[Relationship] = field(default_factory=list) + width: int = 0 + height: int = 0 + + +def build_layout( + tables: dict[str, list[tuple[str, str, bool]]], # name -> [(col_name, col_type, is_pk)] + foreign_keys: list[tuple[str, str, str, str]], # [(src_table, src_col, tgt_table, tgt_col)] + schemas: dict[str, str] | None = None, +) -> DiagramLayout: + """Build a diagram layout from table and FK data.""" + layout = DiagramLayout() + + fk_lookup: set[tuple[str, str]] = set() + for src_table, src_col, _tgt_table, _tgt_col in foreign_keys: + fk_lookup.add((src_table, src_col)) + + for table_name, columns in tables.items(): + box = TableBox( + name=table_name, + schema=(schemas or {}).get(table_name, ""), + columns=[ + (col_name, col_type, is_pk, (table_name, col_name) in fk_lookup) + for col_name, col_type, is_pk in columns + ], + ) + box.compute_dimensions() + layout.tables[table_name] = box + + for src_table, src_col, tgt_table, tgt_col in foreign_keys: + if src_table in layout.tables and tgt_table in layout.tables: + layout.relationships.append( + Relationship(src_table, src_col, tgt_table, tgt_col) + ) + + _position_tables(layout) + return layout + + +def _position_tables(layout: DiagramLayout) -> None: + """Position tables in a grid layout, placing related tables near each other.""" + if not layout.tables: + return + + adjacency: dict[str, set[str]] = {t: set() for t in layout.tables} + for rel in layout.relationships: + adjacency[rel.source_table].add(rel.target_table) + adjacency[rel.target_table].add(rel.source_table) + + ordered = _topo_order(layout.tables.keys(), adjacency) + + h_gap = 6 + v_gap = 2 + max_cols = _pick_grid_cols(len(ordered)) + + row_x = 0 + row_y = 0 + col_idx = 0 + row_max_h = 0 + + for table_name in ordered: + box = layout.tables[table_name] + box.x = row_x + box.y = row_y + row_x += box.width + h_gap + row_max_h = max(row_max_h, box.height) + col_idx += 1 + if col_idx >= max_cols: + col_idx = 0 + row_x = 0 + row_y += row_max_h + v_gap + row_max_h = 0 + + max_x = max(b.x + b.width for b in layout.tables.values()) + max_y = max(b.y + b.height for b in layout.tables.values()) + layout.width = max_x + 1 + layout.height = max_y + 1 + + +def _topo_order( + tables: object, + adjacency: dict[str, set[str]], +) -> list[str]: + """Order tables so that related tables appear near each other (BFS from most-connected).""" + table_list = list(tables) # type: ignore[arg-type] + if not table_list: + return [] + + table_list.sort(key=lambda t: len(adjacency.get(t, set())), reverse=True) + + visited: set[str] = set() + result: list[str] = [] + queue: list[str] = [] + + for start in table_list: + if start in visited: + continue + queue.append(start) + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + result.append(current) + neighbors = sorted(adjacency.get(current, set()), key=lambda t: len(adjacency.get(t, set())), reverse=True) + for neighbor in neighbors: + if neighbor not in visited: + queue.append(neighbor) + + return result + + +def _pick_grid_cols(count: int) -> int: + if count <= 2: + return count + if count <= 4: + return 2 + if count <= 9: + return 3 + return 4 + + +def render_diagram(layout: DiagramLayout) -> list[str]: + """Render the diagram layout to a list of text lines.""" + if not layout.tables: + return ["(no tables)"] + + canvas_h = layout.height + 4 + canvas_w = layout.width + 10 + canvas: list[list[str]] = [[" "] * canvas_w for _ in range(canvas_h)] + + for box in layout.tables.values(): + _draw_table_box(canvas, box) + + for ri, rel in enumerate(layout.relationships): + _draw_relationship(canvas, layout, rel, ri) + + return ["".join(row).rstrip() for row in canvas] + + +def _draw_table_box(canvas: list[list[str]], box: TableBox) -> None: + """Draw a single table box onto the canvas.""" + x, y, w = box.x, box.y, box.width + + _put(canvas, y, x, "╭") + _put(canvas, y, x + w - 1, "╮") + for i in range(1, w - 1): + _put(canvas, y, x + i, "─") + + title = f" {box.name} " + title_start = x + 1 + for i, ch in enumerate(title[:w - 2]): + _put(canvas, y, title_start + i, ch) + + sep_y = y + 1 + _put(canvas, sep_y, x, "├") + _put(canvas, sep_y, x + w - 1, "┤") + for i in range(1, w - 1): + _put(canvas, sep_y, x + i, "─") + + for ci, (col_name, col_type, is_pk, is_fk) in enumerate(box.columns): + row_y = y + 2 + ci + _put(canvas, row_y, x, "│") + _put(canvas, row_y, x + w - 1, "│") + + tag = "" + if is_pk and is_fk: + tag = "◆ " # PK+FK + elif is_pk: + tag = "● " # PK + elif is_fk: + tag = "○ " # FK + + cell = f" {tag}{col_name} {col_type}" + cell = cell[:w - 2].ljust(w - 2) + for i, ch in enumerate(cell): + _put(canvas, row_y, x + 1 + i, ch) + + bottom_y = y + box.height - 1 + _put(canvas, bottom_y, x, "╰") + _put(canvas, bottom_y, x + w - 1, "╯") + for i in range(1, w - 1): + _put(canvas, bottom_y, x + i, "─") + + +def _draw_relationship( + canvas: list[list[str]], + layout: DiagramLayout, + rel: Relationship, + index: int = 0, +) -> None: + """Draw a relationship line between two table boxes.""" + src_box = layout.tables.get(rel.source_table) + tgt_box = layout.tables.get(rel.target_table) + if not src_box or not tgt_box: + return + + src_col_idx = _find_column_index(src_box, rel.source_column) + tgt_col_idx = _find_column_index(tgt_box, rel.target_column) + if src_col_idx < 0 or tgt_col_idx < 0: + return + + src_row = src_box.y + 2 + src_col_idx + tgt_row = tgt_box.y + 2 + tgt_col_idx + offset = index % 3 + + if src_box.x + src_box.width <= tgt_box.x: + src_x = src_box.x + src_box.width + tgt_x = tgt_box.x - 1 + _draw_line_between(canvas, src_x, src_row, tgt_x, tgt_row, offset) + elif tgt_box.x + tgt_box.width <= src_box.x: + src_x = src_box.x - 1 + tgt_x = tgt_box.x + tgt_box.width + _draw_line_between(canvas, src_x, src_row, tgt_x, tgt_row, offset) + else: + sx = src_box.x + src_box.width + tx = tgt_box.x + tgt_box.width + mid_x = max(sx, tx) + 1 + offset + _draw_line_between(canvas, sx, src_row, tx, tgt_row, offset, mid_x) + + +def _draw_line_between( + canvas: list[list[str]], + src_x: int, + src_y: int, + tgt_x: int, + tgt_y: int, + offset: int = 0, + force_mid_x: int | None = None, +) -> None: + """Draw a line from (src_x, src_y) to (tgt_x, tgt_y) with a vertical segment.""" + if src_y == tgt_y: + lo, hi = min(src_x, tgt_x), max(src_x, tgt_x) + for x in range(lo, hi + 1): + _put_line(canvas, src_y, x, "─") + return + + if force_mid_x is not None: + mid_x = force_mid_x + else: + mid_x = (src_x + tgt_x) // 2 + offset + + h_char = "─" + + if src_x <= mid_x: + for x in range(src_x, mid_x): + _put_line(canvas, src_y, x, h_char) + else: + for x in range(mid_x + 1, src_x + 1): + _put_line(canvas, src_y, x, h_char) + + if src_y < tgt_y: + corner_src = "╮" if src_x <= mid_x else "╭" + corner_tgt = "╰" if tgt_x >= mid_x else "╯" + _put_line(canvas, src_y, mid_x, corner_src) + for y in range(src_y + 1, tgt_y): + _put_line(canvas, y, mid_x, "│") + _put_line(canvas, tgt_y, mid_x, corner_tgt) + else: + corner_src = "╯" if src_x <= mid_x else "╰" + corner_tgt = "╭" if tgt_x >= mid_x else "╮" + _put_line(canvas, src_y, mid_x, corner_src) + for y in range(tgt_y + 1, src_y): + _put_line(canvas, y, mid_x, "│") + _put_line(canvas, tgt_y, mid_x, corner_tgt) + + if tgt_x >= mid_x: + for x in range(mid_x + 1, tgt_x + 1): + _put_line(canvas, tgt_y, x, h_char) + else: + for x in range(tgt_x, mid_x): + _put_line(canvas, tgt_y, x, h_char) + + +def _find_column_index(box: TableBox, column_name: str) -> int: + for i, (name, _type, _pk, _fk) in enumerate(box.columns): + if name == column_name: + return i + return -1 + + +def _put(canvas: list[list[str]], row: int, col: int, ch: str) -> None: + if 0 <= row < len(canvas) and 0 <= col < len(canvas[row]): + canvas[row][col] = ch + + +_BOX_CHARS = frozenset("╭╮╰╯├┤│─●○◆") + + +def _put_line(canvas: list[list[str]], row: int, col: int, ch: str) -> None: + if 0 <= row < len(canvas) and 0 <= col < len(canvas[row]): + existing = canvas[row][col] + if existing == " ": + canvas[row][col] = ch + elif existing in _BOX_CHARS: + pass + elif existing.isalnum() or existing in ("_", "(", ")", ","): + pass diff --git a/sqlit/domains/diagram/ui/__init__.py b/sqlit/domains/diagram/ui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sqlit/domains/diagram/ui/diagram_screen.py b/sqlit/domains/diagram/ui/diagram_screen.py new file mode 100644 index 00000000..7fe2aa03 --- /dev/null +++ b/sqlit/domains/diagram/ui/diagram_screen.py @@ -0,0 +1,104 @@ +"""ER diagram modal screen.""" + +from __future__ import annotations + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import VerticalScroll +from textual.screen import ModalScreen +from textual.widgets import Static + +from sqlit.shared.ui.widgets import Dialog + + +class DiagramScreen(ModalScreen): + """Full-screen modal displaying an ER diagram.""" + + BINDINGS = [ + Binding("escape", "close", "Close", show=False, priority=True), + Binding("q", "close", "Close", show=False), + Binding("j", "scroll_down", "Down", show=False), + Binding("k", "scroll_up", "Up", show=False), + Binding("h", "scroll_left", "Left", show=False), + Binding("l", "scroll_right", "Right", show=False), + Binding("down", "scroll_down", "Down", show=False), + Binding("up", "scroll_up", "Up", show=False), + Binding("left", "scroll_left", "Left", show=False), + Binding("right", "scroll_right", "Right", show=False), + Binding("ctrl+d", "page_down", "Page Down", show=False), + Binding("ctrl+u", "page_up", "Page Up", show=False), + Binding("g", "scroll_home", "Top", show=False), + Binding("G", "scroll_end", "Bottom", show=False), + ] + + CSS = """ + DiagramScreen { + align: center middle; + background: $background 80%; + } + + #diagram-dialog { + width: 95%; + height: 90%; + max-width: 200; + } + + #diagram-scroll { + height: 1fr; + overflow: auto; + } + + #diagram-content { + width: auto; + min-width: 100%; + color: $text; + } + """ + + def __init__(self, diagram_text: str, title: str = "ER Diagram") -> None: + super().__init__() + self._diagram_text = diagram_text + self._title = title + + def compose(self) -> ComposeResult: + shortcuts = [("Close", "q"), ("Scroll", "hjkl")] + with Dialog(id="diagram-dialog", title=self._title, shortcuts=shortcuts): + with VerticalScroll(id="diagram-scroll"): + yield Static(self._diagram_text, id="diagram-content") + + def on_mount(self) -> None: + self.query_one("#diagram-scroll").focus() + + def action_close(self) -> None: + self.dismiss() + + def action_scroll_down(self) -> None: + self.query_one("#diagram-scroll").scroll_down() + + def action_scroll_up(self) -> None: + self.query_one("#diagram-scroll").scroll_up() + + def action_scroll_left(self) -> None: + self.query_one("#diagram-scroll").scroll_left() + + def action_scroll_right(self) -> None: + self.query_one("#diagram-scroll").scroll_right() + + def action_page_down(self) -> None: + scroll = self.query_one("#diagram-scroll") + scroll.scroll_page_down() + + def action_page_up(self) -> None: + scroll = self.query_one("#diagram-scroll") + scroll.scroll_page_up() + + def action_scroll_home(self) -> None: + self.query_one("#diagram-scroll").scroll_home() + + def action_scroll_end(self) -> None: + self.query_one("#diagram-scroll").scroll_end() + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/domains/diagram/ui/table_picker.py b/sqlit/domains/diagram/ui/table_picker.py new file mode 100644 index 00000000..fcf8f6d5 --- /dev/null +++ b/sqlit/domains/diagram/ui/table_picker.py @@ -0,0 +1,176 @@ +"""Table picker modal for ER diagram generation.""" + +from __future__ import annotations + +from typing import Any + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Vertical +from textual.screen import ModalScreen +from textual.widgets import Input, OptionList, Static +from textual.widgets.option_list import Option + +from sqlit.shared.ui.widgets import Dialog + + +class DiagramTablePicker(ModalScreen[list[str] | None]): + """Modal to select tables for ER diagram generation.""" + + BINDINGS = [ + Binding("escape", "cancel", "Cancel", show=False, priority=True), + ] + + CSS = """ + DiagramTablePicker { + align: center middle; + background: transparent; + } + + #diagram-picker-dialog { + width: 60; + max-width: 80%; + max-height: 80%; + } + + #diagram-picker-body { + height: auto; + max-height: 100%; + } + + #diagram-picker-input { + margin-bottom: 1; + } + + #diagram-picker-list { + height: auto; + max-height: 30; + border: none; + } + + #diagram-picker-list > .option-list--option { + padding: 0; + } + + #diagram-picker-status { + height: 1; + margin-top: 1; + color: $text-muted; + text-align: right; + } + """ + + def __init__(self, tables: list[str]) -> None: + super().__init__() + self._tables = sorted(tables) + self._selected: set[str] = set() + self._filter_text = "" + + def compose(self) -> ComposeResult: + shortcuts = [("Toggle", ""), ("Toggle All", "a"), ("OK", "")] + with Dialog(id="diagram-picker-dialog", title="Select Tables for Diagram", shortcuts=shortcuts): + with Vertical(id="diagram-picker-body"): + yield Input(placeholder="Filter tables...", id="diagram-picker-input") + yield OptionList(*self._build_options(), id="diagram-picker-list") + yield Static(self._status_text(), id="diagram-picker-status") + + def _build_options(self) -> list[Option]: + options: list[Option] = [] + for table in self._tables: + if self._filter_text and self._filter_text.lower() not in table.lower(): + continue + check = "[*]" if table in self._selected else "[ ]" + options.append(Option(f"{check} {table}", id=table)) + return options + + def _status_text(self) -> str: + return f"{len(self._selected)}/{len(self._tables)} selected" + + def _rebuild(self) -> None: + option_list = self.query_one("#diagram-picker-list", OptionList) + highlighted = option_list.highlighted + option_list.clear_options() + for opt in self._build_options(): + option_list.add_option(opt) + if highlighted is not None and highlighted < option_list.option_count: + option_list.highlighted = highlighted + self.query_one("#diagram-picker-status", Static).update(self._status_text()) + + def on_mount(self) -> None: + self.query_one("#diagram-picker-list", OptionList).focus() + + def on_input_changed(self, event: Input.Changed) -> None: + if event.input.id == "diagram-picker-input": + self._filter_text = event.value + self._rebuild() + + def on_input_submitted(self, event: Input.Submitted) -> None: + if event.input.id == "diagram-picker-input": + option_list = self.query_one("#diagram-picker-list", OptionList) + option_list.focus() + + def _toggle_highlighted(self) -> None: + option_list = self.query_one("#diagram-picker-list", OptionList) + if option_list.highlighted is not None: + opt = option_list.get_option_at_index(option_list.highlighted) + table_name = opt.id or "" + if table_name: + if table_name in self._selected: + self._selected.discard(table_name) + else: + self._selected.add(table_name) + self._rebuild() + + def on_key(self, event: Any) -> None: + key = getattr(event, "key", "") + focused = self.app.focused + input_widget = self.query_one("#diagram-picker-input", Input) + option_list = self.query_one("#diagram-picker-list", OptionList) + in_input = focused is input_widget + + if key == "enter" and not in_input: + self._confirm() + event.prevent_default() + event.stop() + elif key == "tab" and in_input: + option_list.focus() + event.prevent_default() + event.stop() + elif key == "slash" and not in_input: + input_widget.focus() + event.prevent_default() + event.stop() + elif key == "space" and not in_input: + self._toggle_highlighted() + event.prevent_default() + event.stop() + elif key == "j" and not in_input: + option_list.action_cursor_down() + event.prevent_default() + event.stop() + elif key == "k" and not in_input: + option_list.action_cursor_up() + event.prevent_default() + event.stop() + elif key == "a" and not in_input: + if len(self._selected) == len(self._tables): + self._selected.clear() + else: + self._selected = set(self._tables) + self._rebuild() + event.prevent_default() + event.stop() + + def _confirm(self) -> None: + if self._selected: + self.dismiss(sorted(self._selected)) + else: + self.dismiss(None) + + def action_cancel(self) -> None: + self.dismiss(None) + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/domains/explorer/app/schema_service.py b/sqlit/domains/explorer/app/schema_service.py index 1252b774..a588b5f1 100644 --- a/sqlit/domains/explorer/app/schema_service.py +++ b/sqlit/domains/explorer/app/schema_service.py @@ -7,8 +7,9 @@ from typing import Any, TypeVar from sqlit.domains.connections.app.session import ConnectionSession -from sqlit.domains.connections.providers.adapters.base import ColumnInfo +from sqlit.domains.connections.providers.adapters.base import ColumnInfo, ForeignKeyInfo from sqlit.domains.connections.providers.model import ( + ForeignKeyInspector, IndexInspector, ProcedureInspector, SequenceInspector, @@ -206,3 +207,13 @@ def get_sequence_definition(self, database: str | None, name: str) -> dict[str, lambda: inspector.get_sequence_definition(self.session.connection, name, db_arg), database, ) + + def list_foreign_keys(self, database: str | None) -> list[ForeignKeyInfo]: + inspector = self.session.provider.schema_inspector + if not isinstance(inspector, ForeignKeyInspector): + return [] + db_arg = self._resolve_db_arg(database) + return self._run_with_retry( + lambda: inspector.get_foreign_keys(self.session.connection, db_arg), + database, + ) diff --git a/sqlit/domains/explorer/state/tree_on_database.py b/sqlit/domains/explorer/state/tree_on_database.py index 8a2336eb..19868227 100644 --- a/sqlit/domains/explorer/state/tree_on_database.py +++ b/sqlit/domains/explorer/state/tree_on_database.py @@ -12,7 +12,7 @@ class TreeOnDatabaseState(State): help_category = "Explorer" def _setup_actions(self) -> None: - pass # Expanding a database now sets it as active automatically + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -20,6 +20,14 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], left.append(DisplayBinding(key="enter", label="Use database", action="toggle_node")) seen.add("toggle_node") + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/state/tree_on_folder.py b/sqlit/domains/explorer/state/tree_on_folder.py index 68c19394..e56f222c 100644 --- a/sqlit/domains/explorer/state/tree_on_folder.py +++ b/sqlit/domains/explorer/state/tree_on_folder.py @@ -25,6 +25,7 @@ def is_connection_folder(app: InputContext) -> bool: label="Delete", help="Delete connection folder", ) + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -32,6 +33,17 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], left.append(DisplayBinding(key="enter", label="Expand", action="toggle_node")) seen.add("toggle_node") + + if app.tree_node_kind == "folder": + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") + left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/state/tree_on_table.py b/sqlit/domains/explorer/state/tree_on_table.py index ee092789..76fe9113 100644 --- a/sqlit/domains/explorer/state/tree_on_table.py +++ b/sqlit/domains/explorer/state/tree_on_table.py @@ -13,6 +13,7 @@ class TreeOnTableState(State): def _setup_actions(self) -> None: self.allows("select_table", label="Select TOP 100", help="Select TOP 100 (table/view)") + self.allows("show_diagram", label="ER Diagram", help="Show ER diagram") def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] @@ -28,6 +29,14 @@ def get_display_bindings(self, app: InputContext) -> tuple[list[DisplayBinding], ) ) seen.add("select_table") + left.append( + DisplayBinding( + key=resolve_display_key("show_diagram") or "S", + label="Diagram", + action="show_diagram", + ) + ) + seen.add("show_diagram") left.append( DisplayBinding( key=resolve_display_key("refresh_tree") or "f", diff --git a/sqlit/domains/explorer/ui/mixins/tree.py b/sqlit/domains/explorer/ui/mixins/tree.py index 26cf34b3..223f55ef 100644 --- a/sqlit/domains/explorer/ui/mixins/tree.py +++ b/sqlit/domains/explorer/ui/mixins/tree.py @@ -357,6 +357,111 @@ def action_select_table(self: TreeMixinHost) -> None: tree_object_info.show_sequence_info(self, data) return + def action_show_diagram(self: TreeMixinHost) -> None: + """Show ER diagram for selected table (S key) — single table only.""" + if not self.current_provider or not self._session: + return + + node = self.object_tree.cursor_node + if not node or not node.data: + return + + data = node.data + kind = self._get_node_kind(node) + + if kind in ("table", "view"): + self._show_diagram_for_tables(data.database, [data.name]) + elif kind == "database": + self._show_diagram_for_database(data.name) + elif kind == "folder" and data.folder_type == "tables": + self._show_diagram_for_database(data.database) + else: + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + self._show_diagram_for_database(db) + + def action_show_diagram_picker(self: TreeMixinHost) -> None: + """Show table picker then ER diagram (space+d leader) — always opens picker.""" + if not self.current_provider or not self._session: + return + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + self._show_diagram_for_database(db) + + def _show_diagram_for_tables(self: TreeMixinHost, database: str | None, table_names: list[str]) -> None: + """Show diagram for specific tables only (no auto-expansion).""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + self.notify("Not connected", severity="error") + return + + async def work() -> None: + from sqlit.domains.diagram.app.diagram_service import build_diagram_text, fetch_diagram_data + + cols, fks = await asyncio.to_thread(fetch_diagram_data, schema_service, database, table_names) + text = build_diagram_text(table_names, cols, fks) + label = ", ".join(table_names[:3]) + if len(table_names) > 3: + label += f" +{len(table_names) - 3}" + + from sqlit.domains.diagram.ui.diagram_screen import DiagramScreen + + self.app.push_screen(DiagramScreen(text, title=f"ER Diagram: {label}")) + + self.run_worker(work(), name="diagram-load", exclusive=False) + + def _show_diagram_for_database(self: TreeMixinHost, database: str | None) -> None: + """Show table picker then diagram for a database.""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + self.notify("Not connected", severity="error") + return + + async def work() -> None: + tables = await asyncio.to_thread(schema_service.list_folder_items, "tables", database) + table_names = [name for _, _schema, name in tables] + + if not table_names: + self.notify("No tables found", severity="warning") + return + + from sqlit.domains.diagram.ui.table_picker import DiagramTablePicker + + def on_selected(selected: list[str] | None) -> None: + if selected: + self._show_diagram_for_selected(database, selected) + + self.app.push_screen(DiagramTablePicker(table_names), on_selected) + + self.run_worker(work(), name="diagram-tables", exclusive=False) + + def _show_diagram_for_selected(self: TreeMixinHost, database: str | None, table_names: list[str]) -> None: + """Generate and show diagram for selected tables.""" + import asyncio + + schema_service = self._get_schema_service() + if not schema_service: + return + + async def work() -> None: + from sqlit.domains.diagram.app.diagram_service import build_diagram_text, fetch_diagram_data + + cols, fks = await asyncio.to_thread(fetch_diagram_data, schema_service, database, table_names) + text = build_diagram_text(table_names, cols, fks) + title = f"ER Diagram ({len(table_names)} tables)" + + from sqlit.domains.diagram.ui.diagram_screen import DiagramScreen + + self.app.push_screen(DiagramScreen(text, title=title)) + + self.run_worker(work(), name="diagram-render", exclusive=False) + def action_use_database(self: TreeMixinHost) -> None: """Toggle the selected database as the default for the current connection.""" node = self.object_tree.cursor_node diff --git a/sqlit/domains/shell/ui/mixins/ui_leader.py b/sqlit/domains/shell/ui/mixins/ui_leader.py index 55a225bc..3a04c120 100644 --- a/sqlit/domains/shell/ui/mixins/ui_leader.py +++ b/sqlit/domains/shell/ui/mixins/ui_leader.py @@ -118,5 +118,8 @@ def action_leader_telescope(self: UINavigationMixinHost) -> None: def action_leader_telescope_filter(self: UINavigationMixinHost) -> None: self._execute_leader_command("telescope_filter") + def action_leader_show_diagram_picker(self: UINavigationMixinHost) -> None: + self._execute_leader_command("show_diagram_picker") + def action_leader_quit(self: UINavigationMixinHost) -> None: self._execute_leader_command("quit") diff --git a/tests/fixtures/duckdb.py b/tests/fixtures/duckdb.py index 9a658d53..3327092b 100644 --- a/tests/fixtures/duckdb.py +++ b/tests/fixtures/duckdb.py @@ -48,6 +48,15 @@ def duckdb_db(duckdb_db_path: Path) -> Path: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + conn.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests conn.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/firebird.py b/tests/fixtures/firebird.py index 18812af9..50676ae8 100644 --- a/tests/fixtures/firebird.py +++ b/tests/fixtures/firebird.py @@ -86,6 +86,15 @@ def firebird_db(firebird_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + RECREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") cursor.execute(""" @@ -135,6 +144,7 @@ def firebird_db(firebird_server_ready: bool) -> str: try: for cleanup in [ "DROP VIEW test_user_emails", + "DROP TABLE test_orders", "DROP TABLE test_users", "DROP TABLE test_products", "DROP TRIGGER trg_test_users_audit", diff --git a/tests/fixtures/mariadb.py b/tests/fixtures/mariadb.py index 7046abbd..73419160 100644 --- a/tests/fixtures/mariadb.py +++ b/tests/fixtures/mariadb.py @@ -54,6 +54,7 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -81,6 +82,15 @@ def mariadb_db(mariadb_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -130,6 +140,7 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) cursor = conn.cursor() cursor.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/mssql.py b/tests/fixtures/mssql.py index 09f8d6de..0b4395a1 100644 --- a/tests/fixtures/mssql.py +++ b/tests/fixtures/mssql.py @@ -104,6 +104,15 @@ def mssql_db(mssql_server_ready: bool) -> str: END """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/mysql.py b/tests/fixtures/mysql.py index 8ee19774..93723d8a 100644 --- a/tests/fixtures/mysql.py +++ b/tests/fixtures/mysql.py @@ -54,6 +54,7 @@ def mysql_db(mysql_server_ready: bool) -> str: ) cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -80,6 +81,15 @@ def mysql_db(mysql_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -126,6 +136,7 @@ def mysql_db(mysql_server_ready: bool) -> str: ) cursor = conn.cursor() cursor.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") + cursor.execute("DROP TABLE IF EXISTS test_orders") cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/oracle.py b/tests/fixtures/oracle.py index 312e12de..4584ddea 100644 --- a/tests/fixtures/oracle.py +++ b/tests/fixtures/oracle.py @@ -86,6 +86,15 @@ def oracle_db(oracle_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id NUMBER PRIMARY KEY, + user_id NUMBER NOT NULL, + amount NUMBER(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") cursor.execute(""" @@ -141,6 +150,7 @@ def oracle_db(oracle_server_ready: bool) -> str: for stmt in [ "DROP VIEW test_user_emails", + "DROP TABLE test_orders CASCADE CONSTRAINTS", "DROP TABLE test_users CASCADE CONSTRAINTS", "DROP TABLE test_products CASCADE CONSTRAINTS", "DROP SEQUENCE test_sequence", diff --git a/tests/fixtures/postgres.py b/tests/fixtures/postgres.py index 0fa9404a..f2f09d3e 100644 --- a/tests/fixtures/postgres.py +++ b/tests/fixtures/postgres.py @@ -54,6 +54,7 @@ def postgres_db(postgres_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders CASCADE") cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") @@ -80,6 +81,15 @@ def postgres_db(postgres_server_ready: bool) -> str: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount DECIMAL(10,2) NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") @@ -132,6 +142,7 @@ def postgres_db(postgres_server_ready: bool) -> str: ) conn.autocommit = True cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS test_orders CASCADE") cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") diff --git a/tests/fixtures/sqlite.py b/tests/fixtures/sqlite.py index ac8f6a23..6b833e67 100644 --- a/tests/fixtures/sqlite.py +++ b/tests/fixtures/sqlite.py @@ -45,6 +45,15 @@ def sqlite_db(sqlite_db_path: Path) -> Path: SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) + cursor.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount REAL NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create test index for integration tests cursor.execute("CREATE INDEX idx_test_users_email ON test_users(email)") diff --git a/tests/fixtures/turso.py b/tests/fixtures/turso.py index be369088..47bb5428 100644 --- a/tests/fixtures/turso.py +++ b/tests/fixtures/turso.py @@ -86,6 +86,7 @@ def _setup_turso_test_tables(client) -> None: client.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") client.execute("DROP INDEX IF EXISTS idx_test_users_email") client.execute("DROP VIEW IF EXISTS test_user_emails") + client.execute("DROP TABLE IF EXISTS test_orders") client.execute("DROP TABLE IF EXISTS test_users") client.execute("DROP TABLE IF EXISTS test_products") @@ -107,6 +108,15 @@ def _setup_turso_test_tables(client) -> None: ) """) + client.execute(""" + CREATE TABLE test_orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + amount REAL NOT NULL, + CONSTRAINT fk_orders_user_id FOREIGN KEY (user_id) REFERENCES test_users(id) + ) + """) + # Create view client.execute(""" CREATE VIEW test_user_emails AS @@ -149,6 +159,7 @@ def _cleanup_turso_test_tables(client) -> None: client.execute("DROP TRIGGER IF EXISTS trg_test_users_audit") client.execute("DROP INDEX IF EXISTS idx_test_users_email") client.execute("DROP VIEW IF EXISTS test_user_emails") + client.execute("DROP TABLE IF EXISTS test_orders") client.execute("DROP TABLE IF EXISTS test_users") client.execute("DROP TABLE IF EXISTS test_products") client.commit() diff --git a/tests/test_database_base.py b/tests/test_database_base.py index 9030386d..40bd66c7 100644 --- a/tests/test_database_base.py +++ b/tests/test_database_base.py @@ -440,6 +440,48 @@ def test_get_indexes(self, request): f"Index should be on test_users table, got {test_index.table_name}" ) + def test_get_foreign_keys(self, request): + """Test that adapter correctly retrieves foreign keys. + + This tests that get_foreign_keys returns ForeignKeyInfo objects for + foreign keys created on the test tables. The test fixture should create + a test_orders table with a FK from user_id to test_users(id). + """ + from sqlit.domains.connections.app.session import ConnectionSession + from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo + from sqlit.domains.connections.providers.registry import get_adapter + from sqlit.domains.connections.store.connections import load_connections + + connection_name = request.getfixturevalue(self.config.connection_fixture) + connections = load_connections() + config = next((c for c in connections if c.name == connection_name), None) + assert config is not None, f"Connection {connection_name} not found" + + with ConnectionSession.create(config, get_adapter) as session: + foreign_keys = session.adapter.get_foreign_keys( + session.connection, + database=config.database if session.adapter.supports_multiple_databases else None, + ) + + assert isinstance(foreign_keys, list), "get_foreign_keys should return a list" + for fk in foreign_keys: + assert isinstance(fk, ForeignKeyInfo), f"Expected ForeignKeyInfo, got {type(fk)}" + + test_fk = next( + (fk for fk in foreign_keys if "user_id" in fk.source_column.lower()), + None, + ) + assert test_fk is not None, ( + f"FK on user_id not found. " + f"Found FKs: {[(fk.source_table, fk.source_column) for fk in foreign_keys]}" + ) + assert "test_orders" in test_fk.source_table.lower(), ( + f"FK source should be test_orders, got {test_fk.source_table}" + ) + assert "test_users" in test_fk.target_table.lower(), ( + f"FK target should be test_users, got {test_fk.target_table}" + ) + def test_get_triggers(self, request): """Test that adapter correctly retrieves triggers. diff --git a/tests/unit/test_diagram_layout.py b/tests/unit/test_diagram_layout.py new file mode 100644 index 00000000..bcc89909 --- /dev/null +++ b/tests/unit/test_diagram_layout.py @@ -0,0 +1,110 @@ +"""Tests for ER diagram layout engine.""" + +from __future__ import annotations + +from sqlit.domains.diagram.app.layout import ( + DiagramLayout, + TableBox, + build_layout, + render_diagram, +) + + +class TestTableBox: + def test_compute_dimensions_basic(self) -> None: + box = TableBox( + name="users", + schema="", + columns=[("id", "int", True, False), ("name", "varchar", False, False)], + ) + box.compute_dimensions() + assert box.width > 0 + assert box.height == 5 # top + title_sep + 2 cols + bottom + + def test_compute_dimensions_empty_table(self) -> None: + box = TableBox(name="empty", schema="", columns=[]) + box.compute_dimensions() + assert box.width > 0 + assert box.height == 3 # top + title_sep + bottom + + +class TestBuildLayout: + def test_single_table(self) -> None: + tables = {"users": [("id", "int", True), ("name", "varchar", False)]} + layout = build_layout(tables, []) + assert "users" in layout.tables + assert layout.tables["users"].x == 0 + assert layout.tables["users"].y == 0 + + def test_two_related_tables(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True), ("name", "varchar", False)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + assert len(layout.tables) == 2 + assert len(layout.relationships) == 1 + + def test_fk_marks_column_as_fk(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + order_cols = layout.tables["orders"].columns + assert order_cols[1][3] is True # user_id is_fk + + def test_ignores_fks_for_missing_tables(self) -> None: + tables = {"users": [("id", "int", True)]} + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + assert len(layout.relationships) == 0 + + def test_grid_layout_positions(self) -> None: + tables = {f"t{i}": [("id", "int", True)] for i in range(6)} + layout = build_layout(tables, []) + positions = {name: (box.x, box.y) for name, box in layout.tables.items()} + # All positions should be unique + assert len(set(positions.values())) == 6 + + +class TestRenderDiagram: + def test_empty_tables(self) -> None: + layout = DiagramLayout() + lines = render_diagram(layout) + assert lines == ["(no tables)"] + + def test_renders_table_box(self) -> None: + tables = {"users": [("id", "int", True), ("name", "varchar", False)]} + layout = build_layout(tables, []) + lines = render_diagram(layout) + text = "\n".join(lines) + assert "users" in text + assert "●" in text # PK indicator + assert "id" in text + assert "name" in text + + def test_renders_fk_indicator(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + lines = render_diagram(layout) + text = "\n".join(lines) + assert "○" in text # FK indicator + + def test_renders_relationship_lines(self) -> None: + tables = { + "orders": [("id", "int", True), ("user_id", "int", False)], + "users": [("id", "int", True)], + } + fks = [("orders", "user_id", "users", "id")] + layout = build_layout(tables, fks) + lines = render_diagram(layout) + text = "\n".join(lines) + # Should contain relationship line characters + assert any(ch in text for ch in "─│╮╯╰╭◀") diff --git a/tests/unit/test_foreign_keys.py b/tests/unit/test_foreign_keys.py new file mode 100644 index 00000000..9eb89ef1 --- /dev/null +++ b/tests/unit/test_foreign_keys.py @@ -0,0 +1,607 @@ +"""Unit tests for get_foreign_keys() across all database adapters.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from sqlit.domains.connections.providers.adapters.base import ForeignKeyInfo + + +def _assert_fk_list(result: list, expected_count: int = 1) -> None: + assert isinstance(result, list) + assert len(result) == expected_count + for fk in result: + assert isinstance(fk, ForeignKeyInfo) + assert fk.constraint_name + assert fk.source_table + assert fk.source_column + assert fk.target_table + assert fk.target_column + + +# -- PostgreSQL ---------------------------------------------------------------- + +class TestPostgreSQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.postgresql.adapter import PostgreSQLAdapter + return PostgreSQLAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "public", "orders", "user_id", "public", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + assert result[0].source_schema == "public" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_contains_foreign_key_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "FOREIGN KEY" in sql + assert "table_constraints" in sql + + +# -- SQLite -------------------------------------------------------------------- + +class TestSQLiteForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.sqlite.adapter import SQLiteAdapter + return SQLiteAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + # get_tables calls sqlite_master -> returns raw name rows + cursor.fetchall.side_effect = [ + [("orders",)], # get_tables -> sqlite_master + [(0, 0, "users", "user_id", "id", "", "", "")], # PRAGMA foreign_key_list + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].source_column == "user_id" + assert result[0].target_table == "users" + assert result[0].target_column == "id" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.side_effect = [ + [("orders",)], # get_tables + [], # PRAGMA foreign_key_list returns nothing + ] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- DuckDB -------------------------------------------------------------------- + +class TestDuckDBForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.duckdb.adapter import DuckDBAdapter + return DuckDBAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [ + ("fk_orders_user_id", "main", "orders", "user_id", "main", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].source_schema == "main" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_contains_foreign_key_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + result_mock = MagicMock() + mock_conn.execute.return_value = result_mock + result_mock.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = mock_conn.execute.call_args[0][0] + assert "FOREIGN KEY" in sql + assert "referential_constraints" in sql + + +# -- MySQL --------------------------------------------------------------------- + +class TestMySQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mysql.adapter import MySQLAdapter + return MySQLAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "orders", "user_id", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_with_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="mydb") + sql = cursor.execute.call_args[0][0] + assert "table_schema = %s" in sql + + def test_without_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "DATABASE()" in sql + + +# -- MariaDB ------------------------------------------------------------------- + +class TestMariaDBForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mariadb.adapter import MariaDBAdapter + return MariaDBAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "orders", "user_id", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + + def test_with_database_uses_question_mark(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="mydb") + sql = cursor.execute.call_args[0][0] + assert "table_schema = ?" in sql + assert "%s" not in sql + + +# -- SQL Server ---------------------------------------------------------------- + +class TestMSSQLForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter + return SQLServerAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_orders_users", "dbo", "orders", "user_id", "dbo", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "dbo" + assert result[0].target_schema == "dbo" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_uses_sys_foreign_keys(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "sys.foreign_keys" in sql + + +# -- Oracle -------------------------------------------------------------------- + +class TestOracleForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.oracle.adapter import OracleAdapter + return OracleAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "ORDERS", "USER_ID", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "ORDERS" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_cursor_is_closed(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + cursor.close.assert_called_once() + + +# -- Firebird ------------------------------------------------------------------ + +class TestFirebirdForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.firebird.adapter import FirebirdAdapter + return FirebirdAdapter() + + def test_returns_fk_info_with_stripped_names(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS ", "ORDERS ", "USER_ID ", "USERS ", "ID "), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].constraint_name == "FK_ORDERS" + assert result[0].source_table == "ORDERS" + assert result[0].source_column == "USER_ID" + assert result[0].target_table == "USERS" + assert result[0].target_column == "ID" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- Turso --------------------------------------------------------------------- + +class TestTursoForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.turso.adapter import TursoAdapter + return TursoAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + # get_tables uses conn.execute().fetchall() + tables_result = MagicMock() + fk_result = MagicMock() + mock_conn.execute.side_effect = [tables_result, fk_result] + tables_result.fetchall.return_value = [("orders",)] + fk_result.fetchall.return_value = [ + (0, 0, "users", "user_id", "id", "", "", ""), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_table == "orders" + assert result[0].target_table == "users" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + tables_result = MagicMock() + fk_result = MagicMock() + mock_conn.execute.side_effect = [tables_result, fk_result] + tables_result.fetchall.return_value = [("orders",)] + fk_result.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- Snowflake ----------------------------------------------------------------- + +class TestSnowflakeForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.snowflake.adapter import SnowflakeAdapter + return SnowflakeAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "PUBLIC", "ORDERS", "USER_ID", "PUBLIC", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "PUBLIC" + + def test_with_database_uses_prefix(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="MYDB") + sql = cursor.execute.call_args[0][0] + assert '"MYDB".' in sql + + +# -- Spanner ------------------------------------------------------------------- + +class TestSpannerForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.spanner.adapter import SpannerAdapter + return SpannerAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.autocommit = False + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_Orders_Users", "Orders", "UserId", "Users", "Id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.autocommit = False + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- HANA ---------------------------------------------------------------------- + +class TestHanaForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.hana.adapter import HanaAdapter + return HanaAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS_USERS", "MYSCHEMA", "ORDERS", "USER_ID", "MYSCHEMA", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYSCHEMA" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + +# -- DB2 ----------------------------------------------------------------------- + +class TestDb2ForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.db2.adapter import Db2Adapter + return Db2Adapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS", "MYSCHEMA", "ORDERS", "USER_ID", "MYSCHEMA", "USERS", "ID"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYSCHEMA" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_uses_syscat(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "syscat.references" in sql + + +# -- Teradata ------------------------------------------------------------------ + +class TestTeradataForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.teradata.adapter import TeradataAdapter + return TeradataAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("FK_ORDERS", "ORDERS", "USER_ID", "USERS", "ID", "MYDB", "MYDB"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "MYDB" + + def test_with_database(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="MYDB") + sql = cursor.execute.call_args[0][0] + assert "ChildDB = ?" in sql + + def test_sql_uses_lock_row(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "lock row for access" in sql + + +# -- Redshift ------------------------------------------------------------------ + +class TestRedshiftForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.redshift.adapter import RedshiftAdapter + return RedshiftAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [ + ("fk_orders_user_id", "public", "orders", "user_id", "public", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "public" + + def test_empty_result(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + result = adapter.get_foreign_keys(mock_conn) + assert result == [] + + def test_sql_excludes_system_schemas(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = cursor.execute.call_args[0][0] + assert "pg_internal" in sql + + +# -- MotherDuck ---------------------------------------------------------------- + +class TestMotherDuckForeignKeys: + def _make_adapter(self): + from sqlit.domains.connections.providers.motherduck.adapter import MotherDuckAdapter + return MotherDuckAdapter() + + def test_returns_fk_info(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [ + ("fk_orders_user_id", "main", "orders", "user_id", "main", "users", "id"), + ] + result = adapter.get_foreign_keys(mock_conn) + _assert_fk_list(result, 1) + assert result[0].source_schema == "main" + + def test_with_database_filters_by_catalog(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn, database="my_db") + sql = mock_conn.execute.call_args[0][0] + assert "table_catalog" in sql + assert mock_conn.execute.call_args[0][1] == ("my_db",) + + def test_without_database_no_catalog_filter(self): + adapter = self._make_adapter() + mock_conn = MagicMock() + mock_conn.execute.return_value.fetchall.return_value = [] + adapter.get_foreign_keys(mock_conn) + sql = mock_conn.execute.call_args[0][0] + assert "table_catalog" not in sql