@@ -40,7 +40,7 @@ pub struct LanguageServer {
name: String,
capabilities: ServerCapabilities,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
- response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
+ response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@@ -170,12 +170,18 @@ impl LanguageServer {
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
- let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default()));
+ let response_handlers =
+ Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let input_task = cx.spawn(|cx| {
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
async move {
- let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
+ let _clear_response_handlers = util::defer({
+ let response_handlers = response_handlers.clone();
+ move || {
+ response_handlers.lock().take();
+ }
+ });
let mut buffer = Vec::new();
loop {
buffer.clear();
@@ -200,7 +206,11 @@ impl LanguageServer {
} else if let Ok(AnyResponse { id, error, result }) =
serde_json::from_slice(&buffer)
{
- if let Some(handler) = response_handlers.lock().remove(&id) {
+ if let Some(handler) = response_handlers
+ .lock()
+ .as_mut()
+ .and_then(|handlers| handlers.remove(&id))
+ {
if let Some(error) = error {
handler(Err(error));
} else if let Some(result) = result {
@@ -226,7 +236,12 @@ impl LanguageServer {
let output_task = cx.background().spawn({
let response_handlers = response_handlers.clone();
async move {
- let _clear_response_handlers = ClearResponseHandlers(response_handlers);
+ let _clear_response_handlers = util::defer({
+ let response_handlers = response_handlers.clone();
+ move || {
+ response_handlers.lock().take();
+ }
+ });
let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
@@ -366,7 +381,7 @@ impl LanguageServer {
async move {
log::debug!("language server shutdown started");
shutdown_request.await?;
- response_handlers.lock().clear();
+ response_handlers.lock().take();
exit?;
output_done.recv().await;
log::debug!("language server shutdown finished");
@@ -521,7 +536,7 @@ impl LanguageServer {
fn request_internal<T: request::Request>(
next_id: &AtomicUsize,
- response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
+ response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
outbound_tx: &channel::Sender<Vec<u8>>,
params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>>
@@ -537,25 +552,31 @@ impl LanguageServer {
})
.unwrap();
+ let (tx, rx) = oneshot::channel();
+ let handle_response = response_handlers
+ .lock()
+ .as_mut()
+ .ok_or_else(|| anyhow!("server shut down"))
+ .map(|handlers| {
+ handlers.insert(
+ id,
+ Box::new(move |result| {
+ let response = match result {
+ Ok(response) => serde_json::from_str(response)
+ .context("failed to deserialize response"),
+ Err(error) => Err(anyhow!("{}", error.message)),
+ };
+ let _ = tx.send(response);
+ }),
+ );
+ });
+
let send = outbound_tx
.try_send(message)
.context("failed to write to language server's stdin");
- let (tx, rx) = oneshot::channel();
- response_handlers.lock().insert(
- id,
- Box::new(move |result| {
- let response = match result {
- Ok(response) => {
- serde_json::from_str(response).context("failed to deserialize response")
- }
- Err(error) => Err(anyhow!("{}", error.message)),
- };
- let _ = tx.send(response);
- }),
- );
-
async move {
+ handle_response?;
send?;
rx.await?
}
@@ -762,14 +783,6 @@ impl FakeLanguageServer {
}
}
-struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
-
-impl Drop for ClearResponseHandlers {
- fn drop(&mut self) {
- self.0.lock().clear();
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;