Detailed changes
@@ -358,6 +358,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
+ "context_servers",
"ctor",
"db",
"editor",
@@ -2668,6 +2669,27 @@ dependencies = [
"tiny-keccak",
]
+[[package]]
+name = "context_servers"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "futures 0.3.30",
+ "gpui",
+ "log",
+ "parking_lot",
+ "postage",
+ "schemars",
+ "serde",
+ "serde_json",
+ "settings",
+ "smol",
+ "url",
+ "util",
+ "workspace",
+]
+
[[package]]
name = "convert_case"
version = "0.4.0"
@@ -19,6 +19,7 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
+ "crates/context_servers",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@@ -189,6 +190,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
+context_servers = { path = "crates/context_servers" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@@ -1010,5 +1010,16 @@
// ]
// }
// ]
- "ssh_connections": null
+ "ssh_connections": null,
+ // Configures the Context Server Protocol binaries
+ //
+ // Examples:
+ // {
+ // "id": "server-1",
+ // "executable": "/path",
+ // "args": ['arg1", "args2"]
+ // }
+ "experimental.context_servers": {
+ "servers": []
+ }
}
@@ -33,6 +33,7 @@ clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
db.workspace = true
+context_servers.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -21,9 +21,11 @@ use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
+use context_servers::ContextServerRegistry;
pub use context_store::*;
use feature_flags::FeatureFlagAppExt;
use fs::Fs;
+use gpui::Context as _;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
@@ -37,9 +39,9 @@ use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
- default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
- project_command, prompt_command, search_command, symbols_command, tab_command,
- terminal_command, workflow_command,
+ context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
+ file_command, now_command, project_command, prompt_command, search_command, symbols_command,
+ tab_command, terminal_command, workflow_command,
};
use std::sync::Arc;
pub(crate) use streaming_diff::*;
@@ -221,6 +223,7 @@ pub fn init(
init_language_model_settings(cx);
assistant_slash_command::init(cx);
assistant_panel::init(cx);
+ context_servers::init(cx);
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
dev_mode,
@@ -261,9 +264,69 @@ pub fn init(
})
.detach();
+ register_context_server_handlers(cx);
+
prompt_builder
}
+fn register_context_server_handlers(cx: &mut AppContext) {
+ cx.subscribe(
+ &context_servers::manager::ContextServerManager::global(cx),
+ |manager, event, cx| match event {
+ context_servers::manager::Event::ServerStarted { server_id } => {
+ cx.update_model(
+ &manager,
+ |manager: &mut context_servers::manager::ContextServerManager, cx| {
+ let slash_command_registry = SlashCommandRegistry::global(cx);
+ let context_server_registry = ContextServerRegistry::global(cx);
+ if let Some(server) = manager.get_server(server_id) {
+ cx.spawn(|_, _| async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return;
+ };
+
+ if let Some(prompts) = protocol.list_prompts().await.log_err() {
+ for prompt in prompts
+ .into_iter()
+ .filter(context_server_command::acceptable_prompt)
+ {
+ log::info!(
+ "registering context server command: {:?}",
+ prompt.name
+ );
+ context_server_registry.register_command(
+ server.id.clone(),
+ prompt.name.as_str(),
+ );
+ slash_command_registry.register_command(
+ context_server_command::ContextServerSlashCommand::new(
+ &server, prompt,
+ ),
+ true,
+ );
+ }
+ }
+ })
+ .detach();
+ }
+ },
+ );
+ }
+ context_servers::manager::Event::ServerStopped { server_id } => {
+ let slash_command_registry = SlashCommandRegistry::global(cx);
+ let context_server_registry = ContextServerRegistry::global(cx);
+ if let Some(commands) = context_server_registry.get_commands(server_id) {
+ for command_name in commands {
+ slash_command_registry.unregister_command_by_name(&command_name);
+ context_server_registry.unregister_command(&server_id, &command_name);
+ }
+ }
+ }
+ },
+ )
+ .detach();
+}
+
fn init_language_model_settings(cx: &mut AppContext) {
update_active_language_model_from_settings(cx);
@@ -18,6 +18,7 @@ use std::{
use ui::ActiveTheme;
use workspace::Workspace;
+pub mod context_server_command;
pub mod default_command;
pub mod diagnostics_command;
pub mod docs_command;
@@ -0,0 +1,125 @@
+use anyhow::{anyhow, Result};
+use assistant_slash_command::{
+ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
+};
+use collections::HashMap;
+use context_servers::{
+ manager::{ContextServer, ContextServerManager},
+ protocol::PromptInfo,
+};
+use gpui::{Task, WeakView, WindowContext};
+use language::LspAdapterDelegate;
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
+use ui::{IconName, SharedString};
+use workspace::Workspace;
+
+pub struct ContextServerSlashCommand {
+ server_id: String,
+ prompt: PromptInfo,
+}
+
+impl ContextServerSlashCommand {
+ pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
+ Self {
+ server_id: server.id.clone(),
+ prompt,
+ }
+ }
+}
+
+impl SlashCommand for ContextServerSlashCommand {
+ fn name(&self) -> String {
+ self.prompt.name.clone()
+ }
+
+ fn description(&self) -> String {
+ format!("Run context server command: {}", self.prompt.name)
+ }
+
+ fn menu_text(&self) -> String {
+ format!("Run '{}' from {}", self.prompt.name, self.server_id)
+ }
+
+ fn requires_argument(&self) -> bool {
+ self.prompt
+ .arguments
+ .as_ref()
+ .map_or(false, |args| !args.is_empty())
+ }
+
+ fn complete_argument(
+ self: Arc<Self>,
+ _arguments: &[String],
+ _cancel: Arc<AtomicBool>,
+ _workspace: Option<WeakView<Workspace>>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<Vec<ArgumentCompletion>>> {
+ Task::ready(Ok(Vec::new()))
+ }
+
+ fn run(
+ self: Arc<Self>,
+ arguments: &[String],
+ _workspace: WeakView<Workspace>,
+ _delegate: Option<Arc<dyn LspAdapterDelegate>>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<SlashCommandOutput>> {
+ let server_id = self.server_id.clone();
+ let prompt_name = self.prompt.name.clone();
+ let argument = arguments.first().cloned();
+
+ let manager = ContextServerManager::global(cx);
+ let manager = manager.read(cx);
+ if let Some(server) = manager.get_server(&server_id) {
+ cx.foreground_executor().spawn(async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return Err(anyhow!("Context server not initialized"));
+ };
+
+ let result = protocol
+ .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
+ .await?;
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: 0..result.len(),
+ icon: IconName::ZedAssistant,
+ label: SharedString::from(format!("Result from {}", prompt_name)),
+ }],
+ text: result,
+ run_commands_in_text: false,
+ })
+ })
+ } else {
+ Task::ready(Err(anyhow!("Context server not found")))
+ }
+ }
+}
+
+fn prompt_arguments(
+ prompt: &PromptInfo,
+ argument: Option<String>,
+) -> Result<HashMap<String, String>> {
+ match &prompt.arguments {
+ Some(args) if args.len() >= 2 => Err(anyhow!(
+ "Prompt has more than one argument, which is not supported"
+ )),
+ Some(args) if args.len() == 1 => match argument {
+ Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
+ None => Err(anyhow!("Prompt expects argument but none given")),
+ },
+ Some(_) | None => Ok(HashMap::default()),
+ }
+}
+
+/// MCP servers can return prompts with multiple arguments. Since we only
+/// support one argument, we ignore all others. This is the necessary predicate
+/// for this.
+pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
+ match &prompt.arguments {
+ None => true,
+ Some(args) if args.len() == 1 => true,
+ _ => false,
+ }
+}
@@ -58,10 +58,14 @@ impl SlashCommandRegistry {
/// Unregisters the provided [`SlashCommand`].
pub fn unregister_command(&self, command: impl SlashCommand) {
+ self.unregister_command_by_name(command.name().as_str())
+ }
+
+ /// Unregisters the command with the given name.
+ pub fn unregister_command_by_name(&self, command_name: &str) {
let mut state = self.state.write();
- let command_name: Arc<str> = command.name().into();
- state.featured_commands.remove(&command_name);
- state.commands.remove(&command_name);
+ state.featured_commands.remove(command_name);
+ state.commands.remove(command_name);
}
/// Returns the names of registered [`SlashCommand`]s.
@@ -0,0 +1,29 @@
+[package]
+name = "context_servers"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/context_servers.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+futures.workspace = true
+gpui.workspace = true
+log.workspace = true
+parking_lot.workspace = true
+postage.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+smol.workspace = true
+url = { workspace = true, features = ["serde"] }
+util.workspace = true
+workspace.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,432 @@
+use anyhow::{anyhow, Context, Result};
+use collections::HashMap;
+use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
+use gpui::{AsyncAppContext, BackgroundExecutor, Task};
+use parking_lot::Mutex;
+use postage::barrier;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use serde_json::{value::RawValue, Value};
+use smol::{
+ channel,
+ io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
+ process::{self, Child},
+};
+use std::{
+ fmt,
+ path::PathBuf,
+ sync::{
+ atomic::{AtomicI32, Ordering::SeqCst},
+ Arc,
+ },
+ time::{Duration, Instant},
+};
+use util::TryFutureExt;
+
+const JSON_RPC_VERSION: &str = "2.0";
+const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
+
+type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
+type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
+
+#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum RequestId {
+ Int(i32),
+ Str(String),
+}
+
+pub struct Client {
+ server_id: ContextServerId,
+ next_id: AtomicI32,
+ outbound_tx: channel::Sender<String>,
+ name: Arc<str>,
+ notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ #[allow(clippy::type_complexity)]
+ #[allow(dead_code)]
+ io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
+ #[allow(dead_code)]
+ output_done_rx: Mutex<Option<barrier::Receiver>>,
+ executor: BackgroundExecutor,
+ server: Arc<Mutex<Option<Child>>>,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[repr(transparent)]
+pub struct ContextServerId(pub String);
+
+#[derive(Serialize, Deserialize)]
+struct Request<'a, T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ method: &'a str,
+ params: T,
+}
+
+#[derive(Serialize, Deserialize)]
+struct AnyResponse<'a> {
+ jsonrpc: &'a str,
+ id: RequestId,
+ #[serde(default)]
+ error: Option<Error>,
+ #[serde(borrow)]
+ result: Option<&'a RawValue>,
+}
+
+#[derive(Deserialize)]
+#[allow(dead_code)]
+struct Response<T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ #[serde(flatten)]
+ value: CspResult<T>,
+}
+
+#[derive(Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum CspResult<T> {
+ #[serde(rename = "result")]
+ Ok(Option<T>),
+ #[allow(dead_code)]
+ Error(Option<Error>),
+}
+
+#[derive(Serialize, Deserialize)]
+struct Notification<'a, T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ #[serde(borrow)]
+ method: &'a str,
+ params: T,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct AnyNotification<'a> {
+ jsonrpc: &'a str,
+ id: RequestId,
+ method: String,
+ #[serde(default)]
+ params: Option<Value>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Error {
+ message: String,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct ModelContextServerBinary {
+ pub executable: PathBuf,
+ pub args: Vec<String>,
+ pub env: Option<HashMap<String, String>>,
+}
+
+impl Client {
+ /// Creates a new Client instance for a context server.
+ ///
+ /// This function initializes a new Client by spawning a child process for the context server,
+ /// setting up communication channels, and initializing handlers for input/output operations.
+ /// It takes a server ID, binary information, and an async app context as input.
+ pub fn new(
+ server_id: ContextServerId,
+ binary: ModelContextServerBinary,
+ cx: AsyncAppContext,
+ ) -> Result<Self> {
+ log::info!(
+ "starting context server (executable={:?}, args={:?})",
+ binary.executable,
+ &binary.args
+ );
+
+ let mut command = process::Command::new(&binary.executable);
+ command
+ .args(&binary.args)
+ .envs(binary.env.unwrap_or_default())
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::piped())
+ .kill_on_drop(true);
+
+ let mut server = command.spawn().with_context(|| {
+ format!(
+ "failed to spawn command. (path={:?}, args={:?})",
+ binary.executable, &binary.args
+ )
+ })?;
+
+ let stdin = server.stdin.take().unwrap();
+ let stdout = server.stdout.take().unwrap();
+ let stderr = server.stderr.take().unwrap();
+
+ let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
+ let (output_done_tx, output_done_rx) = barrier::channel();
+
+ let notification_handlers =
+ Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
+ let response_handlers =
+ Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+
+ let stdout_input_task = cx.spawn({
+ let notification_handlers = notification_handlers.clone();
+ let response_handlers = response_handlers.clone();
+ move |cx| {
+ Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
+ }
+ });
+ let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
+ let input_task = cx.spawn(|_| async move {
+ let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
+ stdout.or(stderr)
+ });
+ let output_task = cx.background_executor().spawn({
+ Self::handle_output(
+ stdin,
+ outbound_rx,
+ output_done_tx,
+ response_handlers.clone(),
+ )
+ .log_err()
+ });
+
+ let mut context_server = Self {
+ server_id,
+ notification_handlers,
+ response_handlers,
+ name: "".into(),
+ next_id: Default::default(),
+ outbound_tx,
+ executor: cx.background_executor().clone(),
+ io_tasks: Mutex::new(Some((input_task, output_task))),
+ output_done_rx: Mutex::new(Some(output_done_rx)),
+ server: Arc::new(Mutex::new(Some(server))),
+ };
+
+ if let Some(name) = binary.executable.file_name() {
+ context_server.name = name.to_string_lossy().into();
+ }
+
+ Ok(context_server)
+ }
+
+ /// Handles input from the server's stdout.
+ ///
+ /// This function continuously reads lines from the provided stdout stream,
+ /// parses them as JSON-RPC responses or notifications, and dispatches them
+ /// to the appropriate handlers. It processes both responses (which are matched
+ /// to pending requests) and notifications (which trigger registered handlers).
+ async fn handle_input<Stdout>(
+ stdout: Stdout,
+ notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ cx: AsyncAppContext,
+ ) -> anyhow::Result<()>
+ where
+ Stdout: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stdout = BufReader::new(stdout);
+ let mut buffer = String::new();
+
+ loop {
+ buffer.clear();
+ if stdout.read_line(&mut buffer).await? == 0 {
+ return Ok(());
+ }
+
+ let content = buffer.trim();
+
+ if !content.is_empty() {
+ if let Ok(response) = serde_json::from_str::<AnyResponse>(&content) {
+ if let Some(handlers) = response_handlers.lock().as_mut() {
+ if let Some(handler) = handlers.remove(&response.id) {
+ handler(Ok(content.to_string()));
+ }
+ }
+ } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&content) {
+ let mut notification_handlers = notification_handlers.lock();
+ if let Some(handler) =
+ notification_handlers.get_mut(notification.method.as_str())
+ {
+ handler(
+ notification.id,
+ notification.params.unwrap_or(Value::Null),
+ cx.clone(),
+ );
+ }
+ }
+ }
+
+ smol::future::yield_now().await;
+ }
+ }
+
+ /// Handles the stderr output from the context server.
+ /// Continuously reads and logs any error messages from the server.
+ async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
+ where
+ Stderr: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stderr = BufReader::new(stderr);
+ let mut buffer = String::new();
+
+ loop {
+ buffer.clear();
+ if stderr.read_line(&mut buffer).await? == 0 {
+ return Ok(());
+ }
+ log::warn!("context server stderr: {}", buffer.trim());
+ smol::future::yield_now().await;
+ }
+ }
+
+ /// Handles the output to the context server's stdin.
+ /// This function continuously receives messages from the outbound channel,
+ /// writes them to the server's stdin, and manages the lifecycle of response handlers.
+ async fn handle_output<Stdin>(
+ stdin: Stdin,
+ outbound_rx: channel::Receiver<String>,
+ output_done_tx: barrier::Sender,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ ) -> anyhow::Result<()>
+ where
+ Stdin: AsyncWrite + Unpin + Send + 'static,
+ {
+ let mut stdin = BufWriter::new(stdin);
+ let _clear_response_handlers = util::defer({
+ let response_handlers = response_handlers.clone();
+ move || {
+ response_handlers.lock().take();
+ }
+ });
+ while let Ok(message) = outbound_rx.recv().await {
+ log::trace!("outgoing message: {}", message);
+
+ stdin.write_all(message.as_bytes()).await?;
+ stdin.write_all(b"\n").await?;
+ stdin.flush().await?;
+ }
+ drop(output_done_tx);
+ Ok(())
+ }
+
+ /// Sends a JSON-RPC request to the context server and waits for a response.
+ /// This function handles serialization, deserialization, timeout, and error handling.
+ pub async fn request<T: DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl Serialize,
+ ) -> Result<T> {
+ let id = self.next_id.fetch_add(1, SeqCst);
+ let request = serde_json::to_string(&Request {
+ jsonrpc: JSON_RPC_VERSION,
+ id: RequestId::Int(id),
+ method,
+ params,
+ })
+ .unwrap();
+
+ let (tx, rx) = oneshot::channel();
+ let handle_response = self
+ .response_handlers
+ .lock()
+ .as_mut()
+ .ok_or_else(|| anyhow!("server shut down"))
+ .map(|handlers| {
+ handlers.insert(
+ RequestId::Int(id),
+ Box::new(move |result| {
+ let _ = tx.send(result);
+ }),
+ );
+ });
+
+ let send = self
+ .outbound_tx
+ .try_send(request)
+ .context("failed to write to context server's stdin");
+
+ let executor = self.executor.clone();
+ let started = Instant::now();
+ handle_response?;
+ send?;
+
+ let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
+ select! {
+ response = rx.fuse() => {
+ let elapsed = started.elapsed();
+ log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
+ match response? {
+ Ok(response) => {
+ let parsed: AnyResponse = serde_json::from_str(&response)?;
+ if let Some(error) = parsed.error {
+ Err(anyhow!(error.message))
+ } else if let Some(result) = parsed.result {
+ Ok(serde_json::from_str(result.get())?)
+ } else {
+ Err(anyhow!("Invalid response: no result or error"))
+ }
+ }
+ Err(_) => anyhow::bail!("cancelled")
+ }
+ }
+ _ = timeout => {
+ log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
+ anyhow::bail!("Context server request timeout");
+ }
+ }
+ }
+
+ /// Sends a notification to the context server without expecting a response.
+ /// This function serializes the notification and sends it through the outbound channel.
+ pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
+ let id = self.next_id.fetch_add(1, SeqCst);
+ let notification = serde_json::to_string(&Notification {
+ jsonrpc: JSON_RPC_VERSION,
+ id: RequestId::Int(id),
+ method,
+ params,
+ })
+ .unwrap();
+ self.outbound_tx.try_send(notification)?;
+ Ok(())
+ }
+
+ pub fn on_notification<F>(&self, method: &'static str, mut f: F)
+ where
+ F: 'static + Send + FnMut(Value, AsyncAppContext),
+ {
+ self.notification_handlers
+ .lock()
+ .insert(method, Box::new(move |_, params, cx| f(params, cx)));
+ }
+
+ pub fn name(&self) -> &str {
+ &self.name
+ }
+
+ pub fn server_id(&self) -> ContextServerId {
+ self.server_id.clone()
+ }
+}
+
+impl Drop for Client {
+ fn drop(&mut self) {
+ if let Some(mut server) = self.server.lock().take() {
+ let _ = server.kill();
+ }
+ }
+}
+
+impl fmt::Display for ContextServerId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl fmt::Debug for Client {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Context Server Client")
+ .field("id", &self.server_id.0)
+ .field("name", &self.name)
+ .finish_non_exhaustive()
+ }
+}
@@ -0,0 +1,36 @@
+use gpui::{actions, AppContext, Context, ViewContext};
+use log;
+use manager::ContextServerManager;
+use workspace::Workspace;
+
+pub mod client;
+pub mod manager;
+pub mod protocol;
+mod registry;
+pub mod types;
+
+pub use registry::*;
+
+actions!(context_servers, [Restart]);
+
+pub fn init(cx: &mut AppContext) {
+ log::info!("initializing context server client");
+ manager::init(cx);
+ ContextServerRegistry::register(cx);
+
+ cx.observe_new_views(
+ |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
+ workspace.register_action(restart_servers);
+ },
+ )
+ .detach();
+}
+
+fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext<Workspace>) {
+ let model = ContextServerManager::global(&cx);
+ cx.update_model(&model, |manager, cx| {
+ for server in manager.servers() {
+ manager.restart_server(&server.id, cx).detach();
+ }
+ });
+}
@@ -0,0 +1,278 @@
+//! This module implements a context server management system for Zed.
+//!
+//! It provides functionality to:
+//! - Define and load context server settings
+//! - Manage individual context servers (start, stop, restart)
+//! - Maintain a global manager for all context servers
+//!
+//! Key components:
+//! - `ContextServerSettings`: Defines the structure for server configurations
+//! - `ContextServer`: Represents an individual context server
+//! - `ContextServerManager`: Manages multiple context servers
+//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
+//!
+//! The module also includes initialization logic to set up the context server system
+//! and react to changes in settings.
+
+use collections::{HashMap, HashSet};
+use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
+use log;
+use parking_lot::RwLock;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsSources, SettingsStore};
+use std::path::Path;
+use std::sync::Arc;
+
+use crate::{
+ client::{self, Client},
+ types,
+};
+
+#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
+pub struct ContextServerSettings {
+ pub servers: Vec<ServerConfig>,
+}
+
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
+pub struct ServerConfig {
+ pub id: String,
+ pub executable: String,
+ pub args: Vec<String>,
+}
+
+impl Settings for ContextServerSettings {
+ const KEY: Option<&'static str> = Some("experimental.context_servers");
+
+ type FileContent = Self;
+
+ fn load(
+ sources: SettingsSources<Self::FileContent>,
+ _: &mut gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ sources.json_merge()
+ }
+}
+
+pub struct ContextServer {
+ pub id: String,
+ pub config: ServerConfig,
+ pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
+}
+
+impl ContextServer {
+ fn new(config: ServerConfig) -> Self {
+ Self {
+ id: config.id.clone(),
+ config,
+ client: RwLock::new(None),
+ }
+ }
+
+ async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
+ log::info!("starting context server {}", self.config.id);
+ let client = Client::new(
+ client::ContextServerId(self.config.id.clone()),
+ client::ModelContextServerBinary {
+ executable: Path::new(&self.config.executable).to_path_buf(),
+ args: self.config.args.clone(),
+ env: None,
+ },
+ cx.clone(),
+ )?;
+
+ let protocol = crate::protocol::ModelContextProtocol::new(client);
+ let client_info = types::EntityInfo {
+ name: "Zed".to_string(),
+ version: env!("CARGO_PKG_VERSION").to_string(),
+ };
+ let initialized_protocol = protocol.initialize(client_info).await?;
+
+ log::debug!(
+ "context server {} initialized: {:?}",
+ self.config.id,
+ initialized_protocol.initialize,
+ );
+
+ *self.client.write() = Some(Arc::new(initialized_protocol));
+ Ok(())
+ }
+
+ async fn stop(&self) -> anyhow::Result<()> {
+ let mut client = self.client.write();
+ if let Some(protocol) = client.take() {
+ drop(protocol);
+ }
+ Ok(())
+ }
+}
+
+/// A Context server manager manages the starting and stopping
+/// of all servers. To obtain a server to interact with, a crate
+/// must go through the `GlobalContextServerManager` which holds
+/// a model to the ContextServerManager.
+pub struct ContextServerManager {
+ servers: HashMap<String, Arc<ContextServer>>,
+ pending_servers: HashSet<String>,
+}
+
+pub enum Event {
+ ServerStarted { server_id: String },
+ ServerStopped { server_id: String },
+}
+
+impl Global for ContextServerManager {}
+impl EventEmitter<Event> for ContextServerManager {}
+
+impl ContextServerManager {
+ pub fn new() -> Self {
+ Self {
+ servers: HashMap::default(),
+ pending_servers: HashSet::default(),
+ }
+ }
+ pub fn global(cx: &AppContext) -> Model<Self> {
+ cx.global::<GlobalContextServerManager>().0.clone()
+ }
+
+ pub fn add_server(
+ &mut self,
+ config: ServerConfig,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let server_id = config.id.clone();
+ let server_id2 = config.id.clone();
+
+ if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
+ return Task::ready(Ok(()));
+ }
+
+ let task = cx.spawn(|this, mut cx| async move {
+ let server = Arc::new(ContextServer::new(config));
+ server.start(&cx).await?;
+ this.update(&mut cx, |this, cx| {
+ this.servers.insert(server_id.clone(), server);
+ this.pending_servers.remove(&server_id);
+ cx.emit(Event::ServerStarted {
+ server_id: server_id.clone(),
+ });
+ })?;
+ Ok(())
+ });
+
+ self.pending_servers.insert(server_id2);
+ task
+ }
+
+ pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+ self.servers.get(id).cloned()
+ }
+
+ pub fn remove_server(
+ &mut self,
+ id: &str,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let id = id.to_string();
+ cx.spawn(|this, mut cx| async move {
+ if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
+ server.stop().await?;
+ }
+ this.update(&mut cx, |this, cx| {
+ this.pending_servers.remove(&id);
+ cx.emit(Event::ServerStopped {
+ server_id: id.clone(),
+ })
+ })?;
+ Ok(())
+ })
+ }
+
+ pub fn restart_server(
+ &mut self,
+ id: &str,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let id = id.to_string();
+ cx.spawn(|this, mut cx| async move {
+ if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
+ server.stop().await?;
+ let config = server.config.clone();
+ let new_server = Arc::new(ContextServer::new(config));
+ new_server.start(&cx).await?;
+ this.update(&mut cx, |this, cx| {
+ this.servers.insert(id.clone(), new_server);
+ cx.emit(Event::ServerStopped {
+ server_id: id.clone(),
+ });
+ cx.emit(Event::ServerStarted {
+ server_id: id.clone(),
+ });
+ })?;
+ }
+ Ok(())
+ })
+ }
+
+ pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+ self.servers.values().cloned().collect()
+ }
+
+ pub fn model(cx: &mut AppContext) -> Model<Self> {
+ cx.new_model(|_cx| ContextServerManager::new())
+ }
+}
+
+pub struct GlobalContextServerManager(Model<ContextServerManager>);
+impl Global for GlobalContextServerManager {}
+
+impl GlobalContextServerManager {
+ fn register(cx: &mut AppContext) {
+ let model = ContextServerManager::model(cx);
+ cx.set_global(Self(model));
+ }
+}
+
+pub fn init(cx: &mut AppContext) {
+ ContextServerSettings::register(cx);
+ GlobalContextServerManager::register(cx);
+ cx.observe_global::<SettingsStore>(|cx| {
+ let manager = ContextServerManager::global(cx);
+ cx.update_model(&manager, |manager, cx| {
+ let settings = ContextServerSettings::get_global(cx);
+ let current_servers: HashMap<String, ServerConfig> = manager
+ .servers()
+ .into_iter()
+ .map(|server| (server.id.clone(), server.config.clone()))
+ .collect();
+
+ let new_servers = settings
+ .servers
+ .iter()
+ .map(|config| (config.id.clone(), config.clone()))
+ .collect::<HashMap<_, _>>();
+
+ let servers_to_add = new_servers
+ .values()
+ .filter(|config| !current_servers.contains_key(&config.id))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ let servers_to_remove = current_servers
+ .keys()
+ .filter(|id| !new_servers.contains_key(*id))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ log::trace!("servers_to_add={:?}", servers_to_add);
+ for config in servers_to_add {
+ manager.add_server(config, cx).detach();
+ }
+
+ for id in servers_to_remove {
+ manager.remove_server(&id, cx).detach();
+ }
+ })
+ })
+ .detach();
+}
@@ -0,0 +1,140 @@
+//! This module implements parts of the Model Context Protocol.
+//!
+//! It handles the lifecycle messages, and provides a general interface to
+//! interacting with an MCP server. It uses the generic JSON-RPC client to
+//! read/write messages and the types from types.rs for serialization/deserialization
+//! of messages.
+
+use anyhow::Result;
+use collections::HashMap;
+
+use crate::client::Client;
+use crate::types;
+
+pub use types::PromptInfo;
+
+const PROTOCOL_VERSION: u32 = 1;
+
+pub struct ModelContextProtocol {
+ inner: Client,
+}
+
+impl ModelContextProtocol {
+ pub fn new(inner: Client) -> Self {
+ Self { inner }
+ }
+
+ pub async fn initialize(
+ self,
+ client_info: types::EntityInfo,
+ ) -> Result<InitializedContextServerProtocol> {
+ let params = types::InitializeParams {
+ protocol_version: PROTOCOL_VERSION,
+ capabilities: types::ClientCapabilities {
+ experimental: None,
+ sampling: None,
+ },
+ client_info,
+ };
+
+ let response: types::InitializeResponse = self
+ .inner
+ .request(types::RequestType::Initialize.as_str(), params)
+ .await?;
+
+ log::trace!("mcp server info {:?}", response.server_info);
+
+ self.inner.notify(
+ types::NotificationType::Initialized.as_str(),
+ serde_json::json!({}),
+ )?;
+
+ let initialized_protocol = InitializedContextServerProtocol {
+ inner: self.inner,
+ initialize: response,
+ };
+
+ Ok(initialized_protocol)
+ }
+}
+
+pub struct InitializedContextServerProtocol {
+ inner: Client,
+ pub initialize: types::InitializeResponse,
+}
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub enum ServerCapability {
+ Experimental,
+ Logging,
+ Prompts,
+ Resources,
+ Tools,
+}
+
+impl InitializedContextServerProtocol {
+ /// Check if the server supports a specific capability
+ pub fn capable(&self, capability: ServerCapability) -> bool {
+ match capability {
+ ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
+ ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
+ ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
+ ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
+ ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
+ }
+ }
+
+ fn check_capability(&self, capability: ServerCapability) -> Result<()> {
+ if self.capable(capability) {
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!(
+ "Server does not support {:?} capability",
+ capability
+ ))
+ }
+ }
+
+ /// List the MCP prompts.
+ pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
+ self.check_capability(ServerCapability::Prompts)?;
+
+ let response: types::PromptsListResponse = self
+ .inner
+ .request(types::RequestType::PromptsList.as_str(), ())
+ .await?;
+
+ Ok(response.prompts)
+ }
+
+ /// Executes a prompt with the given arguments and returns the result.
+ pub async fn run_prompt<P: AsRef<str>>(
+ &self,
+ prompt: P,
+ arguments: HashMap<String, String>,
+ ) -> Result<String> {
+ self.check_capability(ServerCapability::Prompts)?;
+
+ let params = types::PromptsGetParams {
+ name: prompt.as_ref().to_string(),
+ arguments: Some(arguments),
+ };
+
+ let response: types::PromptsGetResponse = self
+ .inner
+ .request(types::RequestType::PromptsGet.as_str(), params)
+ .await?;
+
+ Ok(response.prompt)
+ }
+}
+
+impl InitializedContextServerProtocol {
+ pub async fn request<R: serde::de::DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl serde::Serialize,
+ ) -> Result<R> {
+ self.inner.request(method, params).await
+ }
+}
@@ -0,0 +1,47 @@
+use std::sync::Arc;
+
+use collections::HashMap;
+use gpui::{AppContext, Global, ReadGlobal};
+use parking_lot::RwLock;
+
+struct GlobalContextServerRegistry(Arc<ContextServerRegistry>);
+
+impl Global for GlobalContextServerRegistry {}
+
+pub struct ContextServerRegistry {
+ registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
+}
+
+impl ContextServerRegistry {
+ pub fn global(cx: &AppContext) -> Arc<Self> {
+ GlobalContextServerRegistry::global(cx).0.clone()
+ }
+
+ pub fn register(cx: &mut AppContext) {
+ cx.set_global(GlobalContextServerRegistry(Arc::new(
+ ContextServerRegistry {
+ registry: RwLock::new(HashMap::default()),
+ },
+ )))
+ }
+
+ pub fn register_command(&self, server_id: String, command_name: &str) {
+ let mut registry = self.registry.write();
+ registry
+ .entry(server_id)
+ .or_default()
+ .push(command_name.into());
+ }
+
+ pub fn unregister_command(&self, server_id: &str, command_name: &str) {
+ let mut registry = self.registry.write();
+ if let Some(commands) = registry.get_mut(server_id) {
+ commands.retain(|name| name.as_ref() != command_name);
+ }
+ }
+
+ pub fn get_commands(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
+ let registry = self.registry.read();
+ registry.get(server_id).cloned()
+ }
+}
@@ -0,0 +1,234 @@
+use collections::HashMap;
+use serde::{Deserialize, Serialize};
+use url::Url;
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum RequestType {
+ Initialize,
+ CallTool,
+ ResourcesUnsubscribe,
+ ResourcesSubscribe,
+ ResourcesRead,
+ ResourcesList,
+ LoggingSetLevel,
+ PromptsGet,
+ PromptsList,
+}
+
+impl RequestType {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ RequestType::Initialize => "initialize",
+ RequestType::CallTool => "tools/call",
+ RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
+ RequestType::ResourcesSubscribe => "resources/subscribe",
+ RequestType::ResourcesRead => "resources/read",
+ RequestType::ResourcesList => "resources/list",
+ RequestType::LoggingSetLevel => "logging/setLevel",
+ RequestType::PromptsGet => "prompts/get",
+ RequestType::PromptsList => "prompts/list",
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InitializeParams {
+ pub protocol_version: u32,
+ pub capabilities: ClientCapabilities,
+ pub client_info: EntityInfo,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CallToolParams {
+ pub name: String,
+ pub arguments: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesUnsubscribeParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesSubscribeParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesReadParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct LoggingSetLevelParams {
+ pub level: LoggingLevel,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsGetParams {
+ pub name: String,
+ pub arguments: Option<HashMap<String, String>>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InitializeResponse {
+ pub protocol_version: u32,
+ pub capabilities: ServerCapabilities,
+ pub server_info: EntityInfo,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesReadResponse {
+ pub contents: Vec<ResourceContent>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesListResponse {
+ pub resource_templates: Option<Vec<ResourceTemplate>>,
+ pub resources: Vec<Resource>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsGetResponse {
+ pub prompt: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsListResponse {
+ pub prompts: Vec<PromptInfo>,
+}
+
+#[derive(Debug, Deserialize, Clone)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptInfo {
+ pub name: String,
+ pub arguments: Option<Vec<PromptArgument>>,
+}
+
+#[derive(Debug, Deserialize, Clone)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptArgument {
+ pub name: String,
+ pub description: Option<String>,
+ pub required: Option<bool>,
+}
+
+// Shared Types
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ClientCapabilities {
+ pub experimental: Option<HashMap<String, serde_json::Value>>,
+ pub sampling: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ServerCapabilities {
+ pub experimental: Option<HashMap<String, serde_json::Value>>,
+ pub logging: Option<HashMap<String, serde_json::Value>>,
+ pub prompts: Option<HashMap<String, serde_json::Value>>,
+ pub resources: Option<ResourcesCapabilities>,
+ pub tools: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesCapabilities {
+ pub subscribe: Option<bool>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Tool {
+ pub name: String,
+ pub description: Option<String>,
+ pub input_schema: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct EntityInfo {
+ pub name: String,
+ pub version: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Resource {
+ pub uri: Url,
+ pub mime_type: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceContent {
+ pub uri: Url,
+ pub mime_type: Option<String>,
+ pub content_type: String,
+ pub text: Option<String>,
+ pub data: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceTemplate {
+ pub uri_template: String,
+ pub name: Option<String>,
+ pub description: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum LoggingLevel {
+ Debug,
+ Info,
+ Warning,
+ Error,
+}
+
+// Client Notifications
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum NotificationType {
+ Initialized,
+ Progress,
+}
+
+impl NotificationType {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ NotificationType::Initialized => "notifications/initialized",
+ NotificationType::Progress => "notifications/progress",
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+#[serde(untagged)]
+pub enum ClientNotification {
+ Initialized,
+ Progress(ProgressParams),
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ProgressParams {
+ pub progress_token: String,
+ pub progress: f64,
+ pub total: Option<f64>,
+}