#!/usr/bin/env python3 """ PostgreSQL MCP Server - Database analysis and querying tools. Provides safe, read-only database access for analysis. Environment variable PG_CONNECTION_STRING required. """ import asyncio import os import json from urllib.parse import urlparse from mcp.server.models import InitializationOptions from mcp.server import NotificationOptions, Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent import mcp.types as types # Database imports import psycopg2 from psycopg2 import sql from psycopg2.extras import RealDictCursor server = Server("postgres-analyzer") # Track connection info for error messages _connection_info = None def get_connection(): """Get database connection from environment.""" global _connection_info conn_str = os.environ.get("PG_CONNECTION_STRING") if not conn_str: raise ValueError("PG_CONNECTION_STRING environment variable not set") # Parse for safe logging (hide password) parsed = urlparse(conn_str) _connection_info = f"{parsed.scheme}://{parsed.username}@***:{parsed.port}{parsed.path}" return psycopg2.connect(conn_str) def check_read_only(query: str) -> bool: """Check if query is read-only (no modifications).""" forbidden = ['insert', 'update', 'delete', 'drop', 'create', 'alter', 'truncate', 'grant', 'revoke'] query_lower = query.lower() return not any(keyword in query_lower for keyword in forbidden) @server.list_tools() async def handle_list_tools() -> list[Tool]: """List available database analysis tools.""" return [ Tool( name="get_schema", description="Get database schema - lists all tables and their columns", inputSchema={ "type": "object", "properties": { "table_name": { "type": "string", "description": "Optional: specific table name. If omitted, returns all tables." } } }, ), Tool( name="execute_query", description="Execute a read-only SQL query and return results (max 1000 rows)", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to execute" }, "limit": { "type": "integer", "description": "Maximum rows to return (default 100, max 1000)", "default": 100 } }, "required": ["query"] }, ), Tool( name="get_table_stats", description="Get statistics for a table: row count, column stats, sample data", inputSchema={ "type": "object", "properties": { "table_name": { "type": "string", "description": "Table name to analyze" }, "sample_size": { "type": "integer", "description": "Number of sample rows (default 5)", "default": 5 } }, "required": ["table_name"] }, ), Tool( name="analyze_column", description="Analyze a specific column: distribution, nulls, unique values", inputSchema={ "type": "object", "properties": { "table_name": { "type": "string", "description": "Table name" }, "column_name": { "type": "string", "description": "Column name to analyze" } }, "required": ["table_name", "column_name"] }, ), ] @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list[types.TextContent]: """Execute database tools.""" if arguments is None: arguments = {} try: conn = get_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) if name == "get_schema": table_name = arguments.get("table_name") return _get_schema(cursor, table_name) elif name == "execute_query": query = arguments.get("query", "") limit = min(arguments.get("limit", 100), 1000) return _execute_query(cursor, query, limit) elif name == "get_table_stats": table_name = arguments.get("table_name", "") sample_size = arguments.get("sample_size", 5) return _get_table_stats(cursor, table_name, sample_size) elif name == "analyze_column": table_name = arguments.get("table_name", "") column_name = arguments.get("column_name", "") return _analyze_column(cursor, table_name, column_name) else: return [TextContent(type="text", text=f"Unknown tool: {name}")] except ValueError as e: return [TextContent(type="text", text=f"Configuration error: {str(e)}")] except psycopg2.Error as e: return [TextContent(type="text", text=f"Database error: {str(e)}")] except Exception as e: return [TextContent(type="text", text=f"Error: {str(e)}")] finally: if 'cursor' in locals(): cursor.close() if 'conn' in locals(): conn.close() def _get_schema(cursor, table_name: str | None) -> list[TextContent]: """Get database schema information.""" if table_name: # Get specific table schema cursor.execute(""" SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_name = %s ORDER BY ordinal_position """, (table_name,)) columns = cursor.fetchall() if not columns: return [TextContent(type="text", text=f"Table '{table_name}' not found.")] result = f"Table: {table_name}\n" result += "-" * 60 + "\n" for col in columns: nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL" default = f" DEFAULT {col['column_default']}" if col['column_default'] else "" result += f" {col['column_name']}: {col['data_type']} {nullable}{default}\n" return [TextContent(type="text", text=result)] else: # Get all tables cursor.execute(""" SELECT table_name, (SELECT COUNT(*) FROM information_schema.columns WHERE table_name = t.table_name) as column_count FROM information_schema.tables t WHERE table_schema = 'public' ORDER BY table_name """) tables = cursor.fetchall() if not tables: return [TextContent(type="text", text="No tables found in public schema.")] result = "Database Schema\n" result += "=" * 60 + "\n\n" for table in tables: result += f"šŸ“‹ {table['table_name']} ({table['column_count']} columns)\n" # Get columns for this table cursor.execute(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s ORDER BY ordinal_position """, (table['table_name'],)) columns = cursor.fetchall() for col in columns: result += f" • {col['column_name']}: {col['data_type']}\n" result += "\n" return [TextContent(type="text", text=result)] def _execute_query(cursor, query: str, limit: int) -> list[TextContent]: """Execute a read-only query.""" if not check_read_only(query): return [TextContent(type="text", text="Error: Only SELECT queries are allowed for safety.")] # Add limit if not present if "limit" not in query.lower(): query = f"{query} LIMIT {limit}" cursor.execute(query) rows = cursor.fetchall() if not rows: return [TextContent(type="text", text="Query returned no results.")] # Format as markdown table columns = list(rows[0].keys()) result = "| " + " | ".join(columns) + " |\n" result += "| " + " | ".join(["---"] * len(columns)) + " |\n" for row in rows[:limit]: values = [str(row.get(col, "NULL"))[:50] for col in columns] result += "| " + " | ".join(values) + " |\n" if len(rows) > limit: result += f"\n... and {len(rows) - limit} more rows" return [TextContent(type="text", text=result)] def _get_table_stats(cursor, table_name: str, sample_size: int) -> list[TextContent]: """Get comprehensive table statistics.""" # Check if table exists cursor.execute(""" SELECT COUNT(*) as count FROM information_schema.tables WHERE table_name = %s AND table_schema = 'public' """, (table_name,)) if cursor.fetchone()['count'] == 0: return [TextContent(type="text", text=f"Table '{table_name}' not found.")] result = f"šŸ“Š Table Analysis: {table_name}\n" result += "=" * 60 + "\n\n" # Row count cursor.execute(sql.SQL("SELECT COUNT(*) as count FROM {}").format( sql.Identifier(table_name))) row_count = cursor.fetchone()['count'] result += f"Total Rows: {row_count:,}\n\n" # Column analysis cursor.execute(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s ORDER BY ordinal_position """, (table_name,)) columns = cursor.fetchall() result += "Column Statistics:\n" result += "-" * 60 + "\n" for col in columns: col_name = col['column_name'] data_type = col['data_type'] # Get null count and distinct count cursor.execute(sql.SQL(""" SELECT COUNT(*) - COUNT({col}) as null_count, COUNT(DISTINCT {col}) as distinct_count FROM {table} """).format(col=sql.Identifier(col_name), table=sql.Identifier(table_name))) stats = cursor.fetchone() null_pct = (stats['null_count'] / row_count * 100) if row_count > 0 else 0 result += f" {col_name} ({data_type}):\n" result += f" • Nulls: {stats['null_count']} ({null_pct:.1f}%)\n" result += f" • Unique values: {stats['distinct_count']:,}\n" # Sample data result += f"\nšŸ“ Sample Data ({min(sample_size, row_count)} rows):\n" result += "-" * 60 + "\n" cursor.execute(sql.SQL("SELECT * FROM {} LIMIT %s").format( sql.Identifier(table_name)), (sample_size,)) samples = cursor.fetchall() if samples: col_names = list(samples[0].keys()) result += "| " + " | ".join(col_names) + " |\n" result += "| " + " | ".join(["---"] * len(col_names)) + " |\n" for row in samples: values = [str(row.get(col, "NULL"))[:30] for col in col_names] result += "| " + " | ".join(values) + " |\n" return [TextContent(type="text", text=result)] def _analyze_column(cursor, table_name: str, column_name: str) -> list[TextContent]: """Deep analysis of a single column.""" result = f"šŸ” Column Analysis: {table_name}.{column_name}\n" result += "=" * 60 + "\n\n" # Basic stats cursor.execute(sql.SQL(""" SELECT COUNT(*) as total, COUNT({col}) as non_null, COUNT(*) - COUNT({col}) as null_count, COUNT(DISTINCT {col}) as unique_count, MIN({col}) as min_val, MAX({col}) as max_val FROM {table} """).format(col=sql.Identifier(column_name), table=sql.Identifier(table_name))) stats = cursor.fetchone() result += f"Total Rows: {stats['total']:,}\n" result += f"Non-Null: {stats['non_null']:,}\n" result += f"Null: {stats['null_count']:,} ({stats['null_count']/stats['total']*100:.1f}%)\n" result += f"Unique Values: {stats['unique_count']:,}\n" if stats['min_val'] is not None: result += f"Min: {stats['min_val']}\n" result += f"Max: {stats['max_val']}\n" # Numeric stats if applicable cursor.execute(""" SELECT data_type FROM information_schema.columns WHERE table_name = %s AND column_name = %s """, (table_name, column_name)) type_info = cursor.fetchone() if type_info and any(t in type_info['data_type'].lower() for t in ['int', 'float', 'double', 'decimal', 'numeric', 'real']): cursor.execute(sql.SQL(""" SELECT AVG({col})::numeric(10,2) as avg_val, STDDEV({col})::numeric(10,2) as stddev_val FROM {table} """).format(col=sql.Identifier(column_name), table=sql.Identifier(table_name))) num_stats = cursor.fetchone() if num_stats['avg_val']: result += f"\nšŸ“ˆ Numeric Statistics:\n" result += f" Average: {num_stats['avg_val']}\n" result += f" Std Dev: {num_stats['stddev_val']}\n" # Top values cursor.execute(sql.SQL(""" SELECT {col} as value, COUNT(*) as count FROM {table} WHERE {col} IS NOT NULL GROUP BY {col} ORDER BY count DESC LIMIT 10 """).format(col=sql.Identifier(column_name), table=sql.Identifier(table_name))) top_values = cursor.fetchall() if top_values: result += f"\nšŸ† Top Values:\n" for i, row in enumerate(top_values, 1): pct = row['count'] / stats['total'] * 100 result += f" {i}. {row['value'][:50]} ({row['count']:,}, {pct:.1f}%)\n" return [TextContent(type="text", text=result)] async def main(): """Run the MCP server.""" async with stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, InitializationOptions( server_name="postgres-analyzer", server_version="0.1.0", capabilities=server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, ), ), ) if __name__ == "__main__": asyncio.run(main())