1pub mod client;
2pub mod protocol;
3#[cfg(any(test, feature = "test-support"))]
4pub mod test;
5pub mod transport;
6pub mod types;
7
8use std::fmt::Display;
9use std::path::Path;
10use std::sync::Arc;
11
12use anyhow::Result;
13use client::Client;
14use collections::HashMap;
15use gpui::AsyncApp;
16use parking_lot::RwLock;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct ContextServerId(pub Arc<str>);
22
23impl Display for ContextServerId {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "{}", self.0)
26 }
27}
28
29#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
30pub struct ContextServerCommand {
31 pub path: String,
32 pub args: Vec<String>,
33 pub env: Option<HashMap<String, String>>,
34}
35
36enum ContextServerTransport {
37 Stdio(ContextServerCommand),
38 Custom(Arc<dyn crate::transport::Transport>),
39}
40
41pub struct ContextServer {
42 id: ContextServerId,
43 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
44 configuration: ContextServerTransport,
45}
46
47impl ContextServer {
48 pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
49 Self {
50 id,
51 client: RwLock::new(None),
52 configuration: ContextServerTransport::Stdio(command),
53 }
54 }
55
56 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
57 Self {
58 id,
59 client: RwLock::new(None),
60 configuration: ContextServerTransport::Custom(transport),
61 }
62 }
63
64 pub fn id(&self) -> ContextServerId {
65 self.id.clone()
66 }
67
68 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
69 self.client.read().clone()
70 }
71
72 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
73 let client = match &self.configuration {
74 ContextServerTransport::Stdio(command) => Client::stdio(
75 client::ContextServerId(self.id.0.clone()),
76 client::ModelContextServerBinary {
77 executable: Path::new(&command.path).to_path_buf(),
78 args: command.args.clone(),
79 env: command.env.clone(),
80 },
81 cx.clone(),
82 )?,
83 ContextServerTransport::Custom(transport) => Client::new(
84 client::ContextServerId(self.id.0.clone()),
85 self.id().0,
86 transport.clone(),
87 cx.clone(),
88 )?,
89 };
90 self.initialize(client).await
91 }
92
93 async fn initialize(&self, client: Client) -> Result<()> {
94 log::info!("starting context server {}", self.id);
95 let protocol = crate::protocol::ModelContextProtocol::new(client);
96 let client_info = types::Implementation {
97 name: "Zed".to_string(),
98 version: env!("CARGO_PKG_VERSION").to_string(),
99 };
100 let initialized_protocol = protocol.initialize(client_info).await?;
101
102 log::debug!(
103 "context server {} initialized: {:?}",
104 self.id,
105 initialized_protocol.initialize,
106 );
107
108 *self.client.write() = Some(Arc::new(initialized_protocol));
109 Ok(())
110 }
111
112 pub fn stop(&self) -> Result<()> {
113 let mut client = self.client.write();
114 if let Some(protocol) = client.take() {
115 drop(protocol);
116 }
117 Ok(())
118 }
119}