1pub mod client;
2pub mod protocol;
3#[cfg(any(test, feature = "test-support"))]
4pub mod test;
5pub mod transport;
6pub mod types;
7
8use std::fmt::Display;
9use std::path::Path;
10use std::sync::Arc;
11
12use anyhow::Result;
13use client::Client;
14use collections::HashMap;
15use gpui::AsyncApp;
16use parking_lot::RwLock;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use util::redact::should_redact;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ContextServerId(pub Arc<str>);
23
24impl Display for ContextServerId {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "{}", self.0)
27 }
28}
29
30#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
31pub struct ContextServerCommand {
32 #[serde(rename = "command")]
33 pub path: String,
34 pub args: Vec<String>,
35 pub env: Option<HashMap<String, String>>,
36}
37
38impl std::fmt::Debug for ContextServerCommand {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 let filtered_env = self.env.as_ref().map(|env| {
41 env.iter()
42 .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
43 .collect::<Vec<_>>()
44 });
45
46 f.debug_struct("ContextServerCommand")
47 .field("path", &self.path)
48 .field("args", &self.args)
49 .field("env", &filtered_env)
50 .finish()
51 }
52}
53
54enum ContextServerTransport {
55 Stdio(ContextServerCommand),
56 Custom(Arc<dyn crate::transport::Transport>),
57}
58
59pub struct ContextServer {
60 id: ContextServerId,
61 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
62 configuration: ContextServerTransport,
63}
64
65impl ContextServer {
66 pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
67 Self {
68 id,
69 client: RwLock::new(None),
70 configuration: ContextServerTransport::Stdio(command),
71 }
72 }
73
74 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
75 Self {
76 id,
77 client: RwLock::new(None),
78 configuration: ContextServerTransport::Custom(transport),
79 }
80 }
81
82 pub fn id(&self) -> ContextServerId {
83 self.id.clone()
84 }
85
86 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
87 self.client.read().clone()
88 }
89
90 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
91 let client = match &self.configuration {
92 ContextServerTransport::Stdio(command) => Client::stdio(
93 client::ContextServerId(self.id.0.clone()),
94 client::ModelContextServerBinary {
95 executable: Path::new(&command.path).to_path_buf(),
96 args: command.args.clone(),
97 env: command.env.clone(),
98 },
99 cx.clone(),
100 )?,
101 ContextServerTransport::Custom(transport) => Client::new(
102 client::ContextServerId(self.id.0.clone()),
103 self.id().0,
104 transport.clone(),
105 cx.clone(),
106 )?,
107 };
108 self.initialize(client).await
109 }
110
111 async fn initialize(&self, client: Client) -> Result<()> {
112 log::info!("starting context server {}", self.id);
113 let protocol = crate::protocol::ModelContextProtocol::new(client);
114 let client_info = types::Implementation {
115 name: "Zed".to_string(),
116 version: env!("CARGO_PKG_VERSION").to_string(),
117 };
118 let initialized_protocol = protocol.initialize(client_info).await?;
119
120 log::debug!(
121 "context server {} initialized: {:?}",
122 self.id,
123 initialized_protocol.initialize,
124 );
125
126 *self.client.write() = Some(Arc::new(initialized_protocol));
127 Ok(())
128 }
129
130 pub fn stop(&self) -> Result<()> {
131 let mut client = self.client.write();
132 if let Some(protocol) = client.take() {
133 drop(protocol);
134 }
135 Ok(())
136 }
137}