1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use std::path::Path;
10use std::sync::Arc;
11use std::{fmt::Display, path::PathBuf};
12
13use anyhow::Result;
14use client::Client;
15use gpui::AsyncApp;
16use parking_lot::RwLock;
17pub use settings::ContextServerCommand;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct ContextServerId(pub Arc<str>);
21
22impl Display for ContextServerId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28enum ContextServerTransport {
29 Stdio(ContextServerCommand, Option<PathBuf>),
30 Custom(Arc<dyn crate::transport::Transport>),
31}
32
33pub struct ContextServer {
34 id: ContextServerId,
35 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
36 configuration: ContextServerTransport,
37}
38
39impl ContextServer {
40 pub fn stdio(
41 id: ContextServerId,
42 command: ContextServerCommand,
43 working_directory: Option<Arc<Path>>,
44 ) -> Self {
45 Self {
46 id,
47 client: RwLock::new(None),
48 configuration: ContextServerTransport::Stdio(
49 command,
50 working_directory.map(|directory| directory.to_path_buf()),
51 ),
52 }
53 }
54
55 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
56 Self {
57 id,
58 client: RwLock::new(None),
59 configuration: ContextServerTransport::Custom(transport),
60 }
61 }
62
63 pub fn id(&self) -> ContextServerId {
64 self.id.clone()
65 }
66
67 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
68 self.client.read().clone()
69 }
70
71 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
72 self.initialize(self.new_client(cx)?).await
73 }
74
75 /// Starts the context server, making sure handlers are registered before initialization happens
76 pub async fn start_with_handlers(
77 &self,
78 notification_handlers: Vec<(
79 &'static str,
80 Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
81 )>,
82 cx: &AsyncApp,
83 ) -> Result<()> {
84 let client = self.new_client(cx)?;
85 for (method, handler) in notification_handlers {
86 client.on_notification(method, handler);
87 }
88 self.initialize(client).await
89 }
90
91 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
92 Ok(match &self.configuration {
93 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
94 client::ContextServerId(self.id.0.clone()),
95 client::ModelContextServerBinary {
96 executable: Path::new(&command.path).to_path_buf(),
97 args: command.args.clone(),
98 env: command.env.clone(),
99 timeout: command.timeout,
100 },
101 working_directory,
102 cx.clone(),
103 )?,
104 ContextServerTransport::Custom(transport) => Client::new(
105 client::ContextServerId(self.id.0.clone()),
106 self.id().0,
107 transport.clone(),
108 None,
109 cx.clone(),
110 )?,
111 })
112 }
113
114 async fn initialize(&self, client: Client) -> Result<()> {
115 log::debug!("starting context server {}", self.id);
116 let protocol = crate::protocol::ModelContextProtocol::new(client);
117 let client_info = types::Implementation {
118 name: "Zed".to_string(),
119 version: env!("CARGO_PKG_VERSION").to_string(),
120 };
121 let initialized_protocol = protocol.initialize(client_info).await?;
122
123 log::debug!(
124 "context server {} initialized: {:?}",
125 self.id,
126 initialized_protocol.initialize,
127 );
128
129 *self.client.write() = Some(Arc::new(initialized_protocol));
130 Ok(())
131 }
132
133 pub fn stop(&self) -> Result<()> {
134 let mut client = self.client.write();
135 if let Some(protocol) = client.take() {
136 drop(protocol);
137 }
138 Ok(())
139 }
140}