nanoagent/pg_mcp_server/server.py

427 lines
15 KiB
Python

#!/usr/bin/env python3
"""
PostgreSQL MCP Server - Database analysis and querying tools.
Provides safe, read-only database access for analysis.
Environment variable PG_CONNECTION_STRING required.
"""
import asyncio
import os
import json
from urllib.parse import urlparse
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions, Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
import mcp.types as types
# Database imports
import psycopg2
from psycopg2 import sql
from psycopg2.extras import RealDictCursor
server = Server("postgres-analyzer")
# Track connection info for error messages
_connection_info = None
def get_connection():
"""Get database connection from environment."""
global _connection_info
conn_str = os.environ.get("PG_CONNECTION_STRING")
if not conn_str:
raise ValueError("PG_CONNECTION_STRING environment variable not set")
# Parse for safe logging (hide password)
parsed = urlparse(conn_str)
_connection_info = f"{parsed.scheme}://{parsed.username}@***:{parsed.port}{parsed.path}"
return psycopg2.connect(conn_str)
def check_read_only(query: str) -> bool:
"""Check if query is read-only (no modifications)."""
forbidden = ['insert', 'update', 'delete', 'drop', 'create', 'alter', 'truncate', 'grant', 'revoke']
query_lower = query.lower()
return not any(keyword in query_lower for keyword in forbidden)
@server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""List available database analysis tools."""
return [
Tool(
name="get_schema",
description="Get database schema - lists all tables and their columns",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Optional: specific table name. If omitted, returns all tables."
}
}
},
),
Tool(
name="execute_query",
description="Execute a read-only SQL query and return results (max 1000 rows)",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"limit": {
"type": "integer",
"description": "Maximum rows to return (default 100, max 1000)",
"default": 100
}
},
"required": ["query"]
},
),
Tool(
name="get_table_stats",
description="Get statistics for a table: row count, column stats, sample data",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Table name to analyze"
},
"sample_size": {
"type": "integer",
"description": "Number of sample rows (default 5)",
"default": 5
}
},
"required": ["table_name"]
},
),
Tool(
name="analyze_column",
description="Analyze a specific column: distribution, nulls, unique values",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Table name"
},
"column_name": {
"type": "string",
"description": "Column name to analyze"
}
},
"required": ["table_name", "column_name"]
},
),
]
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list[types.TextContent]:
"""Execute database tools."""
if arguments is None:
arguments = {}
try:
conn = get_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
if name == "get_schema":
table_name = arguments.get("table_name")
return _get_schema(cursor, table_name)
elif name == "execute_query":
query = arguments.get("query", "")
limit = min(arguments.get("limit", 100), 1000)
return _execute_query(cursor, query, limit)
elif name == "get_table_stats":
table_name = arguments.get("table_name", "")
sample_size = arguments.get("sample_size", 5)
return _get_table_stats(cursor, table_name, sample_size)
elif name == "analyze_column":
table_name = arguments.get("table_name", "")
column_name = arguments.get("column_name", "")
return _analyze_column(cursor, table_name, column_name)
else:
return [TextContent(type="text", text=f"Unknown tool: {name}")]
except ValueError as e:
return [TextContent(type="text", text=f"Configuration error: {str(e)}")]
except psycopg2.Error as e:
return [TextContent(type="text", text=f"Database error: {str(e)}")]
except Exception as e:
return [TextContent(type="text", text=f"Error: {str(e)}")]
finally:
if 'cursor' in locals():
cursor.close()
if 'conn' in locals():
conn.close()
def _get_schema(cursor, table_name: str | None) -> list[TextContent]:
"""Get database schema information."""
if table_name:
# Get specific table schema
cursor.execute("""
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table_name,))
columns = cursor.fetchall()
if not columns:
return [TextContent(type="text", text=f"Table '{table_name}' not found.")]
result = f"Table: {table_name}\n"
result += "-" * 60 + "\n"
for col in columns:
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
result += f" {col['column_name']}: {col['data_type']} {nullable}{default}\n"
return [TextContent(type="text", text=result)]
else:
# Get all tables
cursor.execute("""
SELECT
table_name,
(SELECT COUNT(*) FROM information_schema.columns
WHERE table_name = t.table_name) as column_count
FROM information_schema.tables t
WHERE table_schema = 'public'
ORDER BY table_name
""")
tables = cursor.fetchall()
if not tables:
return [TextContent(type="text", text="No tables found in public schema.")]
result = "Database Schema\n"
result += "=" * 60 + "\n\n"
for table in tables:
result += f"📋 {table['table_name']} ({table['column_count']} columns)\n"
# Get columns for this table
cursor.execute("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table['table_name'],))
columns = cursor.fetchall()
for col in columns:
result += f"{col['column_name']}: {col['data_type']}\n"
result += "\n"
return [TextContent(type="text", text=result)]
def _execute_query(cursor, query: str, limit: int) -> list[TextContent]:
"""Execute a read-only query."""
if not check_read_only(query):
return [TextContent(type="text",
text="Error: Only SELECT queries are allowed for safety.")]
# Add limit if not present
if "limit" not in query.lower():
query = f"{query} LIMIT {limit}"
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
return [TextContent(type="text", text="Query returned no results.")]
# Format as markdown table
columns = list(rows[0].keys())
result = "| " + " | ".join(columns) + " |\n"
result += "| " + " | ".join(["---"] * len(columns)) + " |\n"
for row in rows[:limit]:
values = [str(row.get(col, "NULL"))[:50] for col in columns]
result += "| " + " | ".join(values) + " |\n"
if len(rows) > limit:
result += f"\n... and {len(rows) - limit} more rows"
return [TextContent(type="text", text=result)]
def _get_table_stats(cursor, table_name: str, sample_size: int) -> list[TextContent]:
"""Get comprehensive table statistics."""
# Check if table exists
cursor.execute("""
SELECT COUNT(*) as count
FROM information_schema.tables
WHERE table_name = %s AND table_schema = 'public'
""", (table_name,))
if cursor.fetchone()['count'] == 0:
return [TextContent(type="text", text=f"Table '{table_name}' not found.")]
result = f"📊 Table Analysis: {table_name}\n"
result += "=" * 60 + "\n\n"
# Row count
cursor.execute(sql.SQL("SELECT COUNT(*) as count FROM {}").format(
sql.Identifier(table_name)))
row_count = cursor.fetchone()['count']
result += f"Total Rows: {row_count:,}\n\n"
# Column analysis
cursor.execute("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table_name,))
columns = cursor.fetchall()
result += "Column Statistics:\n"
result += "-" * 60 + "\n"
for col in columns:
col_name = col['column_name']
data_type = col['data_type']
# Get null count and distinct count
cursor.execute(sql.SQL("""
SELECT
COUNT(*) - COUNT({col}) as null_count,
COUNT(DISTINCT {col}) as distinct_count
FROM {table}
""").format(col=sql.Identifier(col_name), table=sql.Identifier(table_name)))
stats = cursor.fetchone()
null_pct = (stats['null_count'] / row_count * 100) if row_count > 0 else 0
result += f" {col_name} ({data_type}):\n"
result += f" • Nulls: {stats['null_count']} ({null_pct:.1f}%)\n"
result += f" • Unique values: {stats['distinct_count']:,}\n"
# Sample data
result += f"\n📝 Sample Data ({min(sample_size, row_count)} rows):\n"
result += "-" * 60 + "\n"
cursor.execute(sql.SQL("SELECT * FROM {} LIMIT %s").format(
sql.Identifier(table_name)), (sample_size,))
samples = cursor.fetchall()
if samples:
col_names = list(samples[0].keys())
result += "| " + " | ".join(col_names) + " |\n"
result += "| " + " | ".join(["---"] * len(col_names)) + " |\n"
for row in samples:
values = [str(row.get(col, "NULL"))[:30] for col in col_names]
result += "| " + " | ".join(values) + " |\n"
return [TextContent(type="text", text=result)]
def _analyze_column(cursor, table_name: str, column_name: str) -> list[TextContent]:
"""Deep analysis of a single column."""
result = f"🔍 Column Analysis: {table_name}.{column_name}\n"
result += "=" * 60 + "\n\n"
# Basic stats
cursor.execute(sql.SQL("""
SELECT
COUNT(*) as total,
COUNT({col}) as non_null,
COUNT(*) - COUNT({col}) as null_count,
COUNT(DISTINCT {col}) as unique_count,
MIN({col}) as min_val,
MAX({col}) as max_val
FROM {table}
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
stats = cursor.fetchone()
result += f"Total Rows: {stats['total']:,}\n"
result += f"Non-Null: {stats['non_null']:,}\n"
result += f"Null: {stats['null_count']:,} ({stats['null_count']/stats['total']*100:.1f}%)\n"
result += f"Unique Values: {stats['unique_count']:,}\n"
if stats['min_val'] is not None:
result += f"Min: {stats['min_val']}\n"
result += f"Max: {stats['max_val']}\n"
# Numeric stats if applicable
cursor.execute("""
SELECT data_type
FROM information_schema.columns
WHERE table_name = %s AND column_name = %s
""", (table_name, column_name))
type_info = cursor.fetchone()
if type_info and any(t in type_info['data_type'].lower()
for t in ['int', 'float', 'double', 'decimal', 'numeric', 'real']):
cursor.execute(sql.SQL("""
SELECT
AVG({col})::numeric(10,2) as avg_val,
STDDEV({col})::numeric(10,2) as stddev_val
FROM {table}
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
num_stats = cursor.fetchone()
if num_stats['avg_val']:
result += f"\n📈 Numeric Statistics:\n"
result += f" Average: {num_stats['avg_val']}\n"
result += f" Std Dev: {num_stats['stddev_val']}\n"
# Top values
cursor.execute(sql.SQL("""
SELECT {col} as value, COUNT(*) as count
FROM {table}
WHERE {col} IS NOT NULL
GROUP BY {col}
ORDER BY count DESC
LIMIT 10
""").format(col=sql.Identifier(column_name), table=sql.Identifier(table_name)))
top_values = cursor.fetchall()
if top_values:
result += f"\n🏆 Top Values:\n"
for i, row in enumerate(top_values, 1):
pct = row['count'] / stats['total'] * 100
result += f" {i}. {row['value'][:50]} ({row['count']:,}, {pct:.1f}%)\n"
return [TextContent(type="text", text=result)]
async def main():
"""Run the MCP server."""
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="postgres-analyzer",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
if __name__ == "__main__":
asyncio.run(main())