427 lines
15 KiB
Python
427 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
PostgreSQL MCP Server - Database analysis and querying tools.
|
|
|
|
Provides safe, read-only database access for analysis.
|
|
Environment variable PG_CONNECTION_STRING required.
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import json
|
|
from urllib.parse import urlparse
|
|
from mcp.server.models import InitializationOptions
|
|
from mcp.server import NotificationOptions, Server
|
|
from mcp.server.stdio import stdio_server
|
|
from mcp.types import Tool, TextContent
|
|
import mcp.types as types
|
|
|
|
# Database imports
|
|
import psycopg2
|
|
from psycopg2 import sql
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
server = Server("postgres-analyzer")
|
|
|
|
# Track connection info for error messages
|
|
_connection_info = None
|
|
|
|
|
|
def get_connection():
|
|
"""Get database connection from environment."""
|
|
global _connection_info
|
|
conn_str = os.environ.get("PG_CONNECTION_STRING")
|
|
if not conn_str:
|
|
raise ValueError("PG_CONNECTION_STRING environment variable not set")
|
|
|
|
# Parse for safe logging (hide password)
|
|
parsed = urlparse(conn_str)
|
|
_connection_info = f"{parsed.scheme}://{parsed.username}@***:{parsed.port}{parsed.path}"
|
|
|
|
return psycopg2.connect(conn_str)
|
|
|
|
|
|
def check_read_only(query: str) -> bool:
|
|
"""Check if query is read-only (no modifications)."""
|
|
forbidden = ['insert', 'update', 'delete', 'drop', 'create', 'alter', 'truncate', 'grant', 'revoke']
|
|
query_lower = query.lower()
|
|
return not any(keyword in query_lower for keyword in forbidden)
|
|
|
|
|
|
@server.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
"""List available database analysis tools."""
|
|
return [
|
|
Tool(
|
|
name="get_schema",
|
|
description="Get database schema - lists all tables and their columns",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"table_name": {
|
|
"type": "string",
|
|
"description": "Optional: specific table name. If omitted, returns all tables."
|
|
}
|
|
}
|
|
},
|
|
),
|
|
Tool(
|
|
name="execute_query",
|
|
description="Execute a read-only SQL query and return results (max 1000 rows)",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to execute"
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Maximum rows to return (default 100, max 1000)",
|
|
"default": 100
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
},
|
|
),
|
|
Tool(
|
|
name="get_table_stats",
|
|
description="Get statistics for a table: row count, column stats, sample data",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"table_name": {
|
|
"type": "string",
|
|
"description": "Table name to analyze"
|
|
},
|
|
"sample_size": {
|
|
"type": "integer",
|
|
"description": "Number of sample rows (default 5)",
|
|
"default": 5
|
|
}
|
|
},
|
|
"required": ["table_name"]
|
|
},
|
|
),
|
|
Tool(
|
|
name="analyze_column",
|
|
description="Analyze a specific column: distribution, nulls, unique values",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"table_name": {
|
|
"type": "string",
|
|
"description": "Table name"
|
|
},
|
|
"column_name": {
|
|
"type": "string",
|
|
"description": "Column name to analyze"
|
|
}
|
|
},
|
|
"required": ["table_name", "column_name"]
|
|
},
|
|
),
|
|
]
|
|
|
|
|
|
@server.call_tool()
|
|
async def handle_call_tool(name: str, arguments: dict | None) -> list[types.TextContent]:
|
|
"""Execute database tools."""
|
|
if arguments is None:
|
|
arguments = {}
|
|
|
|
try:
|
|
conn = get_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
if name == "get_schema":
|
|
table_name = arguments.get("table_name")
|
|
return _get_schema(cursor, table_name)
|
|
|
|
elif name == "execute_query":
|
|
query = arguments.get("query", "")
|
|
limit = min(arguments.get("limit", 100), 1000)
|
|
return _execute_query(cursor, query, limit)
|
|
|
|
elif name == "get_table_stats":
|
|
table_name = arguments.get("table_name", "")
|
|
sample_size = arguments.get("sample_size", 5)
|
|
return _get_table_stats(cursor, table_name, sample_size)
|
|
|
|
elif name == "analyze_column":
|
|
table_name = arguments.get("table_name", "")
|
|
column_name = arguments.get("column_name", "")
|
|
return _analyze_column(cursor, table_name, column_name)
|
|
|
|
else:
|
|
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
|
|
|
except ValueError as e:
|
|
return [TextContent(type="text", text=f"Configuration error: {str(e)}")]
|
|
except psycopg2.Error as e:
|
|
return [TextContent(type="text", text=f"Database error: {str(e)}")]
|
|
except Exception as e:
|
|
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
|
finally:
|
|
if 'cursor' in locals():
|
|
cursor.close()
|
|
if 'conn' in locals():
|
|
conn.close()
|
|
|
|
|
|
def _get_schema(cursor, table_name: str | None) -> list[TextContent]:
|
|
"""Get database schema information."""
|
|
if table_name:
|
|
# Get specific table schema
|
|
cursor.execute("""
|
|
SELECT
|
|
column_name,
|
|
data_type,
|
|
is_nullable,
|
|
column_default
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (table_name,))
|
|
columns = cursor.fetchall()
|
|
|
|
if not columns:
|
|
return [TextContent(type="text", text=f"Table '{table_name}' not found.")]
|
|
|
|
result = f"Table: {table_name}\n"
|
|
result += "-" * 60 + "\n"
|
|
for col in columns:
|
|
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
|
|
default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
|
|
result += f" {col['column_name']}: {col['data_type']} {nullable}{default}\n"
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
|
|
else:
|
|
# Get all tables
|
|
cursor.execute("""
|
|
SELECT
|
|
table_name,
|
|
(SELECT COUNT(*) FROM information_schema.columns
|
|
WHERE table_name = t.table_name) as column_count
|
|
FROM information_schema.tables t
|
|
WHERE table_schema = 'public'
|
|
ORDER BY table_name
|
|
""")
|
|
tables = cursor.fetchall()
|
|
|
|
if not tables:
|
|
return [TextContent(type="text", text="No tables found in public schema.")]
|
|
|
|
result = "Database Schema\n"
|
|
result += "=" * 60 + "\n\n"
|
|
for table in tables:
|
|
result += f"📋 {table['table_name']} ({table['column_count']} columns)\n"
|
|
|
|
# Get columns for this table
|
|
cursor.execute("""
|
|
SELECT column_name, data_type
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (table['table_name'],))
|
|
columns = cursor.fetchall()
|
|
for col in columns:
|
|
result += f" • {col['column_name']}: {col['data_type']}\n"
|
|
result += "\n"
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
|
|
|
|
def _execute_query(cursor, query: str, limit: int) -> list[TextContent]:
|
|
"""Execute a read-only query."""
|
|
if not check_read_only(query):
|
|
return [TextContent(type="text",
|
|
text="Error: Only SELECT queries are allowed for safety.")]
|
|
|
|
# Add limit if not present
|
|
if "limit" not in query.lower():
|
|
query = f"{query} LIMIT {limit}"
|
|
|
|
cursor.execute(query)
|
|
rows = cursor.fetchall()
|
|
|
|
if not rows:
|
|
return [TextContent(type="text", text="Query returned no results.")]
|
|
|
|
# Format as markdown table
|
|
columns = list(rows[0].keys())
|
|
result = "| " + " | ".join(columns) + " |\n"
|
|
result += "| " + " | ".join(["---"] * len(columns)) + " |\n"
|
|
|
|
for row in rows[:limit]:
|
|
values = [str(row.get(col, "NULL"))[:50] for col in columns]
|
|
result += "| " + " | ".join(values) + " |\n"
|
|
|
|
if len(rows) > limit:
|
|
result += f"\n... and {len(rows) - limit} more rows"
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
|
|
|
|
def _get_table_stats(cursor, table_name: str, sample_size: int) -> list[TextContent]:
|
|
"""Get comprehensive table statistics."""
|
|
# Check if table exists
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as count
|
|
FROM information_schema.tables
|
|
WHERE table_name = %s AND table_schema = 'public'
|
|
""", (table_name,))
|
|
|
|
if cursor.fetchone()['count'] == 0:
|
|
return [TextContent(type="text", text=f"Table '{table_name}' not found.")]
|
|
|
|
result = f"📊 Table Analysis: {table_name}\n"
|
|
result += "=" * 60 + "\n\n"
|
|
|
|
# Row count
|
|
cursor.execute(sql.SQL("SELECT COUNT(*) as count FROM {}").format(
|
|
sql.Identifier(table_name)))
|
|
row_count = cursor.fetchone()['count']
|
|
result += f"Total Rows: {row_count:,}\n\n"
|
|
|
|
# Column analysis
|
|
cursor.execute("""
|
|
SELECT column_name, data_type
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (table_name,))
|
|
columns = cursor.fetchall()
|
|
|
|
result += "Column Statistics:\n"
|
|
result += "-" * 60 + "\n"
|
|
|
|
for col in columns:
|
|
col_name = col['column_name']
|
|
data_type = col['data_type']
|
|
|
|
# Get null count and distinct count
|
|
cursor.execute(sql.SQL("""
|
|
SELECT
|
|
COUNT(*) - COUNT({col}) as null_count,
|
|
COUNT(DISTINCT {col}) as distinct_count
|
|
FROM {table}
|
|
""").format(col=sql.Identifier(col_name), table=sql.Identifier(table_name)))
|
|
stats = cursor.fetchone()
|
|
|
|
null_pct = (stats['null_count'] / row_count * 100) if row_count > 0 else 0
|
|
result += f" {col_name} ({data_type}):\n"
|
|
result += f" • Nulls: {stats['null_count']} ({null_pct:.1f}%)\n"
|
|
result += f" • Unique values: {stats['distinct_count']:,}\n"
|
|
|
|
# Sample data
|
|
result += f"\n📝 Sample Data ({min(sample_size, row_count)} rows):\n"
|
|
result += "-" * 60 + "\n"
|
|
|
|
cursor.execute(sql.SQL("SELECT * FROM {} LIMIT %s").format(
|
|
sql.Identifier(table_name)), (sample_size,))
|
|
samples = cursor.fetchall()
|
|
|
|
if samples:
|
|
col_names = list(samples[0].keys())
|
|
result += "| " + " | ".join(col_names) + " |\n"
|
|
result += "| " + " | ".join(["---"] * len(col_names)) + " |\n"
|
|
for row in samples:
|
|
values = [str(row.get(col, "NULL"))[:30] for col in col_names]
|
|
result += "| " + " | ".join(values) + " |\n"
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
|
|
|
|
def _analyze_column(cursor, table_name: str, column_name: str) -> list[TextContent]:
|
|
"""Deep analysis of a single column."""
|
|
result = f"🔍 Column Analysis: {table_name}.{column_name}\n"
|
|
result += "=" * 60 + "\n\n"
|
|
|
|
# Basic stats
|
|
cursor.execute(sql.SQL("""
|
|
SELECT
|
|
COUNT(*) as total,
|
|
COUNT({col}) as non_null,
|
|
COUNT(*) - COUNT({col}) as null_count,
|
|
COUNT(DISTINCT {col}) as unique_count,
|
|
MIN({col}) as min_val,
|
|
MAX({col}) as max_val
|
|
FROM {table}
|
|
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
|
|
|
|
stats = cursor.fetchone()
|
|
|
|
result += f"Total Rows: {stats['total']:,}\n"
|
|
result += f"Non-Null: {stats['non_null']:,}\n"
|
|
result += f"Null: {stats['null_count']:,} ({stats['null_count']/stats['total']*100:.1f}%)\n"
|
|
result += f"Unique Values: {stats['unique_count']:,}\n"
|
|
|
|
if stats['min_val'] is not None:
|
|
result += f"Min: {stats['min_val']}\n"
|
|
result += f"Max: {stats['max_val']}\n"
|
|
|
|
# Numeric stats if applicable
|
|
cursor.execute("""
|
|
SELECT data_type
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s AND column_name = %s
|
|
""", (table_name, column_name))
|
|
|
|
type_info = cursor.fetchone()
|
|
if type_info and any(t in type_info['data_type'].lower()
|
|
for t in ['int', 'float', 'double', 'decimal', 'numeric', 'real']):
|
|
cursor.execute(sql.SQL("""
|
|
SELECT
|
|
AVG({col})::numeric(10,2) as avg_val,
|
|
STDDEV({col})::numeric(10,2) as stddev_val
|
|
FROM {table}
|
|
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
|
|
|
|
num_stats = cursor.fetchone()
|
|
if num_stats['avg_val']:
|
|
result += f"\n📈 Numeric Statistics:\n"
|
|
result += f" Average: {num_stats['avg_val']}\n"
|
|
result += f" Std Dev: {num_stats['stddev_val']}\n"
|
|
|
|
# Top values
|
|
cursor.execute(sql.SQL("""
|
|
SELECT {col} as value, COUNT(*) as count
|
|
FROM {table}
|
|
WHERE {col} IS NOT NULL
|
|
GROUP BY {col}
|
|
ORDER BY count DESC
|
|
LIMIT 10
|
|
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
|
|
|
|
top_values = cursor.fetchall()
|
|
if top_values:
|
|
result += f"\n🏆 Top Values:\n"
|
|
for i, row in enumerate(top_values, 1):
|
|
pct = row['count'] / stats['total'] * 100
|
|
result += f" {i}. {row['value'][:50]} ({row['count']:,}, {pct:.1f}%)\n"
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
|
|
|
|
async def main():
|
|
"""Run the MCP server."""
|
|
async with stdio_server() as (read_stream, write_stream):
|
|
await server.run(
|
|
read_stream,
|
|
write_stream,
|
|
InitializationOptions(
|
|
server_name="postgres-analyzer",
|
|
server_version="0.1.0",
|
|
capabilities=server.get_capabilities(
|
|
notification_options=NotificationOptions(),
|
|
experimental_capabilities={},
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|