Detailed changes
@@ -484,7 +484,6 @@ dependencies = [
"client",
"collections",
"command_palette_hooks",
- "context_server",
"ctor",
"db",
"editor",
@@ -3328,43 +3327,22 @@ name = "context_server"
version = "0.1.0"
dependencies = [
"anyhow",
- "assistant_tool",
"async-trait",
"collections",
- "command_palette_hooks",
- "context_server_settings",
- "extension",
"futures 0.3.31",
"gpui",
- "icons",
- "language_model",
"log",
"parking_lot",
"postage",
- "project",
+ "schemars",
"serde",
"serde_json",
- "settings",
"smol",
"url",
"util",
"workspace-hack",
]
-[[package]]
-name = "context_server_settings"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "collections",
- "gpui",
- "schemars",
- "serde",
- "serde_json",
- "settings",
- "workspace-hack",
-]
-
[[package]]
name = "convert_case"
version = "0.4.0"
@@ -5029,7 +5007,6 @@ dependencies = [
"clap",
"client",
"collections",
- "context_server",
"dirs 4.0.0",
"dotenv",
"env_logger 0.11.8",
@@ -5182,7 +5159,6 @@ dependencies = [
"async-trait",
"client",
"collections",
- "context_server_settings",
"ctor",
"env_logger 0.11.8",
"extension",
@@ -11112,6 +11088,7 @@ dependencies = [
"client",
"clock",
"collections",
+ "context_server",
"dap",
"dap_adapters",
"env_logger 0.11.8",
@@ -34,7 +34,6 @@ members = [
"crates/component",
"crates/component_preview",
"crates/context_server",
- "crates/context_server_settings",
"crates/copilot",
"crates/credentials_provider",
"crates/dap",
@@ -243,7 +242,6 @@ command_palette_hooks = { path = "crates/command_palette_hooks" }
component = { path = "crates/component" }
component_preview = { path = "crates/component_preview" }
context_server = { path = "crates/context_server" }
-context_server_settings = { path = "crates/context_server_settings" }
copilot = { path = "crates/copilot" }
credentials_provider = { path = "crates/credentials_provider" }
dap = { path = "crates/dap" }
@@ -3487,7 +3487,6 @@ fn open_editor_at_position(
#[cfg(test)]
mod tests {
use assistant_tool::{ToolRegistry, ToolWorkingSet};
- use context_server::ContextServerSettings;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{TestAppContext, VisualTestContext};
@@ -3559,7 +3558,6 @@ mod tests {
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
- ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});
@@ -1748,7 +1748,6 @@ mod tests {
use crate::{Keep, ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
- use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
use project::{FakeFs, Project};
@@ -1771,7 +1770,6 @@ mod tests {
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
- ContextServerSettings::register(cx);
EditorSettings::register(cx);
language_model::init_settings(cx);
});
@@ -1928,7 +1926,6 @@ mod tests {
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
- ContextServerSettings::register(cx);
EditorSettings::register(cx);
language_model::init_settings(cx);
workspace::register_project_item::<Editor>(cx);
@@ -7,6 +7,7 @@ mod buffer_codegen;
mod context;
mod context_picker;
mod context_server_configuration;
+mod context_server_tool;
mod context_store;
mod context_strip;
mod debug;
@@ -8,13 +8,14 @@ use std::{sync::Arc, time::Duration};
use assistant_settings::AssistantSettings;
use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
-use context_server::manager::{ContextServer, ContextServerManager, ContextServerStatus};
+use context_server::ContextServerId;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyView, App, Entity, EventEmitter, FocusHandle,
Focusable, ScrollHandle, Subscription, pulsating_between,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
+use project::context_server_store::{ContextServerStatus, ContextServerStore};
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
@@ -33,8 +34,8 @@ pub struct AssistantConfiguration {
fs: Arc<dyn Fs>,
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
- context_server_manager: Entity<ContextServerManager>,
- expanded_context_server_tools: HashMap<Arc<str>, bool>,
+ context_server_store: Entity<ContextServerStore>,
+ expanded_context_server_tools: HashMap<ContextServerId, bool>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
@@ -44,7 +45,7 @@ pub struct AssistantConfiguration {
impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
- context_server_manager: Entity<ContextServerManager>,
+ context_server_store: Entity<ContextServerStore>,
tools: Entity<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -75,7 +76,7 @@ impl AssistantConfiguration {
fs,
focus_handle,
configuration_views_by_provider: HashMap::default(),
- context_server_manager,
+ context_server_store,
expanded_context_server_tools: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
@@ -306,7 +307,7 @@ impl AssistantConfiguration {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let context_servers = self.context_server_manager.read(cx).all_servers().clone();
+ let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
@@ -322,9 +323,9 @@ impl AssistantConfiguration {
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(
- context_servers
- .into_iter()
- .map(|context_server| self.render_context_server(context_server, window, cx)),
+ context_server_ids.into_iter().map(|context_server_id| {
+ self.render_context_server(context_server_id, window, cx)
+ }),
)
.child(
h_flex()
@@ -374,19 +375,20 @@ impl AssistantConfiguration {
fn render_context_server(
&self,
- context_server: Arc<ContextServer>,
+ context_server_id: ContextServerId,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl use<> + IntoElement {
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let server_status = self
- .context_server_manager
+ .context_server_store
.read(cx)
- .status_for_server(&context_server.id());
+ .status_for_server(&context_server_id)
+ .unwrap_or(ContextServerStatus::Stopped);
- let is_running = matches!(server_status, Some(ContextServerStatus::Running));
+ let is_running = matches!(server_status, ContextServerStatus::Running);
- let error = if let Some(ContextServerStatus::Error(error)) = server_status.clone() {
+ let error = if let ContextServerStatus::Error(error) = server_status.clone() {
Some(error)
} else {
None
@@ -394,13 +396,13 @@ impl AssistantConfiguration {
let are_tools_expanded = self
.expanded_context_server_tools
- .get(&context_server.id())
+ .get(&context_server_id)
.copied()
.unwrap_or_default();
let tools = tools_by_source
.get(&ToolSource::ContextServer {
- id: context_server.id().into(),
+ id: context_server_id.0.clone().into(),
})
.map_or([].as_slice(), |tools| tools.as_slice());
let tool_count = tools.len();
@@ -408,7 +410,7 @@ impl AssistantConfiguration {
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
- .id(SharedString::from(context_server.id()))
+ .id(SharedString::from(context_server_id.0.clone()))
.border_1()
.rounded_md()
.border_color(border_color)
@@ -432,7 +434,7 @@ impl AssistantConfiguration {
)
.disabled(tool_count == 0)
.on_click(cx.listener({
- let context_server_id = context_server.id();
+ let context_server_id = context_server_id.clone();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_context_server_tools
@@ -444,14 +446,14 @@ impl AssistantConfiguration {
})),
)
.child(match server_status {
- Some(ContextServerStatus::Starting) => {
+ ContextServerStatus::Starting => {
let color = Color::Success.color(cx);
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!(
"{}-starting",
- context_server.id(),
+ context_server_id.0.clone(),
)),
Animation::new(Duration::from_secs(2))
.repeat()
@@ -462,15 +464,17 @@ impl AssistantConfiguration {
)
.into_any_element()
}
- Some(ContextServerStatus::Running) => {
+ ContextServerStatus::Running => {
Indicator::dot().color(Color::Success).into_any_element()
}
- Some(ContextServerStatus::Error(_)) => {
+ ContextServerStatus::Error(_) => {
Indicator::dot().color(Color::Error).into_any_element()
}
- None => Indicator::dot().color(Color::Muted).into_any_element(),
+ ContextServerStatus::Stopped => {
+ Indicator::dot().color(Color::Muted).into_any_element()
+ }
})
- .child(Label::new(context_server.id()).ml_0p5())
+ .child(Label::new(context_server_id.0.clone()).ml_0p5())
.when(is_running, |this| {
this.child(
Label::new(if tool_count == 1 {
@@ -487,32 +491,22 @@ impl AssistantConfiguration {
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
- let context_server_manager = self.context_server_manager.clone();
- let context_server = context_server.clone();
+ let context_server_manager = self.context_server_store.clone();
+ let context_server_id = context_server_id.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
- this.stop_server(context_server.clone(), cx).log_err();
+ this.stop_server(&context_server_id, cx).log_err();
});
}
ToggleState::Selected => {
- cx.spawn({
- let context_server_manager =
- context_server_manager.clone();
- let context_server = context_server.clone();
- async move |cx| {
- if let Some(start_server_task) =
- context_server_manager
- .update(cx, |this, cx| {
- this.start_server(context_server, cx)
- })
- .log_err()
- {
- start_server_task.await.log_err();
- }
+ context_server_manager.update(cx, |this, cx| {
+ if let Some(server) =
+ this.get_server(&context_server_id)
+ {
+ this.start_server(server, cx).log_err();
}
})
- .detach();
}
}
}),
@@ -1,5 +1,6 @@
-use context_server::{ContextServerSettings, ServerCommand, ServerConfig};
+use context_server::ContextServerCommand;
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, prelude::*};
+use project::project_settings::{ContextServerConfiguration, ProjectSettings};
use serde_json::json;
use settings::update_settings_file;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*};
@@ -77,11 +78,11 @@ impl AddContextServerModal {
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
- update_settings_file::<ContextServerSettings>(fs.clone(), cx, |settings, _| {
+ update_settings_file::<ProjectSettings>(fs.clone(), cx, |settings, _| {
settings.context_servers.insert(
name.into(),
- ServerConfig {
- command: Some(ServerCommand {
+ ContextServerConfiguration {
+ command: Some(ContextServerCommand {
path,
args,
env: None,
@@ -4,9 +4,8 @@ use std::{
};
use anyhow::Context as _;
-use context_server::manager::{ContextServerManager, ContextServerStatus};
+use context_server::ContextServerId;
use editor::{Editor, EditorElement, EditorStyle};
-use extension::ContextServerConfiguration;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Task,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, percentage,
@@ -14,6 +13,10 @@ use gpui::{
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
+use project::{
+ context_server_store::{ContextServerStatus, ContextServerStore},
+ project_settings::{ContextServerConfiguration, ProjectSettings},
+};
use settings::{Settings as _, update_settings_file};
use theme::ThemeSettings;
use ui::{KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
@@ -23,11 +26,11 @@ use workspace::{ModalView, Workspace};
pub(crate) struct ConfigureContextServerModal {
workspace: WeakEntity<Workspace>,
context_servers_to_setup: Vec<ConfigureContextServer>,
- context_server_manager: Entity<ContextServerManager>,
+ context_server_store: Entity<ContextServerStore>,
}
struct ConfigureContextServer {
- id: Arc<str>,
+ id: ContextServerId,
installation_instructions: Entity<markdown::Markdown>,
settings_validator: Option<jsonschema::Validator>,
settings_editor: Entity<Editor>,
@@ -37,9 +40,9 @@ struct ConfigureContextServer {
impl ConfigureContextServerModal {
pub fn new(
- configurations: impl Iterator<Item = (Arc<str>, ContextServerConfiguration)>,
+ configurations: impl Iterator<Item = (ContextServerId, extension::ContextServerConfiguration)>,
+ context_server_store: Entity<ContextServerStore>,
jsonc_language: Option<Arc<Language>>,
- context_server_manager: Entity<ContextServerManager>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@@ -85,7 +88,7 @@ impl ConfigureContextServerModal {
Some(Self {
workspace,
context_servers_to_setup,
- context_server_manager,
+ context_server_store,
})
}
}
@@ -126,14 +129,14 @@ impl ConfigureContextServerModal {
}
let id = configuration.id.clone();
- let settings_changed = context_server::ContextServerSettings::get_global(cx)
+ let settings_changed = ProjectSettings::get_global(cx)
.context_servers
- .get(&id)
+ .get(&id.0)
.map_or(true, |config| {
config.settings.as_ref() != Some(&settings_value)
});
- let is_running = self.context_server_manager.read(cx).status_for_server(&id)
+ let is_running = self.context_server_store.read(cx).status_for_server(&id)
== Some(ContextServerStatus::Running);
if !settings_changed && is_running {
@@ -143,7 +146,7 @@ impl ConfigureContextServerModal {
configuration.waiting_for_context_server = true;
- let task = wait_for_context_server(&self.context_server_manager, id.clone(), cx);
+ let task = wait_for_context_server(&self.context_server_store, id.clone(), cx);
cx.spawn({
let id = id.clone();
async move |this, cx| {
@@ -167,29 +170,25 @@ impl ConfigureContextServerModal {
.detach();
// When we write the settings to the file, the context server will be restarted.
- update_settings_file::<context_server::ContextServerSettings>(
- workspace.read(cx).app_state().fs.clone(),
- cx,
- {
- let id = id.clone();
- |settings, _| {
- if let Some(server_config) = settings.context_servers.get_mut(&id) {
- server_config.settings = Some(settings_value);
- } else {
- settings.context_servers.insert(
- id,
- context_server::ServerConfig {
- settings: Some(settings_value),
- ..Default::default()
- },
- );
- }
+ update_settings_file::<ProjectSettings>(workspace.read(cx).app_state().fs.clone(), cx, {
+ let id = id.clone();
+ |settings, _| {
+ if let Some(server_config) = settings.context_servers.get_mut(&id.0) {
+ server_config.settings = Some(settings_value);
+ } else {
+ settings.context_servers.insert(
+ id.0,
+ ContextServerConfiguration {
+ settings: Some(settings_value),
+ ..Default::default()
+ },
+ );
}
- },
- );
+ }
+ });
}
- fn complete_setup(&mut self, id: Arc<str>, cx: &mut Context<Self>) {
+ fn complete_setup(&mut self, id: ContextServerId, cx: &mut Context<Self>) {
self.context_servers_to_setup.remove(0);
cx.notify();
@@ -223,31 +222,40 @@ impl ConfigureContextServerModal {
}
fn wait_for_context_server(
- context_server_manager: &Entity<ContextServerManager>,
- context_server_id: Arc<str>,
+ context_server_store: &Entity<ContextServerStore>,
+ context_server_id: ContextServerId,
cx: &mut App,
) -> Task<Result<(), Arc<str>>> {
let (tx, rx) = futures::channel::oneshot::channel();
let tx = Arc::new(Mutex::new(Some(tx)));
- let subscription = cx.subscribe(context_server_manager, move |_, event, _cx| match event {
- context_server::manager::Event::ServerStatusChanged { server_id, status } => match status {
- Some(ContextServerStatus::Running) => {
- if server_id == &context_server_id {
- if let Some(tx) = tx.lock().unwrap().take() {
- let _ = tx.send(Ok(()));
+ let subscription = cx.subscribe(context_server_store, move |_, event, _cx| match event {
+ project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
+ match status {
+ ContextServerStatus::Running => {
+ if server_id == &context_server_id {
+ if let Some(tx) = tx.lock().unwrap().take() {
+ let _ = tx.send(Ok(()));
+ }
}
}
- }
- Some(ContextServerStatus::Error(error)) => {
- if server_id == &context_server_id {
- if let Some(tx) = tx.lock().unwrap().take() {
- let _ = tx.send(Err(error.clone()));
+ ContextServerStatus::Stopped => {
+ if server_id == &context_server_id {
+ if let Some(tx) = tx.lock().unwrap().take() {
+ let _ = tx.send(Err("Context server stopped running".into()));
+ }
}
}
+ ContextServerStatus::Error(error) => {
+ if server_id == &context_server_id {
+ if let Some(tx) = tx.lock().unwrap().take() {
+ let _ = tx.send(Err(error.clone()));
+ }
+ }
+ }
+ _ => {}
}
- _ => {}
- },
+ }
});
cx.spawn(async move |_cx| {
@@ -1026,14 +1026,14 @@ impl AssistantPanel {
}
pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let context_server_manager = self.thread_store.read(cx).context_server_manager();
+ let context_server_store = self.project.read(cx).context_server_store();
let tools = self.thread_store.read(cx).tools();
let fs = self.fs.clone();
self.set_active_view(ActiveView::Configuration, window, cx);
self.configuration =
Some(cx.new(|cx| {
- AssistantConfiguration::new(fs, context_server_manager, tools, window, cx)
+ AssistantConfiguration::new(fs, context_server_store, tools, window, cx)
}));
if let Some(configuration) = self.configuration.as_ref() {
@@ -1095,9 +1095,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
- cx.set_global(cx.update(SettingsStore::test));
- cx.update(language_model::LanguageModelRegistry::test);
- cx.update(language_settings::init);
+ init_test(cx);
let text = indoc! {"
fn main() {
@@ -1167,8 +1165,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
- cx.set_global(cx.update(SettingsStore::test));
- cx.update(language_settings::init);
+ init_test(cx);
let text = indoc! {"
fn main() {
@@ -1237,9 +1234,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
- cx.update(LanguageModelRegistry::test);
- cx.set_global(cx.update(SettingsStore::test));
- cx.update(language_settings::init);
+ init_test(cx);
let text = concat!(
"fn main() {\n",
@@ -1305,9 +1300,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
- cx.update(LanguageModelRegistry::test);
- cx.set_global(cx.update(SettingsStore::test));
- cx.update(language_settings::init);
+ init_test(cx);
let text = indoc! {"
func main() {
@@ -1367,9 +1360,7 @@ mod tests {
#[gpui::test]
async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
- cx.update(LanguageModelRegistry::test);
- cx.set_global(cx.update(SettingsStore::test));
- cx.update(language_settings::init);
+ init_test(cx);
let text = indoc! {"
fn main() {
@@ -1473,6 +1464,13 @@ mod tests {
}
}
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(Project::init_settings);
+ cx.update(language_settings::init);
+ }
+
fn simulate_response_stream(
codegen: Entity<CodegenAlternative>,
cx: &mut TestAppContext,
@@ -1,14 +1,15 @@
use std::sync::Arc;
use anyhow::Context as _;
-use context_server::ContextServerDescriptorRegistry;
+use context_server::ContextServerId;
use extension::ExtensionManifest;
use language::LanguageRegistry;
+use project::context_server_store::registry::ContextServerDescriptorRegistry;
use ui::prelude::*;
use util::ResultExt;
use workspace::Workspace;
-use crate::{AssistantPanel, assistant_configuration::ConfigureContextServerModal};
+use crate::assistant_configuration::ConfigureContextServerModal;
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
cx.observe_new(move |_: &mut Workspace, window, cx| {
@@ -60,18 +61,10 @@ fn show_configure_mcp_modal(
window: &mut Window,
cx: &mut Context<'_, Workspace>,
) {
- let Some(context_server_manager) = workspace.panel::<AssistantPanel>(cx).map(|panel| {
- panel
- .read(cx)
- .thread_store()
- .read(cx)
- .context_server_manager()
- }) else {
- return;
- };
+ let context_server_store = workspace.project().read(cx).context_server_store();
- let registry = ContextServerDescriptorRegistry::global(cx).read(cx);
- let project = workspace.project().clone();
+ let registry = ContextServerDescriptorRegistry::default_global(cx).read(cx);
+ let worktree_store = workspace.project().read(cx).worktree_store();
let configuration_tasks = manifest
.context_servers
.keys()
@@ -80,15 +73,15 @@ fn show_configure_mcp_modal(
|key| {
let descriptor = registry.context_server_descriptor(&key)?;
Some(cx.spawn({
- let project = project.clone();
+ let worktree_store = worktree_store.clone();
async move |_, cx| {
descriptor
- .configuration(project, &cx)
+ .configuration(worktree_store.clone(), &cx)
.await
.context("Failed to resolve context server configuration")
.log_err()
.flatten()
- .map(|config| (key, config))
+ .map(|config| (ContextServerId(key), config))
}
}))
}
@@ -104,8 +97,8 @@ fn show_configure_mcp_modal(
this.update_in(cx, |this, window, cx| {
let modal = ConfigureContextServerModal::new(
descriptors.into_iter().flatten(),
+ context_server_store,
jsonc_language,
- context_server_manager,
language_registry,
cx.entity().downgrade(),
window,
@@ -2,29 +2,27 @@ use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
+use context_server::{ContextServerId, types};
use gpui::{AnyWindowHandle, App, Entity, Task};
-use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
-use project::Project;
-
-use crate::manager::ContextServerManager;
-use crate::types;
+use project::{Project, context_server_store::ContextServerStore};
+use ui::IconName;
pub struct ContextServerTool {
- server_manager: Entity<ContextServerManager>,
- server_id: Arc<str>,
+ store: Entity<ContextServerStore>,
+ server_id: ContextServerId,
tool: types::Tool,
}
impl ContextServerTool {
pub fn new(
- server_manager: Entity<ContextServerManager>,
- server_id: impl Into<Arc<str>>,
+ store: Entity<ContextServerStore>,
+ server_id: ContextServerId,
tool: types::Tool,
) -> Self {
Self {
- server_manager,
- server_id: server_id.into(),
+ store,
+ server_id,
tool,
}
}
@@ -45,7 +43,7 @@ impl Tool for ContextServerTool {
fn source(&self) -> ToolSource {
ToolSource::ContextServer {
- id: self.server_id.clone().into(),
+ id: self.server_id.clone().0.into(),
}
}
@@ -80,7 +78,7 @@ impl Tool for ContextServerTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
- if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
+ if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
let tool_name = self.tool.name.clone();
let server_clone = server.clone();
let input_clone = input.clone();
@@ -2660,7 +2660,6 @@ mod tests {
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
- use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
@@ -3082,7 +3081,6 @@ fn main() {{
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
- ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});
@@ -9,8 +9,7 @@ use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
-use context_server::manager::{ContextServerManager, ContextServerStatus};
-use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
+use context_server::ContextServerId;
use futures::channel::{mpsc, oneshot};
use futures::future::{self, BoxFuture, Shared};
use futures::{FutureExt as _, StreamExt as _};
@@ -21,6 +20,7 @@ use gpui::{
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
+use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
@@ -30,6 +30,7 @@ use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
+use crate::context_server_tool::ContextServerTool;
use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
@@ -62,8 +63,7 @@ pub struct ThreadStore {
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
- context_server_manager: Entity<ContextServerManager>,
- context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
+ context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
reload_system_prompt_tx: mpsc::Sender<()>,
@@ -108,11 +108,6 @@ impl ThreadStore {
prompt_store: Option<Entity<PromptStore>>,
cx: &mut Context<Self>,
) -> (Self, oneshot::Receiver<()>) {
- let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
- let context_server_manager = cx.new(|cx| {
- ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
- });
-
let mut subscriptions = vec![
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
@@ -159,7 +154,6 @@ impl ThreadStore {
tools,
prompt_builder,
prompt_store,
- context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
@@ -354,10 +348,6 @@ impl ThreadStore {
})
}
- pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
- self.context_server_manager.clone()
- }
-
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
@@ -494,11 +484,17 @@ impl ThreadStore {
});
if profile.enable_all_context_servers {
- for context_server in self.context_server_manager.read(cx).all_servers() {
+ for context_server_id in self
+ .project
+ .read(cx)
+ .context_server_store()
+ .read(cx)
+ .all_server_ids()
+ {
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
- id: context_server.id().into(),
+ id: context_server_id.0.into(),
},
cx,
);
@@ -541,7 +537,7 @@ impl ThreadStore {
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe(
- &self.context_server_manager.clone(),
+ &self.project.read(cx).context_server_store(),
Self::handle_context_server_event,
)
.detach();
@@ -549,18 +545,19 @@ impl ThreadStore {
fn handle_context_server_event(
&mut self,
- context_server_manager: Entity<ContextServerManager>,
- event: &context_server::manager::Event,
+ context_server_store: Entity<ContextServerStore>,
+ event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
let tool_working_set = self.tools.clone();
match event {
- context_server::manager::Event::ServerStatusChanged { server_id, status } => {
+ project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
- Some(ContextServerStatus::Running) => {
- if let Some(server) = context_server_manager.read(cx).get_server(server_id)
+ ContextServerStatus::Running => {
+ if let Some(server) =
+ context_server_store.read(cx).get_running_server(server_id)
{
- let context_server_manager = context_server_manager.clone();
+ let context_server_manager = context_server_store.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
@@ -608,7 +605,7 @@ impl ThreadStore {
.detach();
}
}
- None => {
+ ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
@@ -31,7 +31,6 @@ async-watch.workspace = true
client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
-context_server.workspace = true
db.workspace = true
editor.workspace = true
feature_flags.workspace = true
@@ -106,7 +106,6 @@ pub fn init(
assistant_slash_command::init(cx);
assistant_tool::init(cx);
assistant_panel::init(cx);
- context_server::init(cx);
register_slash_commands(cx);
inline_assistant::init(
@@ -1192,21 +1192,19 @@ impl AssistantPanel {
fn restart_context_servers(
workspace: &mut Workspace,
- _action: &context_server::Restart,
+ _action: &project::context_server_store::Restart,
_: &mut Window,
cx: &mut Context<Workspace>,
) {
- let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
- return;
- };
-
- assistant_panel.update(cx, |assistant_panel, cx| {
- assistant_panel
- .context_store
- .update(cx, |context_store, cx| {
- context_store.restart_context_servers(cx);
- });
- });
+ workspace
+ .project()
+ .read(cx)
+ .context_server_store()
+ .update(cx, |store, cx| {
+ for server in store.running_servers() {
+ store.restart_server(&server.id(), cx).log_err();
+ }
+ });
}
}
@@ -7,15 +7,17 @@ use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
use client::{Client, TypedEnvelope, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::HashMap;
-use context_server::ContextServerDescriptorRegistry;
-use context_server::manager::{ContextServerManager, ContextServerStatus};
+use context_server::ContextServerId;
use fs::{Fs, RemoveOptions};
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::LanguageRegistry;
use paths::contexts_dir;
-use project::Project;
+use project::{
+ Project,
+ context_server_store::{ContextServerStatus, ContextServerStore},
+};
use prompt_store::PromptBuilder;
use regex::Regex;
use rpc::AnyProtoClient;
@@ -40,8 +42,7 @@ pub struct RemoteContextMetadata {
pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
- context_server_manager: Entity<ContextServerManager>,
- context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
+ context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
@@ -98,15 +99,9 @@ impl ContextStore {
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new(|cx: &mut Context<Self>| {
- let context_server_factory_registry =
- ContextServerDescriptorRegistry::default_global(cx);
- let context_server_manager = cx.new(|cx| {
- ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
- });
let mut this = Self {
contexts: Vec::new(),
contexts_metadata: Vec::new(),
- context_server_manager,
context_server_slash_command_ids: HashMap::default(),
host_contexts: Vec::new(),
fs,
@@ -802,22 +797,9 @@ impl ContextStore {
})
}
- pub fn restart_context_servers(&mut self, cx: &mut Context<Self>) {
- cx.update_entity(
- &self.context_server_manager,
- |context_server_manager, cx| {
- for server in context_server_manager.running_servers() {
- context_server_manager
- .restart_server(&server.id(), cx)
- .detach_and_log_err(cx);
- }
- },
- );
- }
-
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe(
- &self.context_server_manager.clone(),
+ &self.project.read(cx).context_server_store(),
Self::handle_context_server_event,
)
.detach();
@@ -825,16 +807,18 @@ impl ContextStore {
fn handle_context_server_event(
&mut self,
- context_server_manager: Entity<ContextServerManager>,
- event: &context_server::manager::Event,
+ context_server_manager: Entity<ContextServerStore>,
+ event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
let slash_command_working_set = self.slash_commands.clone();
match event {
- context_server::manager::Event::ServerStatusChanged { server_id, status } => {
+ project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
- Some(ContextServerStatus::Running) => {
- if let Some(server) = context_server_manager.read(cx).get_server(server_id)
+ ContextServerStatus::Running => {
+ if let Some(server) = context_server_manager
+ .read(cx)
+ .get_running_server(server_id)
{
let context_server_manager = context_server_manager.clone();
cx.spawn({
@@ -858,7 +842,7 @@ impl ContextStore {
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
- &server,
+ server.id(),
prompt,
),
))
@@ -877,7 +861,7 @@ impl ContextStore {
.detach();
}
}
- None => {
+ ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
@@ -4,12 +4,10 @@ use assistant_slash_command::{
SlashCommandOutputSection, SlashCommandResult,
};
use collections::HashMap;
-use context_server::{
- manager::{ContextServer, ContextServerManager},
- types::Prompt,
-};
+use context_server::{ContextServerId, types::Prompt};
use gpui::{App, Entity, Task, WeakEntity, Window};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
+use project::context_server_store::ContextServerStore;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use text::LineEnding;
@@ -19,21 +17,17 @@ use workspace::Workspace;
use crate::create_label_for_command;
pub struct ContextServerSlashCommand {
- server_manager: Entity<ContextServerManager>,
- server_id: Arc<str>,
+ store: Entity<ContextServerStore>,
+ server_id: ContextServerId,
prompt: Prompt,
}
impl ContextServerSlashCommand {
- pub fn new(
- server_manager: Entity<ContextServerManager>,
- server: &Arc<ContextServer>,
- prompt: Prompt,
- ) -> Self {
+ pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
Self {
- server_id: server.id(),
+ server_id: id,
prompt,
- server_manager,
+ store,
}
}
}
@@ -88,7 +82,7 @@ impl SlashCommand for ContextServerSlashCommand {
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
- if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
+ if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));
@@ -142,8 +136,8 @@ impl SlashCommand for ContextServerSlashCommand {
Err(e) => return Task::ready(Err(e)),
};
- let manager = self.server_manager.read(cx);
- if let Some(server) = manager.get_server(&server_id) {
+ let store = self.store.read(cx);
+ if let Some(server) = store.get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized"));
@@ -6709,8 +6709,6 @@ async fn test_context_collaboration_with_reconnect(
assert_eq!(project.collaborators().len(), 1);
});
- cx_a.update(context_server::init);
- cx_b.update(context_server::init);
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a
.update(|cx| {
@@ -709,6 +709,7 @@ impl TestClient {
worktree
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
.await;
+ cx.run_until_parked();
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
@@ -13,28 +13,17 @@ path = "src/context_server.rs"
[dependencies]
anyhow.workspace = true
-assistant_tool.workspace = true
async-trait.workspace = true
collections.workspace = true
-command_palette_hooks.workspace = true
-context_server_settings.workspace = true
-extension.workspace = true
futures.workspace = true
gpui.workspace = true
-icons.workspace = true
-language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
-project.workspace = true
+schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
-settings.workspace = true
smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
-project = { workspace = true, features = ["test-support"] }
@@ -40,7 +40,7 @@ pub enum RequestId {
Str(String),
}
-pub struct Client {
+pub(crate) struct Client {
server_id: ContextServerId,
next_id: AtomicI32,
outbound_tx: channel::Sender<String>,
@@ -59,7 +59,7 @@ pub struct Client {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
-pub struct ContextServerId(pub Arc<str>);
+pub(crate) struct ContextServerId(pub Arc<str>);
fn is_null_value<T: Serialize>(value: &T) -> bool {
if let Ok(Value::Null) = serde_json::to_value(value) {
@@ -367,6 +367,7 @@ impl Client {
Ok(())
}
+ #[allow(unused)]
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
@@ -375,14 +376,6 @@ impl Client {
.lock()
.insert(method, Box::new(f));
}
-
- pub fn name(&self) -> &str {
- &self.name
- }
-
- pub fn server_id(&self) -> ContextServerId {
- self.server_id.clone()
- }
}
impl fmt::Display for ContextServerId {
@@ -1,30 +1,117 @@
pub mod client;
-mod context_server_tool;
-mod extension_context_server;
-pub mod manager;
pub mod protocol;
-mod registry;
-mod transport;
+pub mod transport;
pub mod types;
-use command_palette_hooks::CommandPaletteFilter;
-pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
-use gpui::{App, actions};
+use std::fmt::Display;
+use std::path::Path;
+use std::sync::Arc;
-pub use crate::context_server_tool::ContextServerTool;
-pub use crate::registry::ContextServerDescriptorRegistry;
+use anyhow::Result;
+use client::Client;
+use collections::HashMap;
+use gpui::AsyncApp;
+use parking_lot::RwLock;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
-actions!(context_servers, [Restart]);
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct ContextServerId(pub Arc<str>);
-/// The namespace for the context servers actions.
-pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
+impl Display for ContextServerId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
+pub struct ContextServerCommand {
+ pub path: String,
+ pub args: Vec<String>,
+ pub env: Option<HashMap<String, String>>,
+}
+
+enum ContextServerTransport {
+ Stdio(ContextServerCommand),
+ Custom(Arc<dyn crate::transport::Transport>),
+}
+
+pub struct ContextServer {
+ id: ContextServerId,
+ client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
+ configuration: ContextServerTransport,
+}
+
+impl ContextServer {
+ pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
+ Self {
+ id,
+ client: RwLock::new(None),
+ configuration: ContextServerTransport::Stdio(command),
+ }
+ }
+
+ pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
+ Self {
+ id,
+ client: RwLock::new(None),
+ configuration: ContextServerTransport::Custom(transport),
+ }
+ }
+
+ pub fn id(&self) -> ContextServerId {
+ self.id.clone()
+ }
+
+ pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
+ self.client.read().clone()
+ }
+
+ pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
+ let client = match &self.configuration {
+ ContextServerTransport::Stdio(command) => Client::stdio(
+ client::ContextServerId(self.id.0.clone()),
+ client::ModelContextServerBinary {
+ executable: Path::new(&command.path).to_path_buf(),
+ args: command.args.clone(),
+ env: command.env.clone(),
+ },
+ cx.clone(),
+ )?,
+ ContextServerTransport::Custom(transport) => Client::new(
+ client::ContextServerId(self.id.0.clone()),
+ self.id().0,
+ transport.clone(),
+ cx.clone(),
+ )?,
+ };
+ self.initialize(client).await
+ }
+
+ async fn initialize(&self, client: Client) -> Result<()> {
+ log::info!("starting context server {}", self.id);
+ let protocol = crate::protocol::ModelContextProtocol::new(client);
+ let client_info = types::Implementation {
+ name: "Zed".to_string(),
+ version: env!("CARGO_PKG_VERSION").to_string(),
+ };
+ let initialized_protocol = protocol.initialize(client_info).await?;
+
+ log::debug!(
+ "context server {} initialized: {:?}",
+ self.id,
+ initialized_protocol.initialize,
+ );
-pub fn init(cx: &mut App) {
- context_server_settings::init(cx);
- ContextServerDescriptorRegistry::default_global(cx);
- extension_context_server::init(cx);
+ *self.client.write() = Some(Arc::new(initialized_protocol));
+ Ok(())
+ }
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
- });
+ pub fn stop(&self) -> Result<()> {
+ let mut client = self.client.write();
+ if let Some(protocol) = client.take() {
+ drop(protocol);
+ }
+ Ok(())
+ }
}
@@ -1,584 +0,0 @@
-//! This module implements a context server management system for Zed.
-//!
-//! It provides functionality to:
-//! - Define and load context server settings
-//! - Manage individual context servers (start, stop, restart)
-//! - Maintain a global manager for all context servers
-//!
-//! Key components:
-//! - `ContextServerSettings`: Defines the structure for server configurations
-//! - `ContextServer`: Represents an individual context server
-//! - `ContextServerManager`: Manages multiple context servers
-//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
-//!
-//! The module also includes initialization logic to set up the context server system
-//! and react to changes in settings.
-
-use std::path::Path;
-use std::sync::Arc;
-
-use anyhow::{Result, bail};
-use collections::HashMap;
-use command_palette_hooks::CommandPaletteFilter;
-use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
-use log;
-use parking_lot::RwLock;
-use project::Project;
-use settings::{Settings, SettingsStore};
-use util::ResultExt as _;
-
-use crate::transport::Transport;
-use crate::{ContextServerSettings, ServerConfig};
-
-use crate::{
- CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
- client::{self, Client},
- types,
-};
-
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-pub enum ContextServerStatus {
- Starting,
- Running,
- Error(Arc<str>),
-}
-
-pub struct ContextServer {
- pub id: Arc<str>,
- pub config: Arc<ServerConfig>,
- pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
- transport: Option<Arc<dyn Transport>>,
-}
-
-impl ContextServer {
- pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
- Self {
- id,
- config,
- client: RwLock::new(None),
- transport: None,
- }
- }
-
- #[cfg(any(test, feature = "test-support"))]
- pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
- Arc::new(Self {
- id,
- client: RwLock::new(None),
- config: Arc::new(ServerConfig::default()),
- transport: Some(transport),
- })
- }
-
- pub fn id(&self) -> Arc<str> {
- self.id.clone()
- }
-
- pub fn config(&self) -> Arc<ServerConfig> {
- self.config.clone()
- }
-
- pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
- self.client.read().clone()
- }
-
- pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
- let client = if let Some(transport) = self.transport.clone() {
- Client::new(
- client::ContextServerId(self.id.clone()),
- self.id(),
- transport,
- cx.clone(),
- )?
- } else {
- let Some(command) = &self.config.command else {
- bail!("no command specified for server {}", self.id);
- };
- Client::stdio(
- client::ContextServerId(self.id.clone()),
- client::ModelContextServerBinary {
- executable: Path::new(&command.path).to_path_buf(),
- args: command.args.clone(),
- env: command.env.clone(),
- },
- cx.clone(),
- )?
- };
- self.initialize(client).await
- }
-
- async fn initialize(&self, client: Client) -> Result<()> {
- log::info!("starting context server {}", self.id);
- let protocol = crate::protocol::ModelContextProtocol::new(client);
- let client_info = types::Implementation {
- name: "Zed".to_string(),
- version: env!("CARGO_PKG_VERSION").to_string(),
- };
- let initialized_protocol = protocol.initialize(client_info).await?;
-
- log::debug!(
- "context server {} initialized: {:?}",
- self.id,
- initialized_protocol.initialize,
- );
-
- *self.client.write() = Some(Arc::new(initialized_protocol));
- Ok(())
- }
-
- pub fn stop(&self) -> Result<()> {
- let mut client = self.client.write();
- if let Some(protocol) = client.take() {
- drop(protocol);
- }
- Ok(())
- }
-}
-
-pub struct ContextServerManager {
- servers: HashMap<Arc<str>, Arc<ContextServer>>,
- server_status: HashMap<Arc<str>, ContextServerStatus>,
- project: Entity<Project>,
- registry: Entity<ContextServerDescriptorRegistry>,
- update_servers_task: Option<Task<Result<()>>>,
- needs_server_update: bool,
- _subscriptions: Vec<Subscription>,
-}
-
-pub enum Event {
- ServerStatusChanged {
- server_id: Arc<str>,
- status: Option<ContextServerStatus>,
- },
-}
-
-impl EventEmitter<Event> for ContextServerManager {}
-
-impl ContextServerManager {
- pub fn new(
- registry: Entity<ContextServerDescriptorRegistry>,
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Self {
- let mut this = Self {
- _subscriptions: vec![
- cx.observe(®istry, |this, _registry, cx| {
- this.available_context_servers_changed(cx);
- }),
- cx.observe_global::<SettingsStore>(|this, cx| {
- this.available_context_servers_changed(cx);
- }),
- ],
- project,
- registry,
- needs_server_update: false,
- servers: HashMap::default(),
- server_status: HashMap::default(),
- update_servers_task: None,
- };
- this.available_context_servers_changed(cx);
- this
- }
-
- fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
- if self.update_servers_task.is_some() {
- self.needs_server_update = true;
- } else {
- self.update_servers_task = Some(cx.spawn(async move |this, cx| {
- this.update(cx, |this, _| {
- this.needs_server_update = false;
- })?;
-
- if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
- log::error!("Error maintaining context servers: {}", err);
- }
-
- this.update(cx, |this, cx| {
- let has_any_context_servers = !this.running_servers().is_empty();
- if has_any_context_servers {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
- });
- }
-
- this.update_servers_task.take();
- if this.needs_server_update {
- this.available_context_servers_changed(cx);
- }
- })?;
-
- Ok(())
- }));
- }
- }
-
- pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
- self.servers
- .get(id)
- .filter(|server| server.client().is_some())
- .cloned()
- }
-
- pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
- self.server_status.get(id).cloned()
- }
-
- pub fn start_server(
- &self,
- server: Arc<ContextServer>,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
- }
-
- pub fn stop_server(
- &mut self,
- server: Arc<ContextServer>,
- cx: &mut Context<Self>,
- ) -> Result<()> {
- server.stop().log_err();
- self.update_server_status(server.id().clone(), None, cx);
- Ok(())
- }
-
- pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
- let id = id.clone();
- cx.spawn(async move |this, cx| {
- if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
- let config = server.config();
-
- this.update(cx, |this, cx| this.stop_server(server, cx))??;
- let new_server = Arc::new(ContextServer::new(id.clone(), config));
- Self::run_server(this, new_server, cx).await?;
- }
- Ok(())
- })
- }
-
- pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
- self.servers.values().cloned().collect()
- }
-
- pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
- self.servers
- .values()
- .filter(|server| server.client().is_some())
- .cloned()
- .collect()
- }
-
- async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
- let mut desired_servers = HashMap::default();
-
- let (registry, project) = this.update(cx, |this, cx| {
- let location = this
- .project
- .read(cx)
- .visible_worktrees(cx)
- .next()
- .map(|worktree| settings::SettingsLocation {
- worktree_id: worktree.read(cx).id(),
- path: Path::new(""),
- });
- let settings = ContextServerSettings::get(location, cx);
- desired_servers = settings.context_servers.clone();
-
- (this.registry.clone(), this.project.clone())
- })?;
-
- for (id, descriptor) in
- registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
- {
- let config = desired_servers.entry(id).or_default();
- if config.command.is_none() {
- if let Some(extension_command) =
- descriptor.command(project.clone(), &cx).await.log_err()
- {
- config.command = Some(extension_command);
- }
- }
- }
-
- let mut servers_to_start = HashMap::default();
- let mut servers_to_stop = HashMap::default();
-
- this.update(cx, |this, _cx| {
- this.servers.retain(|id, server| {
- if desired_servers.contains_key(id) {
- true
- } else {
- servers_to_stop.insert(id.clone(), server.clone());
- false
- }
- });
-
- for (id, config) in desired_servers {
- let existing_config = this.servers.get(&id).map(|server| server.config());
- if existing_config.as_deref() != Some(&config) {
- let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
- servers_to_start.insert(id.clone(), server.clone());
- if let Some(old_server) = this.servers.remove(&id) {
- servers_to_stop.insert(id, old_server);
- }
- }
- }
- })?;
-
- for (_, server) in servers_to_stop {
- this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
- }
-
- for (_, server) in servers_to_start {
- Self::run_server(this.clone(), server, cx).await.ok();
- }
-
- Ok(())
- }
-
- async fn run_server(
- this: WeakEntity<Self>,
- server: Arc<ContextServer>,
- cx: &mut AsyncApp,
- ) -> Result<()> {
- let id = server.id();
-
- this.update(cx, |this, cx| {
- this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
- this.servers.insert(id.clone(), server.clone());
- })?;
-
- match server.start(&cx).await {
- Ok(_) => {
- log::debug!("`{}` context server started", id);
- this.update(cx, |this, cx| {
- this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
- })?;
- Ok(())
- }
- Err(err) => {
- log::error!("`{}` context server failed to start\n{}", id, err);
- this.update(cx, |this, cx| {
- this.update_server_status(
- id.clone(),
- Some(ContextServerStatus::Error(err.to_string().into())),
- cx,
- )
- })?;
- Err(err)
- }
- }
- }
-
- fn update_server_status(
- &mut self,
- id: Arc<str>,
- status: Option<ContextServerStatus>,
- cx: &mut Context<Self>,
- ) {
- if let Some(status) = status.clone() {
- self.server_status.insert(id.clone(), status);
- } else {
- self.server_status.remove(&id);
- }
-
- cx.emit(Event::ServerStatusChanged {
- server_id: id,
- status,
- });
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::pin::Pin;
-
- use crate::types::{
- Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
- };
-
- use super::*;
- use futures::{Stream, StreamExt as _, lock::Mutex};
- use gpui::{AppContext as _, TestAppContext};
- use project::FakeFs;
- use serde_json::json;
- use util::path;
-
- #[gpui::test]
- async fn test_context_server_status(cx: &mut TestAppContext) {
- init_test_settings(cx);
- let project = create_test_project(cx, json!({"code.rs": ""})).await;
-
- let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
- let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
-
- let server_1_id: Arc<str> = "mcp-1".into();
- let server_2_id: Arc<str> = "mcp-2".into();
-
- let transport_1 = Arc::new(FakeTransport::new(
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-1".to_string()))
- }
- _ => None,
- },
- ));
-
- let transport_2 = Arc::new(FakeTransport::new(
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-2".to_string()))
- }
- _ => None,
- },
- ));
-
- let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
- let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
-
- manager
- .update(cx, |manager, cx| manager.start_server(server_1, cx))
- .await
- .unwrap();
-
- cx.update(|cx| {
- assert_eq!(
- manager.read(cx).status_for_server(&server_1_id),
- Some(ContextServerStatus::Running)
- );
- assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
- });
-
- manager
- .update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
- .await
- .unwrap();
-
- cx.update(|cx| {
- assert_eq!(
- manager.read(cx).status_for_server(&server_1_id),
- Some(ContextServerStatus::Running)
- );
- assert_eq!(
- manager.read(cx).status_for_server(&server_2_id),
- Some(ContextServerStatus::Running)
- );
- });
-
- manager
- .update(cx, |manager, cx| manager.stop_server(server_2, cx))
- .unwrap();
-
- cx.update(|cx| {
- assert_eq!(
- manager.read(cx).status_for_server(&server_1_id),
- Some(ContextServerStatus::Running)
- );
- assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
- });
- }
-
- async fn create_test_project(
- cx: &mut TestAppContext,
- files: serde_json::Value,
- ) -> Entity<Project> {
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(path!("/test"), files).await;
- Project::test(fs, [path!("/test").as_ref()], cx).await
- }
-
- fn init_test_settings(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- ContextServerSettings::register(cx);
- });
- }
-
- fn create_initialize_response(server_name: String) -> serde_json::Value {
- serde_json::to_value(&InitializeResponse {
- protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
- server_info: Implementation {
- name: server_name,
- version: "1.0.0".to_string(),
- },
- capabilities: ServerCapabilities::default(),
- meta: None,
- })
- .unwrap()
- }
-
- struct FakeTransport {
- on_request: Arc<
- dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
- + Send
- + Sync,
- >,
- tx: futures::channel::mpsc::UnboundedSender<String>,
- rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
- }
-
- impl FakeTransport {
- fn new(
- on_request: impl Fn(
- u64,
- Option<RequestType>,
- serde_json::Value,
- ) -> Option<serde_json::Value>
- + 'static
- + Send
- + Sync,
- ) -> Self {
- let (tx, rx) = futures::channel::mpsc::unbounded();
- Self {
- on_request: Arc::new(on_request),
- tx,
- rx: Arc::new(Mutex::new(rx)),
- }
- }
- }
-
- #[async_trait::async_trait]
- impl Transport for FakeTransport {
- async fn send(&self, message: String) -> Result<()> {
- if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
- let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
-
- if let Some(method) = msg.get("method") {
- let request_type = method
- .as_str()
- .and_then(|method| types::RequestType::try_from(method).ok());
- if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
- let response = serde_json::json!({
- "jsonrpc": "2.0",
- "id": id,
- "result": payload
- });
-
- self.tx
- .unbounded_send(response.to_string())
- .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
- }
- }
- }
- Ok(())
- }
-
- fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
- let rx = self.rx.clone();
- Box::pin(futures::stream::unfold(rx, |rx| async move {
- let mut rx_guard = rx.lock().await;
- if let Some(message) = rx_guard.next().await {
- drop(rx_guard);
- Some((message, rx))
- } else {
- None
- }
- }))
- }
-
- fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
- Box::pin(futures::stream::empty())
- }
- }
-}
@@ -16,7 +16,7 @@ pub struct ModelContextProtocol {
}
impl ModelContextProtocol {
- pub fn new(inner: Client) -> Self {
+ pub(crate) fn new(inner: Client) -> Self {
Self { inner }
}
@@ -610,7 +610,7 @@ pub enum ToolResponseContent {
Resource { resource: ResourceContents },
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResponse {
pub tools: Vec<Tool>,
@@ -1,22 +0,0 @@
-[package]
-name = "context_server_settings"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/context_server_settings.rs"
-
-[dependencies]
-anyhow.workspace = true
-collections.workspace = true
-gpui.workspace = true
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-workspace-hack.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,99 +0,0 @@
-use std::sync::Arc;
-
-use collections::HashMap;
-use gpui::App;
-use schemars::JsonSchema;
-use schemars::r#gen::SchemaGenerator;
-use schemars::schema::{InstanceType, Schema, SchemaObject};
-use serde::{Deserialize, Serialize};
-use settings::{Settings, SettingsSources};
-
-pub fn init(cx: &mut App) {
- ContextServerSettings::register(cx);
-}
-
-#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
-pub struct ServerConfig {
- /// The command to run this context server.
- ///
- /// This will override the command set by an extension.
- pub command: Option<ServerCommand>,
- /// The settings for this context server.
- ///
- /// Consult the documentation for the context server to see what settings
- /// are supported.
- #[schemars(schema_with = "server_config_settings_json_schema")]
- pub settings: Option<serde_json::Value>,
-}
-
-fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
- Schema::Object(SchemaObject {
- instance_type: Some(InstanceType::Object.into()),
- ..Default::default()
- })
-}
-
-#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
-pub struct ServerCommand {
- pub path: String,
- pub args: Vec<String>,
- pub env: Option<HashMap<String, String>>,
-}
-
-#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
-pub struct ContextServerSettings {
- /// Settings for context servers used in the Assistant.
- #[serde(default)]
- pub context_servers: HashMap<Arc<str>, ServerConfig>,
-}
-
-impl Settings for ContextServerSettings {
- const KEY: Option<&'static str> = None;
-
- type FileContent = Self;
-
- fn load(
- sources: SettingsSources<Self::FileContent>,
- _: &mut gpui::App,
- ) -> anyhow::Result<Self> {
- sources.json_merge()
- }
-
- fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
- // we don't handle "inputs" replacement strings, see perplexity-key in this example:
- // https://code.visualstudio.com/docs/copilot/chat/mcp-servers#_configuration-example
- #[derive(Deserialize)]
- struct VsCodeServerCommand {
- command: String,
- args: Option<Vec<String>>,
- env: Option<HashMap<String, String>>,
- // note: we don't support envFile and type
- }
- impl From<VsCodeServerCommand> for ServerCommand {
- fn from(cmd: VsCodeServerCommand) -> Self {
- Self {
- path: cmd.command,
- args: cmd.args.unwrap_or_default(),
- env: cmd.env,
- }
- }
- }
- if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
- current
- .context_servers
- .extend(mcp.iter().filter_map(|(k, v)| {
- Some((
- k.clone().into(),
- ServerConfig {
- command: Some(
- serde_json::from_value::<VsCodeServerCommand>(v.clone())
- .ok()?
- .into(),
- ),
- settings: None,
- },
- ))
- }));
- }
- }
-}
@@ -30,7 +30,6 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
collections.workspace = true
-context_server.workspace = true
dirs.workspace = true
dotenv.workspace = true
env_logger.workspace = true
@@ -426,7 +426,6 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
- context_server::init(cx);
prompt_store::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
@@ -362,6 +362,8 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
server_id: Arc<str>,
cx: &mut App,
);
+
+ fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
}
impl ExtensionContextServerProxy for ExtensionHostProxy {
@@ -377,6 +379,14 @@ impl ExtensionContextServerProxy for ExtensionHostProxy {
proxy.register_context_server(extension, server_id, cx)
}
+
+ fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
+ let Some(proxy) = self.context_server_proxy.read().clone() else {
+ return;
+ };
+
+ proxy.unregister_context_server(server_id, cx)
+ }
}
pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static {
@@ -22,7 +22,6 @@ async-tar.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
-context_server_settings.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -1130,6 +1130,10 @@ impl ExtensionStore {
.remove_language_server(&language, language_server_name);
}
}
+
+ for (server_id, _) in extension.manifest.context_servers.iter() {
+ self.proxy.unregister_context_server(server_id.clone(), cx);
+ }
}
self.wasm_extensions
@@ -7,7 +7,6 @@ use anyhow::{Context, Result, anyhow, bail};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
-use context_server_settings::ContextServerSettings;
use extension::{
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
};
@@ -676,21 +675,23 @@ impl ExtensionImports for WasmState {
})?)
}
"context_servers" => {
- let settings = key
+ let configuration = key
.and_then(|key| {
- ContextServerSettings::get(location, cx)
+ ProjectSettings::get(location, cx)
.context_servers
.get(key.as_str())
})
.cloned()
.unwrap_or_default();
Ok(serde_json::to_string(&settings::ContextServerSettings {
- command: settings.command.map(|command| settings::CommandSettings {
- path: Some(command.path),
- arguments: Some(command.args),
- env: command.env.map(|env| env.into_iter().collect()),
+ command: configuration.command.map(|command| {
+ settings::CommandSettings {
+ path: Some(command.path),
+ arguments: Some(command.args),
+ env: command.env.map(|env| env.into_iter().collect()),
+ }
}),
- settings: settings.settings,
+ settings: configuration.settings,
})?)
}
_ => {
@@ -36,6 +36,7 @@ circular-buffer.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
+context_server.workspace = true
dap.workspace = true
extension.workspace = true
fancy-regex.workspace = true
@@ -0,0 +1,1129 @@
+pub mod extension;
+pub mod registry;
+
+use std::{path::Path, sync::Arc};
+
+use anyhow::{Context as _, Result};
+use collections::{HashMap, HashSet};
+use context_server::{ContextServer, ContextServerId};
+use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
+use registry::ContextServerDescriptorRegistry;
+use settings::{Settings as _, SettingsStore};
+use util::ResultExt as _;
+
+use crate::{
+ project_settings::{ContextServerConfiguration, ProjectSettings},
+ worktree_store::WorktreeStore,
+};
+
+pub fn init(cx: &mut App) {
+ extension::init(cx);
+}
+
+actions!(context_server, [Restart]);
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum ContextServerStatus {
+ Starting,
+ Running,
+ Stopped,
+ Error(Arc<str>),
+}
+
+impl ContextServerStatus {
+ fn from_state(state: &ContextServerState) -> Self {
+ match state {
+ ContextServerState::Starting { .. } => ContextServerStatus::Starting,
+ ContextServerState::Running { .. } => ContextServerStatus::Running,
+ ContextServerState::Stopped { error, .. } => {
+ if let Some(error) = error {
+ ContextServerStatus::Error(error.clone())
+ } else {
+ ContextServerStatus::Stopped
+ }
+ }
+ }
+ }
+}
+
+enum ContextServerState {
+ Starting {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ _task: Task<()>,
+ },
+ Running {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ },
+ Stopped {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ error: Option<Arc<str>>,
+ },
+}
+
+impl ContextServerState {
+ pub fn server(&self) -> Arc<ContextServer> {
+ match self {
+ ContextServerState::Starting { server, .. } => server.clone(),
+ ContextServerState::Running { server, .. } => server.clone(),
+ ContextServerState::Stopped { server, .. } => server.clone(),
+ }
+ }
+
+ pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
+ match self {
+ ContextServerState::Starting { configuration, .. } => configuration.clone(),
+ ContextServerState::Running { configuration, .. } => configuration.clone(),
+ ContextServerState::Stopped { configuration, .. } => configuration.clone(),
+ }
+ }
+}
+
+pub type ContextServerFactory =
+ Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
+
+pub struct ContextServerStore {
+ servers: HashMap<ContextServerId, ContextServerState>,
+ worktree_store: Entity<WorktreeStore>,
+ registry: Entity<ContextServerDescriptorRegistry>,
+ update_servers_task: Option<Task<Result<()>>>,
+ context_server_factory: Option<ContextServerFactory>,
+ needs_server_update: bool,
+ _subscriptions: Vec<Subscription>,
+}
+
+pub enum Event {
+ ServerStatusChanged {
+ server_id: ContextServerId,
+ status: ContextServerStatus,
+ },
+}
+
+impl EventEmitter<Event> for ContextServerStore {}
+
+impl ContextServerStore {
+ pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
+ Self::new_internal(
+ true,
+ None,
+ ContextServerDescriptorRegistry::default_global(cx),
+ worktree_store,
+ cx,
+ )
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn test(
+ registry: Entity<ContextServerDescriptorRegistry>,
+ worktree_store: Entity<WorktreeStore>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ Self::new_internal(false, None, registry, worktree_store, cx)
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn test_maintain_server_loop(
+ context_server_factory: ContextServerFactory,
+ registry: Entity<ContextServerDescriptorRegistry>,
+ worktree_store: Entity<WorktreeStore>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ Self::new_internal(
+ true,
+ Some(context_server_factory),
+ registry,
+ worktree_store,
+ cx,
+ )
+ }
+
+ fn new_internal(
+ maintain_server_loop: bool,
+ context_server_factory: Option<ContextServerFactory>,
+ registry: Entity<ContextServerDescriptorRegistry>,
+ worktree_store: Entity<WorktreeStore>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let subscriptions = if maintain_server_loop {
+ vec![
+ cx.observe(®istry, |this, _registry, cx| {
+ this.available_context_servers_changed(cx);
+ }),
+ cx.observe_global::<SettingsStore>(|this, cx| {
+ this.available_context_servers_changed(cx);
+ }),
+ ]
+ } else {
+ Vec::new()
+ };
+
+ let mut this = Self {
+ _subscriptions: subscriptions,
+ worktree_store,
+ registry,
+ needs_server_update: false,
+ servers: HashMap::default(),
+ update_servers_task: None,
+ context_server_factory,
+ };
+ if maintain_server_loop {
+ this.available_context_servers_changed(cx);
+ }
+ this
+ }
+
+ pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
+ self.servers.get(id).map(|state| state.server())
+ }
+
+ pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
+ if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
+ Some(server.clone())
+ } else {
+ None
+ }
+ }
+
+ pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
+ self.servers.get(id).map(ContextServerStatus::from_state)
+ }
+
+ pub fn all_server_ids(&self) -> Vec<ContextServerId> {
+ self.servers.keys().cloned().collect()
+ }
+
+ pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
+ self.servers
+ .values()
+ .filter_map(|state| {
+ if let ContextServerState::Running { server, .. } = state {
+ Some(server.clone())
+ } else {
+ None
+ }
+ })
+ .collect()
+ }
+
+ pub fn start_server(
+ &mut self,
+ server: Arc<ContextServer>,
+ cx: &mut Context<Self>,
+ ) -> Result<()> {
+ let location = self
+ .worktree_store
+ .read(cx)
+ .visible_worktrees(cx)
+ .next()
+ .map(|worktree| settings::SettingsLocation {
+ worktree_id: worktree.read(cx).id(),
+ path: Path::new(""),
+ });
+ let settings = ProjectSettings::get(location, cx);
+ let configuration = settings
+ .context_servers
+ .get(&server.id().0)
+ .context("Failed to load context server configuration from settings")?
+ .clone();
+
+ self.run_server(server, Arc::new(configuration), cx);
+ Ok(())
+ }
+
+ pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ let Some(state) = self.servers.remove(id) else {
+ return Err(anyhow::anyhow!("Context server not found"));
+ };
+
+ let server = state.server();
+ let configuration = state.configuration();
+ let mut result = Ok(());
+ if let ContextServerState::Running { server, .. } = &state {
+ result = server.stop();
+ }
+ drop(state);
+
+ self.update_server_state(
+ id.clone(),
+ ContextServerState::Stopped {
+ configuration,
+ server,
+ error: None,
+ },
+ cx,
+ );
+
+ result
+ }
+
+ pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ if let Some(state) = self.servers.get(&id) {
+ let configuration = state.configuration();
+
+ self.stop_server(&state.server().id(), cx)?;
+ let new_server = self.create_context_server(id.clone(), configuration.clone())?;
+ self.run_server(new_server, configuration, cx);
+ }
+ Ok(())
+ }
+
+ fn run_server(
+ &mut self,
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &mut Context<Self>,
+ ) {
+ let id = server.id();
+ if matches!(
+ self.servers.get(&id),
+ Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
+ ) {
+ self.stop_server(&id, cx).log_err();
+ }
+
+ let task = cx.spawn({
+ let id = server.id();
+ let server = server.clone();
+ let configuration = configuration.clone();
+ async move |this, cx| {
+ match server.clone().start(&cx).await {
+ Ok(_) => {
+ log::info!("Started {} context server", id);
+ debug_assert!(server.client().is_some());
+
+ this.update(cx, |this, cx| {
+ this.update_server_state(
+ id.clone(),
+ ContextServerState::Running {
+ server,
+ configuration,
+ },
+ cx,
+ )
+ })
+ .log_err()
+ }
+ Err(err) => {
+ log::error!("{} context server failed to start: {}", id, err);
+ this.update(cx, |this, cx| {
+ this.update_server_state(
+ id.clone(),
+ ContextServerState::Stopped {
+ configuration,
+ server,
+ error: Some(err.to_string().into()),
+ },
+ cx,
+ )
+ })
+ .log_err()
+ }
+ };
+ }
+ });
+
+ self.update_server_state(
+ id.clone(),
+ ContextServerState::Starting {
+ configuration,
+ _task: task,
+ server,
+ },
+ cx,
+ );
+ }
+
+ fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ let Some(state) = self.servers.remove(id) else {
+ return Err(anyhow::anyhow!("Context server not found"));
+ };
+ drop(state);
+ cx.emit(Event::ServerStatusChanged {
+ server_id: id.clone(),
+ status: ContextServerStatus::Stopped,
+ });
+ Ok(())
+ }
+
+ fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool {
+ // Command must be some when we are running in stdio mode.
+ self.context_server_factory.as_ref().is_some() || configuration.command.is_some()
+ }
+
+ fn create_context_server(
+ &self,
+ id: ContextServerId,
+ configuration: Arc<ContextServerConfiguration>,
+ ) -> Result<Arc<ContextServer>> {
+ if let Some(factory) = self.context_server_factory.as_ref() {
+ Ok(factory(id, configuration))
+ } else {
+ let command = configuration
+ .command
+ .clone()
+ .context("Missing command to run context server")?;
+ Ok(Arc::new(ContextServer::stdio(id, command)))
+ }
+ }
+
+ fn update_server_state(
+ &mut self,
+ id: ContextServerId,
+ state: ContextServerState,
+ cx: &mut Context<Self>,
+ ) {
+ let status = ContextServerStatus::from_state(&state);
+ self.servers.insert(id.clone(), state);
+ cx.emit(Event::ServerStatusChanged {
+ server_id: id,
+ status,
+ });
+ }
+
+ fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
+ if self.update_servers_task.is_some() {
+ self.needs_server_update = true;
+ } else {
+ self.needs_server_update = false;
+ self.update_servers_task = Some(cx.spawn(async move |this, cx| {
+ if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
+ log::error!("Error maintaining context servers: {}", err);
+ }
+
+ this.update(cx, |this, cx| {
+ this.update_servers_task.take();
+ if this.needs_server_update {
+ this.available_context_servers_changed(cx);
+ }
+ })?;
+
+ Ok(())
+ }));
+ }
+ }
+
+ async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
+ let mut desired_servers = HashMap::default();
+
+ let (registry, worktree_store) = this.update(cx, |this, cx| {
+ let location = this
+ .worktree_store
+ .read(cx)
+ .visible_worktrees(cx)
+ .next()
+ .map(|worktree| settings::SettingsLocation {
+ worktree_id: worktree.read(cx).id(),
+ path: Path::new(""),
+ });
+ let settings = ProjectSettings::get(location, cx);
+ desired_servers = settings.context_servers.clone();
+
+ (this.registry.clone(), this.worktree_store.clone())
+ })?;
+
+ for (id, descriptor) in
+ registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
+ {
+ let config = desired_servers.entry(id.clone()).or_default();
+ if config.command.is_none() {
+ if let Some(extension_command) = descriptor
+ .command(worktree_store.clone(), &cx)
+ .await
+ .log_err()
+ {
+ config.command = Some(extension_command);
+ }
+ }
+ }
+
+ this.update(cx, |this, _| {
+ // Filter out configurations without commands, the user uninstalled an extension.
+ desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration));
+ })?;
+
+ let mut servers_to_start = Vec::new();
+ let mut servers_to_remove = HashSet::default();
+ let mut servers_to_stop = HashSet::default();
+
+ this.update(cx, |this, _cx| {
+ for server_id in this.servers.keys() {
+ // All servers that are not in desired_servers should be removed from the store.
+ // E.g. this can happen if the user removed a server from the configuration,
+ // or the user uninstalled an extension.
+ if !desired_servers.contains_key(&server_id.0) {
+ servers_to_remove.insert(server_id.clone());
+ }
+ }
+
+ for (id, config) in desired_servers {
+ let id = ContextServerId(id.clone());
+
+ let existing_config = this.servers.get(&id).map(|state| state.configuration());
+ if existing_config.as_deref() != Some(&config) {
+ let config = Arc::new(config);
+ if let Some(server) = this
+ .create_context_server(id.clone(), config.clone())
+ .log_err()
+ {
+ servers_to_start.push((server, config));
+ if this.servers.contains_key(&id) {
+ servers_to_stop.insert(id);
+ }
+ }
+ }
+ }
+ })?;
+
+ for id in servers_to_stop {
+ this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?;
+ }
+
+ for id in servers_to_remove {
+ this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?;
+ }
+
+ for (server, config) in servers_to_start {
+ this.update(cx, |this, cx| this.run_server(server, config, cx))
+ .log_err();
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{FakeFs, Project, project_settings::ProjectSettings};
+ use context_server::{
+ transport::Transport,
+ types::{
+ self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
+ ServerCapabilities,
+ },
+ };
+ use futures::{Stream, StreamExt as _, lock::Mutex};
+ use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
+ use serde_json::json;
+ use std::{cell::RefCell, pin::Pin, rc::Rc};
+ use util::path;
+
+ #[gpui::test]
+ async fn test_context_server_status(cx: &mut TestAppContext) {
+ const SERVER_1_ID: &'static str = "mcp-1";
+ const SERVER_2_ID: &'static str = "mcp-2";
+
+ let (_fs, project) = setup_context_server_test(
+ cx,
+ json!({"code.rs": ""}),
+ vec![
+ (SERVER_1_ID.into(), ContextServerConfiguration::default()),
+ (SERVER_2_ID.into(), ContextServerConfiguration::default()),
+ ],
+ )
+ .await;
+
+ let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
+ let store = cx.new(|cx| {
+ ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+ });
+
+ let server_1_id = ContextServerId("mcp-1".into());
+ let server_2_id = ContextServerId("mcp-2".into());
+
+ let transport_1 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response("mcp-1".to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let transport_2 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response("mcp-2".to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
+ let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
+
+ store
+ .update(cx, |store, cx| store.start_server(server_1, cx))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ assert_eq!(
+ store.read(cx).status_for_server(&server_1_id),
+ Some(ContextServerStatus::Running)
+ );
+ assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
+ });
+
+ store
+ .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ assert_eq!(
+ store.read(cx).status_for_server(&server_1_id),
+ Some(ContextServerStatus::Running)
+ );
+ assert_eq!(
+ store.read(cx).status_for_server(&server_2_id),
+ Some(ContextServerStatus::Running)
+ );
+ });
+
+ store
+ .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
+ .unwrap();
+
+ cx.update(|cx| {
+ assert_eq!(
+ store.read(cx).status_for_server(&server_1_id),
+ Some(ContextServerStatus::Running)
+ );
+ assert_eq!(
+ store.read(cx).status_for_server(&server_2_id),
+ Some(ContextServerStatus::Stopped)
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_context_server_status_events(cx: &mut TestAppContext) {
+ const SERVER_1_ID: &'static str = "mcp-1";
+ const SERVER_2_ID: &'static str = "mcp-2";
+
+ let (_fs, project) = setup_context_server_test(
+ cx,
+ json!({"code.rs": ""}),
+ vec![
+ (SERVER_1_ID.into(), ContextServerConfiguration::default()),
+ (SERVER_2_ID.into(), ContextServerConfiguration::default()),
+ ],
+ )
+ .await;
+
+ let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
+ let store = cx.new(|cx| {
+ ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+ });
+
+ let server_1_id = ContextServerId("mcp-1".into());
+ let server_2_id = ContextServerId("mcp-2".into());
+
+ let transport_1 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response("mcp-1".to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let transport_2 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response("mcp-2".to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
+ let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
+
+ let _server_events = assert_server_events(
+ &store,
+ vec![
+ (server_1_id.clone(), ContextServerStatus::Starting),
+ (server_1_id.clone(), ContextServerStatus::Running),
+ (server_2_id.clone(), ContextServerStatus::Starting),
+ (server_2_id.clone(), ContextServerStatus::Running),
+ (server_2_id.clone(), ContextServerStatus::Stopped),
+ ],
+ cx,
+ );
+
+ store
+ .update(cx, |store, cx| store.start_server(server_1, cx))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ store
+ .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ store
+ .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
+ .unwrap();
+ }
+
+ #[gpui::test(iterations = 25)]
+ async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
+ const SERVER_1_ID: &'static str = "mcp-1";
+
+ let (_fs, project) = setup_context_server_test(
+ cx,
+ json!({"code.rs": ""}),
+ vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())],
+ )
+ .await;
+
+ let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
+ let store = cx.new(|cx| {
+ ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
+ });
+
+ let server_id = ContextServerId(SERVER_1_ID.into());
+
+ let transport_1 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response(SERVER_1_ID.to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let transport_2 =
+ Arc::new(FakeTransport::new(
+ cx.executor(),
+ |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response(SERVER_1_ID.to_string()))
+ }
+ _ => None,
+ },
+ ));
+
+ let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
+ let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
+
+ // If we start another server with the same id, we should report that we stopped the previous one
+ let _server_events = assert_server_events(
+ &store,
+ vec![
+ (server_id.clone(), ContextServerStatus::Starting),
+ (server_id.clone(), ContextServerStatus::Stopped),
+ (server_id.clone(), ContextServerStatus::Starting),
+ (server_id.clone(), ContextServerStatus::Running),
+ ],
+ cx,
+ );
+
+ store
+ .update(cx, |store, cx| {
+ store.start_server(server_with_same_id_1.clone(), cx)
+ })
+ .unwrap();
+ store
+ .update(cx, |store, cx| {
+ store.start_server(server_with_same_id_2.clone(), cx)
+ })
+ .unwrap();
+ cx.update(|cx| {
+ assert_eq!(
+ store.read(cx).status_for_server(&server_id),
+ Some(ContextServerStatus::Starting)
+ );
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ assert_eq!(
+ store.read(cx).status_for_server(&server_id),
+ Some(ContextServerStatus::Running)
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
+ const SERVER_1_ID: &'static str = "mcp-1";
+ const SERVER_2_ID: &'static str = "mcp-2";
+
+ let server_1_id = ContextServerId(SERVER_1_ID.into());
+ let server_2_id = ContextServerId(SERVER_2_ID.into());
+
+ let (_fs, project) = setup_context_server_test(
+ cx,
+ json!({"code.rs": ""}),
+ vec![(
+ SERVER_1_ID.into(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": true
+ })),
+ },
+ )],
+ )
+ .await;
+
+ let executor = cx.executor();
+ let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
+ let store = cx.new(|cx| {
+ ContextServerStore::test_maintain_server_loop(
+ Box::new(move |id, _| {
+ let transport = FakeTransport::new(executor.clone(), {
+ let id = id.0.clone();
+ move |_, request_type, _| match request_type {
+ Some(RequestType::Initialize) => {
+ Some(create_initialize_response(id.clone().to_string()))
+ }
+ _ => None,
+ }
+ });
+ Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
+ }),
+ registry.clone(),
+ project.read(cx).worktree_store(),
+ cx,
+ )
+ });
+
+ // Ensure that mcp-1 starts up
+ {
+ let _server_events = assert_server_events(
+ &store,
+ vec![
+ (server_1_id.clone(), ContextServerStatus::Starting),
+ (server_1_id.clone(), ContextServerStatus::Running),
+ ],
+ cx,
+ );
+ cx.run_until_parked();
+ }
+
+ // Ensure that mcp-1 is restarted when the configuration was changed
+ {
+ let _server_events = assert_server_events(
+ &store,
+ vec![
+ (server_1_id.clone(), ContextServerStatus::Stopped),
+ (server_1_id.clone(), ContextServerStatus::Starting),
+ (server_1_id.clone(), ContextServerStatus::Running),
+ ],
+ cx,
+ );
+ set_context_server_configuration(
+ vec![(
+ server_1_id.0.clone(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": false
+ })),
+ },
+ )],
+ cx,
+ );
+
+ cx.run_until_parked();
+ }
+
+ // Ensure that mcp-1 is not restarted when the configuration was not changed
+ {
+ let _server_events = assert_server_events(&store, vec![], cx);
+ set_context_server_configuration(
+ vec![(
+ server_1_id.0.clone(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": false
+ })),
+ },
+ )],
+ cx,
+ );
+
+ cx.run_until_parked();
+ }
+
+ // Ensure that mcp-2 is started once it is added to the settings
+ {
+ let _server_events = assert_server_events(
+ &store,
+ vec![
+ (server_2_id.clone(), ContextServerStatus::Starting),
+ (server_2_id.clone(), ContextServerStatus::Running),
+ ],
+ cx,
+ );
+ set_context_server_configuration(
+ vec![
+ (
+ server_1_id.0.clone(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": false
+ })),
+ },
+ ),
+ (
+ server_2_id.0.clone(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": true
+ })),
+ },
+ ),
+ ],
+ cx,
+ );
+
+ cx.run_until_parked();
+ }
+
+ // Ensure that mcp-2 is removed once it is removed from the settings
+ {
+ let _server_events = assert_server_events(
+ &store,
+ vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
+ cx,
+ );
+ set_context_server_configuration(
+ vec![(
+ server_1_id.0.clone(),
+ ContextServerConfiguration {
+ command: None,
+ settings: Some(json!({
+ "somevalue": false
+ })),
+ },
+ )],
+ cx,
+ );
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
+ });
+ }
+ }
+
+ fn set_context_server_configuration(
+ context_servers: Vec<(Arc<str>, ContextServerConfiguration)>,
+ cx: &mut TestAppContext,
+ ) {
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ let mut settings = ProjectSettings::default();
+ for (id, config) in context_servers {
+ settings.context_servers.insert(id, config);
+ }
+ store
+ .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
+ .unwrap();
+ })
+ });
+ }
+
+ struct ServerEvents {
+ received_event_count: Rc<RefCell<usize>>,
+ expected_event_count: usize,
+ _subscription: Subscription,
+ }
+
+ impl Drop for ServerEvents {
+ fn drop(&mut self) {
+ let actual_event_count = *self.received_event_count.borrow();
+ assert_eq!(
+ actual_event_count, self.expected_event_count,
+ "
+ Expected to receive {} context server store events, but received {} events",
+ self.expected_event_count, actual_event_count
+ );
+ }
+ }
+
+ fn assert_server_events(
+ store: &Entity<ContextServerStore>,
+ expected_events: Vec<(ContextServerId, ContextServerStatus)>,
+ cx: &mut TestAppContext,
+ ) -> ServerEvents {
+ cx.update(|cx| {
+ let mut ix = 0;
+ let received_event_count = Rc::new(RefCell::new(0));
+ let expected_event_count = expected_events.len();
+ let subscription = cx.subscribe(store, {
+ let received_event_count = received_event_count.clone();
+ move |_, event, _| match event {
+ Event::ServerStatusChanged {
+ server_id: actual_server_id,
+ status: actual_status,
+ } => {
+ let (expected_server_id, expected_status) = &expected_events[ix];
+
+ assert_eq!(
+ actual_server_id, expected_server_id,
+ "Expected different server id at index {}",
+ ix
+ );
+ assert_eq!(
+ actual_status, expected_status,
+ "Expected different status at index {}",
+ ix
+ );
+ ix += 1;
+ *received_event_count.borrow_mut() += 1;
+ }
+ }
+ });
+ ServerEvents {
+ expected_event_count,
+ received_event_count,
+ _subscription: subscription,
+ }
+ })
+ }
+
+ async fn setup_context_server_test(
+ cx: &mut TestAppContext,
+ files: serde_json::Value,
+ context_server_configurations: Vec<(Arc<str>, ContextServerConfiguration)>,
+ ) -> (Arc<FakeFs>, Entity<Project>) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ let mut settings = ProjectSettings::get_global(cx).clone();
+ for (id, config) in context_server_configurations {
+ settings.context_servers.insert(id, config);
+ }
+ ProjectSettings::override_global(settings, cx);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), files).await;
+ let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
+
+ (fs, project)
+ }
+
+ fn create_initialize_response(server_name: String) -> serde_json::Value {
+ serde_json::to_value(&InitializeResponse {
+ protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
+ server_info: Implementation {
+ name: server_name,
+ version: "1.0.0".to_string(),
+ },
+ capabilities: ServerCapabilities::default(),
+ meta: None,
+ })
+ .unwrap()
+ }
+
+ struct FakeTransport {
+ on_request: Arc<
+ dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ + Send
+ + Sync,
+ >,
+ tx: futures::channel::mpsc::UnboundedSender<String>,
+ rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
+ executor: BackgroundExecutor,
+ }
+
+ impl FakeTransport {
+ fn new(
+ executor: BackgroundExecutor,
+ on_request: impl Fn(
+ u64,
+ Option<RequestType>,
+ serde_json::Value,
+ ) -> Option<serde_json::Value>
+ + 'static
+ + Send
+ + Sync,
+ ) -> Self {
+ let (tx, rx) = futures::channel::mpsc::unbounded();
+ Self {
+ on_request: Arc::new(on_request),
+ tx,
+ rx: Arc::new(Mutex::new(rx)),
+ executor,
+ }
+ }
+ }
+
+ #[async_trait::async_trait]
+ impl Transport for FakeTransport {
+ async fn send(&self, message: String) -> Result<()> {
+ if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
+ let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
+
+ if let Some(method) = msg.get("method") {
+ let request_type = method
+ .as_str()
+ .and_then(|method| types::RequestType::try_from(method).ok());
+ if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
+ let response = serde_json::json!({
+ "jsonrpc": "2.0",
+ "id": id,
+ "result": payload
+ });
+
+ self.tx
+ .unbounded_send(response.to_string())
+ .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ let rx = self.rx.clone();
+ let executor = self.executor.clone();
+ Box::pin(futures::stream::unfold(rx, move |rx| {
+ let executor = executor.clone();
+ async move {
+ let mut rx_guard = rx.lock().await;
+ executor.simulate_random_delay().await;
+ if let Some(message) = rx_guard.next().await {
+ drop(rx_guard);
+ Some((message, rx))
+ } else {
+ None
+ }
+ }
+ }))
+ }
+
+ fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ Box::pin(futures::stream::empty())
+ }
+ }
+}
@@ -1,19 +1,21 @@
use std::sync::Arc;
use anyhow::Result;
+use context_server::ContextServerCommand;
use extension::{
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
ProjectDelegate,
};
use gpui::{App, AsyncApp, Entity, Task};
-use project::Project;
-use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
+use crate::worktree_store::WorktreeStore;
+
+use super::registry::{self, ContextServerDescriptorRegistry};
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
- context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
+ context_server_factory_registry: ContextServerDescriptorRegistry::default_global(cx),
});
}
@@ -32,10 +34,13 @@ struct ContextServerDescriptor {
extension: Arc<dyn Extension>,
}
-fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
- project.update(cx, |project, cx| {
+fn extension_project(
+ worktree_store: Entity<WorktreeStore>,
+ cx: &mut AsyncApp,
+) -> Result<Arc<ExtensionProject>> {
+ worktree_store.update(cx, |worktree_store, cx| {
Arc::new(ExtensionProject {
- worktree_ids: project
+ worktree_ids: worktree_store
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
@@ -44,11 +49,15 @@ fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<
}
impl registry::ContextServerDescriptor for ContextServerDescriptor {
- fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
+ fn command(
+ &self,
+ worktree_store: Entity<WorktreeStore>,
+ cx: &AsyncApp,
+ ) -> Task<Result<ContextServerCommand>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
- let extension_project = extension_project(project, cx)?;
+ let extension_project = extension_project(worktree_store, cx)?;
let mut command = extension
.context_server_command(id.clone(), extension_project.clone())
.await?;
@@ -59,7 +68,7 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
log::info!("loaded command for context server {id}: {command:?}");
- Ok(ServerCommand {
+ Ok(ContextServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
@@ -69,13 +78,13 @@ impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn configuration(
&self,
- project: Entity<Project>,
+ worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
- let extension_project = extension_project(project, cx)?;
+ let extension_project = extension_project(worktree_store, cx)?;
let configuration = extension
.context_server_configuration(id.clone(), extension_project)
.await?;
@@ -102,4 +111,11 @@ impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
)
});
}
+
+ fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App) {
+ self.context_server_factory_registry
+ .update(cx, |registry, _| {
+ registry.unregister_context_server_descriptor_by_id(&server_id)
+ });
+ }
}
@@ -2,17 +2,21 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
+use context_server::ContextServerCommand;
use extension::ContextServerConfiguration;
-use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
-use project::Project;
+use gpui::{App, AppContext as _, AsyncApp, Entity, Global, Task};
-use crate::ServerCommand;
+use crate::worktree_store::WorktreeStore;
pub trait ContextServerDescriptor {
- fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
+ fn command(
+ &self,
+ worktree_store: Entity<WorktreeStore>,
+ cx: &AsyncApp,
+ ) -> Task<Result<ContextServerCommand>>;
fn configuration(
&self,
- project: Entity<Project>,
+ worktree_store: Entity<WorktreeStore>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>>;
}
@@ -27,11 +31,6 @@ pub struct ContextServerDescriptorRegistry {
}
impl ContextServerDescriptorRegistry {
- /// Returns the global [`ContextServerDescriptorRegistry`].
- pub fn global(cx: &App) -> Entity<Self> {
- GlobalContextServerDescriptorRegistry::global(cx).0.clone()
- }
-
/// Returns the global [`ContextServerDescriptorRegistry`].
///
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
@@ -244,9 +244,8 @@ mod tests {
use git::status::{FileStatus, StatusCode, TrackedSummary, UnmergedStatus, UnmergedStatusCode};
use gpui::TestAppContext;
use serde_json::json;
- use settings::{Settings as _, SettingsStore};
+ use settings::SettingsStore;
use util::path;
- use worktree::WorktreeSettings;
const CONFLICT: FileStatus = FileStatus::Unmerged(UnmergedStatus {
first_head: UnmergedStatusCode::Updated,
@@ -682,7 +681,7 @@ mod tests {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
- WorktreeSettings::register(cx);
+ Project::init_settings(cx);
});
}
@@ -1,6 +1,7 @@
pub mod buffer_store;
mod color_extractor;
pub mod connection_manager;
+pub mod context_server_store;
pub mod debounced_delay;
pub mod debugger;
pub mod git_store;
@@ -23,6 +24,7 @@ mod project_tests;
mod direnv;
mod environment;
use buffer_diff::BufferDiff;
+use context_server_store::ContextServerStore;
pub use environment::{EnvironmentErrorMessage, ProjectEnvironmentEvent};
use git_store::{Repository, RepositoryId};
pub mod search_history;
@@ -182,6 +184,7 @@ pub struct Project {
client_subscriptions: Vec<client::Subscription>,
worktree_store: Entity<WorktreeStore>,
buffer_store: Entity<BufferStore>,
+ context_server_store: Entity<ContextServerStore>,
image_store: Entity<ImageStore>,
lsp_store: Entity<LspStore>,
_subscriptions: Vec<gpui::Subscription>,
@@ -845,6 +848,7 @@ impl Project {
ToolchainStore::init(&client);
DapStore::init(&client, cx);
BreakpointStore::init(&client);
+ context_server_store::init(cx);
}
pub fn local(
@@ -865,6 +869,9 @@ impl Project {
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
.detach();
+ let context_server_store =
+ cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
+
let environment = cx.new(|_| ProjectEnvironment::new(env));
let toolchain_store = cx.new(|cx| {
ToolchainStore::local(
@@ -965,6 +972,7 @@ impl Project {
buffer_store,
image_store,
lsp_store,
+ context_server_store,
join_project_response_message_id: 0,
client_state: ProjectClientState::Local,
git_store,
@@ -1025,6 +1033,9 @@ impl Project {
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
.detach();
+ let context_server_store =
+ cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx));
+
let buffer_store = cx.new(|cx| {
BufferStore::remote(
worktree_store.clone(),
@@ -1109,6 +1120,7 @@ impl Project {
buffer_store,
image_store,
lsp_store,
+ context_server_store,
breakpoint_store,
dap_store,
join_project_response_message_id: 0,
@@ -1267,6 +1279,8 @@ impl Project {
let image_store = cx.new(|cx| {
ImageStore::remote(worktree_store.clone(), client.clone().into(), remote_id, cx)
})?;
+ let context_server_store =
+ cx.new(|cx| ContextServerStore::new(worktree_store.clone(), cx))?;
let environment = cx.new(|_| ProjectEnvironment::new(None))?;
@@ -1360,6 +1374,7 @@ impl Project {
image_store,
worktree_store: worktree_store.clone(),
lsp_store: lsp_store.clone(),
+ context_server_store,
active_entry: None,
collaborators: Default::default(),
join_project_response_message_id: response.message_id,
@@ -1590,6 +1605,10 @@ impl Project {
self.worktree_store.clone()
}
+ pub fn context_server_store(&self) -> Entity<ContextServerStore> {
+ self.context_server_store.clone()
+ }
+
pub fn buffer_for_id(&self, remote_id: BufferId, cx: &App) -> Option<Entity<Buffer>> {
self.buffer_store.read(cx).get(remote_id)
}
@@ -1,5 +1,6 @@
use anyhow::Context as _;
use collections::HashMap;
+use context_server::ContextServerCommand;
use dap::adapters::DebugAdapterName;
use fs::Fs;
use futures::StreamExt as _;
@@ -51,6 +52,10 @@ pub struct ProjectSettings {
#[serde(default)]
pub dap: HashMap<DebugAdapterName, DapSettings>,
+ /// Settings for context servers used for AI-related features.
+ #[serde(default)]
+ pub context_servers: HashMap<Arc<str>, ContextServerConfiguration>,
+
/// Configuration for Diagnostics-related features.
#[serde(default)]
pub diagnostics: DiagnosticsSettings,
@@ -78,6 +83,19 @@ pub struct DapSettings {
pub binary: Option<String>,
}
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
+pub struct ContextServerConfiguration {
+ /// The command to run this context server.
+ ///
+ /// This will override the command set by an extension.
+ pub command: Option<ContextServerCommand>,
+ /// The settings for this context server.
+ ///
+ /// Consult the documentation for the context server to see what settings
+ /// are supported.
+ pub settings: Option<serde_json::Value>,
+}
+
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct NodeBinarySettings {
/// The path to the Node binary.
@@ -376,6 +394,40 @@ impl Settings for ProjectSettings {
}
}
+ #[derive(Deserialize)]
+ struct VsCodeContextServerCommand {
+ command: String,
+ args: Option<Vec<String>>,
+ env: Option<HashMap<String, String>>,
+ // note: we don't support envFile and type
+ }
+ impl From<VsCodeContextServerCommand> for ContextServerCommand {
+ fn from(cmd: VsCodeContextServerCommand) -> Self {
+ Self {
+ path: cmd.command,
+ args: cmd.args.unwrap_or_default(),
+ env: cmd.env,
+ }
+ }
+ }
+ if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
+ current
+ .context_servers
+ .extend(mcp.iter().filter_map(|(k, v)| {
+ Some((
+ k.clone().into(),
+ ContextServerConfiguration {
+ command: Some(
+ serde_json::from_value::<VsCodeContextServerCommand>(v.clone())
+ .ok()?
+ .into(),
+ ),
+ settings: None,
+ },
+ ))
+ }));
+ }
+
// TODO: translate lsp settings for rust-analyzer and other popular ones to old.lsp
}
}