agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod templates;
   7#[cfg(test)]
   8mod tests;
   9mod thread;
  10mod thread_store;
  11mod tool_permissions;
  12mod tools;
  13
  14use context_server::ContextServerId;
  15pub use db::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use thread_store::*;
  20pub use tool_permissions::*;
  21pub use tools::*;
  22
  23use acp_thread::{
  24    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  25    AgentSessionListResponse, UserMessageId,
  26};
  27use agent_client_protocol as acp;
  28use anyhow::{Context as _, Result, anyhow};
  29use chrono::{DateTime, Utc};
  30use collections::{HashMap, HashSet, IndexMap};
  31use fs::Fs;
  32use futures::channel::{mpsc, oneshot};
  33use futures::future::Shared;
  34use futures::{StreamExt, future};
  35use gpui::{
  36    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  37};
  38use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  39use project::{Project, ProjectItem, ProjectPath, Worktree};
  40use prompt_store::{
  41    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  42    WorktreeContext,
  43};
  44use serde::{Deserialize, Serialize};
  45use settings::{LanguageModelSelection, update_settings_file};
  46use std::any::Any;
  47use std::path::{Path, PathBuf};
  48use std::rc::Rc;
  49use std::sync::Arc;
  50use util::ResultExt;
  51use util::rel_path::RelPath;
  52
  53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  54pub struct ProjectSnapshot {
  55    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  56    pub timestamp: DateTime<Utc>,
  57}
  58
  59pub struct RulesLoadingError {
  60    pub message: SharedString,
  61}
  62
  63/// Holds both the internal Thread and the AcpThread for a session
  64struct Session {
  65    /// The internal thread that processes messages
  66    thread: Entity<Thread>,
  67    /// The ACP thread that handles protocol communication
  68    acp_thread: WeakEntity<acp_thread::AcpThread>,
  69    pending_save: Task<()>,
  70    _subscriptions: Vec<Subscription>,
  71}
  72
  73pub struct LanguageModels {
  74    /// Access language model by ID
  75    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  76    /// Cached list for returning language model information
  77    model_list: acp_thread::AgentModelList,
  78    refresh_models_rx: watch::Receiver<()>,
  79    refresh_models_tx: watch::Sender<()>,
  80    _authenticate_all_providers_task: Task<()>,
  81}
  82
  83impl LanguageModels {
  84    fn new(cx: &mut App) -> Self {
  85        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  86
  87        let mut this = Self {
  88            models: HashMap::default(),
  89            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  90            refresh_models_rx,
  91            refresh_models_tx,
  92            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  93        };
  94        this.refresh_list(cx);
  95        this
  96    }
  97
  98    fn refresh_list(&mut self, cx: &App) {
  99        let providers = LanguageModelRegistry::global(cx)
 100            .read(cx)
 101            .visible_providers()
 102            .into_iter()
 103            .filter(|provider| provider.is_authenticated(cx))
 104            .collect::<Vec<_>>();
 105
 106        let mut language_model_list = IndexMap::default();
 107        let mut recommended_models = HashSet::default();
 108
 109        let mut recommended = Vec::new();
 110        for provider in &providers {
 111            for model in provider.recommended_models(cx) {
 112                recommended_models.insert((model.provider_id(), model.id()));
 113                recommended.push(Self::map_language_model_to_info(&model, provider));
 114            }
 115        }
 116        if !recommended.is_empty() {
 117            language_model_list.insert(
 118                acp_thread::AgentModelGroupName("Recommended".into()),
 119                recommended,
 120            );
 121        }
 122
 123        let mut models = HashMap::default();
 124        for provider in providers {
 125            let mut provider_models = Vec::new();
 126            for model in provider.provided_models(cx) {
 127                let model_info = Self::map_language_model_to_info(&model, &provider);
 128                let model_id = model_info.id.clone();
 129                provider_models.push(model_info);
 130                models.insert(model_id, model);
 131            }
 132            if !provider_models.is_empty() {
 133                language_model_list.insert(
 134                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 135                    provider_models,
 136                );
 137            }
 138        }
 139
 140        self.models = models;
 141        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 142        self.refresh_models_tx.send(()).ok();
 143    }
 144
 145    fn watch(&self) -> watch::Receiver<()> {
 146        self.refresh_models_rx.clone()
 147    }
 148
 149    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 150        self.models.get(model_id).cloned()
 151    }
 152
 153    fn map_language_model_to_info(
 154        model: &Arc<dyn LanguageModel>,
 155        provider: &Arc<dyn LanguageModelProvider>,
 156    ) -> acp_thread::AgentModelInfo {
 157        acp_thread::AgentModelInfo {
 158            id: Self::model_id(model),
 159            name: model.name().0,
 160            description: None,
 161            icon: Some(match provider.icon() {
 162                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 163                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 164            }),
 165        }
 166    }
 167
 168    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 169        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 170    }
 171
 172    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 173        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 174            .read(cx)
 175            .visible_providers()
 176            .iter()
 177            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 178            .collect::<Vec<_>>();
 179
 180        cx.background_spawn(async move {
 181            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 182                if let Err(err) = authenticate_task.await {
 183                    match err {
 184                        language_model::AuthenticateError::CredentialsNotFound => {
 185                            // Since we're authenticating these providers in the
 186                            // background for the purposes of populating the
 187                            // language selector, we don't care about providers
 188                            // where the credentials are not found.
 189                        }
 190                        language_model::AuthenticateError::ConnectionRefused => {
 191                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 192                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 193                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 194                        }
 195                        _ => {
 196                            // Some providers have noisy failure states that we
 197                            // don't want to spam the logs with every time the
 198                            // language model selector is initialized.
 199                            //
 200                            // Ideally these should have more clear failure modes
 201                            // that we know are safe to ignore here, like what we do
 202                            // with `CredentialsNotFound` above.
 203                            match provider_id.0.as_ref() {
 204                                "lmstudio" | "ollama" => {
 205                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 206                                    //
 207                                    // These fail noisily, so we don't log them.
 208                                }
 209                                "copilot_chat" => {
 210                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 211                                }
 212                                _ => {
 213                                    log::error!(
 214                                        "Failed to authenticate provider: {}: {err:#}",
 215                                        provider_name.0
 216                                    );
 217                                }
 218                            }
 219                        }
 220                    }
 221                }
 222            }
 223        })
 224    }
 225}
 226
 227pub struct NativeAgent {
 228    /// Session ID -> Session mapping
 229    sessions: HashMap<acp::SessionId, Session>,
 230    thread_store: Entity<ThreadStore>,
 231    /// Shared project context for all threads
 232    project_context: Entity<ProjectContext>,
 233    project_context_needs_refresh: watch::Sender<()>,
 234    _maintain_project_context: Task<Result<()>>,
 235    context_server_registry: Entity<ContextServerRegistry>,
 236    /// Shared templates for all threads
 237    templates: Arc<Templates>,
 238    /// Cached model information
 239    models: LanguageModels,
 240    project: Entity<Project>,
 241    prompt_store: Option<Entity<PromptStore>>,
 242    fs: Arc<dyn Fs>,
 243    _subscriptions: Vec<Subscription>,
 244}
 245
 246impl NativeAgent {
 247    pub async fn new(
 248        project: Entity<Project>,
 249        thread_store: Entity<ThreadStore>,
 250        templates: Arc<Templates>,
 251        prompt_store: Option<Entity<PromptStore>>,
 252        fs: Arc<dyn Fs>,
 253        cx: &mut AsyncApp,
 254    ) -> Result<Entity<NativeAgent>> {
 255        log::debug!("Creating new NativeAgent");
 256
 257        let project_context = cx
 258            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 259            .await;
 260
 261        Ok(cx.new(|cx| {
 262            let context_server_store = project.read(cx).context_server_store();
 263            let context_server_registry =
 264                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 265
 266            let mut subscriptions = vec![
 267                cx.subscribe(&project, Self::handle_project_event),
 268                cx.subscribe(
 269                    &LanguageModelRegistry::global(cx),
 270                    Self::handle_models_updated_event,
 271                ),
 272                cx.subscribe(
 273                    &context_server_store,
 274                    Self::handle_context_server_store_updated,
 275                ),
 276                cx.subscribe(
 277                    &context_server_registry,
 278                    Self::handle_context_server_registry_event,
 279                ),
 280            ];
 281            if let Some(prompt_store) = prompt_store.as_ref() {
 282                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 283            }
 284
 285            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 286                watch::channel(());
 287            Self {
 288                sessions: HashMap::default(),
 289                thread_store,
 290                project_context: cx.new(|_| project_context),
 291                project_context_needs_refresh: project_context_needs_refresh_tx,
 292                _maintain_project_context: cx.spawn(async move |this, cx| {
 293                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 294                }),
 295                context_server_registry,
 296                templates,
 297                models: LanguageModels::new(cx),
 298                project,
 299                prompt_store,
 300                fs,
 301                _subscriptions: subscriptions,
 302            }
 303        }))
 304    }
 305
 306    fn register_session(
 307        &mut self,
 308        thread_handle: Entity<Thread>,
 309        cx: &mut Context<Self>,
 310    ) -> Entity<AcpThread> {
 311        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 312
 313        let thread = thread_handle.read(cx);
 314        let session_id = thread.id().clone();
 315        let title = thread.title();
 316        let project = thread.project.clone();
 317        let action_log = thread.action_log.clone();
 318        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 319        let acp_thread = cx.new(|cx| {
 320            acp_thread::AcpThread::new(
 321                title,
 322                connection,
 323                project.clone(),
 324                action_log.clone(),
 325                session_id.clone(),
 326                prompt_capabilities_rx,
 327                cx,
 328            )
 329        });
 330
 331        let registry = LanguageModelRegistry::read_global(cx);
 332        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 333
 334        thread_handle.update(cx, |thread, cx| {
 335            thread.set_summarization_model(summarization_model, cx);
 336            thread.add_default_tools(
 337                Rc::new(AcpThreadEnvironment {
 338                    acp_thread: acp_thread.downgrade(),
 339                }) as _,
 340                cx,
 341            )
 342        });
 343
 344        let subscriptions = vec![
 345            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 346                this.sessions.remove(acp_thread.session_id());
 347            }),
 348            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 349            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 350            cx.observe(&thread_handle, move |this, thread, cx| {
 351                this.save_thread(thread, cx)
 352            }),
 353        ];
 354
 355        self.sessions.insert(
 356            session_id,
 357            Session {
 358                thread: thread_handle,
 359                acp_thread: acp_thread.downgrade(),
 360                _subscriptions: subscriptions,
 361                pending_save: Task::ready(()),
 362            },
 363        );
 364
 365        self.update_available_commands(cx);
 366
 367        acp_thread
 368    }
 369
 370    pub fn models(&self) -> &LanguageModels {
 371        &self.models
 372    }
 373
 374    async fn maintain_project_context(
 375        this: WeakEntity<Self>,
 376        mut needs_refresh: watch::Receiver<()>,
 377        cx: &mut AsyncApp,
 378    ) -> Result<()> {
 379        while needs_refresh.changed().await.is_ok() {
 380            let project_context = this
 381                .update(cx, |this, cx| {
 382                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 383                })?
 384                .await;
 385            this.update(cx, |this, cx| {
 386                this.project_context = cx.new(|_| project_context);
 387            })?;
 388        }
 389
 390        Ok(())
 391    }
 392
 393    fn build_project_context(
 394        project: &Entity<Project>,
 395        prompt_store: Option<&Entity<PromptStore>>,
 396        cx: &mut App,
 397    ) -> Task<ProjectContext> {
 398        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 399        let worktree_tasks = worktrees
 400            .into_iter()
 401            .map(|worktree| {
 402                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 403            })
 404            .collect::<Vec<_>>();
 405        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 406            prompt_store.read_with(cx, |prompt_store, cx| {
 407                let prompts = prompt_store.default_prompt_metadata();
 408                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 409                    let contents = prompt_store.load(prompt_metadata.id, cx);
 410                    async move { (contents.await, prompt_metadata) }
 411                });
 412                cx.background_spawn(future::join_all(load_tasks))
 413            })
 414        } else {
 415            Task::ready(vec![])
 416        };
 417
 418        cx.spawn(async move |_cx| {
 419            let (worktrees, default_user_rules) =
 420                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 421
 422            let worktrees = worktrees
 423                .into_iter()
 424                .map(|(worktree, _rules_error)| {
 425                    // TODO: show error message
 426                    // if let Some(rules_error) = rules_error {
 427                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 428                    // }
 429                    worktree
 430                })
 431                .collect::<Vec<_>>();
 432
 433            let default_user_rules = default_user_rules
 434                .into_iter()
 435                .flat_map(|(contents, prompt_metadata)| match contents {
 436                    Ok(contents) => Some(UserRulesContext {
 437                        uuid: prompt_metadata.id.as_user()?,
 438                        title: prompt_metadata.title.map(|title| title.to_string()),
 439                        contents,
 440                    }),
 441                    Err(_err) => {
 442                        // TODO: show error message
 443                        // this.update(cx, |_, cx| {
 444                        //     cx.emit(RulesLoadingError {
 445                        //         message: format!("{err:?}").into(),
 446                        //     });
 447                        // })
 448                        // .ok();
 449                        None
 450                    }
 451                })
 452                .collect::<Vec<_>>();
 453
 454            ProjectContext::new(worktrees, default_user_rules)
 455        })
 456    }
 457
 458    fn load_worktree_info_for_system_prompt(
 459        worktree: Entity<Worktree>,
 460        project: Entity<Project>,
 461        cx: &mut App,
 462    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 463        let tree = worktree.read(cx);
 464        let root_name = tree.root_name_str().into();
 465        let abs_path = tree.abs_path();
 466
 467        let mut context = WorktreeContext {
 468            root_name,
 469            abs_path,
 470            rules_file: None,
 471        };
 472
 473        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 474        let Some(rules_task) = rules_task else {
 475            return Task::ready((context, None));
 476        };
 477
 478        cx.spawn(async move |_| {
 479            let (rules_file, rules_file_error) = match rules_task.await {
 480                Ok(rules_file) => (Some(rules_file), None),
 481                Err(err) => (
 482                    None,
 483                    Some(RulesLoadingError {
 484                        message: format!("{err}").into(),
 485                    }),
 486                ),
 487            };
 488            context.rules_file = rules_file;
 489            (context, rules_file_error)
 490        })
 491    }
 492
 493    fn load_worktree_rules_file(
 494        worktree: Entity<Worktree>,
 495        project: Entity<Project>,
 496        cx: &mut App,
 497    ) -> Option<Task<Result<RulesFileContext>>> {
 498        let worktree = worktree.read(cx);
 499        let worktree_id = worktree.id();
 500        let selected_rules_file = RULES_FILE_NAMES
 501            .into_iter()
 502            .filter_map(|name| {
 503                worktree
 504                    .entry_for_path(RelPath::unix(name).unwrap())
 505                    .filter(|entry| entry.is_file())
 506                    .map(|entry| entry.path.clone())
 507            })
 508            .next();
 509
 510        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 511        // supported. This doesn't seem to occur often in GitHub repositories.
 512        selected_rules_file.map(|path_in_worktree| {
 513            let project_path = ProjectPath {
 514                worktree_id,
 515                path: path_in_worktree.clone(),
 516            };
 517            let buffer_task =
 518                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 519            let rope_task = cx.spawn(async move |cx| {
 520                let buffer = buffer_task.await?;
 521                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 522                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 523                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 524                })?;
 525                anyhow::Ok((project_entry_id, rope))
 526            });
 527            // Build a string from the rope on a background thread.
 528            cx.background_spawn(async move {
 529                let (project_entry_id, rope) = rope_task.await?;
 530                anyhow::Ok(RulesFileContext {
 531                    path_in_worktree,
 532                    text: rope.to_string().trim().to_string(),
 533                    project_entry_id: project_entry_id.to_usize(),
 534                })
 535            })
 536        })
 537    }
 538
 539    fn handle_thread_title_updated(
 540        &mut self,
 541        thread: Entity<Thread>,
 542        _: &TitleUpdated,
 543        cx: &mut Context<Self>,
 544    ) {
 545        let session_id = thread.read(cx).id();
 546        let Some(session) = self.sessions.get(session_id) else {
 547            return;
 548        };
 549        let thread = thread.downgrade();
 550        let acp_thread = session.acp_thread.clone();
 551        cx.spawn(async move |_, cx| {
 552            let title = thread.read_with(cx, |thread, _| thread.title())?;
 553            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 554            task.await
 555        })
 556        .detach_and_log_err(cx);
 557    }
 558
 559    fn handle_thread_token_usage_updated(
 560        &mut self,
 561        thread: Entity<Thread>,
 562        usage: &TokenUsageUpdated,
 563        cx: &mut Context<Self>,
 564    ) {
 565        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 566            return;
 567        };
 568        session
 569            .acp_thread
 570            .update(cx, |acp_thread, cx| {
 571                acp_thread.update_token_usage(usage.0.clone(), cx);
 572            })
 573            .ok();
 574    }
 575
 576    fn handle_project_event(
 577        &mut self,
 578        _project: Entity<Project>,
 579        event: &project::Event,
 580        _cx: &mut Context<Self>,
 581    ) {
 582        match event {
 583            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 584                self.project_context_needs_refresh.send(()).ok();
 585            }
 586            project::Event::WorktreeUpdatedEntries(_, items) => {
 587                if items.iter().any(|(path, _, _)| {
 588                    RULES_FILE_NAMES
 589                        .iter()
 590                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 591                }) {
 592                    self.project_context_needs_refresh.send(()).ok();
 593                }
 594            }
 595            _ => {}
 596        }
 597    }
 598
 599    fn handle_prompts_updated_event(
 600        &mut self,
 601        _prompt_store: Entity<PromptStore>,
 602        _event: &prompt_store::PromptsUpdatedEvent,
 603        _cx: &mut Context<Self>,
 604    ) {
 605        self.project_context_needs_refresh.send(()).ok();
 606    }
 607
 608    fn handle_models_updated_event(
 609        &mut self,
 610        _registry: Entity<LanguageModelRegistry>,
 611        _event: &language_model::Event,
 612        cx: &mut Context<Self>,
 613    ) {
 614        self.models.refresh_list(cx);
 615
 616        let registry = LanguageModelRegistry::read_global(cx);
 617        let default_model = registry.default_model().map(|m| m.model);
 618        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 619
 620        for session in self.sessions.values_mut() {
 621            session.thread.update(cx, |thread, cx| {
 622                if thread.model().is_none()
 623                    && let Some(model) = default_model.clone()
 624                {
 625                    thread.set_model(model, cx);
 626                    cx.notify();
 627                }
 628                thread.set_summarization_model(summarization_model.clone(), cx);
 629            });
 630        }
 631    }
 632
 633    fn handle_context_server_store_updated(
 634        &mut self,
 635        _store: Entity<project::context_server_store::ContextServerStore>,
 636        _event: &project::context_server_store::Event,
 637        cx: &mut Context<Self>,
 638    ) {
 639        self.update_available_commands(cx);
 640    }
 641
 642    fn handle_context_server_registry_event(
 643        &mut self,
 644        _registry: Entity<ContextServerRegistry>,
 645        event: &ContextServerRegistryEvent,
 646        cx: &mut Context<Self>,
 647    ) {
 648        match event {
 649            ContextServerRegistryEvent::ToolsChanged => {}
 650            ContextServerRegistryEvent::PromptsChanged => {
 651                self.update_available_commands(cx);
 652            }
 653        }
 654    }
 655
 656    fn update_available_commands(&self, cx: &mut Context<Self>) {
 657        let available_commands = self.build_available_commands(cx);
 658        for session in self.sessions.values() {
 659            if let Some(acp_thread) = session.acp_thread.upgrade() {
 660                acp_thread.update(cx, |thread, cx| {
 661                    thread
 662                        .handle_session_update(
 663                            acp::SessionUpdate::AvailableCommandsUpdate(
 664                                acp::AvailableCommandsUpdate::new(available_commands.clone()),
 665                            ),
 666                            cx,
 667                        )
 668                        .log_err();
 669                });
 670            }
 671        }
 672    }
 673
 674    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 675        let registry = self.context_server_registry.read(cx);
 676
 677        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 678        for context_server_prompt in registry.prompts() {
 679            *prompt_name_counts
 680                .entry(context_server_prompt.prompt.name.as_str())
 681                .or_insert(0) += 1;
 682        }
 683
 684        registry
 685            .prompts()
 686            .flat_map(|context_server_prompt| {
 687                let prompt = &context_server_prompt.prompt;
 688
 689                let should_prefix = prompt_name_counts
 690                    .get(prompt.name.as_str())
 691                    .copied()
 692                    .unwrap_or(0)
 693                    > 1;
 694
 695                let name = if should_prefix {
 696                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 697                } else {
 698                    prompt.name.clone()
 699                };
 700
 701                let mut command = acp::AvailableCommand::new(
 702                    name,
 703                    prompt.description.clone().unwrap_or_default(),
 704                );
 705
 706                match prompt.arguments.as_deref() {
 707                    Some([arg]) => {
 708                        let hint = format!("<{}>", arg.name);
 709
 710                        command = command.input(acp::AvailableCommandInput::Unstructured(
 711                            acp::UnstructuredCommandInput::new(hint),
 712                        ));
 713                    }
 714                    Some([]) | None => {}
 715                    Some(_) => {
 716                        // skip >1 argument commands since we don't support them yet
 717                        return None;
 718                    }
 719                }
 720
 721                Some(command)
 722            })
 723            .collect()
 724    }
 725
 726    pub fn load_thread(
 727        &mut self,
 728        id: acp::SessionId,
 729        cx: &mut Context<Self>,
 730    ) -> Task<Result<Entity<Thread>>> {
 731        let database_future = ThreadsDatabase::connect(cx);
 732        cx.spawn(async move |this, cx| {
 733            let database = database_future.await.map_err(|err| anyhow!(err))?;
 734            let db_thread = database
 735                .load_thread(id.clone())
 736                .await?
 737                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 738
 739            this.update(cx, |this, cx| {
 740                let summarization_model = LanguageModelRegistry::read_global(cx)
 741                    .thread_summary_model()
 742                    .map(|c| c.model);
 743
 744                cx.new(|cx| {
 745                    let mut thread = Thread::from_db(
 746                        id.clone(),
 747                        db_thread,
 748                        this.project.clone(),
 749                        this.project_context.clone(),
 750                        this.context_server_registry.clone(),
 751                        this.templates.clone(),
 752                        cx,
 753                    );
 754                    thread.set_summarization_model(summarization_model, cx);
 755                    thread
 756                })
 757            })
 758        })
 759    }
 760
 761    pub fn open_thread(
 762        &mut self,
 763        id: acp::SessionId,
 764        cx: &mut Context<Self>,
 765    ) -> Task<Result<Entity<AcpThread>>> {
 766        let task = self.load_thread(id, cx);
 767        cx.spawn(async move |this, cx| {
 768            let thread = task.await?;
 769            let acp_thread =
 770                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 771            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 772            cx.update(|cx| {
 773                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 774            })
 775            .await?;
 776            Ok(acp_thread)
 777        })
 778    }
 779
 780    pub fn thread_summary(
 781        &mut self,
 782        id: acp::SessionId,
 783        cx: &mut Context<Self>,
 784    ) -> Task<Result<SharedString>> {
 785        let thread = self.open_thread(id.clone(), cx);
 786        cx.spawn(async move |this, cx| {
 787            let acp_thread = thread.await?;
 788            let result = this
 789                .update(cx, |this, cx| {
 790                    this.sessions
 791                        .get(&id)
 792                        .unwrap()
 793                        .thread
 794                        .update(cx, |thread, cx| thread.summary(cx))
 795                })?
 796                .await
 797                .context("Failed to generate summary")?;
 798            drop(acp_thread);
 799            Ok(result)
 800        })
 801    }
 802
 803    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 804        if thread.read(cx).is_empty() {
 805            return;
 806        }
 807
 808        let database_future = ThreadsDatabase::connect(cx);
 809        let (id, db_thread) =
 810            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 811        let Some(session) = self.sessions.get_mut(&id) else {
 812            return;
 813        };
 814        let thread_store = self.thread_store.clone();
 815        session.pending_save = cx.spawn(async move |_, cx| {
 816            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 817                return;
 818            };
 819            let db_thread = db_thread.await;
 820            database.save_thread(id, db_thread).await.log_err();
 821            thread_store.update(cx, |store, cx| store.reload(cx));
 822        });
 823    }
 824
 825    fn send_mcp_prompt(
 826        &self,
 827        message_id: UserMessageId,
 828        session_id: agent_client_protocol::SessionId,
 829        prompt_name: String,
 830        server_id: ContextServerId,
 831        arguments: HashMap<String, String>,
 832        original_content: Vec<acp::ContentBlock>,
 833        cx: &mut Context<Self>,
 834    ) -> Task<Result<acp::PromptResponse>> {
 835        let server_store = self.context_server_registry.read(cx).server_store().clone();
 836        let path_style = self.project.read(cx).path_style(cx);
 837
 838        cx.spawn(async move |this, cx| {
 839            let prompt =
 840                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 841
 842            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 843                let session = this
 844                    .sessions
 845                    .get(&session_id)
 846                    .context("Failed to get session")?;
 847                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 848            })??;
 849
 850            let mut last_is_user = true;
 851
 852            thread.update(cx, |thread, cx| {
 853                thread.push_acp_user_block(
 854                    message_id,
 855                    original_content.into_iter().skip(1),
 856                    path_style,
 857                    cx,
 858                );
 859            });
 860
 861            for message in prompt.messages {
 862                let context_server::types::PromptMessage { role, content } = message;
 863                let block = mcp_message_content_to_acp_content_block(content);
 864
 865                match role {
 866                    context_server::types::Role::User => {
 867                        let id = acp_thread::UserMessageId::new();
 868
 869                        acp_thread.update(cx, |acp_thread, cx| {
 870                            acp_thread.push_user_content_block_with_indent(
 871                                Some(id.clone()),
 872                                block.clone(),
 873                                true,
 874                                cx,
 875                            );
 876                        })?;
 877
 878                        thread.update(cx, |thread, cx| {
 879                            thread.push_acp_user_block(id, [block], path_style, cx);
 880                        });
 881                    }
 882                    context_server::types::Role::Assistant => {
 883                        acp_thread.update(cx, |acp_thread, cx| {
 884                            acp_thread.push_assistant_content_block_with_indent(
 885                                block.clone(),
 886                                false,
 887                                true,
 888                                cx,
 889                            );
 890                        })?;
 891
 892                        thread.update(cx, |thread, cx| {
 893                            thread.push_acp_agent_block(block, cx);
 894                        });
 895                    }
 896                }
 897
 898                last_is_user = role == context_server::types::Role::User;
 899            }
 900
 901            let response_stream = thread.update(cx, |thread, cx| {
 902                if last_is_user {
 903                    thread.send_existing(cx)
 904                } else {
 905                    // Resume if MCP prompt did not end with a user message
 906                    thread.resume(cx)
 907                }
 908            })?;
 909
 910            cx.update(|cx| {
 911                NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
 912            })
 913            .await
 914        })
 915    }
 916}
 917
 918/// Wrapper struct that implements the AgentConnection trait
 919#[derive(Clone)]
 920pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 921
 922impl NativeAgentConnection {
 923    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 924        self.0
 925            .read(cx)
 926            .sessions
 927            .get(session_id)
 928            .map(|session| session.thread.clone())
 929    }
 930
 931    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 932        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 933    }
 934
 935    fn run_turn(
 936        &self,
 937        session_id: acp::SessionId,
 938        cx: &mut App,
 939        f: impl 'static
 940        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 941    ) -> Task<Result<acp::PromptResponse>> {
 942        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 943            agent
 944                .sessions
 945                .get_mut(&session_id)
 946                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 947        }) else {
 948            return Task::ready(Err(anyhow!("Session not found")));
 949        };
 950        log::debug!("Found session for: {}", session_id);
 951
 952        let response_stream = match f(thread, cx) {
 953            Ok(stream) => stream,
 954            Err(err) => return Task::ready(Err(err)),
 955        };
 956        Self::handle_thread_events(response_stream, acp_thread, cx)
 957    }
 958
 959    fn handle_thread_events(
 960        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 961        acp_thread: WeakEntity<AcpThread>,
 962        cx: &App,
 963    ) -> Task<Result<acp::PromptResponse>> {
 964        cx.spawn(async move |cx| {
 965            // Handle response stream and forward to session.acp_thread
 966            while let Some(result) = events.next().await {
 967                match result {
 968                    Ok(event) => {
 969                        log::trace!("Received completion event: {:?}", event);
 970
 971                        match event {
 972                            ThreadEvent::UserMessage(message) => {
 973                                acp_thread.update(cx, |thread, cx| {
 974                                    for content in message.content {
 975                                        thread.push_user_content_block(
 976                                            Some(message.id.clone()),
 977                                            content.into(),
 978                                            cx,
 979                                        );
 980                                    }
 981                                })?;
 982                            }
 983                            ThreadEvent::AgentText(text) => {
 984                                acp_thread.update(cx, |thread, cx| {
 985                                    thread.push_assistant_content_block(text.into(), false, cx)
 986                                })?;
 987                            }
 988                            ThreadEvent::AgentThinking(text) => {
 989                                acp_thread.update(cx, |thread, cx| {
 990                                    thread.push_assistant_content_block(text.into(), true, cx)
 991                                })?;
 992                            }
 993                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 994                                tool_call,
 995                                options,
 996                                response,
 997                            }) => {
 998                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 999                                    thread.request_tool_call_authorization(
1000                                        tool_call, options, true, cx,
1001                                    )
1002                                })??;
1003                                cx.background_spawn(async move {
1004                                    if let acp::RequestPermissionOutcome::Selected(
1005                                        acp::SelectedPermissionOutcome { option_id, .. },
1006                                    ) = outcome_task.await
1007                                    {
1008                                        response
1009                                            .send(option_id)
1010                                            .map(|_| anyhow!("authorization receiver was dropped"))
1011                                            .log_err();
1012                                    }
1013                                })
1014                                .detach();
1015                            }
1016                            ThreadEvent::ToolCall(tool_call) => {
1017                                acp_thread.update(cx, |thread, cx| {
1018                                    thread.upsert_tool_call(tool_call, cx)
1019                                })??;
1020                            }
1021                            ThreadEvent::ToolCallUpdate(update) => {
1022                                acp_thread.update(cx, |thread, cx| {
1023                                    thread.update_tool_call(update, cx)
1024                                })??;
1025                            }
1026                            ThreadEvent::Retry(status) => {
1027                                acp_thread.update(cx, |thread, cx| {
1028                                    thread.update_retry_status(status, cx)
1029                                })?;
1030                            }
1031                            ThreadEvent::Stop(stop_reason) => {
1032                                log::debug!("Assistant message complete: {:?}", stop_reason);
1033                                return Ok(acp::PromptResponse::new(stop_reason));
1034                            }
1035                        }
1036                    }
1037                    Err(e) => {
1038                        log::error!("Error in model response stream: {:?}", e);
1039                        return Err(e);
1040                    }
1041                }
1042            }
1043
1044            log::debug!("Response stream completed");
1045            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1046        })
1047    }
1048}
1049
1050struct Command<'a> {
1051    prompt_name: &'a str,
1052    arg_value: &'a str,
1053    explicit_server_id: Option<&'a str>,
1054}
1055
1056impl<'a> Command<'a> {
1057    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1058        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1059            return None;
1060        };
1061        let text = text_content.text.trim();
1062        let command = text.strip_prefix('/')?;
1063        let (command, arg_value) = command
1064            .split_once(char::is_whitespace)
1065            .unwrap_or((command, ""));
1066
1067        if let Some((server_id, prompt_name)) = command.split_once('.') {
1068            Some(Self {
1069                prompt_name,
1070                arg_value,
1071                explicit_server_id: Some(server_id),
1072            })
1073        } else {
1074            Some(Self {
1075                prompt_name: command,
1076                arg_value,
1077                explicit_server_id: None,
1078            })
1079        }
1080    }
1081}
1082
1083struct NativeAgentModelSelector {
1084    session_id: acp::SessionId,
1085    connection: NativeAgentConnection,
1086}
1087
1088impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1089    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1090        log::debug!("NativeAgentConnection::list_models called");
1091        let list = self.connection.0.read(cx).models.model_list.clone();
1092        Task::ready(if list.is_empty() {
1093            Err(anyhow::anyhow!("No models available"))
1094        } else {
1095            Ok(list)
1096        })
1097    }
1098
1099    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1100        log::debug!(
1101            "Setting model for session {}: {}",
1102            self.session_id,
1103            model_id
1104        );
1105        let Some(thread) = self
1106            .connection
1107            .0
1108            .read(cx)
1109            .sessions
1110            .get(&self.session_id)
1111            .map(|session| session.thread.clone())
1112        else {
1113            return Task::ready(Err(anyhow!("Session not found")));
1114        };
1115
1116        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1117            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1118        };
1119
1120        thread.update(cx, |thread, cx| {
1121            thread.set_model(model.clone(), cx);
1122        });
1123
1124        update_settings_file(
1125            self.connection.0.read(cx).fs.clone(),
1126            cx,
1127            move |settings, _cx| {
1128                let provider = model.provider_id().0.to_string();
1129                let model = model.id().0.to_string();
1130                settings
1131                    .agent
1132                    .get_or_insert_default()
1133                    .set_model(LanguageModelSelection {
1134                        provider: provider.into(),
1135                        model,
1136                    });
1137            },
1138        );
1139
1140        Task::ready(Ok(()))
1141    }
1142
1143    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1144        let Some(thread) = self
1145            .connection
1146            .0
1147            .read(cx)
1148            .sessions
1149            .get(&self.session_id)
1150            .map(|session| session.thread.clone())
1151        else {
1152            return Task::ready(Err(anyhow!("Session not found")));
1153        };
1154        let Some(model) = thread.read(cx).model() else {
1155            return Task::ready(Err(anyhow!("Model not found")));
1156        };
1157        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1158        else {
1159            return Task::ready(Err(anyhow!("Provider not found")));
1160        };
1161        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1162            model, &provider,
1163        )))
1164    }
1165
1166    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1167        Some(self.connection.0.read(cx).models.watch())
1168    }
1169
1170    fn should_render_footer(&self) -> bool {
1171        true
1172    }
1173}
1174
1175impl acp_thread::AgentConnection for NativeAgentConnection {
1176    fn telemetry_id(&self) -> SharedString {
1177        "zed".into()
1178    }
1179
1180    fn new_thread(
1181        self: Rc<Self>,
1182        project: Entity<Project>,
1183        cwd: &Path,
1184        cx: &mut App,
1185    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1186        let agent = self.0.clone();
1187        log::debug!("Creating new thread for project at: {:?}", cwd);
1188
1189        cx.spawn(async move |cx| {
1190            log::debug!("Starting thread creation in async context");
1191
1192            // Create Thread
1193            let thread = agent.update(cx, |agent, cx| {
1194                // Fetch default model from registry settings
1195                let registry = LanguageModelRegistry::read_global(cx);
1196                // Log available models for debugging
1197                let available_count = registry.available_models(cx).count();
1198                log::debug!("Total available models: {}", available_count);
1199
1200                let default_model = registry.default_model().and_then(|default_model| {
1201                    agent
1202                        .models
1203                        .model_from_id(&LanguageModels::model_id(&default_model.model))
1204                });
1205                cx.new(|cx| {
1206                    Thread::new(
1207                        project.clone(),
1208                        agent.project_context.clone(),
1209                        agent.context_server_registry.clone(),
1210                        agent.templates.clone(),
1211                        default_model,
1212                        cx,
1213                    )
1214                })
1215            });
1216            Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1217        })
1218    }
1219
1220    fn auth_methods(&self) -> &[acp::AuthMethod] {
1221        &[] // No auth for in-process
1222    }
1223
1224    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1225        Task::ready(Ok(()))
1226    }
1227
1228    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1229        Some(Rc::new(NativeAgentModelSelector {
1230            session_id: session_id.clone(),
1231            connection: self.clone(),
1232        }) as Rc<dyn AgentModelSelector>)
1233    }
1234
1235    fn prompt(
1236        &self,
1237        id: Option<acp_thread::UserMessageId>,
1238        params: acp::PromptRequest,
1239        cx: &mut App,
1240    ) -> Task<Result<acp::PromptResponse>> {
1241        let id = id.expect("UserMessageId is required");
1242        let session_id = params.session_id.clone();
1243        log::info!("Received prompt request for session: {}", session_id);
1244        log::debug!("Prompt blocks count: {}", params.prompt.len());
1245
1246        if let Some(parsed_command) = Command::parse(&params.prompt) {
1247            let registry = self.0.read(cx).context_server_registry.read(cx);
1248
1249            let explicit_server_id = parsed_command
1250                .explicit_server_id
1251                .map(|server_id| ContextServerId(server_id.into()));
1252
1253            if let Some(prompt) =
1254                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1255            {
1256                let arguments = if !parsed_command.arg_value.is_empty()
1257                    && let Some(arg_name) = prompt
1258                        .prompt
1259                        .arguments
1260                        .as_ref()
1261                        .and_then(|args| args.first())
1262                        .map(|arg| arg.name.clone())
1263                {
1264                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1265                } else {
1266                    Default::default()
1267                };
1268
1269                let prompt_name = prompt.prompt.name.clone();
1270                let server_id = prompt.server_id.clone();
1271
1272                return self.0.update(cx, |agent, cx| {
1273                    agent.send_mcp_prompt(
1274                        id,
1275                        session_id.clone(),
1276                        prompt_name,
1277                        server_id,
1278                        arguments,
1279                        params.prompt,
1280                        cx,
1281                    )
1282                });
1283            };
1284        };
1285
1286        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1287
1288        self.run_turn(session_id, cx, move |thread, cx| {
1289            let content: Vec<UserMessageContent> = params
1290                .prompt
1291                .into_iter()
1292                .map(|block| UserMessageContent::from_content_block(block, path_style))
1293                .collect::<Vec<_>>();
1294            log::debug!("Converted prompt to message: {} chars", content.len());
1295            log::debug!("Message id: {:?}", id);
1296            log::debug!("Message content: {:?}", content);
1297
1298            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1299        })
1300    }
1301
1302    fn resume(
1303        &self,
1304        session_id: &acp::SessionId,
1305        _cx: &App,
1306    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1307        Some(Rc::new(NativeAgentSessionResume {
1308            connection: self.clone(),
1309            session_id: session_id.clone(),
1310        }) as _)
1311    }
1312
1313    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1314        log::info!("Cancelling on session: {}", session_id);
1315        self.0.update(cx, |agent, cx| {
1316            if let Some(agent) = agent.sessions.get(session_id) {
1317                agent
1318                    .thread
1319                    .update(cx, |thread, cx| thread.cancel(cx))
1320                    .detach();
1321            }
1322        });
1323    }
1324
1325    fn truncate(
1326        &self,
1327        session_id: &agent_client_protocol::SessionId,
1328        cx: &App,
1329    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1330        self.0.read_with(cx, |agent, _cx| {
1331            agent.sessions.get(session_id).map(|session| {
1332                Rc::new(NativeAgentSessionTruncate {
1333                    thread: session.thread.clone(),
1334                    acp_thread: session.acp_thread.clone(),
1335                }) as _
1336            })
1337        })
1338    }
1339
1340    fn set_title(
1341        &self,
1342        session_id: &acp::SessionId,
1343        _cx: &App,
1344    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1345        Some(Rc::new(NativeAgentSessionSetTitle {
1346            connection: self.clone(),
1347            session_id: session_id.clone(),
1348        }) as _)
1349    }
1350
1351    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1352        let thread_store = self.0.read(cx).thread_store.clone();
1353        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1354    }
1355
1356    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1357        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1358    }
1359
1360    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1361        self
1362    }
1363}
1364
1365impl acp_thread::AgentTelemetry for NativeAgentConnection {
1366    fn thread_data(
1367        &self,
1368        session_id: &acp::SessionId,
1369        cx: &mut App,
1370    ) -> Task<Result<serde_json::Value>> {
1371        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1372            return Task::ready(Err(anyhow!("Session not found")));
1373        };
1374
1375        let task = session.thread.read(cx).to_db(cx);
1376        cx.background_spawn(async move {
1377            serde_json::to_value(task.await).context("Failed to serialize thread")
1378        })
1379    }
1380}
1381
1382pub struct NativeAgentSessionList {
1383    thread_store: Entity<ThreadStore>,
1384    updates_rx: watch::Receiver<()>,
1385    _subscription: Subscription,
1386}
1387
1388impl NativeAgentSessionList {
1389    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1390        let (mut tx, rx) = watch::channel(());
1391        let subscription = cx.observe(&thread_store, move |_, _| {
1392            tx.send(()).ok();
1393        });
1394        Self {
1395            thread_store,
1396            updates_rx: rx,
1397            _subscription: subscription,
1398        }
1399    }
1400
1401    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1402        AgentSessionInfo {
1403            session_id: entry.id,
1404            cwd: None,
1405            title: Some(entry.title),
1406            updated_at: Some(entry.updated_at),
1407            meta: None,
1408        }
1409    }
1410
1411    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1412        &self.thread_store
1413    }
1414}
1415
1416impl AgentSessionList for NativeAgentSessionList {
1417    fn list_sessions(
1418        &self,
1419        _request: AgentSessionListRequest,
1420        cx: &mut App,
1421    ) -> Task<Result<AgentSessionListResponse>> {
1422        let sessions = self
1423            .thread_store
1424            .read(cx)
1425            .entries()
1426            .map(Self::to_session_info)
1427            .collect();
1428        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1429    }
1430
1431    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1432        self.thread_store
1433            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1434    }
1435
1436    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1437        self.thread_store
1438            .update(cx, |store, cx| store.delete_threads(cx))
1439    }
1440
1441    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
1442        Some(self.updates_rx.clone())
1443    }
1444
1445    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1446        self
1447    }
1448}
1449
1450struct NativeAgentSessionTruncate {
1451    thread: Entity<Thread>,
1452    acp_thread: WeakEntity<AcpThread>,
1453}
1454
1455impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1456    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1457        match self.thread.update(cx, |thread, cx| {
1458            thread.truncate(message_id.clone(), cx)?;
1459            Ok(thread.latest_token_usage())
1460        }) {
1461            Ok(usage) => {
1462                self.acp_thread
1463                    .update(cx, |thread, cx| {
1464                        thread.update_token_usage(usage, cx);
1465                    })
1466                    .ok();
1467                Task::ready(Ok(()))
1468            }
1469            Err(error) => Task::ready(Err(error)),
1470        }
1471    }
1472}
1473
1474struct NativeAgentSessionResume {
1475    connection: NativeAgentConnection,
1476    session_id: acp::SessionId,
1477}
1478
1479impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1480    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1481        self.connection
1482            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1483                thread.update(cx, |thread, cx| thread.resume(cx))
1484            })
1485    }
1486}
1487
1488struct NativeAgentSessionSetTitle {
1489    connection: NativeAgentConnection,
1490    session_id: acp::SessionId,
1491}
1492
1493impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1494    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1495        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1496            return Task::ready(Err(anyhow!("session not found")));
1497        };
1498        let thread = session.thread.clone();
1499        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1500        Task::ready(Ok(()))
1501    }
1502}
1503
1504pub struct AcpThreadEnvironment {
1505    acp_thread: WeakEntity<AcpThread>,
1506}
1507
1508impl ThreadEnvironment for AcpThreadEnvironment {
1509    fn create_terminal(
1510        &self,
1511        command: String,
1512        cwd: Option<PathBuf>,
1513        output_byte_limit: Option<u64>,
1514        cx: &mut AsyncApp,
1515    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1516        let task = self.acp_thread.update(cx, |thread, cx| {
1517            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1518        });
1519
1520        let acp_thread = self.acp_thread.clone();
1521        cx.spawn(async move |cx| {
1522            let terminal = task?.await?;
1523
1524            let (drop_tx, drop_rx) = oneshot::channel();
1525            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1526
1527            cx.spawn(async move |cx| {
1528                drop_rx.await.ok();
1529                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1530            })
1531            .detach();
1532
1533            let handle = AcpTerminalHandle {
1534                terminal,
1535                _drop_tx: Some(drop_tx),
1536            };
1537
1538            Ok(Rc::new(handle) as _)
1539        })
1540    }
1541}
1542
1543pub struct AcpTerminalHandle {
1544    terminal: Entity<acp_thread::Terminal>,
1545    _drop_tx: Option<oneshot::Sender<()>>,
1546}
1547
1548impl TerminalHandle for AcpTerminalHandle {
1549    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1550        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1551    }
1552
1553    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1554        Ok(self
1555            .terminal
1556            .read_with(cx, |term, _cx| term.wait_for_exit()))
1557    }
1558
1559    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1560        Ok(self
1561            .terminal
1562            .read_with(cx, |term, cx| term.current_output(cx)))
1563    }
1564
1565    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1566        cx.update(|cx| {
1567            self.terminal.update(cx, |terminal, cx| {
1568                terminal.kill(cx);
1569            });
1570        });
1571        Ok(())
1572    }
1573
1574    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1575        Ok(self
1576            .terminal
1577            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1578    }
1579}
1580
1581#[cfg(test)]
1582mod internal_tests {
1583    use super::*;
1584    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1585    use fs::FakeFs;
1586    use gpui::TestAppContext;
1587    use indoc::formatdoc;
1588    use language_model::fake_provider::FakeLanguageModel;
1589    use serde_json::json;
1590    use settings::SettingsStore;
1591    use util::{path, rel_path::rel_path};
1592
1593    #[gpui::test]
1594    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1595        init_test(cx);
1596        let fs = FakeFs::new(cx.executor());
1597        fs.insert_tree(
1598            "/",
1599            json!({
1600                "a": {}
1601            }),
1602        )
1603        .await;
1604        let project = Project::test(fs.clone(), [], cx).await;
1605        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1606        let agent = NativeAgent::new(
1607            project.clone(),
1608            thread_store,
1609            Templates::new(),
1610            None,
1611            fs.clone(),
1612            &mut cx.to_async(),
1613        )
1614        .await
1615        .unwrap();
1616        agent.read_with(cx, |agent, cx| {
1617            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1618        });
1619
1620        let worktree = project
1621            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1622            .await
1623            .unwrap();
1624        cx.run_until_parked();
1625        agent.read_with(cx, |agent, cx| {
1626            assert_eq!(
1627                agent.project_context.read(cx).worktrees,
1628                vec![WorktreeContext {
1629                    root_name: "a".into(),
1630                    abs_path: Path::new("/a").into(),
1631                    rules_file: None
1632                }]
1633            )
1634        });
1635
1636        // Creating `/a/.rules` updates the project context.
1637        fs.insert_file("/a/.rules", Vec::new()).await;
1638        cx.run_until_parked();
1639        agent.read_with(cx, |agent, cx| {
1640            let rules_entry = worktree
1641                .read(cx)
1642                .entry_for_path(rel_path(".rules"))
1643                .unwrap();
1644            assert_eq!(
1645                agent.project_context.read(cx).worktrees,
1646                vec![WorktreeContext {
1647                    root_name: "a".into(),
1648                    abs_path: Path::new("/a").into(),
1649                    rules_file: Some(RulesFileContext {
1650                        path_in_worktree: rel_path(".rules").into(),
1651                        text: "".into(),
1652                        project_entry_id: rules_entry.id.to_usize()
1653                    })
1654                }]
1655            )
1656        });
1657    }
1658
1659    #[gpui::test]
1660    async fn test_listing_models(cx: &mut TestAppContext) {
1661        init_test(cx);
1662        let fs = FakeFs::new(cx.executor());
1663        fs.insert_tree("/", json!({ "a": {}  })).await;
1664        let project = Project::test(fs.clone(), [], cx).await;
1665        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1666        let connection = NativeAgentConnection(
1667            NativeAgent::new(
1668                project.clone(),
1669                thread_store,
1670                Templates::new(),
1671                None,
1672                fs.clone(),
1673                &mut cx.to_async(),
1674            )
1675            .await
1676            .unwrap(),
1677        );
1678
1679        // Create a thread/session
1680        let acp_thread = cx
1681            .update(|cx| {
1682                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1683            })
1684            .await
1685            .unwrap();
1686
1687        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1688
1689        let models = cx
1690            .update(|cx| {
1691                connection
1692                    .model_selector(&session_id)
1693                    .unwrap()
1694                    .list_models(cx)
1695            })
1696            .await
1697            .unwrap();
1698
1699        let acp_thread::AgentModelList::Grouped(models) = models else {
1700            panic!("Unexpected model group");
1701        };
1702        assert_eq!(
1703            models,
1704            IndexMap::from_iter([(
1705                AgentModelGroupName("Fake".into()),
1706                vec![AgentModelInfo {
1707                    id: acp::ModelId::new("fake/fake"),
1708                    name: "Fake".into(),
1709                    description: None,
1710                    icon: Some(acp_thread::AgentModelIcon::Named(
1711                        ui::IconName::ZedAssistant
1712                    )),
1713                }]
1714            )])
1715        );
1716    }
1717
1718    #[gpui::test]
1719    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1720        init_test(cx);
1721        let fs = FakeFs::new(cx.executor());
1722        fs.create_dir(paths::settings_file().parent().unwrap())
1723            .await
1724            .unwrap();
1725        fs.insert_file(
1726            paths::settings_file(),
1727            json!({
1728                "agent": {
1729                    "default_model": {
1730                        "provider": "foo",
1731                        "model": "bar"
1732                    }
1733                }
1734            })
1735            .to_string()
1736            .into_bytes(),
1737        )
1738        .await;
1739        let project = Project::test(fs.clone(), [], cx).await;
1740
1741        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1742
1743        // Create the agent and connection
1744        let agent = NativeAgent::new(
1745            project.clone(),
1746            thread_store,
1747            Templates::new(),
1748            None,
1749            fs.clone(),
1750            &mut cx.to_async(),
1751        )
1752        .await
1753        .unwrap();
1754        let connection = NativeAgentConnection(agent.clone());
1755
1756        // Create a thread/session
1757        let acp_thread = cx
1758            .update(|cx| {
1759                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1760            })
1761            .await
1762            .unwrap();
1763
1764        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1765
1766        // Select a model
1767        let selector = connection.model_selector(&session_id).unwrap();
1768        let model_id = acp::ModelId::new("fake/fake");
1769        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1770            .await
1771            .unwrap();
1772
1773        // Verify the thread has the selected model
1774        agent.read_with(cx, |agent, _| {
1775            let session = agent.sessions.get(&session_id).unwrap();
1776            session.thread.read_with(cx, |thread, _| {
1777                assert_eq!(thread.model().unwrap().id().0, "fake");
1778            });
1779        });
1780
1781        cx.run_until_parked();
1782
1783        // Verify settings file was updated
1784        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1785        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1786
1787        // Check that the agent settings contain the selected model
1788        assert_eq!(
1789            settings_json["agent"]["default_model"]["model"],
1790            json!("fake")
1791        );
1792        assert_eq!(
1793            settings_json["agent"]["default_model"]["provider"],
1794            json!("fake")
1795        );
1796    }
1797
1798    #[gpui::test]
1799    async fn test_save_load_thread(cx: &mut TestAppContext) {
1800        init_test(cx);
1801        let fs = FakeFs::new(cx.executor());
1802        fs.insert_tree(
1803            "/",
1804            json!({
1805                "a": {
1806                    "b.md": "Lorem"
1807                }
1808            }),
1809        )
1810        .await;
1811        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1812        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1813        let agent = NativeAgent::new(
1814            project.clone(),
1815            thread_store.clone(),
1816            Templates::new(),
1817            None,
1818            fs.clone(),
1819            &mut cx.to_async(),
1820        )
1821        .await
1822        .unwrap();
1823        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1824
1825        let acp_thread = cx
1826            .update(|cx| {
1827                connection
1828                    .clone()
1829                    .new_thread(project.clone(), Path::new(""), cx)
1830            })
1831            .await
1832            .unwrap();
1833        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1834        let thread = agent.read_with(cx, |agent, _| {
1835            agent.sessions.get(&session_id).unwrap().thread.clone()
1836        });
1837
1838        // Ensure empty threads are not saved, even if they get mutated.
1839        let model = Arc::new(FakeLanguageModel::default());
1840        let summary_model = Arc::new(FakeLanguageModel::default());
1841        thread.update(cx, |thread, cx| {
1842            thread.set_model(model.clone(), cx);
1843            thread.set_summarization_model(Some(summary_model.clone()), cx);
1844        });
1845        cx.run_until_parked();
1846        assert_eq!(thread_entries(&thread_store, cx), vec![]);
1847
1848        let send = acp_thread.update(cx, |thread, cx| {
1849            thread.send(
1850                vec![
1851                    "What does ".into(),
1852                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1853                        "b.md",
1854                        MentionUri::File {
1855                            abs_path: path!("/a/b.md").into(),
1856                        }
1857                        .to_uri()
1858                        .to_string(),
1859                    )),
1860                    " mean?".into(),
1861                ],
1862                cx,
1863            )
1864        });
1865        let send = cx.foreground_executor().spawn(send);
1866        cx.run_until_parked();
1867
1868        model.send_last_completion_stream_text_chunk("Lorem.");
1869        model.end_last_completion_stream();
1870        cx.run_until_parked();
1871        summary_model
1872            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1873        summary_model.end_last_completion_stream();
1874
1875        send.await.unwrap();
1876        let uri = MentionUri::File {
1877            abs_path: path!("/a/b.md").into(),
1878        }
1879        .to_uri();
1880        acp_thread.read_with(cx, |thread, cx| {
1881            assert_eq!(
1882                thread.to_markdown(cx),
1883                formatdoc! {"
1884                    ## User
1885
1886                    What does [@b.md]({uri}) mean?
1887
1888                    ## Assistant
1889
1890                    Lorem.
1891
1892                "}
1893            )
1894        });
1895
1896        cx.run_until_parked();
1897
1898        // Drop the ACP thread, which should cause the session to be dropped as well.
1899        cx.update(|_| {
1900            drop(thread);
1901            drop(acp_thread);
1902        });
1903        agent.read_with(cx, |agent, _| {
1904            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1905        });
1906
1907        // Ensure the thread can be reloaded from disk.
1908        assert_eq!(
1909            thread_entries(&thread_store, cx),
1910            vec![(
1911                session_id.clone(),
1912                format!("Explaining {}", path!("/a/b.md"))
1913            )]
1914        );
1915        let acp_thread = agent
1916            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1917            .await
1918            .unwrap();
1919        acp_thread.read_with(cx, |thread, cx| {
1920            assert_eq!(
1921                thread.to_markdown(cx),
1922                formatdoc! {"
1923                    ## User
1924
1925                    What does [@b.md]({uri}) mean?
1926
1927                    ## Assistant
1928
1929                    Lorem.
1930
1931                "}
1932            )
1933        });
1934    }
1935
1936    fn thread_entries(
1937        thread_store: &Entity<ThreadStore>,
1938        cx: &mut TestAppContext,
1939    ) -> Vec<(acp::SessionId, String)> {
1940        thread_store.read_with(cx, |store, _| {
1941            store
1942                .entries()
1943                .map(|entry| (entry.id.clone(), entry.title.to_string()))
1944                .collect::<Vec<_>>()
1945        })
1946    }
1947
1948    fn init_test(cx: &mut TestAppContext) {
1949        env_logger::try_init().ok();
1950        cx.update(|cx| {
1951            let settings_store = SettingsStore::test(cx);
1952            cx.set_global(settings_store);
1953
1954            LanguageModelRegistry::test(cx);
1955        });
1956    }
1957}
1958
1959fn mcp_message_content_to_acp_content_block(
1960    content: context_server::types::MessageContent,
1961) -> acp::ContentBlock {
1962    match content {
1963        context_server::types::MessageContent::Text {
1964            text,
1965            annotations: _,
1966        } => text.into(),
1967        context_server::types::MessageContent::Image {
1968            data,
1969            mime_type,
1970            annotations: _,
1971        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1972        context_server::types::MessageContent::Audio {
1973            data,
1974            mime_type,
1975            annotations: _,
1976        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1977        context_server::types::MessageContent::Resource {
1978            resource,
1979            annotations: _,
1980        } => {
1981            let mut link =
1982                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1983            if let Some(mime_type) = resource.mime_type {
1984                link = link.mime_type(mime_type);
1985            }
1986            acp::ContentBlock::ResourceLink(link)
1987        }
1988    }
1989}