Allow both integer and string request IDs in LSP (#7662)

Michal Příhoda created

Zed's LSP support expects request messages to have integer ID, Metals
LSP uses string. According to specification, both is acceptable:

interface RequestMessage extends Message {

	/**
	 * The request id.
	 */
	id: integer | string;

...
This pull requests modifies the types and serialization/deserialization
so that string IDs are accepted.

Release Notes:

- Make Zed LSP request ids compliant to the LSP specification

Change summary

crates/lsp/src/lsp.rs | 79 +++++++++++++++++++++++++++++++++-----------
1 file changed, 58 insertions(+), 21 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -23,7 +23,7 @@ use std::{
     path::PathBuf,
     str::{self, FromStr as _},
     sync::{
-        atomic::{AtomicUsize, Ordering::SeqCst},
+        atomic::{AtomicI32, Ordering::SeqCst},
         Arc, Weak,
     },
     time::{Duration, Instant},
@@ -36,7 +36,7 @@ const JSON_RPC_VERSION: &str = "2.0";
 const CONTENT_LEN_HEADER: &str = "Content-Length: ";
 const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 
-type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
+type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, &str, AsyncAppContext)>;
 type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
 type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
 
@@ -59,14 +59,14 @@ pub struct LanguageServerBinary {
 /// A running language server process.
 pub struct LanguageServer {
     server_id: LanguageServerId,
-    next_id: AtomicUsize,
+    next_id: AtomicI32,
     outbound_tx: channel::Sender<String>,
     name: String,
     capabilities: ServerCapabilities,
     code_action_kinds: Option<Vec<CodeActionKind>>,
     notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
-    response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
-    io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
+    response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+    io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
     executor: BackgroundExecutor,
     #[allow(clippy::type_complexity)]
     io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@@ -87,18 +87,28 @@ pub enum Subscription {
         notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>,
     },
     Io {
-        id: usize,
-        io_handlers: Option<Weak<Mutex<HashMap<usize, IoHandler>>>>,
+        id: i32,
+        io_handlers: Option<Weak<Mutex<HashMap<i32, IoHandler>>>>,
     },
 }
 
+/// Language server protocol RPC request message ID.
+///
+/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
+#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum RequestId {
+    Int(i32),
+    Str(String),
+}
+
 /// Language server protocol RPC request message.
 ///
 /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
 #[derive(Serialize, Deserialize)]
 pub struct Request<'a, T> {
     jsonrpc: &'static str,
-    id: usize,
+    id: RequestId,
     method: &'a str,
     params: T,
 }
@@ -107,7 +117,7 @@ pub struct Request<'a, T> {
 #[derive(Serialize, Deserialize)]
 struct AnyResponse<'a> {
     jsonrpc: &'a str,
-    id: usize,
+    id: RequestId,
     #[serde(default)]
     error: Option<Error>,
     #[serde(borrow)]
@@ -120,7 +130,7 @@ struct AnyResponse<'a> {
 #[derive(Serialize)]
 struct Response<T> {
     jsonrpc: &'static str,
-    id: usize,
+    id: RequestId,
     result: Option<T>,
     error: Option<Error>,
 }
@@ -140,7 +150,7 @@ struct Notification<'a, T> {
 #[derive(Debug, Clone, Deserialize)]
 struct AnyNotification<'a> {
     #[serde(default)]
-    id: Option<usize>,
+    id: Option<RequestId>,
     #[serde(borrow)]
     method: &'a str,
     #[serde(borrow, default)]
@@ -305,8 +315,8 @@ impl LanguageServer {
         stdout: Stdout,
         mut on_unhandled_notification: F,
         notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
-        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
-        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
+        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
         cx: AsyncAppContext,
     ) -> anyhow::Result<()>
     where
@@ -387,7 +397,7 @@ impl LanguageServer {
 
     async fn handle_stderr<Stderr>(
         stderr: Stderr,
-        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
+        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
         stderr_capture: Arc<Mutex<Option<String>>>,
     ) -> anyhow::Result<()>
     where
@@ -424,8 +434,8 @@ impl LanguageServer {
         stdin: Stdin,
         outbound_rx: channel::Receiver<String>,
         output_done_tx: barrier::Sender,
-        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
-        io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
+        response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+        io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
     ) -> anyhow::Result<()>
     where
         Stdin: AsyncWrite + Unpin + Send + 'static,
@@ -621,7 +631,7 @@ impl LanguageServer {
     pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
         if let Some(tasks) = self.io_tasks.lock().take() {
             let response_handlers = self.response_handlers.clone();
-            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
+            let next_id = AtomicI32::new(self.next_id.load(SeqCst));
             let outbound_tx = self.outbound_tx.clone();
             let executor = self.executor.clone();
             let mut output_done = self.output_done_rx.lock().take().unwrap();
@@ -850,8 +860,8 @@ impl LanguageServer {
     }
 
     fn request_internal<T: request::Request>(
-        next_id: &AtomicUsize,
-        response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
+        next_id: &AtomicI32,
+        response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>,
         outbound_tx: &channel::Sender<String>,
         executor: &BackgroundExecutor,
         params: T::Params,
@@ -862,7 +872,7 @@ impl LanguageServer {
         let id = next_id.fetch_add(1, SeqCst);
         let message = serde_json::to_string(&Request {
             jsonrpc: JSON_RPC_VERSION,
-            id,
+            id: RequestId::Int(id),
             method: T::METHOD,
             params,
         })
@@ -876,7 +886,7 @@ impl LanguageServer {
             .map(|handlers| {
                 let executor = executor.clone();
                 handlers.insert(
-                    id,
+                    RequestId::Int(id),
                     Box::new(move |result| {
                         executor
                             .spawn(async move {
@@ -1340,4 +1350,31 @@ mod tests {
             b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
         );
     }
+
+    #[gpui::test]
+    fn test_deserialize_string_digit_id() {
+        let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
+        let notification = serde_json::from_str::<AnyNotification>(json)
+            .expect("message with string id should be parsed");
+        let expected_id = RequestId::Str("2".to_string());
+        assert_eq!(notification.id, Some(expected_id));
+    }
+
+    #[gpui::test]
+    fn test_deserialize_string_id() {
+        let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
+        let notification = serde_json::from_str::<AnyNotification>(json)
+            .expect("message with string id should be parsed");
+        let expected_id = RequestId::Str("anythingAtAll".to_string());
+        assert_eq!(notification.id, Some(expected_id));
+    }
+
+    #[gpui::test]
+    fn test_deserialize_int_id() {
+        let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
+        let notification = serde_json::from_str::<AnyNotification>(json)
+            .expect("message with string id should be parsed");
+        let expected_id = RequestId::Int(2);
+        assert_eq!(notification.id, Some(expected_id));
+    }
 }