protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14pub use types::PromptInfo;
 15
 16const PROTOCOL_VERSION: u32 = 1;
 17
 18pub struct ModelContextProtocol {
 19    inner: Client,
 20}
 21
 22impl ModelContextProtocol {
 23    pub fn new(inner: Client) -> Self {
 24        Self { inner }
 25    }
 26
 27    pub async fn initialize(
 28        self,
 29        client_info: types::EntityInfo,
 30    ) -> Result<InitializedContextServerProtocol> {
 31        let params = types::InitializeParams {
 32            protocol_version: PROTOCOL_VERSION,
 33            capabilities: types::ClientCapabilities {
 34                experimental: None,
 35                sampling: None,
 36            },
 37            client_info,
 38        };
 39
 40        let response: types::InitializeResponse = self
 41            .inner
 42            .request(types::RequestType::Initialize.as_str(), params)
 43            .await?;
 44
 45        log::trace!("mcp server info {:?}", response.server_info);
 46
 47        self.inner.notify(
 48            types::NotificationType::Initialized.as_str(),
 49            serde_json::json!({}),
 50        )?;
 51
 52        let initialized_protocol = InitializedContextServerProtocol {
 53            inner: self.inner,
 54            initialize: response,
 55        };
 56
 57        Ok(initialized_protocol)
 58    }
 59}
 60
 61pub struct InitializedContextServerProtocol {
 62    inner: Client,
 63    pub initialize: types::InitializeResponse,
 64}
 65
 66#[derive(Debug, PartialEq, Clone, Copy)]
 67pub enum ServerCapability {
 68    Experimental,
 69    Logging,
 70    Prompts,
 71    Resources,
 72    Tools,
 73}
 74
 75impl InitializedContextServerProtocol {
 76    /// Check if the server supports a specific capability
 77    pub fn capable(&self, capability: ServerCapability) -> bool {
 78        match capability {
 79            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 80            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 81            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 82            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 83            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 84        }
 85    }
 86
 87    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
 88        if self.capable(capability) {
 89            Ok(())
 90        } else {
 91            Err(anyhow::anyhow!(
 92                "Server does not support {:?} capability",
 93                capability
 94            ))
 95        }
 96    }
 97
 98    /// List the MCP prompts.
 99    pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
100        self.check_capability(ServerCapability::Prompts)?;
101
102        let response: types::PromptsListResponse = self
103            .inner
104            .request(types::RequestType::PromptsList.as_str(), ())
105            .await?;
106
107        Ok(response.prompts)
108    }
109
110    /// Executes a prompt with the given arguments and returns the result.
111    pub async fn run_prompt<P: AsRef<str>>(
112        &self,
113        prompt: P,
114        arguments: HashMap<String, String>,
115    ) -> Result<types::PromptsGetResponse> {
116        self.check_capability(ServerCapability::Prompts)?;
117
118        let params = types::PromptsGetParams {
119            name: prompt.as_ref().to_string(),
120            arguments: Some(arguments),
121        };
122
123        let response: types::PromptsGetResponse = self
124            .inner
125            .request(types::RequestType::PromptsGet.as_str(), params)
126            .await?;
127
128        Ok(response)
129    }
130}
131
132impl InitializedContextServerProtocol {
133    pub async fn request<R: serde::de::DeserializeOwned>(
134        &self,
135        method: &str,
136        params: impl serde::Serialize,
137    ) -> Result<R> {
138        self.inner.request(method, params).await
139    }
140}