1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use std::time::Duration;
9
10use anyhow::Result;
11use futures::channel::oneshot;
12use gpui::AsyncApp;
13use serde_json::Value;
14
15use crate::client::Client;
16use crate::types::{self, Notification, Request};
17
18pub struct ModelContextProtocol {
19 inner: Client,
20}
21
22impl ModelContextProtocol {
23 pub(crate) fn new(inner: Client) -> Self {
24 Self { inner }
25 }
26
27 fn supported_protocols() -> Vec<types::ProtocolVersion> {
28 vec![
29 types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
30 types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
31 ]
32 }
33
34 pub async fn initialize(
35 self,
36 client_info: types::Implementation,
37 ) -> Result<InitializedContextServerProtocol> {
38 let params = types::InitializeParams {
39 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
40 capabilities: types::ClientCapabilities {
41 experimental: None,
42 sampling: None,
43 roots: None,
44 },
45 meta: None,
46 client_info,
47 };
48
49 let response: types::InitializeResponse = self
50 .inner
51 .request(types::requests::Initialize::METHOD, params)
52 .await?;
53
54 anyhow::ensure!(
55 Self::supported_protocols().contains(&response.protocol_version),
56 "Unsupported protocol version: {:?}",
57 response.protocol_version
58 );
59
60 log::trace!("mcp server info {:?}", response.server_info);
61
62 let initialized_protocol = InitializedContextServerProtocol {
63 inner: self.inner,
64 initialize: response,
65 };
66
67 initialized_protocol.notify::<types::notifications::Initialized>(())?;
68
69 Ok(initialized_protocol)
70 }
71}
72
73pub struct InitializedContextServerProtocol {
74 inner: Client,
75 pub initialize: types::InitializeResponse,
76}
77
78#[derive(Debug, PartialEq, Clone, Copy)]
79pub enum ServerCapability {
80 Experimental,
81 Logging,
82 Prompts,
83 Resources,
84 Tools,
85}
86
87impl InitializedContextServerProtocol {
88 /// Check if the server supports a specific capability
89 pub fn capable(&self, capability: ServerCapability) -> bool {
90 match capability {
91 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
92 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
93 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
94 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
95 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
96 }
97 }
98
99 pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
100 self.inner.request(T::METHOD, params).await
101 }
102
103 pub async fn request_with<T: Request>(
104 &self,
105 params: T::Params,
106 cancel_rx: Option<oneshot::Receiver<()>>,
107 timeout: Option<Duration>,
108 ) -> Result<T::Response> {
109 self.inner
110 .request_with(T::METHOD, params, cancel_rx, timeout)
111 .await
112 }
113
114 pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
115 self.inner.notify(T::METHOD, params)
116 }
117
118 pub fn on_notification(
119 &self,
120 method: &'static str,
121 f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
122 ) {
123 self.inner.on_notification(method, f);
124 }
125}