1pub mod client;
2pub mod protocol;
3pub mod transport;
4pub mod types;
5
6use std::fmt::Display;
7use std::path::Path;
8use std::sync::Arc;
9
10use anyhow::Result;
11use client::Client;
12use collections::HashMap;
13use gpui::AsyncApp;
14use parking_lot::RwLock;
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub struct ContextServerId(pub Arc<str>);
20
21impl Display for ContextServerId {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "{}", self.0)
24 }
25}
26
27#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
28pub struct ContextServerCommand {
29 pub path: String,
30 pub args: Vec<String>,
31 pub env: Option<HashMap<String, String>>,
32}
33
34enum ContextServerTransport {
35 Stdio(ContextServerCommand),
36 Custom(Arc<dyn crate::transport::Transport>),
37}
38
39pub struct ContextServer {
40 id: ContextServerId,
41 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
42 configuration: ContextServerTransport,
43}
44
45impl ContextServer {
46 pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
47 Self {
48 id,
49 client: RwLock::new(None),
50 configuration: ContextServerTransport::Stdio(command),
51 }
52 }
53
54 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
55 Self {
56 id,
57 client: RwLock::new(None),
58 configuration: ContextServerTransport::Custom(transport),
59 }
60 }
61
62 pub fn id(&self) -> ContextServerId {
63 self.id.clone()
64 }
65
66 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
67 self.client.read().clone()
68 }
69
70 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
71 let client = match &self.configuration {
72 ContextServerTransport::Stdio(command) => Client::stdio(
73 client::ContextServerId(self.id.0.clone()),
74 client::ModelContextServerBinary {
75 executable: Path::new(&command.path).to_path_buf(),
76 args: command.args.clone(),
77 env: command.env.clone(),
78 },
79 cx.clone(),
80 )?,
81 ContextServerTransport::Custom(transport) => Client::new(
82 client::ContextServerId(self.id.0.clone()),
83 self.id().0,
84 transport.clone(),
85 cx.clone(),
86 )?,
87 };
88 self.initialize(client).await
89 }
90
91 async fn initialize(&self, client: Client) -> Result<()> {
92 log::info!("starting context server {}", self.id);
93 let protocol = crate::protocol::ModelContextProtocol::new(client);
94 let client_info = types::Implementation {
95 name: "Zed".to_string(),
96 version: env!("CARGO_PKG_VERSION").to_string(),
97 };
98 let initialized_protocol = protocol.initialize(client_info).await?;
99
100 log::debug!(
101 "context server {} initialized: {:?}",
102 self.id,
103 initialized_protocol.initialize,
104 );
105
106 *self.client.write() = Some(Arc::new(initialized_protocol));
107 Ok(())
108 }
109
110 pub fn stop(&self) -> Result<()> {
111 let mut client = self.client.write();
112 if let Some(protocol) = client.take() {
113 drop(protocol);
114 }
115 Ok(())
116 }
117}