1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use futures::channel::oneshot;
10use gpui::AsyncApp;
11use serde_json::Value;
12
13use crate::client::Client;
14use crate::types::{self, Notification, Request};
15
16pub struct ModelContextProtocol {
17 inner: Client,
18}
19
20impl ModelContextProtocol {
21 pub(crate) fn new(inner: Client) -> Self {
22 Self { inner }
23 }
24
25 fn supported_protocols() -> Vec<types::ProtocolVersion> {
26 vec![
27 types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
28 types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
29 ]
30 }
31
32 pub async fn initialize(
33 self,
34 client_info: types::Implementation,
35 ) -> Result<InitializedContextServerProtocol> {
36 let params = types::InitializeParams {
37 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
38 capabilities: types::ClientCapabilities {
39 experimental: None,
40 sampling: None,
41 roots: None,
42 },
43 meta: None,
44 client_info,
45 };
46
47 let response: types::InitializeResponse = self
48 .inner
49 .request(types::requests::Initialize::METHOD, params)
50 .await?;
51
52 anyhow::ensure!(
53 Self::supported_protocols().contains(&response.protocol_version),
54 "Unsupported protocol version: {:?}",
55 response.protocol_version
56 );
57
58 log::trace!("mcp server info {:?}", response.server_info);
59
60 let initialized_protocol = InitializedContextServerProtocol {
61 inner: self.inner,
62 initialize: response,
63 };
64
65 initialized_protocol.notify::<types::notifications::Initialized>(())?;
66
67 Ok(initialized_protocol)
68 }
69}
70
71pub struct InitializedContextServerProtocol {
72 inner: Client,
73 pub initialize: types::InitializeResponse,
74}
75
76#[derive(Debug, PartialEq, Clone, Copy)]
77pub enum ServerCapability {
78 Experimental,
79 Logging,
80 Prompts,
81 Resources,
82 Tools,
83}
84
85impl InitializedContextServerProtocol {
86 /// Check if the server supports a specific capability
87 pub fn capable(&self, capability: ServerCapability) -> bool {
88 match capability {
89 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
90 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
91 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
92 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
93 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
94 }
95 }
96
97 pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
98 self.inner.request(T::METHOD, params).await
99 }
100
101 pub async fn cancellable_request<T: Request>(
102 &self,
103 params: T::Params,
104 cancel_rx: oneshot::Receiver<()>,
105 ) -> Result<T::Response> {
106 self.inner
107 .cancellable_request(T::METHOD, params, cancel_rx)
108 .await
109 }
110
111 pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
112 self.inner.notify(T::METHOD, params)
113 }
114
115 pub fn on_notification<F>(&self, method: &'static str, f: F)
116 where
117 F: 'static + Send + FnMut(Value, AsyncApp),
118 {
119 self.inner.on_notification(method, f);
120 }
121}