1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use std::path::Path;
10use std::sync::Arc;
11use std::{fmt::Display, path::PathBuf};
12
13use anyhow::Result;
14use client::Client;
15use collections::HashMap;
16use gpui::AsyncApp;
17use parking_lot::RwLock;
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20pub use settings::ContextServerCommand;
21use util::redact::should_redact;
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct ContextServerId(pub Arc<str>);
25
26impl Display for ContextServerId {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "{}", self.0)
29 }
30}
31
32enum ContextServerTransport {
33 Stdio(ContextServerCommand, Option<PathBuf>),
34 Custom(Arc<dyn crate::transport::Transport>),
35}
36
37pub struct ContextServer {
38 id: ContextServerId,
39 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
40 configuration: ContextServerTransport,
41}
42
43impl ContextServer {
44 pub fn stdio(
45 id: ContextServerId,
46 command: ContextServerCommand,
47 working_directory: Option<Arc<Path>>,
48 ) -> Self {
49 Self {
50 id,
51 client: RwLock::new(None),
52 configuration: ContextServerTransport::Stdio(
53 command,
54 working_directory.map(|directory| directory.to_path_buf()),
55 ),
56 }
57 }
58
59 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
60 Self {
61 id,
62 client: RwLock::new(None),
63 configuration: ContextServerTransport::Custom(transport),
64 }
65 }
66
67 pub fn id(&self) -> ContextServerId {
68 self.id.clone()
69 }
70
71 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
72 self.client.read().clone()
73 }
74
75 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
76 self.initialize(self.new_client(cx)?).await
77 }
78
79 /// Starts the context server, making sure handlers are registered before initialization happens
80 pub async fn start_with_handlers(
81 &self,
82 notification_handlers: Vec<(
83 &'static str,
84 Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
85 )>,
86 cx: &AsyncApp,
87 ) -> Result<()> {
88 let client = self.new_client(cx)?;
89 for (method, handler) in notification_handlers {
90 client.on_notification(method, handler);
91 }
92 self.initialize(client).await
93 }
94
95 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
96 Ok(match &self.configuration {
97 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
98 client::ContextServerId(self.id.0.clone()),
99 client::ModelContextServerBinary {
100 executable: Path::new(&command.path).to_path_buf(),
101 args: command.args.clone(),
102 env: command.env.clone(),
103 timeout: command.timeout,
104 },
105 working_directory,
106 cx.clone(),
107 )?,
108 ContextServerTransport::Custom(transport) => Client::new(
109 client::ContextServerId(self.id.0.clone()),
110 self.id().0,
111 transport.clone(),
112 None,
113 cx.clone(),
114 )?,
115 })
116 }
117
118 async fn initialize(&self, client: Client) -> Result<()> {
119 log::debug!("starting context server {}", self.id);
120 let protocol = crate::protocol::ModelContextProtocol::new(client);
121 let client_info = types::Implementation {
122 name: "Zed".to_string(),
123 version: env!("CARGO_PKG_VERSION").to_string(),
124 };
125 let initialized_protocol = protocol.initialize(client_info).await?;
126
127 log::debug!(
128 "context server {} initialized: {:?}",
129 self.id,
130 initialized_protocol.initialize,
131 );
132
133 *self.client.write() = Some(Arc::new(initialized_protocol));
134 Ok(())
135 }
136
137 pub fn stop(&self) -> Result<()> {
138 let mut client = self.client.write();
139 if let Some(protocol) = client.take() {
140 drop(protocol);
141 }
142 Ok(())
143 }
144}