protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14const PROTOCOL_VERSION: u32 = 1;
 15
 16pub struct ModelContextProtocol {
 17    inner: Client,
 18}
 19
 20impl ModelContextProtocol {
 21    pub fn new(inner: Client) -> Self {
 22        Self { inner }
 23    }
 24
 25    pub async fn initialize(
 26        self,
 27        client_info: types::Implementation,
 28    ) -> Result<InitializedContextServerProtocol> {
 29        let params = types::InitializeParams {
 30            protocol_version: PROTOCOL_VERSION,
 31            capabilities: types::ClientCapabilities {
 32                experimental: None,
 33                sampling: None,
 34            },
 35            client_info,
 36        };
 37
 38        let response: types::InitializeResponse = self
 39            .inner
 40            .request(types::RequestType::Initialize.as_str(), params)
 41            .await?;
 42
 43        log::trace!("mcp server info {:?}", response.server_info);
 44
 45        self.inner.notify(
 46            types::NotificationType::Initialized.as_str(),
 47            serde_json::json!({}),
 48        )?;
 49
 50        let initialized_protocol = InitializedContextServerProtocol {
 51            inner: self.inner,
 52            initialize: response,
 53        };
 54
 55        Ok(initialized_protocol)
 56    }
 57}
 58
 59pub struct InitializedContextServerProtocol {
 60    inner: Client,
 61    pub initialize: types::InitializeResponse,
 62}
 63
 64#[derive(Debug, PartialEq, Clone, Copy)]
 65pub enum ServerCapability {
 66    Experimental,
 67    Logging,
 68    Prompts,
 69    Resources,
 70    Tools,
 71}
 72
 73impl InitializedContextServerProtocol {
 74    /// Check if the server supports a specific capability
 75    pub fn capable(&self, capability: ServerCapability) -> bool {
 76        match capability {
 77            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 78            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 79            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 80            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 81            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 82        }
 83    }
 84
 85    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
 86        if self.capable(capability) {
 87            Ok(())
 88        } else {
 89            Err(anyhow::anyhow!(
 90                "Server does not support {:?} capability",
 91                capability
 92            ))
 93        }
 94    }
 95
 96    /// List the MCP prompts.
 97    pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
 98        self.check_capability(ServerCapability::Prompts)?;
 99
100        let response: types::PromptsListResponse = self
101            .inner
102            .request(types::RequestType::PromptsList.as_str(), ())
103            .await?;
104
105        Ok(response.prompts)
106    }
107
108    /// List the MCP resources.
109    pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
110        self.check_capability(ServerCapability::Resources)?;
111
112        let response: types::ResourcesListResponse = self
113            .inner
114            .request(types::RequestType::ResourcesList.as_str(), ())
115            .await?;
116
117        Ok(response)
118    }
119
120    /// Executes a prompt with the given arguments and returns the result.
121    pub async fn run_prompt<P: AsRef<str>>(
122        &self,
123        prompt: P,
124        arguments: HashMap<String, String>,
125    ) -> Result<types::PromptsGetResponse> {
126        self.check_capability(ServerCapability::Prompts)?;
127
128        let params = types::PromptsGetParams {
129            name: prompt.as_ref().to_string(),
130            arguments: Some(arguments),
131        };
132
133        let response: types::PromptsGetResponse = self
134            .inner
135            .request(types::RequestType::PromptsGet.as_str(), params)
136            .await?;
137
138        Ok(response)
139    }
140
141    pub async fn completion<P: Into<String>>(
142        &self,
143        reference: types::CompletionReference,
144        argument: P,
145        value: P,
146    ) -> Result<types::Completion> {
147        let params = types::CompletionCompleteParams {
148            r#ref: reference,
149            argument: types::CompletionArgument {
150                name: argument.into(),
151                value: value.into(),
152            },
153        };
154        let result: types::CompletionCompleteResponse = self
155            .inner
156            .request(types::RequestType::CompletionComplete.as_str(), params)
157            .await?;
158
159        let completion = types::Completion {
160            values: result.completion.values,
161            total: types::CompletionTotal::from_options(
162                result.completion.has_more,
163                result.completion.total,
164            ),
165        };
166
167        Ok(completion)
168    }
169}
170
171impl InitializedContextServerProtocol {
172    pub async fn request<R: serde::de::DeserializeOwned>(
173        &self,
174        method: &str,
175        params: impl serde::Serialize,
176    ) -> Result<R> {
177        self.inner.request(method, params).await
178    }
179}