1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use gpui::AsyncApp;
10use serde_json::Value;
11
12use crate::client::Client;
13use crate::types::{self, Notification, Request};
14
15pub struct ModelContextProtocol {
16 inner: Client,
17}
18
19impl ModelContextProtocol {
20 pub(crate) fn new(inner: Client) -> Self {
21 Self { inner }
22 }
23
24 fn supported_protocols() -> Vec<types::ProtocolVersion> {
25 vec![
26 types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
27 types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
28 ]
29 }
30
31 pub async fn initialize(
32 self,
33 client_info: types::Implementation,
34 ) -> Result<InitializedContextServerProtocol> {
35 let params = types::InitializeParams {
36 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
37 capabilities: types::ClientCapabilities {
38 experimental: None,
39 sampling: None,
40 roots: None,
41 },
42 meta: None,
43 client_info,
44 };
45
46 let response: types::InitializeResponse = self
47 .inner
48 .request(types::requests::Initialize::METHOD, params)
49 .await?;
50
51 anyhow::ensure!(
52 Self::supported_protocols().contains(&response.protocol_version),
53 "Unsupported protocol version: {:?}",
54 response.protocol_version
55 );
56
57 log::trace!("mcp server info {:?}", response.server_info);
58
59 let initialized_protocol = InitializedContextServerProtocol {
60 inner: self.inner,
61 initialize: response,
62 };
63
64 initialized_protocol.notify::<types::notifications::Initialized>(())?;
65
66 Ok(initialized_protocol)
67 }
68}
69
70pub struct InitializedContextServerProtocol {
71 inner: Client,
72 pub initialize: types::InitializeResponse,
73}
74
75#[derive(Debug, PartialEq, Clone, Copy)]
76pub enum ServerCapability {
77 Experimental,
78 Logging,
79 Prompts,
80 Resources,
81 Tools,
82}
83
84impl InitializedContextServerProtocol {
85 /// Check if the server supports a specific capability
86 pub fn capable(&self, capability: ServerCapability) -> bool {
87 match capability {
88 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
89 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
90 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
91 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
92 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
93 }
94 }
95
96 pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
97 self.inner.request(T::METHOD, params).await
98 }
99
100 pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
101 self.inner.notify(T::METHOD, params)
102 }
103
104 pub fn on_notification<F>(&self, method: &'static str, f: F)
105 where
106 F: 'static + Send + FnMut(Value, AsyncApp),
107 {
108 self.inner.on_notification(method, f);
109 }
110}