1use std::io::{BufRead, Write};
28
29use anyhow::Result;
30use serde_json::json;
31
32use crate::config::{Config, LogLevel};
33use crate::db::DbPool;
34use crate::protocol::{
35 InitializeResult, JsonRpcRequest, JsonRpcResponse, ServerCapabilities, ServerInfo,
36 ToolCallParams, ToolsCapability, ToolsListResult,
37};
38use crate::tools;
39
40const SERVER_INSTRUCTIONS: &str = include_str!("../MCP_INSTRUCTIONS.md");
48
49pub struct ServerContext {
54 pub db: DbPool,
55 pub config: Config,
56}
57
58pub fn run_stdio(ctx: &ServerContext) -> Result<()> {
63 let stdin = std::io::stdin();
64 let stdout = std::io::stdout();
65 let mut reader = stdin.lock();
66 let mut writer = stdout.lock();
67
68 log(
69 ctx,
70 LogLevel::Info,
71 "Zotero MCP server started (Rust, stdio)",
72 );
73
74 let mut line = String::new();
75 loop {
76 line.clear();
77 let bytes_read = reader.read_line(&mut line)?;
78 if bytes_read == 0 {
79 log(ctx, LogLevel::Info, "Client disconnected (EOF)");
80 break;
81 }
82
83 let trimmed = line.trim();
84 if trimmed.is_empty() {
85 continue;
86 }
87
88 log(ctx, LogLevel::Debug, &format!("< {trimmed}"));
89
90 let request: JsonRpcRequest = match serde_json::from_str(trimmed) {
91 Ok(r) => r,
92 Err(e) => {
93 let resp = JsonRpcResponse::error(None, -32700, format!("Parse error: {e}"));
94 write_response(&mut writer, &resp, ctx)?;
95 continue;
96 }
97 };
98
99 let is_notification = request.id.is_none();
100 let response = dispatch(&request, ctx);
101
102 if !is_notification && let Some(resp) = response {
103 write_response(&mut writer, &resp, ctx)?;
104 }
105 }
106
107 Ok(())
108}
109
110pub(crate) fn dispatch(request: &JsonRpcRequest, ctx: &ServerContext) -> Option<JsonRpcResponse> {
113 match request.method.as_str() {
114 "initialize" => Some(handle_initialize(request)),
115 "notifications/initialized" => {
116 log(ctx, LogLevel::Debug, "Client initialized");
117 None
118 }
119 "tools/list" => Some(handle_tools_list(request)),
120 "tools/call" => Some(handle_tools_call(request, ctx)),
121 "ping" => Some(JsonRpcResponse::success(request.id.clone(), json!({}))),
122 _ => {
123 if request.id.is_some() {
124 Some(JsonRpcResponse::method_not_found(
125 request.id.clone(),
126 &request.method,
127 ))
128 } else {
129 None
130 }
131 }
132 }
133}
134
135fn handle_initialize(request: &JsonRpcRequest) -> JsonRpcResponse {
136 let result = InitializeResult {
137 protocol_version: "2024-11-05".into(),
138 capabilities: ServerCapabilities {
139 tools: ToolsCapability {
140 list_changed: Some(false),
141 },
142 },
143 server_info: ServerInfo {
144 name: "biblion".into(),
145 version: env!("CARGO_PKG_VERSION").into(),
146 },
147 instructions: Some(SERVER_INSTRUCTIONS.into()),
148 };
149
150 JsonRpcResponse::success(
151 request.id.clone(),
152 serde_json::to_value(result).unwrap_or_default(),
153 )
154}
155
156fn handle_tools_list(request: &JsonRpcRequest) -> JsonRpcResponse {
157 let catalog = tools::tool_catalog();
158 let result = ToolsListResult { tools: catalog };
159 JsonRpcResponse::success(
160 request.id.clone(),
161 serde_json::to_value(result).unwrap_or_default(),
162 )
163}
164
165fn handle_tools_call(request: &JsonRpcRequest, ctx: &ServerContext) -> JsonRpcResponse {
166 let params: ToolCallParams = match serde_json::from_value(request.params.clone()) {
167 Ok(p) => p,
168 Err(e) => {
169 return JsonRpcResponse::error(
170 request.id.clone(),
171 -32602,
172 format!("Invalid params: {e}"),
173 );
174 }
175 };
176
177 log(
178 ctx,
179 LogLevel::Debug,
180 &format!("Tool call: {} args={}", params.name, params.arguments),
181 );
182
183 let result = tools::handle_tool_call(¶ms.name, ¶ms.arguments, ctx);
184
185 JsonRpcResponse::success(
186 request.id.clone(),
187 serde_json::to_value(result).unwrap_or_default(),
188 )
189}
190
191fn write_response(
192 writer: &mut impl Write,
193 response: &JsonRpcResponse,
194 ctx: &ServerContext,
195) -> Result<()> {
196 let json = serde_json::to_string(response)?;
197 log(ctx, LogLevel::Debug, &format!("> {json}"));
198 writeln!(writer, "{json}")?;
199 writer.flush()?;
200 Ok(())
201}
202
203pub fn log(ctx: &ServerContext, level: LogLevel, message: &str) {
204 if level_enabled(ctx.config.log_level, level) {
205 let prefix = match level {
206 LogLevel::Quiet => "",
207 LogLevel::Info => "[biblion] ",
208 LogLevel::Debug => "[biblion:debug] ",
209 };
210 eprintln!("{prefix}{message}");
211 }
212}
213
214fn level_enabled(configured: LogLevel, requested: LogLevel) -> bool {
215 match configured {
216 LogLevel::Quiet => false,
217 LogLevel::Info => matches!(requested, LogLevel::Info),
218 LogLevel::Debug => matches!(requested, LogLevel::Info | LogLevel::Debug),
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::protocol::JsonRpcRequest;
226
227 fn test_ctx() -> ServerContext {
228 ServerContext {
229 db: DbPool::empty(),
230 config: Config {
231 zotero_sqlite_path: "/tmp/nonexistent.sqlite".into(),
232 zotero_storage_path: "/tmp/storage".into(),
233 bbt_migrated_path: "/tmp/nonexistent.migrated".into(),
234 zotero_api_key: None,
235 zotero_library_id: "1".into(),
236 zotero_library_type: "user".into(),
237 bbt_url: "http://localhost:23119/better-bibtex/json-rpc".into(),
238 log_level: LogLevel::Quiet,
239 writes_enabled: false,
240 resolver: paper_resolver::ResolverConfig::default(),
241 zotero_api_base_url: None,
242 },
243 }
244 }
245
246 #[test]
247 fn dispatch_initialize() {
248 let req = JsonRpcRequest {
249 jsonrpc: "2.0".into(),
250 id: Some(json!(1)),
251 method: "initialize".into(),
252 params: json!({}),
253 };
254 let ctx = test_ctx();
255 let resp = dispatch(&req, &ctx).unwrap();
256 let result = resp.result.unwrap();
257 assert_eq!(result["protocolVersion"], "2024-11-05");
258 assert_eq!(result["serverInfo"]["name"], "biblion");
259 }
260
261 #[test]
262 fn dispatch_ping() {
263 let req = JsonRpcRequest {
264 jsonrpc: "2.0".into(),
265 id: Some(json!(2)),
266 method: "ping".into(),
267 params: json!(null),
268 };
269 let ctx = test_ctx();
270 let resp = dispatch(&req, &ctx).unwrap();
271 assert!(resp.result.is_some());
272 assert!(resp.error.is_none());
273 }
274
275 #[test]
276 fn dispatch_tools_list() {
277 let req = JsonRpcRequest {
278 jsonrpc: "2.0".into(),
279 id: Some(json!(3)),
280 method: "tools/list".into(),
281 params: json!({}),
282 };
283 let ctx = test_ctx();
284 let resp = dispatch(&req, &ctx).unwrap();
285 let result = resp.result.unwrap();
286 let tools = result["tools"].as_array().unwrap();
287 assert!(!tools.is_empty());
289 }
290
291 #[test]
292 fn dispatch_notification_returns_none() {
293 let req = JsonRpcRequest {
294 jsonrpc: "2.0".into(),
295 id: None,
296 method: "notifications/initialized".into(),
297 params: json!(null),
298 };
299 let ctx = test_ctx();
300 assert!(dispatch(&req, &ctx).is_none());
301 }
302
303 #[test]
304 fn dispatch_unknown_method_returns_error() {
305 let req = JsonRpcRequest {
306 jsonrpc: "2.0".into(),
307 id: Some(json!(99)),
308 method: "bogus/method".into(),
309 params: json!(null),
310 };
311 let ctx = test_ctx();
312 let resp = dispatch(&req, &ctx).unwrap();
313 assert!(resp.error.is_some());
314 assert_eq!(resp.error.unwrap().code, -32601);
315 }
316
317 #[test]
318 fn dispatch_unknown_notification_ignored() {
319 let req = JsonRpcRequest {
320 jsonrpc: "2.0".into(),
321 id: None,
322 method: "bogus/notification".into(),
323 params: json!(null),
324 };
325 let ctx = test_ctx();
326 assert!(dispatch(&req, &ctx).is_none());
327 }
328
329 #[test]
330 fn dispatch_tools_call_invalid_params() {
331 let req = JsonRpcRequest {
332 jsonrpc: "2.0".into(),
333 id: Some(json!(10)),
334 method: "tools/call".into(),
335 params: json!("not an object"),
336 };
337 let ctx = test_ctx();
338 let resp = dispatch(&req, &ctx).unwrap();
339 assert!(resp.error.is_some());
340 assert_eq!(resp.error.unwrap().code, -32602);
341 }
342
343 #[test]
344 fn level_quiet_blocks_all() {
345 assert!(!level_enabled(LogLevel::Quiet, LogLevel::Info));
346 assert!(!level_enabled(LogLevel::Quiet, LogLevel::Debug));
347 }
348
349 #[test]
350 fn level_info_passes_info_only() {
351 assert!(level_enabled(LogLevel::Info, LogLevel::Info));
352 assert!(!level_enabled(LogLevel::Info, LogLevel::Debug));
353 }
354
355 #[test]
356 fn level_debug_passes_both() {
357 assert!(level_enabled(LogLevel::Debug, LogLevel::Info));
358 assert!(level_enabled(LogLevel::Debug, LogLevel::Debug));
359 }
360}