biblion/
sse.rs

1//! SSE (Server-Sent Events) transport for the MCP server.
2//!
3//! # MCP SSE Protocol
4//!
5//! - `GET /sse` — establish SSE stream, receive `endpoint` event with POST URL
6//! - `POST /messages?session_id=<id>` — send JSON-RPC requests
7//!
8//! # Thread safety
9//!
10//! rusqlite::Connection is not Send+Sync, so we can't share it across
11//! async tasks. Instead, we open a fresh read-only connection per request
12//! inside `spawn_blocking`. Since SQLite reads take <1ms and the connection
13//! open is ~0.5ms, this overhead is negligible for an SSE server that
14//! handles maybe 1-5 requests/second.
15//!
16//! # Reference
17//!
18//! MCP SSE spec: <https://modelcontextprotocol.io/docs/concepts/transports#server-sent-events-sse>
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use anyhow::Result;
24use axum::extract::{Query, State};
25use axum::response::IntoResponse;
26use axum::response::sse::{Event, KeepAlive, Sse};
27use axum::routing::{get, post};
28use tokio::sync::{RwLock, mpsc};
29use tokio_stream::StreamExt;
30use tokio_stream::wrappers::ReceiverStream;
31
32use crate::config::Config;
33use crate::db::DbPool;
34use crate::protocol::JsonRpcRequest;
35use crate::server::ServerContext;
36
37/// Per-session state: an SSE sender channel.
38type SessionMap = Arc<RwLock<HashMap<String, mpsc::Sender<String>>>>;
39
40/// Shared application state (thread-safe — no rusqlite in here).
41#[derive(Clone)]
42struct AppState {
43    config: Arc<Config>,
44    sessions: SessionMap,
45}
46
47/// Run the MCP server in SSE mode (async, multi-session).
48pub fn run_sse(ctx: ServerContext, host: &str, port: u16) -> Result<()> {
49    let rt = tokio::runtime::Builder::new_multi_thread()
50        .enable_all()
51        .build()?;
52
53    rt.block_on(async {
54        let state = AppState {
55            config: Arc::new(ctx.config),
56            sessions: Arc::new(RwLock::new(HashMap::new())),
57        };
58
59        let app = axum::Router::new()
60            .route("/sse", get(handle_sse))
61            .route("/messages", post(handle_message))
62            .route("/messages/", post(handle_message))
63            .with_state(state);
64
65        let addr = format!("{host}:{port}");
66        eprintln!("[biblion] SSE server listening on http://{addr}/sse");
67
68        let listener = tokio::net::TcpListener::bind(&addr).await?;
69        axum::serve(listener, app).await?;
70
71        Ok(())
72    })
73}
74
75/// GET /sse — establish SSE connection, return event stream.
76async fn handle_sse(
77    State(state): State<AppState>,
78) -> Sse<impl tokio_stream::Stream<Item = Result<Event, std::convert::Infallible>>> {
79    let session_id = uuid::Uuid::new_v4().to_string();
80    let (tx, rx) = mpsc::channel::<String>(64);
81
82    state
83        .sessions
84        .write()
85        .await
86        .insert(session_id.clone(), tx.clone());
87
88    // Clean up session when client disconnects (sender channel closes)
89    let tx_cleanup = tx.clone();
90    let sessions_cleanup = state.sessions.clone();
91    let session_id_cleanup = session_id.clone();
92    tokio::spawn(async move {
93        tx_cleanup.closed().await;
94        sessions_cleanup.write().await.remove(&session_id_cleanup);
95        eprintln!("[biblion] Session disconnected: {session_id_cleanup}");
96    });
97
98    eprintln!("[biblion] New SSE session: {session_id}");
99
100    // Send endpoint event
101    let endpoint_url = format!("/messages?session_id={session_id}");
102    let _ = tx.send(format!("endpoint:{endpoint_url}")).await;
103
104    let stream = ReceiverStream::new(rx).map(move |msg| {
105        if let Some(url) = msg.strip_prefix("endpoint:") {
106            Ok(Event::default().event("endpoint").data(url))
107        } else {
108            Ok(Event::default().event("message").data(msg))
109        }
110    });
111
112    Sse::new(stream).keep_alive(KeepAlive::default())
113}
114
115/// POST /messages?session_id=<id> — receive JSON-RPC, respond via SSE.
116async fn handle_message(
117    State(state): State<AppState>,
118    Query(params): Query<HashMap<String, String>>,
119    body: String,
120) -> impl IntoResponse {
121    let session_id = match params.get("session_id") {
122        Some(id) => id.clone(),
123        None => return axum::http::StatusCode::BAD_REQUEST,
124    };
125
126    // Parse JSON-RPC request
127    let request: JsonRpcRequest = match serde_json::from_str(&body) {
128        Ok(r) => r,
129        Err(_) => return axum::http::StatusCode::BAD_REQUEST,
130    };
131
132    let sessions = state.sessions.read().await;
133    let tx = match sessions.get(&session_id) {
134        Some(tx) => tx.clone(),
135        None => return axum::http::StatusCode::NOT_FOUND,
136    };
137    drop(sessions);
138
139    let is_notification = request.id.is_none();
140    let config = state.config.clone();
141
142    // Process in spawn_blocking (opens fresh SQLite connections per request
143    // because rusqlite::Connection is not Send+Sync)
144    let response = tokio::task::spawn_blocking(move || {
145        let db = DbPool::open(&config.zotero_sqlite_path, &config.bbt_migrated_path);
146        let ctx = ServerContext {
147            db,
148            config: (*config).clone(),
149        };
150        // Reuse shared dispatch from server module
151        crate::server::dispatch(&request, &ctx)
152    })
153    .await;
154
155    if !is_notification && let Ok(Some(resp)) = response {
156        let json = serde_json::to_string(&resp).unwrap_or_default();
157        let _ = tx.send(json).await;
158    }
159
160    axum::http::StatusCode::ACCEPTED
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn test_ctx() -> ServerContext {
168        ServerContext {
169            db: DbPool::empty(),
170            config: Config {
171                zotero_sqlite_path: "/tmp/z.sqlite".into(),
172                zotero_storage_path: "/tmp/storage".into(),
173                bbt_migrated_path: "/tmp/bbt".into(),
174                zotero_api_key: None,
175                zotero_library_id: "1".into(),
176                zotero_library_type: "user".into(),
177                bbt_url: "http://localhost:23119".into(),
178                log_level: crate::config::LogLevel::Quiet,
179                writes_enabled: false,
180                resolver: paper_resolver::ResolverConfig::default(),
181                zotero_api_base_url: None,
182            },
183        }
184    }
185
186    #[test]
187    fn sse_dispatch_initialize() {
188        let ctx = test_ctx();
189        let req = JsonRpcRequest {
190            jsonrpc: "2.0".into(),
191            id: Some(serde_json::json!(1)),
192            method: "initialize".into(),
193            params: serde_json::json!({}),
194        };
195        let resp = crate::server::dispatch(&req, &ctx).unwrap();
196        let result = resp.result.unwrap();
197        assert_eq!(result["serverInfo"]["name"], "biblion");
198    }
199
200    #[test]
201    fn sse_dispatch_ping() {
202        let ctx = test_ctx();
203        let req = JsonRpcRequest {
204            jsonrpc: "2.0".into(),
205            id: Some(serde_json::json!(2)),
206            method: "ping".into(),
207            params: serde_json::json!(null),
208        };
209        let resp = crate::server::dispatch(&req, &ctx).unwrap();
210        assert!(resp.result.is_some());
211    }
212
213    #[test]
214    fn sse_dispatch_notification_ignored() {
215        let ctx = test_ctx();
216        let req = JsonRpcRequest {
217            jsonrpc: "2.0".into(),
218            id: None,
219            method: "notifications/initialized".into(),
220            params: serde_json::json!(null),
221        };
222        assert!(crate::server::dispatch(&req, &ctx).is_none());
223    }
224}