1use std::collections::HashMap;
21use std::sync::Arc;
22
23use anyhow::Result;
24use axum::extract::{Query, State};
25use axum::response::IntoResponse;
26use axum::response::sse::{Event, KeepAlive, Sse};
27use axum::routing::{get, post};
28use tokio::sync::{RwLock, mpsc};
29use tokio_stream::StreamExt;
30use tokio_stream::wrappers::ReceiverStream;
31
32use crate::config::Config;
33use crate::db::DbPool;
34use crate::protocol::JsonRpcRequest;
35use crate::server::ServerContext;
36
37type SessionMap = Arc<RwLock<HashMap<String, mpsc::Sender<String>>>>;
39
40#[derive(Clone)]
42struct AppState {
43 config: Arc<Config>,
44 sessions: SessionMap,
45}
46
47pub fn run_sse(ctx: ServerContext, host: &str, port: u16) -> Result<()> {
49 let rt = tokio::runtime::Builder::new_multi_thread()
50 .enable_all()
51 .build()?;
52
53 rt.block_on(async {
54 let state = AppState {
55 config: Arc::new(ctx.config),
56 sessions: Arc::new(RwLock::new(HashMap::new())),
57 };
58
59 let app = axum::Router::new()
60 .route("/sse", get(handle_sse))
61 .route("/messages", post(handle_message))
62 .route("/messages/", post(handle_message))
63 .with_state(state);
64
65 let addr = format!("{host}:{port}");
66 eprintln!("[biblion] SSE server listening on http://{addr}/sse");
67
68 let listener = tokio::net::TcpListener::bind(&addr).await?;
69 axum::serve(listener, app).await?;
70
71 Ok(())
72 })
73}
74
75async fn handle_sse(
77 State(state): State<AppState>,
78) -> Sse<impl tokio_stream::Stream<Item = Result<Event, std::convert::Infallible>>> {
79 let session_id = uuid::Uuid::new_v4().to_string();
80 let (tx, rx) = mpsc::channel::<String>(64);
81
82 state
83 .sessions
84 .write()
85 .await
86 .insert(session_id.clone(), tx.clone());
87
88 let tx_cleanup = tx.clone();
90 let sessions_cleanup = state.sessions.clone();
91 let session_id_cleanup = session_id.clone();
92 tokio::spawn(async move {
93 tx_cleanup.closed().await;
94 sessions_cleanup.write().await.remove(&session_id_cleanup);
95 eprintln!("[biblion] Session disconnected: {session_id_cleanup}");
96 });
97
98 eprintln!("[biblion] New SSE session: {session_id}");
99
100 let endpoint_url = format!("/messages?session_id={session_id}");
102 let _ = tx.send(format!("endpoint:{endpoint_url}")).await;
103
104 let stream = ReceiverStream::new(rx).map(move |msg| {
105 if let Some(url) = msg.strip_prefix("endpoint:") {
106 Ok(Event::default().event("endpoint").data(url))
107 } else {
108 Ok(Event::default().event("message").data(msg))
109 }
110 });
111
112 Sse::new(stream).keep_alive(KeepAlive::default())
113}
114
115async fn handle_message(
117 State(state): State<AppState>,
118 Query(params): Query<HashMap<String, String>>,
119 body: String,
120) -> impl IntoResponse {
121 let session_id = match params.get("session_id") {
122 Some(id) => id.clone(),
123 None => return axum::http::StatusCode::BAD_REQUEST,
124 };
125
126 let request: JsonRpcRequest = match serde_json::from_str(&body) {
128 Ok(r) => r,
129 Err(_) => return axum::http::StatusCode::BAD_REQUEST,
130 };
131
132 let sessions = state.sessions.read().await;
133 let tx = match sessions.get(&session_id) {
134 Some(tx) => tx.clone(),
135 None => return axum::http::StatusCode::NOT_FOUND,
136 };
137 drop(sessions);
138
139 let is_notification = request.id.is_none();
140 let config = state.config.clone();
141
142 let response = tokio::task::spawn_blocking(move || {
145 let db = DbPool::open(&config.zotero_sqlite_path, &config.bbt_migrated_path);
146 let ctx = ServerContext {
147 db,
148 config: (*config).clone(),
149 };
150 crate::server::dispatch(&request, &ctx)
152 })
153 .await;
154
155 if !is_notification && let Ok(Some(resp)) = response {
156 let json = serde_json::to_string(&resp).unwrap_or_default();
157 let _ = tx.send(json).await;
158 }
159
160 axum::http::StatusCode::ACCEPTED
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn test_ctx() -> ServerContext {
168 ServerContext {
169 db: DbPool::empty(),
170 config: Config {
171 zotero_sqlite_path: "/tmp/z.sqlite".into(),
172 zotero_storage_path: "/tmp/storage".into(),
173 bbt_migrated_path: "/tmp/bbt".into(),
174 zotero_api_key: None,
175 zotero_library_id: "1".into(),
176 zotero_library_type: "user".into(),
177 bbt_url: "http://localhost:23119".into(),
178 log_level: crate::config::LogLevel::Quiet,
179 writes_enabled: false,
180 resolver: paper_resolver::ResolverConfig::default(),
181 zotero_api_base_url: None,
182 },
183 }
184 }
185
186 #[test]
187 fn sse_dispatch_initialize() {
188 let ctx = test_ctx();
189 let req = JsonRpcRequest {
190 jsonrpc: "2.0".into(),
191 id: Some(serde_json::json!(1)),
192 method: "initialize".into(),
193 params: serde_json::json!({}),
194 };
195 let resp = crate::server::dispatch(&req, &ctx).unwrap();
196 let result = resp.result.unwrap();
197 assert_eq!(result["serverInfo"]["name"], "biblion");
198 }
199
200 #[test]
201 fn sse_dispatch_ping() {
202 let ctx = test_ctx();
203 let req = JsonRpcRequest {
204 jsonrpc: "2.0".into(),
205 id: Some(serde_json::json!(2)),
206 method: "ping".into(),
207 params: serde_json::json!(null),
208 };
209 let resp = crate::server::dispatch(&req, &ctx).unwrap();
210 assert!(resp.result.is_some());
211 }
212
213 #[test]
214 fn sse_dispatch_notification_ignored() {
215 let ctx = test_ctx();
216 let req = JsonRpcRequest {
217 jsonrpc: "2.0".into(),
218 id: None,
219 method: "notifications/initialized".into(),
220 params: serde_json::json!(null),
221 };
222 assert!(crate::server::dispatch(&req, &ctx).is_none());
223 }
224}