agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let session_id = thread.id().clone();
 332        let parent_session_id = thread.parent_thread_id();
 333        let title = thread.title();
 334        let draft_prompt = thread.draft_prompt().map(Vec::from);
 335        let scroll_position = thread.ui_scroll_position();
 336        let token_usage = thread.latest_token_usage();
 337        let project = thread.project.clone();
 338        let action_log = thread.action_log.clone();
 339        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 340        let acp_thread = cx.new(|cx| {
 341            let mut acp_thread = acp_thread::AcpThread::new(
 342                parent_session_id,
 343                title,
 344                None,
 345                connection,
 346                project.clone(),
 347                action_log.clone(),
 348                session_id.clone(),
 349                prompt_capabilities_rx,
 350                cx,
 351            );
 352            acp_thread.set_draft_prompt(draft_prompt);
 353            acp_thread.set_ui_scroll_position(scroll_position);
 354            acp_thread.update_token_usage(token_usage, cx);
 355            acp_thread
 356        });
 357
 358        let registry = LanguageModelRegistry::read_global(cx);
 359        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 360
 361        let weak = cx.weak_entity();
 362        let weak_thread = thread_handle.downgrade();
 363        thread_handle.update(cx, |thread, cx| {
 364            thread.set_summarization_model(summarization_model, cx);
 365            thread.add_default_tools(
 366                Rc::new(NativeThreadEnvironment {
 367                    acp_thread: acp_thread.downgrade(),
 368                    thread: weak_thread,
 369                    agent: weak,
 370                }) as _,
 371                cx,
 372            )
 373        });
 374
 375        let subscriptions = vec![
 376            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 377            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 378            cx.observe(&thread_handle, move |this, thread, cx| {
 379                this.save_thread(thread, cx)
 380            }),
 381        ];
 382
 383        self.sessions.insert(
 384            session_id,
 385            Session {
 386                thread: thread_handle,
 387                acp_thread: acp_thread.clone(),
 388                project_id,
 389                _subscriptions: subscriptions,
 390                pending_save: Task::ready(Ok(())),
 391            },
 392        );
 393
 394        self.update_available_commands_for_project(project_id, cx);
 395
 396        acp_thread
 397    }
 398
 399    pub fn models(&self) -> &LanguageModels {
 400        &self.models
 401    }
 402
 403    fn get_or_create_project_state(
 404        &mut self,
 405        project: &Entity<Project>,
 406        cx: &mut Context<Self>,
 407    ) -> EntityId {
 408        let project_id = project.entity_id();
 409        if self.projects.contains_key(&project_id) {
 410            return project_id;
 411        }
 412
 413        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 414        self.register_project_with_initial_context(project.clone(), project_context, cx);
 415        if let Some(state) = self.projects.get_mut(&project_id) {
 416            state.project_context_needs_refresh.send(()).ok();
 417        }
 418        project_id
 419    }
 420
 421    fn register_project_with_initial_context(
 422        &mut self,
 423        project: Entity<Project>,
 424        project_context: Entity<ProjectContext>,
 425        cx: &mut Context<Self>,
 426    ) {
 427        let project_id = project.entity_id();
 428
 429        let context_server_store = project.read(cx).context_server_store();
 430        let context_server_registry =
 431            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 432
 433        let subscriptions = vec![
 434            cx.subscribe(&project, Self::handle_project_event),
 435            cx.subscribe(
 436                &context_server_store,
 437                Self::handle_context_server_store_updated,
 438            ),
 439            cx.subscribe(
 440                &context_server_registry,
 441                Self::handle_context_server_registry_event,
 442            ),
 443        ];
 444
 445        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 446            watch::channel(());
 447
 448        self.projects.insert(
 449            project_id,
 450            ProjectState {
 451                project,
 452                project_context,
 453                project_context_needs_refresh: project_context_needs_refresh_tx,
 454                _maintain_project_context: cx.spawn(async move |this, cx| {
 455                    Self::maintain_project_context(
 456                        this,
 457                        project_id,
 458                        project_context_needs_refresh_rx,
 459                        cx,
 460                    )
 461                    .await
 462                }),
 463                context_server_registry,
 464                _subscriptions: subscriptions,
 465            },
 466        );
 467    }
 468
 469    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 470        self.sessions
 471            .get(session_id)
 472            .and_then(|session| self.projects.get(&session.project_id))
 473    }
 474
 475    async fn maintain_project_context(
 476        this: WeakEntity<Self>,
 477        project_id: EntityId,
 478        mut needs_refresh: watch::Receiver<()>,
 479        cx: &mut AsyncApp,
 480    ) -> Result<()> {
 481        while needs_refresh.changed().await.is_ok() {
 482            let project_context = this
 483                .update(cx, |this, cx| {
 484                    let state = this
 485                        .projects
 486                        .get(&project_id)
 487                        .context("project state not found")?;
 488                    anyhow::Ok(Self::build_project_context(
 489                        &state.project,
 490                        this.prompt_store.as_ref(),
 491                        cx,
 492                    ))
 493                })??
 494                .await;
 495            this.update(cx, |this, cx| {
 496                if let Some(state) = this.projects.get(&project_id) {
 497                    state
 498                        .project_context
 499                        .update(cx, |current_project_context, _cx| {
 500                            *current_project_context = project_context;
 501                        });
 502                }
 503            })?;
 504        }
 505
 506        Ok(())
 507    }
 508
 509    fn build_project_context(
 510        project: &Entity<Project>,
 511        prompt_store: Option<&Entity<PromptStore>>,
 512        cx: &mut App,
 513    ) -> Task<ProjectContext> {
 514        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 515        let worktree_tasks = worktrees
 516            .into_iter()
 517            .map(|worktree| {
 518                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 519            })
 520            .collect::<Vec<_>>();
 521        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 522            prompt_store.read_with(cx, |prompt_store, cx| {
 523                let prompts = prompt_store.default_prompt_metadata();
 524                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 525                    let contents = prompt_store.load(prompt_metadata.id, cx);
 526                    async move { (contents.await, prompt_metadata) }
 527                });
 528                cx.background_spawn(future::join_all(load_tasks))
 529            })
 530        } else {
 531            Task::ready(vec![])
 532        };
 533
 534        cx.spawn(async move |_cx| {
 535            let (worktrees, default_user_rules) =
 536                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 537
 538            let worktrees = worktrees
 539                .into_iter()
 540                .map(|(worktree, _rules_error)| {
 541                    // TODO: show error message
 542                    // if let Some(rules_error) = rules_error {
 543                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 544                    // }
 545                    worktree
 546                })
 547                .collect::<Vec<_>>();
 548
 549            let default_user_rules = default_user_rules
 550                .into_iter()
 551                .flat_map(|(contents, prompt_metadata)| match contents {
 552                    Ok(contents) => Some(UserRulesContext {
 553                        uuid: prompt_metadata.id.as_user()?,
 554                        title: prompt_metadata.title.map(|title| title.to_string()),
 555                        contents,
 556                    }),
 557                    Err(_err) => {
 558                        // TODO: show error message
 559                        // this.update(cx, |_, cx| {
 560                        //     cx.emit(RulesLoadingError {
 561                        //         message: format!("{err:?}").into(),
 562                        //     });
 563                        // })
 564                        // .ok();
 565                        None
 566                    }
 567                })
 568                .collect::<Vec<_>>();
 569
 570            ProjectContext::new(worktrees, default_user_rules)
 571        })
 572    }
 573
 574    fn load_worktree_info_for_system_prompt(
 575        worktree: Entity<Worktree>,
 576        project: Entity<Project>,
 577        cx: &mut App,
 578    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 579        let tree = worktree.read(cx);
 580        let root_name = tree.root_name_str().into();
 581        let abs_path = tree.abs_path();
 582
 583        let mut context = WorktreeContext {
 584            root_name,
 585            abs_path,
 586            rules_file: None,
 587        };
 588
 589        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 590        let Some(rules_task) = rules_task else {
 591            return Task::ready((context, None));
 592        };
 593
 594        cx.spawn(async move |_| {
 595            let (rules_file, rules_file_error) = match rules_task.await {
 596                Ok(rules_file) => (Some(rules_file), None),
 597                Err(err) => (
 598                    None,
 599                    Some(RulesLoadingError {
 600                        message: format!("{err}").into(),
 601                    }),
 602                ),
 603            };
 604            context.rules_file = rules_file;
 605            (context, rules_file_error)
 606        })
 607    }
 608
 609    fn load_worktree_rules_file(
 610        worktree: Entity<Worktree>,
 611        project: Entity<Project>,
 612        cx: &mut App,
 613    ) -> Option<Task<Result<RulesFileContext>>> {
 614        let worktree = worktree.read(cx);
 615        let worktree_id = worktree.id();
 616        let selected_rules_file = RULES_FILE_NAMES
 617            .into_iter()
 618            .filter_map(|name| {
 619                worktree
 620                    .entry_for_path(RelPath::unix(name).unwrap())
 621                    .filter(|entry| entry.is_file())
 622                    .map(|entry| entry.path.clone())
 623            })
 624            .next();
 625
 626        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 627        // supported. This doesn't seem to occur often in GitHub repositories.
 628        selected_rules_file.map(|path_in_worktree| {
 629            let project_path = ProjectPath {
 630                worktree_id,
 631                path: path_in_worktree.clone(),
 632            };
 633            let buffer_task =
 634                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 635            let rope_task = cx.spawn(async move |cx| {
 636                let buffer = buffer_task.await?;
 637                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 638                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 639                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 640                })?;
 641                anyhow::Ok((project_entry_id, rope))
 642            });
 643            // Build a string from the rope on a background thread.
 644            cx.background_spawn(async move {
 645                let (project_entry_id, rope) = rope_task.await?;
 646                anyhow::Ok(RulesFileContext {
 647                    path_in_worktree,
 648                    text: rope.to_string().trim().to_string(),
 649                    project_entry_id: project_entry_id.to_usize(),
 650                })
 651            })
 652        })
 653    }
 654
 655    fn handle_thread_title_updated(
 656        &mut self,
 657        thread: Entity<Thread>,
 658        _: &TitleUpdated,
 659        cx: &mut Context<Self>,
 660    ) {
 661        let session_id = thread.read(cx).id();
 662        let Some(session) = self.sessions.get(session_id) else {
 663            return;
 664        };
 665
 666        let thread = thread.downgrade();
 667        let acp_thread = session.acp_thread.downgrade();
 668        cx.spawn(async move |_, cx| {
 669            let title = thread.read_with(cx, |thread, _| thread.title())?;
 670            if let Some(title) = title {
 671                let task =
 672                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 673                task.await?;
 674            }
 675            anyhow::Ok(())
 676        })
 677        .detach_and_log_err(cx);
 678    }
 679
 680    fn handle_thread_token_usage_updated(
 681        &mut self,
 682        thread: Entity<Thread>,
 683        usage: &TokenUsageUpdated,
 684        cx: &mut Context<Self>,
 685    ) {
 686        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 687            return;
 688        };
 689        session.acp_thread.update(cx, |acp_thread, cx| {
 690            acp_thread.update_token_usage(usage.0.clone(), cx);
 691        });
 692    }
 693
 694    fn handle_project_event(
 695        &mut self,
 696        project: Entity<Project>,
 697        event: &project::Event,
 698        _cx: &mut Context<Self>,
 699    ) {
 700        let project_id = project.entity_id();
 701        let Some(state) = self.projects.get_mut(&project_id) else {
 702            return;
 703        };
 704        match event {
 705            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 706                state.project_context_needs_refresh.send(()).ok();
 707            }
 708            project::Event::WorktreeUpdatedEntries(_, items) => {
 709                if items.iter().any(|(path, _, _)| {
 710                    RULES_FILE_NAMES
 711                        .iter()
 712                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 713                }) {
 714                    state.project_context_needs_refresh.send(()).ok();
 715                }
 716            }
 717            _ => {}
 718        }
 719    }
 720
 721    fn handle_prompts_updated_event(
 722        &mut self,
 723        _prompt_store: Entity<PromptStore>,
 724        _event: &prompt_store::PromptsUpdatedEvent,
 725        _cx: &mut Context<Self>,
 726    ) {
 727        for state in self.projects.values_mut() {
 728            state.project_context_needs_refresh.send(()).ok();
 729        }
 730    }
 731
 732    fn handle_models_updated_event(
 733        &mut self,
 734        _registry: Entity<LanguageModelRegistry>,
 735        event: &language_model::Event,
 736        cx: &mut Context<Self>,
 737    ) {
 738        self.models.refresh_list(cx);
 739
 740        let registry = LanguageModelRegistry::read_global(cx);
 741        let default_model = registry.default_model().map(|m| m.model);
 742        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 743
 744        for session in self.sessions.values_mut() {
 745            session.thread.update(cx, |thread, cx| {
 746                if thread.model().is_none()
 747                    && let Some(model) = default_model.clone()
 748                {
 749                    thread.set_model(model, cx);
 750                    cx.notify();
 751                }
 752                if let Some(model) = summarization_model.clone() {
 753                    if thread.summarization_model().is_none()
 754                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 755                    {
 756                        thread.set_summarization_model(Some(model), cx);
 757                    }
 758                }
 759            });
 760        }
 761    }
 762
 763    fn handle_context_server_store_updated(
 764        &mut self,
 765        store: Entity<project::context_server_store::ContextServerStore>,
 766        _event: &project::context_server_store::ServerStatusChangedEvent,
 767        cx: &mut Context<Self>,
 768    ) {
 769        let project_id = self.projects.iter().find_map(|(id, state)| {
 770            if *state.context_server_registry.read(cx).server_store() == store {
 771                Some(*id)
 772            } else {
 773                None
 774            }
 775        });
 776        if let Some(project_id) = project_id {
 777            self.update_available_commands_for_project(project_id, cx);
 778        }
 779    }
 780
 781    fn handle_context_server_registry_event(
 782        &mut self,
 783        registry: Entity<ContextServerRegistry>,
 784        event: &ContextServerRegistryEvent,
 785        cx: &mut Context<Self>,
 786    ) {
 787        match event {
 788            ContextServerRegistryEvent::ToolsChanged => {}
 789            ContextServerRegistryEvent::PromptsChanged => {
 790                let project_id = self.projects.iter().find_map(|(id, state)| {
 791                    if state.context_server_registry == registry {
 792                        Some(*id)
 793                    } else {
 794                        None
 795                    }
 796                });
 797                if let Some(project_id) = project_id {
 798                    self.update_available_commands_for_project(project_id, cx);
 799                }
 800            }
 801        }
 802    }
 803
 804    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 805        let available_commands =
 806            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 807        for session in self.sessions.values() {
 808            if session.project_id != project_id {
 809                continue;
 810            }
 811            session.acp_thread.update(cx, |thread, cx| {
 812                thread
 813                    .handle_session_update(
 814                        acp::SessionUpdate::AvailableCommandsUpdate(
 815                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 816                        ),
 817                        cx,
 818                    )
 819                    .log_err();
 820            });
 821        }
 822    }
 823
 824    fn build_available_commands_for_project(
 825        project_state: Option<&ProjectState>,
 826        cx: &App,
 827    ) -> Vec<acp::AvailableCommand> {
 828        let Some(state) = project_state else {
 829            return vec![];
 830        };
 831        let registry = state.context_server_registry.read(cx);
 832
 833        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 834        for context_server_prompt in registry.prompts() {
 835            *prompt_name_counts
 836                .entry(context_server_prompt.prompt.name.as_str())
 837                .or_insert(0) += 1;
 838        }
 839
 840        registry
 841            .prompts()
 842            .flat_map(|context_server_prompt| {
 843                let prompt = &context_server_prompt.prompt;
 844
 845                let should_prefix = prompt_name_counts
 846                    .get(prompt.name.as_str())
 847                    .copied()
 848                    .unwrap_or(0)
 849                    > 1;
 850
 851                let name = if should_prefix {
 852                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 853                } else {
 854                    prompt.name.clone()
 855                };
 856
 857                let mut command = acp::AvailableCommand::new(
 858                    name,
 859                    prompt.description.clone().unwrap_or_default(),
 860                );
 861
 862                match prompt.arguments.as_deref() {
 863                    Some([arg]) => {
 864                        let hint = format!("<{}>", arg.name);
 865
 866                        command = command.input(acp::AvailableCommandInput::Unstructured(
 867                            acp::UnstructuredCommandInput::new(hint),
 868                        ));
 869                    }
 870                    Some([]) | None => {}
 871                    Some(_) => {
 872                        // skip >1 argument commands since we don't support them yet
 873                        return None;
 874                    }
 875                }
 876
 877                Some(command)
 878            })
 879            .collect()
 880    }
 881
 882    pub fn load_thread(
 883        &mut self,
 884        id: acp::SessionId,
 885        project: Entity<Project>,
 886        cx: &mut Context<Self>,
 887    ) -> Task<Result<Entity<Thread>>> {
 888        let database_future = ThreadsDatabase::connect(cx);
 889        cx.spawn(async move |this, cx| {
 890            let database = database_future.await.map_err(|err| anyhow!(err))?;
 891            let db_thread = database
 892                .load_thread(id.clone())
 893                .await?
 894                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 895
 896            this.update(cx, |this, cx| {
 897                let project_id = this.get_or_create_project_state(&project, cx);
 898                let project_state = this
 899                    .projects
 900                    .get(&project_id)
 901                    .context("project state not found")?;
 902                let summarization_model = LanguageModelRegistry::read_global(cx)
 903                    .thread_summary_model()
 904                    .map(|c| c.model);
 905
 906                Ok(cx.new(|cx| {
 907                    let mut thread = Thread::from_db(
 908                        id.clone(),
 909                        db_thread,
 910                        project_state.project.clone(),
 911                        project_state.project_context.clone(),
 912                        project_state.context_server_registry.clone(),
 913                        this.templates.clone(),
 914                        cx,
 915                    );
 916                    thread.set_summarization_model(summarization_model, cx);
 917                    thread
 918                }))
 919            })?
 920        })
 921    }
 922
 923    pub fn open_thread(
 924        &mut self,
 925        id: acp::SessionId,
 926        project: Entity<Project>,
 927        cx: &mut Context<Self>,
 928    ) -> Task<Result<Entity<AcpThread>>> {
 929        if let Some(session) = self.sessions.get(&id) {
 930            return Task::ready(Ok(session.acp_thread.clone()));
 931        }
 932
 933        let task = self.load_thread(id, project.clone(), cx);
 934        cx.spawn(async move |this, cx| {
 935            let thread = task.await?;
 936            let acp_thread = this.update(cx, |this, cx| {
 937                let project_id = this.get_or_create_project_state(&project, cx);
 938                this.register_session(thread.clone(), project_id, cx)
 939            })?;
 940            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 941            cx.update(|cx| {
 942                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 943            })
 944            .await?;
 945            acp_thread.update(cx, |thread, cx| {
 946                thread.snapshot_completed_plan(cx);
 947            });
 948            Ok(acp_thread)
 949        })
 950    }
 951
 952    pub fn thread_summary(
 953        &mut self,
 954        id: acp::SessionId,
 955        project: Entity<Project>,
 956        cx: &mut Context<Self>,
 957    ) -> Task<Result<SharedString>> {
 958        let thread = self.open_thread(id.clone(), project, cx);
 959        cx.spawn(async move |this, cx| {
 960            let acp_thread = thread.await?;
 961            let result = this
 962                .update(cx, |this, cx| {
 963                    this.sessions
 964                        .get(&id)
 965                        .unwrap()
 966                        .thread
 967                        .update(cx, |thread, cx| thread.summary(cx))
 968                })?
 969                .await
 970                .context("Failed to generate summary")?;
 971            drop(acp_thread);
 972            Ok(result)
 973        })
 974    }
 975
 976    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 977        let id = thread.read(cx).id().clone();
 978        let Some(session) = self.sessions.get_mut(&id) else {
 979            return;
 980        };
 981
 982        let has_draft_prompt = session
 983            .acp_thread
 984            .read(cx)
 985            .draft_prompt()
 986            .is_some_and(|p| !p.is_empty());
 987        if thread.read(cx).is_empty() && !has_draft_prompt {
 988            return;
 989        }
 990
 991        let project_id = session.project_id;
 992        let Some(state) = self.projects.get(&project_id) else {
 993            return;
 994        };
 995
 996        let folder_paths = PathList::new(
 997            &state
 998                .project
 999                .read(cx)
1000                .visible_worktrees(cx)
1001                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1002                .collect::<Vec<_>>(),
1003        );
1004
1005        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1006        let database_future = ThreadsDatabase::connect(cx);
1007        let db_thread = thread.update(cx, |thread, cx| {
1008            thread.set_draft_prompt(draft_prompt);
1009            thread.to_db(cx)
1010        });
1011        let thread_store = self.thread_store.clone();
1012        session.pending_save = cx.spawn(async move |_, cx| {
1013            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1014                return Ok(());
1015            };
1016            let db_thread = db_thread.await;
1017            database
1018                .save_thread(id, db_thread, folder_paths)
1019                .await
1020                .log_err();
1021            thread_store.update(cx, |store, cx| store.reload(cx));
1022            Ok(())
1023        });
1024    }
1025
1026    fn send_mcp_prompt(
1027        &self,
1028        message_id: UserMessageId,
1029        session_id: acp::SessionId,
1030        prompt_name: String,
1031        server_id: ContextServerId,
1032        arguments: HashMap<String, String>,
1033        original_content: Vec<acp::ContentBlock>,
1034        cx: &mut Context<Self>,
1035    ) -> Task<Result<acp::PromptResponse>> {
1036        let Some(state) = self.session_project_state(&session_id) else {
1037            return Task::ready(Err(anyhow!("Project state not found for session")));
1038        };
1039        let server_store = state
1040            .context_server_registry
1041            .read(cx)
1042            .server_store()
1043            .clone();
1044        let path_style = state.project.read(cx).path_style(cx);
1045
1046        cx.spawn(async move |this, cx| {
1047            let prompt =
1048                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1049
1050            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1051                let session = this
1052                    .sessions
1053                    .get(&session_id)
1054                    .context("Failed to get session")?;
1055                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1056            })??;
1057
1058            let mut last_is_user = true;
1059
1060            thread.update(cx, |thread, cx| {
1061                thread.push_acp_user_block(
1062                    message_id,
1063                    original_content.into_iter().skip(1),
1064                    path_style,
1065                    cx,
1066                );
1067            });
1068
1069            for message in prompt.messages {
1070                let context_server::types::PromptMessage { role, content } = message;
1071                let block = mcp_message_content_to_acp_content_block(content);
1072
1073                match role {
1074                    context_server::types::Role::User => {
1075                        let id = acp_thread::UserMessageId::new();
1076
1077                        acp_thread.update(cx, |acp_thread, cx| {
1078                            acp_thread.push_user_content_block_with_indent(
1079                                Some(id.clone()),
1080                                block.clone(),
1081                                true,
1082                                cx,
1083                            );
1084                        });
1085
1086                        thread.update(cx, |thread, cx| {
1087                            thread.push_acp_user_block(id, [block], path_style, cx);
1088                        });
1089                    }
1090                    context_server::types::Role::Assistant => {
1091                        acp_thread.update(cx, |acp_thread, cx| {
1092                            acp_thread.push_assistant_content_block_with_indent(
1093                                block.clone(),
1094                                false,
1095                                true,
1096                                cx,
1097                            );
1098                        });
1099
1100                        thread.update(cx, |thread, cx| {
1101                            thread.push_acp_agent_block(block, cx);
1102                        });
1103                    }
1104                }
1105
1106                last_is_user = role == context_server::types::Role::User;
1107            }
1108
1109            let response_stream = thread.update(cx, |thread, cx| {
1110                if last_is_user {
1111                    thread.send_existing(cx)
1112                } else {
1113                    // Resume if MCP prompt did not end with a user message
1114                    thread.resume(cx)
1115                }
1116            })?;
1117
1118            cx.update(|cx| {
1119                NativeAgentConnection::handle_thread_events(
1120                    response_stream,
1121                    acp_thread.downgrade(),
1122                    cx,
1123                )
1124            })
1125            .await
1126        })
1127    }
1128}
1129
1130/// Wrapper struct that implements the AgentConnection trait
1131#[derive(Clone)]
1132pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1133
1134impl NativeAgentConnection {
1135    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1136        self.0
1137            .read(cx)
1138            .sessions
1139            .get(session_id)
1140            .map(|session| session.thread.clone())
1141    }
1142
1143    pub fn load_thread(
1144        &self,
1145        id: acp::SessionId,
1146        project: Entity<Project>,
1147        cx: &mut App,
1148    ) -> Task<Result<Entity<Thread>>> {
1149        self.0
1150            .update(cx, |this, cx| this.load_thread(id, project, cx))
1151    }
1152
1153    fn run_turn(
1154        &self,
1155        session_id: acp::SessionId,
1156        cx: &mut App,
1157        f: impl 'static
1158        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1159    ) -> Task<Result<acp::PromptResponse>> {
1160        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1161            agent
1162                .sessions
1163                .get_mut(&session_id)
1164                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1165        }) else {
1166            return Task::ready(Err(anyhow!("Session not found")));
1167        };
1168        log::debug!("Found session for: {}", session_id);
1169
1170        let response_stream = match f(thread, cx) {
1171            Ok(stream) => stream,
1172            Err(err) => return Task::ready(Err(err)),
1173        };
1174        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1175    }
1176
1177    fn handle_thread_events(
1178        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1179        acp_thread: WeakEntity<AcpThread>,
1180        cx: &App,
1181    ) -> Task<Result<acp::PromptResponse>> {
1182        cx.spawn(async move |cx| {
1183            // Handle response stream and forward to session.acp_thread
1184            while let Some(result) = events.next().await {
1185                match result {
1186                    Ok(event) => {
1187                        log::trace!("Received completion event: {:?}", event);
1188
1189                        match event {
1190                            ThreadEvent::UserMessage(message) => {
1191                                acp_thread.update(cx, |thread, cx| {
1192                                    for content in message.content {
1193                                        thread.push_user_content_block(
1194                                            Some(message.id.clone()),
1195                                            content.into(),
1196                                            cx,
1197                                        );
1198                                    }
1199                                })?;
1200                            }
1201                            ThreadEvent::AgentText(text) => {
1202                                acp_thread.update(cx, |thread, cx| {
1203                                    thread.push_assistant_content_block(text.into(), false, cx)
1204                                })?;
1205                            }
1206                            ThreadEvent::AgentThinking(text) => {
1207                                acp_thread.update(cx, |thread, cx| {
1208                                    thread.push_assistant_content_block(text.into(), true, cx)
1209                                })?;
1210                            }
1211                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1212                                tool_call,
1213                                options,
1214                                response,
1215                                context: _,
1216                            }) => {
1217                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1218                                    thread.request_tool_call_authorization(tool_call, options, cx)
1219                                })??;
1220                                cx.background_spawn(async move {
1221                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1222                                        outcome_task.await
1223                                    {
1224                                        response
1225                                            .send(outcome)
1226                                            .map(|_| anyhow!("authorization receiver was dropped"))
1227                                            .log_err();
1228                                    }
1229                                })
1230                                .detach();
1231                            }
1232                            ThreadEvent::ToolCall(tool_call) => {
1233                                acp_thread.update(cx, |thread, cx| {
1234                                    thread.upsert_tool_call(tool_call, cx)
1235                                })??;
1236                            }
1237                            ThreadEvent::ToolCallUpdate(update) => {
1238                                acp_thread.update(cx, |thread, cx| {
1239                                    thread.update_tool_call(update, cx)
1240                                })??;
1241                            }
1242                            ThreadEvent::Plan(plan) => {
1243                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1244                            }
1245                            ThreadEvent::SubagentSpawned(session_id) => {
1246                                acp_thread.update(cx, |thread, cx| {
1247                                    thread.subagent_spawned(session_id, cx);
1248                                })?;
1249                            }
1250                            ThreadEvent::Retry(status) => {
1251                                acp_thread.update(cx, |thread, cx| {
1252                                    thread.update_retry_status(status, cx)
1253                                })?;
1254                            }
1255                            ThreadEvent::Stop(stop_reason) => {
1256                                log::debug!("Assistant message complete: {:?}", stop_reason);
1257                                return Ok(acp::PromptResponse::new(stop_reason));
1258                            }
1259                        }
1260                    }
1261                    Err(e) => {
1262                        log::error!("Error in model response stream: {:?}", e);
1263                        return Err(e);
1264                    }
1265                }
1266            }
1267
1268            log::debug!("Response stream completed");
1269            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1270        })
1271    }
1272}
1273
1274struct Command<'a> {
1275    prompt_name: &'a str,
1276    arg_value: &'a str,
1277    explicit_server_id: Option<&'a str>,
1278}
1279
1280impl<'a> Command<'a> {
1281    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1282        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1283            return None;
1284        };
1285        let text = text_content.text.trim();
1286        let command = text.strip_prefix('/')?;
1287        let (command, arg_value) = command
1288            .split_once(char::is_whitespace)
1289            .unwrap_or((command, ""));
1290
1291        if let Some((server_id, prompt_name)) = command.split_once('.') {
1292            Some(Self {
1293                prompt_name,
1294                arg_value,
1295                explicit_server_id: Some(server_id),
1296            })
1297        } else {
1298            Some(Self {
1299                prompt_name: command,
1300                arg_value,
1301                explicit_server_id: None,
1302            })
1303        }
1304    }
1305}
1306
1307struct NativeAgentModelSelector {
1308    session_id: acp::SessionId,
1309    connection: NativeAgentConnection,
1310}
1311
1312impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1313    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1314        log::debug!("NativeAgentConnection::list_models called");
1315        let list = self.connection.0.read(cx).models.model_list.clone();
1316        Task::ready(if list.is_empty() {
1317            Err(anyhow::anyhow!("No models available"))
1318        } else {
1319            Ok(list)
1320        })
1321    }
1322
1323    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1324        log::debug!(
1325            "Setting model for session {}: {}",
1326            self.session_id,
1327            model_id
1328        );
1329        let Some(thread) = self
1330            .connection
1331            .0
1332            .read(cx)
1333            .sessions
1334            .get(&self.session_id)
1335            .map(|session| session.thread.clone())
1336        else {
1337            return Task::ready(Err(anyhow!("Session not found")));
1338        };
1339
1340        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1341            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1342        };
1343
1344        // We want to reset the effort level when switching models, as the currently-selected effort level may
1345        // not be compatible.
1346        let effort = model
1347            .default_effort_level()
1348            .map(|effort_level| effort_level.value.to_string());
1349
1350        thread.update(cx, |thread, cx| {
1351            thread.set_model(model.clone(), cx);
1352            thread.set_thinking_effort(effort.clone(), cx);
1353            thread.set_thinking_enabled(model.supports_thinking(), cx);
1354        });
1355
1356        update_settings_file(
1357            self.connection.0.read(cx).fs.clone(),
1358            cx,
1359            move |settings, cx| {
1360                let provider = model.provider_id().0.to_string();
1361                let model = model.id().0.to_string();
1362                let enable_thinking = thread.read(cx).thinking_enabled();
1363                settings
1364                    .agent
1365                    .get_or_insert_default()
1366                    .set_model(LanguageModelSelection {
1367                        provider: provider.into(),
1368                        model,
1369                        enable_thinking,
1370                        effort,
1371                    });
1372            },
1373        );
1374
1375        Task::ready(Ok(()))
1376    }
1377
1378    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1379        let Some(thread) = self
1380            .connection
1381            .0
1382            .read(cx)
1383            .sessions
1384            .get(&self.session_id)
1385            .map(|session| session.thread.clone())
1386        else {
1387            return Task::ready(Err(anyhow!("Session not found")));
1388        };
1389        let Some(model) = thread.read(cx).model() else {
1390            return Task::ready(Err(anyhow!("Model not found")));
1391        };
1392        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1393        else {
1394            return Task::ready(Err(anyhow!("Provider not found")));
1395        };
1396        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1397            model, &provider,
1398        )))
1399    }
1400
1401    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1402        Some(self.connection.0.read(cx).models.watch())
1403    }
1404
1405    fn should_render_footer(&self) -> bool {
1406        true
1407    }
1408}
1409
1410pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1411
1412impl acp_thread::AgentConnection for NativeAgentConnection {
1413    fn agent_id(&self) -> AgentId {
1414        ZED_AGENT_ID.clone()
1415    }
1416
1417    fn telemetry_id(&self) -> SharedString {
1418        "zed".into()
1419    }
1420
1421    fn new_session(
1422        self: Rc<Self>,
1423        project: Entity<Project>,
1424        work_dirs: PathList,
1425        cx: &mut App,
1426    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1427        log::debug!("Creating new thread for project at: {work_dirs:?}");
1428        Task::ready(Ok(self
1429            .0
1430            .update(cx, |agent, cx| agent.new_session(project, cx))))
1431    }
1432
1433    fn supports_load_session(&self) -> bool {
1434        true
1435    }
1436
1437    fn load_session(
1438        self: Rc<Self>,
1439        session_id: acp::SessionId,
1440        project: Entity<Project>,
1441        _work_dirs: PathList,
1442        _title: Option<SharedString>,
1443        cx: &mut App,
1444    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1445        self.0
1446            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1447    }
1448
1449    fn supports_close_session(&self) -> bool {
1450        true
1451    }
1452
1453    fn close_session(
1454        self: Rc<Self>,
1455        session_id: &acp::SessionId,
1456        cx: &mut App,
1457    ) -> Task<Result<()>> {
1458        self.0.update(cx, |agent, cx| {
1459            let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1460            if let Some(thread) = thread {
1461                agent.save_thread(thread, cx);
1462            }
1463
1464            let Some(session) = agent.sessions.remove(session_id) else {
1465                return Task::ready(Ok(()));
1466            };
1467            let project_id = session.project_id;
1468
1469            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1470            if !has_remaining {
1471                agent.projects.remove(&project_id);
1472            }
1473
1474            session.pending_save
1475        })
1476    }
1477
1478    fn auth_methods(&self) -> &[acp::AuthMethod] {
1479        &[] // No auth for in-process
1480    }
1481
1482    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1483        Task::ready(Ok(()))
1484    }
1485
1486    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1487        Some(Rc::new(NativeAgentModelSelector {
1488            session_id: session_id.clone(),
1489            connection: self.clone(),
1490        }) as Rc<dyn AgentModelSelector>)
1491    }
1492
1493    fn prompt(
1494        &self,
1495        id: Option<acp_thread::UserMessageId>,
1496        params: acp::PromptRequest,
1497        cx: &mut App,
1498    ) -> Task<Result<acp::PromptResponse>> {
1499        let id = id.expect("UserMessageId is required");
1500        let session_id = params.session_id.clone();
1501        log::info!("Received prompt request for session: {}", session_id);
1502        log::debug!("Prompt blocks count: {}", params.prompt.len());
1503
1504        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1505            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1506        };
1507
1508        if let Some(parsed_command) = Command::parse(&params.prompt) {
1509            let registry = project_state.context_server_registry.read(cx);
1510
1511            let explicit_server_id = parsed_command
1512                .explicit_server_id
1513                .map(|server_id| ContextServerId(server_id.into()));
1514
1515            if let Some(prompt) =
1516                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1517            {
1518                let arguments = if !parsed_command.arg_value.is_empty()
1519                    && let Some(arg_name) = prompt
1520                        .prompt
1521                        .arguments
1522                        .as_ref()
1523                        .and_then(|args| args.first())
1524                        .map(|arg| arg.name.clone())
1525                {
1526                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1527                } else {
1528                    Default::default()
1529                };
1530
1531                let prompt_name = prompt.prompt.name.clone();
1532                let server_id = prompt.server_id.clone();
1533
1534                return self.0.update(cx, |agent, cx| {
1535                    agent.send_mcp_prompt(
1536                        id,
1537                        session_id.clone(),
1538                        prompt_name,
1539                        server_id,
1540                        arguments,
1541                        params.prompt,
1542                        cx,
1543                    )
1544                });
1545            }
1546        };
1547
1548        let path_style = project_state.project.read(cx).path_style(cx);
1549
1550        self.run_turn(session_id, cx, move |thread, cx| {
1551            let content: Vec<UserMessageContent> = params
1552                .prompt
1553                .into_iter()
1554                .map(|block| UserMessageContent::from_content_block(block, path_style))
1555                .collect::<Vec<_>>();
1556            log::debug!("Converted prompt to message: {} chars", content.len());
1557            log::debug!("Message id: {:?}", id);
1558            log::debug!("Message content: {:?}", content);
1559
1560            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1561        })
1562    }
1563
1564    fn retry(
1565        &self,
1566        session_id: &acp::SessionId,
1567        _cx: &App,
1568    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1569        Some(Rc::new(NativeAgentSessionRetry {
1570            connection: self.clone(),
1571            session_id: session_id.clone(),
1572        }) as _)
1573    }
1574
1575    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1576        log::info!("Cancelling on session: {}", session_id);
1577        self.0.update(cx, |agent, cx| {
1578            if let Some(session) = agent.sessions.get(session_id) {
1579                session
1580                    .thread
1581                    .update(cx, |thread, cx| thread.cancel(cx))
1582                    .detach();
1583            }
1584        });
1585    }
1586
1587    fn truncate(
1588        &self,
1589        session_id: &acp::SessionId,
1590        cx: &App,
1591    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1592        self.0.read_with(cx, |agent, _cx| {
1593            agent.sessions.get(session_id).map(|session| {
1594                Rc::new(NativeAgentSessionTruncate {
1595                    thread: session.thread.clone(),
1596                    acp_thread: session.acp_thread.downgrade(),
1597                }) as _
1598            })
1599        })
1600    }
1601
1602    fn set_title(
1603        &self,
1604        session_id: &acp::SessionId,
1605        cx: &App,
1606    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1607        self.0.read_with(cx, |agent, _cx| {
1608            agent
1609                .sessions
1610                .get(session_id)
1611                .filter(|s| !s.thread.read(cx).is_subagent())
1612                .map(|session| {
1613                    Rc::new(NativeAgentSessionSetTitle {
1614                        thread: session.thread.clone(),
1615                    }) as _
1616                })
1617        })
1618    }
1619
1620    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1621        let thread_store = self.0.read(cx).thread_store.clone();
1622        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1623    }
1624
1625    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1626        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1627    }
1628
1629    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1630        self
1631    }
1632}
1633
1634impl acp_thread::AgentTelemetry for NativeAgentConnection {
1635    fn thread_data(
1636        &self,
1637        session_id: &acp::SessionId,
1638        cx: &mut App,
1639    ) -> Task<Result<serde_json::Value>> {
1640        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1641            return Task::ready(Err(anyhow!("Session not found")));
1642        };
1643
1644        let task = session.thread.read(cx).to_db(cx);
1645        cx.background_spawn(async move {
1646            serde_json::to_value(task.await).context("Failed to serialize thread")
1647        })
1648    }
1649}
1650
1651pub struct NativeAgentSessionList {
1652    thread_store: Entity<ThreadStore>,
1653    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1654    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1655    _subscription: Subscription,
1656}
1657
1658impl NativeAgentSessionList {
1659    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1660        let (tx, rx) = smol::channel::unbounded();
1661        let this_tx = tx.clone();
1662        let subscription = cx.observe(&thread_store, move |_, _| {
1663            this_tx
1664                .try_send(acp_thread::SessionListUpdate::Refresh)
1665                .ok();
1666        });
1667        Self {
1668            thread_store,
1669            updates_tx: tx,
1670            updates_rx: rx,
1671            _subscription: subscription,
1672        }
1673    }
1674
1675    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1676        &self.thread_store
1677    }
1678}
1679
1680impl AgentSessionList for NativeAgentSessionList {
1681    fn list_sessions(
1682        &self,
1683        _request: AgentSessionListRequest,
1684        cx: &mut App,
1685    ) -> Task<Result<AgentSessionListResponse>> {
1686        let sessions = self
1687            .thread_store
1688            .read(cx)
1689            .entries()
1690            .map(|entry| AgentSessionInfo::from(&entry))
1691            .collect();
1692        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1693    }
1694
1695    fn supports_delete(&self) -> bool {
1696        true
1697    }
1698
1699    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1700        self.thread_store
1701            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1702    }
1703
1704    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1705        self.thread_store
1706            .update(cx, |store, cx| store.delete_threads(cx))
1707    }
1708
1709    fn watch(
1710        &self,
1711        _cx: &mut App,
1712    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1713        Some(self.updates_rx.clone())
1714    }
1715
1716    fn notify_refresh(&self) {
1717        self.updates_tx
1718            .try_send(acp_thread::SessionListUpdate::Refresh)
1719            .ok();
1720    }
1721
1722    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1723        self
1724    }
1725}
1726
1727struct NativeAgentSessionTruncate {
1728    thread: Entity<Thread>,
1729    acp_thread: WeakEntity<AcpThread>,
1730}
1731
1732impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1733    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1734        match self.thread.update(cx, |thread, cx| {
1735            thread.truncate(message_id.clone(), cx)?;
1736            Ok(thread.latest_token_usage())
1737        }) {
1738            Ok(usage) => {
1739                self.acp_thread
1740                    .update(cx, |thread, cx| {
1741                        thread.update_token_usage(usage, cx);
1742                    })
1743                    .ok();
1744                Task::ready(Ok(()))
1745            }
1746            Err(error) => Task::ready(Err(error)),
1747        }
1748    }
1749}
1750
1751struct NativeAgentSessionRetry {
1752    connection: NativeAgentConnection,
1753    session_id: acp::SessionId,
1754}
1755
1756impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1757    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1758        self.connection
1759            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1760                thread.update(cx, |thread, cx| thread.resume(cx))
1761            })
1762    }
1763}
1764
1765struct NativeAgentSessionSetTitle {
1766    thread: Entity<Thread>,
1767}
1768
1769impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1770    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1771        self.thread
1772            .update(cx, |thread, cx| thread.set_title(title, cx));
1773        Task::ready(Ok(()))
1774    }
1775}
1776
1777pub struct NativeThreadEnvironment {
1778    agent: WeakEntity<NativeAgent>,
1779    thread: WeakEntity<Thread>,
1780    acp_thread: WeakEntity<AcpThread>,
1781}
1782
1783impl NativeThreadEnvironment {
1784    pub(crate) fn create_subagent_thread(
1785        &self,
1786        label: String,
1787        cx: &mut App,
1788    ) -> Result<Rc<dyn SubagentHandle>> {
1789        let Some(parent_thread_entity) = self.thread.upgrade() else {
1790            anyhow::bail!("Parent thread no longer exists".to_string());
1791        };
1792        let parent_thread = parent_thread_entity.read(cx);
1793        let current_depth = parent_thread.depth();
1794        let parent_session_id = parent_thread.id().clone();
1795
1796        if current_depth >= MAX_SUBAGENT_DEPTH {
1797            return Err(anyhow!(
1798                "Maximum subagent depth ({}) reached",
1799                MAX_SUBAGENT_DEPTH
1800            ));
1801        }
1802
1803        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1804            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1805            thread.set_title(label.into(), cx);
1806            thread
1807        });
1808
1809        let session_id = subagent_thread.read(cx).id().clone();
1810
1811        let acp_thread = self
1812            .agent
1813            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1814                let project_id = agent
1815                    .sessions
1816                    .get(&parent_session_id)
1817                    .map(|s| s.project_id)
1818                    .context("parent session not found")?;
1819                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1820            })??;
1821
1822        let depth = current_depth + 1;
1823
1824        telemetry::event!(
1825            "Subagent Started",
1826            session = parent_thread_entity.read(cx).id().to_string(),
1827            subagent_session = session_id.to_string(),
1828            depth,
1829            is_resumed = false,
1830        );
1831
1832        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1833    }
1834
1835    pub(crate) fn resume_subagent_thread(
1836        &self,
1837        session_id: acp::SessionId,
1838        cx: &mut App,
1839    ) -> Result<Rc<dyn SubagentHandle>> {
1840        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1841            let session = agent
1842                .sessions
1843                .get(&session_id)
1844                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1845            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1846        })??;
1847
1848        let depth = subagent_thread.read(cx).depth();
1849
1850        if let Some(parent_thread_entity) = self.thread.upgrade() {
1851            telemetry::event!(
1852                "Subagent Started",
1853                session = parent_thread_entity.read(cx).id().to_string(),
1854                subagent_session = session_id.to_string(),
1855                depth,
1856                is_resumed = true,
1857            );
1858        }
1859
1860        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1861    }
1862
1863    fn prompt_subagent(
1864        &self,
1865        session_id: acp::SessionId,
1866        subagent_thread: Entity<Thread>,
1867        acp_thread: Entity<acp_thread::AcpThread>,
1868    ) -> Result<Rc<dyn SubagentHandle>> {
1869        let Some(parent_thread_entity) = self.thread.upgrade() else {
1870            anyhow::bail!("Parent thread no longer exists".to_string());
1871        };
1872        Ok(Rc::new(NativeSubagentHandle::new(
1873            session_id,
1874            subagent_thread,
1875            acp_thread,
1876            parent_thread_entity,
1877        )) as _)
1878    }
1879}
1880
1881impl ThreadEnvironment for NativeThreadEnvironment {
1882    fn create_terminal(
1883        &self,
1884        command: String,
1885        cwd: Option<PathBuf>,
1886        output_byte_limit: Option<u64>,
1887        cx: &mut AsyncApp,
1888    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1889        let task = self.acp_thread.update(cx, |thread, cx| {
1890            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1891        });
1892
1893        let acp_thread = self.acp_thread.clone();
1894        cx.spawn(async move |cx| {
1895            let terminal = task?.await?;
1896
1897            let (drop_tx, drop_rx) = oneshot::channel();
1898            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1899
1900            cx.spawn(async move |cx| {
1901                drop_rx.await.ok();
1902                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1903            })
1904            .detach();
1905
1906            let handle = AcpTerminalHandle {
1907                terminal,
1908                _drop_tx: Some(drop_tx),
1909            };
1910
1911            Ok(Rc::new(handle) as _)
1912        })
1913    }
1914
1915    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1916        self.create_subagent_thread(label, cx)
1917    }
1918
1919    fn resume_subagent(
1920        &self,
1921        session_id: acp::SessionId,
1922        cx: &mut App,
1923    ) -> Result<Rc<dyn SubagentHandle>> {
1924        self.resume_subagent_thread(session_id, cx)
1925    }
1926}
1927
1928#[derive(Debug, Clone)]
1929enum SubagentPromptResult {
1930    Completed,
1931    Cancelled,
1932    ContextWindowWarning,
1933    Error(String),
1934}
1935
1936pub struct NativeSubagentHandle {
1937    session_id: acp::SessionId,
1938    parent_thread: WeakEntity<Thread>,
1939    subagent_thread: Entity<Thread>,
1940    acp_thread: Entity<acp_thread::AcpThread>,
1941}
1942
1943impl NativeSubagentHandle {
1944    fn new(
1945        session_id: acp::SessionId,
1946        subagent_thread: Entity<Thread>,
1947        acp_thread: Entity<acp_thread::AcpThread>,
1948        parent_thread_entity: Entity<Thread>,
1949    ) -> Self {
1950        NativeSubagentHandle {
1951            session_id,
1952            subagent_thread,
1953            parent_thread: parent_thread_entity.downgrade(),
1954            acp_thread,
1955        }
1956    }
1957}
1958
1959impl SubagentHandle for NativeSubagentHandle {
1960    fn id(&self) -> acp::SessionId {
1961        self.session_id.clone()
1962    }
1963
1964    fn num_entries(&self, cx: &App) -> usize {
1965        self.acp_thread.read(cx).entries().len()
1966    }
1967
1968    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1969        let thread = self.subagent_thread.clone();
1970        let acp_thread = self.acp_thread.clone();
1971        let subagent_session_id = self.session_id.clone();
1972        let parent_thread = self.parent_thread.clone();
1973
1974        cx.spawn(async move |cx| {
1975            let (task, _subscription) = cx.update(|cx| {
1976                let ratio_before_prompt = thread
1977                    .read(cx)
1978                    .latest_token_usage()
1979                    .map(|usage| usage.ratio());
1980
1981                parent_thread
1982                    .update(cx, |parent_thread, _cx| {
1983                        parent_thread.register_running_subagent(thread.downgrade())
1984                    })
1985                    .ok();
1986
1987                let task = acp_thread.update(cx, |acp_thread, cx| {
1988                    acp_thread.send(vec![message.into()], cx)
1989                });
1990
1991                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1992                let mut token_limit_tx = Some(token_limit_tx);
1993
1994                let subscription = cx.subscribe(
1995                    &thread,
1996                    move |_thread, event: &TokenUsageUpdated, _cx| {
1997                        if let Some(usage) = &event.0 {
1998                            let old_ratio = ratio_before_prompt
1999                                .clone()
2000                                .unwrap_or(TokenUsageRatio::Normal);
2001                            let new_ratio = usage.ratio();
2002                            if old_ratio == TokenUsageRatio::Normal
2003                                && new_ratio == TokenUsageRatio::Warning
2004                            {
2005                                if let Some(tx) = token_limit_tx.take() {
2006                                    tx.send(()).ok();
2007                                }
2008                            }
2009                        }
2010                    },
2011                );
2012
2013                let wait_for_prompt = cx
2014                    .background_spawn(async move {
2015                        futures::select! {
2016                            response = task.fuse() => match response {
2017                                Ok(Some(response)) => {
2018                                    match response.stop_reason {
2019                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2020                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2021                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2022                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2023                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2024                                    }
2025                                }
2026                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2027                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2028                            },
2029                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2030                        }
2031                    });
2032
2033                (wait_for_prompt, subscription)
2034            });
2035
2036            let result = match task.await {
2037                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2038                    thread
2039                        .last_message()
2040                        .and_then(|message| {
2041                            let content = message.as_agent_message()?
2042                                .content
2043                                .iter()
2044                                .filter_map(|c| match c {
2045                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2046                                    _ => None,
2047                                })
2048                                .join("\n\n");
2049                            if content.is_empty() {
2050                                None
2051                            } else {
2052                                Some( content)
2053                            }
2054                        })
2055                        .context("No response from subagent")
2056                }),
2057                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2058                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2059                SubagentPromptResult::ContextWindowWarning => {
2060                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2061                    Err(anyhow!(
2062                        "The agent is nearing the end of its context window and has been \
2063                         stopped. You can prompt the thread again to have the agent wrap up \
2064                         or hand off its work."
2065                    ))
2066                }
2067            };
2068
2069            parent_thread
2070                .update(cx, |parent_thread, cx| {
2071                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2072                })
2073                .ok();
2074
2075            result
2076        })
2077    }
2078}
2079
2080pub struct AcpTerminalHandle {
2081    terminal: Entity<acp_thread::Terminal>,
2082    _drop_tx: Option<oneshot::Sender<()>>,
2083}
2084
2085impl TerminalHandle for AcpTerminalHandle {
2086    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2087        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2088    }
2089
2090    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2091        Ok(self
2092            .terminal
2093            .read_with(cx, |term, _cx| term.wait_for_exit()))
2094    }
2095
2096    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2097        Ok(self
2098            .terminal
2099            .read_with(cx, |term, cx| term.current_output(cx)))
2100    }
2101
2102    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2103        cx.update(|cx| {
2104            self.terminal.update(cx, |terminal, cx| {
2105                terminal.kill(cx);
2106            });
2107        });
2108        Ok(())
2109    }
2110
2111    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2112        Ok(self
2113            .terminal
2114            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2115    }
2116}
2117
2118#[cfg(test)]
2119mod internal_tests {
2120    use std::path::Path;
2121
2122    use super::*;
2123    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2124    use fs::FakeFs;
2125    use gpui::TestAppContext;
2126    use indoc::formatdoc;
2127    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2128    use language_model::{
2129        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2130    };
2131    use serde_json::json;
2132    use settings::SettingsStore;
2133    use util::{path, rel_path::rel_path};
2134
2135    #[gpui::test]
2136    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2137        init_test(cx);
2138        let fs = FakeFs::new(cx.executor());
2139        fs.insert_tree(
2140            "/",
2141            json!({
2142                "a": {}
2143            }),
2144        )
2145        .await;
2146        let project = Project::test(fs.clone(), [], cx).await;
2147        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2148        let agent =
2149            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2150
2151        // Creating a session registers the project and triggers context building.
2152        let connection = NativeAgentConnection(agent.clone());
2153        let _acp_thread = cx
2154            .update(|cx| {
2155                Rc::new(connection).new_session(
2156                    project.clone(),
2157                    PathList::new(&[Path::new("/")]),
2158                    cx,
2159                )
2160            })
2161            .await
2162            .unwrap();
2163        cx.run_until_parked();
2164
2165        let thread = agent.read_with(cx, |agent, _cx| {
2166            agent.sessions.values().next().unwrap().thread.clone()
2167        });
2168
2169        agent.read_with(cx, |agent, cx| {
2170            let project_id = project.entity_id();
2171            let state = agent.projects.get(&project_id).unwrap();
2172            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2173            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2174        });
2175
2176        let worktree = project
2177            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2178            .await
2179            .unwrap();
2180        cx.run_until_parked();
2181        agent.read_with(cx, |agent, cx| {
2182            let project_id = project.entity_id();
2183            let state = agent.projects.get(&project_id).unwrap();
2184            let expected_worktrees = vec![WorktreeContext {
2185                root_name: "a".into(),
2186                abs_path: Path::new("/a").into(),
2187                rules_file: None,
2188            }];
2189            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2190            assert_eq!(
2191                thread.read(cx).project_context().read(cx).worktrees,
2192                expected_worktrees
2193            );
2194        });
2195
2196        // Creating `/a/.rules` updates the project context.
2197        fs.insert_file("/a/.rules", Vec::new()).await;
2198        cx.run_until_parked();
2199        agent.read_with(cx, |agent, cx| {
2200            let project_id = project.entity_id();
2201            let state = agent.projects.get(&project_id).unwrap();
2202            let rules_entry = worktree
2203                .read(cx)
2204                .entry_for_path(rel_path(".rules"))
2205                .unwrap();
2206            let expected_worktrees = vec![WorktreeContext {
2207                root_name: "a".into(),
2208                abs_path: Path::new("/a").into(),
2209                rules_file: Some(RulesFileContext {
2210                    path_in_worktree: rel_path(".rules").into(),
2211                    text: "".into(),
2212                    project_entry_id: rules_entry.id.to_usize(),
2213                }),
2214            }];
2215            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2216            assert_eq!(
2217                thread.read(cx).project_context().read(cx).worktrees,
2218                expected_worktrees
2219            );
2220        });
2221    }
2222
2223    #[gpui::test]
2224    async fn test_listing_models(cx: &mut TestAppContext) {
2225        init_test(cx);
2226        let fs = FakeFs::new(cx.executor());
2227        fs.insert_tree("/", json!({ "a": {}  })).await;
2228        let project = Project::test(fs.clone(), [], cx).await;
2229        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2230        let connection =
2231            NativeAgentConnection(cx.update(|cx| {
2232                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2233            }));
2234
2235        // Create a thread/session
2236        let acp_thread = cx
2237            .update(|cx| {
2238                Rc::new(connection.clone()).new_session(
2239                    project.clone(),
2240                    PathList::new(&[Path::new("/a")]),
2241                    cx,
2242                )
2243            })
2244            .await
2245            .unwrap();
2246
2247        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2248
2249        let models = cx
2250            .update(|cx| {
2251                connection
2252                    .model_selector(&session_id)
2253                    .unwrap()
2254                    .list_models(cx)
2255            })
2256            .await
2257            .unwrap();
2258
2259        let acp_thread::AgentModelList::Grouped(models) = models else {
2260            panic!("Unexpected model group");
2261        };
2262        assert_eq!(
2263            models,
2264            IndexMap::from_iter([(
2265                AgentModelGroupName("Fake".into()),
2266                vec![AgentModelInfo {
2267                    id: acp::ModelId::new("fake/fake"),
2268                    name: "Fake".into(),
2269                    description: None,
2270                    icon: Some(acp_thread::AgentModelIcon::Named(
2271                        ui::IconName::ZedAssistant
2272                    )),
2273                    is_latest: false,
2274                    cost: None,
2275                }]
2276            )])
2277        );
2278    }
2279
2280    #[gpui::test]
2281    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2282        init_test(cx);
2283        let fs = FakeFs::new(cx.executor());
2284        fs.create_dir(paths::settings_file().parent().unwrap())
2285            .await
2286            .unwrap();
2287        fs.insert_file(
2288            paths::settings_file(),
2289            json!({
2290                "agent": {
2291                    "default_model": {
2292                        "provider": "foo",
2293                        "model": "bar"
2294                    }
2295                }
2296            })
2297            .to_string()
2298            .into_bytes(),
2299        )
2300        .await;
2301        let project = Project::test(fs.clone(), [], cx).await;
2302
2303        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2304
2305        // Create the agent and connection
2306        let agent =
2307            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2308        let connection = NativeAgentConnection(agent.clone());
2309
2310        // Create a thread/session
2311        let acp_thread = cx
2312            .update(|cx| {
2313                Rc::new(connection.clone()).new_session(
2314                    project.clone(),
2315                    PathList::new(&[Path::new("/a")]),
2316                    cx,
2317                )
2318            })
2319            .await
2320            .unwrap();
2321
2322        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2323
2324        // Select a model
2325        let selector = connection.model_selector(&session_id).unwrap();
2326        let model_id = acp::ModelId::new("fake/fake");
2327        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2328            .await
2329            .unwrap();
2330
2331        // Verify the thread has the selected model
2332        agent.read_with(cx, |agent, _| {
2333            let session = agent.sessions.get(&session_id).unwrap();
2334            session.thread.read_with(cx, |thread, _| {
2335                assert_eq!(thread.model().unwrap().id().0, "fake");
2336            });
2337        });
2338
2339        cx.run_until_parked();
2340
2341        // Verify settings file was updated
2342        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2343        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2344
2345        // Check that the agent settings contain the selected model
2346        assert_eq!(
2347            settings_json["agent"]["default_model"]["model"],
2348            json!("fake")
2349        );
2350        assert_eq!(
2351            settings_json["agent"]["default_model"]["provider"],
2352            json!("fake")
2353        );
2354
2355        // Register a thinking model and select it.
2356        cx.update(|cx| {
2357            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2358                "fake-corp",
2359                "fake-thinking",
2360                "Fake Thinking",
2361                true,
2362            ));
2363            let thinking_provider = Arc::new(
2364                FakeLanguageModelProvider::new(
2365                    LanguageModelProviderId::from("fake-corp".to_string()),
2366                    LanguageModelProviderName::from("Fake Corp".to_string()),
2367                )
2368                .with_models(vec![thinking_model]),
2369            );
2370            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2371                registry.register_provider(thinking_provider, cx);
2372            });
2373        });
2374        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2375
2376        let selector = connection.model_selector(&session_id).unwrap();
2377        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2378            .await
2379            .unwrap();
2380        cx.run_until_parked();
2381
2382        // Verify enable_thinking was written to settings as true.
2383        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2384        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2385        assert_eq!(
2386            settings_json["agent"]["default_model"]["enable_thinking"],
2387            json!(true),
2388            "selecting a thinking model should persist enable_thinking: true to settings"
2389        );
2390    }
2391
2392    #[gpui::test]
2393    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2394        init_test(cx);
2395        let fs = FakeFs::new(cx.executor());
2396        fs.create_dir(paths::settings_file().parent().unwrap())
2397            .await
2398            .unwrap();
2399        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2400        let project = Project::test(fs.clone(), [], cx).await;
2401
2402        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2403        let agent =
2404            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2405        let connection = NativeAgentConnection(agent.clone());
2406
2407        let acp_thread = cx
2408            .update(|cx| {
2409                Rc::new(connection.clone()).new_session(
2410                    project.clone(),
2411                    PathList::new(&[Path::new("/a")]),
2412                    cx,
2413                )
2414            })
2415            .await
2416            .unwrap();
2417        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2418
2419        // Register a second provider with a thinking model.
2420        cx.update(|cx| {
2421            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2422                "fake-corp",
2423                "fake-thinking",
2424                "Fake Thinking",
2425                true,
2426            ));
2427            let thinking_provider = Arc::new(
2428                FakeLanguageModelProvider::new(
2429                    LanguageModelProviderId::from("fake-corp".to_string()),
2430                    LanguageModelProviderName::from("Fake Corp".to_string()),
2431                )
2432                .with_models(vec![thinking_model]),
2433            );
2434            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2435                registry.register_provider(thinking_provider, cx);
2436            });
2437        });
2438        // Refresh the agent's model list so it picks up the new provider.
2439        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2440
2441        // Thread starts with thinking_enabled = false (the default).
2442        agent.read_with(cx, |agent, _| {
2443            let session = agent.sessions.get(&session_id).unwrap();
2444            session.thread.read_with(cx, |thread, _| {
2445                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2446            });
2447        });
2448
2449        // Select the thinking model via select_model.
2450        let selector = connection.model_selector(&session_id).unwrap();
2451        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2452            .await
2453            .unwrap();
2454
2455        // select_model should have enabled thinking based on the model's supports_thinking().
2456        agent.read_with(cx, |agent, _| {
2457            let session = agent.sessions.get(&session_id).unwrap();
2458            session.thread.read_with(cx, |thread, _| {
2459                assert!(
2460                    thread.thinking_enabled(),
2461                    "select_model should enable thinking when model supports it"
2462                );
2463            });
2464        });
2465
2466        // Switch back to the non-thinking model.
2467        let selector = connection.model_selector(&session_id).unwrap();
2468        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2469            .await
2470            .unwrap();
2471
2472        // select_model should have disabled thinking.
2473        agent.read_with(cx, |agent, _| {
2474            let session = agent.sessions.get(&session_id).unwrap();
2475            session.thread.read_with(cx, |thread, _| {
2476                assert!(
2477                    !thread.thinking_enabled(),
2478                    "select_model should disable thinking when model does not support it"
2479                );
2480            });
2481        });
2482    }
2483
2484    #[gpui::test]
2485    async fn test_summarization_model_survives_transient_registry_clearing(
2486        cx: &mut TestAppContext,
2487    ) {
2488        init_test(cx);
2489        let fs = FakeFs::new(cx.executor());
2490        fs.insert_tree("/", json!({ "a": {} })).await;
2491        let project = Project::test(fs.clone(), [], cx).await;
2492
2493        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2494        let agent =
2495            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2496        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2497
2498        let acp_thread = cx
2499            .update(|cx| {
2500                connection.clone().new_session(
2501                    project.clone(),
2502                    PathList::new(&[Path::new("/a")]),
2503                    cx,
2504                )
2505            })
2506            .await
2507            .unwrap();
2508        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2509
2510        let thread = agent.read_with(cx, |agent, _| {
2511            agent.sessions.get(&session_id).unwrap().thread.clone()
2512        });
2513
2514        thread.read_with(cx, |thread, _| {
2515            assert!(
2516                thread.summarization_model().is_some(),
2517                "session should have a summarization model from the test registry"
2518            );
2519        });
2520
2521        // Simulate what happens during a provider blip:
2522        // update_active_language_model_from_settings calls set_default_model(None)
2523        // when it can't resolve the model, clearing all fallbacks.
2524        cx.update(|cx| {
2525            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2526                registry.set_default_model(None, cx);
2527            });
2528        });
2529        cx.run_until_parked();
2530
2531        thread.read_with(cx, |thread, _| {
2532            assert!(
2533                thread.summarization_model().is_some(),
2534                "summarization model should survive a transient default model clearing"
2535            );
2536        });
2537    }
2538
2539    #[gpui::test]
2540    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2541        init_test(cx);
2542        let fs = FakeFs::new(cx.executor());
2543        fs.insert_tree("/", json!({ "a": {} })).await;
2544        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2545        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2546        let agent = cx.update(|cx| {
2547            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2548        });
2549        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2550
2551        // Register a thinking model.
2552        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2553            "fake-corp",
2554            "fake-thinking",
2555            "Fake Thinking",
2556            true,
2557        ));
2558        let thinking_provider = Arc::new(
2559            FakeLanguageModelProvider::new(
2560                LanguageModelProviderId::from("fake-corp".to_string()),
2561                LanguageModelProviderName::from("Fake Corp".to_string()),
2562            )
2563            .with_models(vec![thinking_model.clone()]),
2564        );
2565        cx.update(|cx| {
2566            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2567                registry.register_provider(thinking_provider, cx);
2568            });
2569        });
2570        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2571
2572        // Create a thread and select the thinking model.
2573        let acp_thread = cx
2574            .update(|cx| {
2575                connection.clone().new_session(
2576                    project.clone(),
2577                    PathList::new(&[Path::new("/a")]),
2578                    cx,
2579                )
2580            })
2581            .await
2582            .unwrap();
2583        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2584
2585        let selector = connection.model_selector(&session_id).unwrap();
2586        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2587            .await
2588            .unwrap();
2589
2590        // Verify thinking is enabled after selecting the thinking model.
2591        let thread = agent.read_with(cx, |agent, _| {
2592            agent.sessions.get(&session_id).unwrap().thread.clone()
2593        });
2594        thread.read_with(cx, |thread, _| {
2595            assert!(
2596                thread.thinking_enabled(),
2597                "thinking should be enabled after selecting thinking model"
2598            );
2599        });
2600
2601        // Send a message so the thread gets persisted.
2602        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2603        let send = cx.foreground_executor().spawn(send);
2604        cx.run_until_parked();
2605
2606        thinking_model.send_last_completion_stream_text_chunk("Response.");
2607        thinking_model.end_last_completion_stream();
2608
2609        send.await.unwrap();
2610        cx.run_until_parked();
2611
2612        // Close the session so it can be reloaded from disk.
2613        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2614            .await
2615            .unwrap();
2616        drop(thread);
2617        drop(acp_thread);
2618        agent.read_with(cx, |agent, _| {
2619            assert!(agent.sessions.is_empty());
2620        });
2621
2622        // Reload the thread and verify thinking_enabled is still true.
2623        let reloaded_acp_thread = agent
2624            .update(cx, |agent, cx| {
2625                agent.open_thread(session_id.clone(), project.clone(), cx)
2626            })
2627            .await
2628            .unwrap();
2629        let reloaded_thread = agent.read_with(cx, |agent, _| {
2630            agent.sessions.get(&session_id).unwrap().thread.clone()
2631        });
2632        reloaded_thread.read_with(cx, |thread, _| {
2633            assert!(
2634                thread.thinking_enabled(),
2635                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2636            );
2637        });
2638
2639        drop(reloaded_acp_thread);
2640    }
2641
2642    #[gpui::test]
2643    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2644        init_test(cx);
2645        let fs = FakeFs::new(cx.executor());
2646        fs.insert_tree("/", json!({ "a": {} })).await;
2647        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2648        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2649        let agent = cx.update(|cx| {
2650            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2651        });
2652        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2653
2654        // Register a model where id() != name(), like real Anthropic models
2655        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2656        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2657            "fake-corp",
2658            "custom-model-id",
2659            "Custom Model Display Name",
2660            false,
2661        ));
2662        let provider = Arc::new(
2663            FakeLanguageModelProvider::new(
2664                LanguageModelProviderId::from("fake-corp".to_string()),
2665                LanguageModelProviderName::from("Fake Corp".to_string()),
2666            )
2667            .with_models(vec![model.clone()]),
2668        );
2669        cx.update(|cx| {
2670            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2671                registry.register_provider(provider, cx);
2672            });
2673        });
2674        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2675
2676        // Create a thread and select the model.
2677        let acp_thread = cx
2678            .update(|cx| {
2679                connection.clone().new_session(
2680                    project.clone(),
2681                    PathList::new(&[Path::new("/a")]),
2682                    cx,
2683                )
2684            })
2685            .await
2686            .unwrap();
2687        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2688
2689        let selector = connection.model_selector(&session_id).unwrap();
2690        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2691            .await
2692            .unwrap();
2693
2694        let thread = agent.read_with(cx, |agent, _| {
2695            agent.sessions.get(&session_id).unwrap().thread.clone()
2696        });
2697        thread.read_with(cx, |thread, _| {
2698            assert_eq!(
2699                thread.model().unwrap().id().0.as_ref(),
2700                "custom-model-id",
2701                "model should be set before persisting"
2702            );
2703        });
2704
2705        // Send a message so the thread gets persisted.
2706        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2707        let send = cx.foreground_executor().spawn(send);
2708        cx.run_until_parked();
2709
2710        model.send_last_completion_stream_text_chunk("Response.");
2711        model.end_last_completion_stream();
2712
2713        send.await.unwrap();
2714        cx.run_until_parked();
2715
2716        // Close the session so it can be reloaded from disk.
2717        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2718            .await
2719            .unwrap();
2720        drop(thread);
2721        drop(acp_thread);
2722        agent.read_with(cx, |agent, _| {
2723            assert!(agent.sessions.is_empty());
2724        });
2725
2726        // Reload the thread and verify the model was preserved.
2727        let reloaded_acp_thread = agent
2728            .update(cx, |agent, cx| {
2729                agent.open_thread(session_id.clone(), project.clone(), cx)
2730            })
2731            .await
2732            .unwrap();
2733        let reloaded_thread = agent.read_with(cx, |agent, _| {
2734            agent.sessions.get(&session_id).unwrap().thread.clone()
2735        });
2736        reloaded_thread.read_with(cx, |thread, _| {
2737            let reloaded_model = thread
2738                .model()
2739                .expect("model should be present after reload");
2740            assert_eq!(
2741                reloaded_model.id().0.as_ref(),
2742                "custom-model-id",
2743                "reloaded thread should have the same model, not fall back to the default"
2744            );
2745        });
2746
2747        drop(reloaded_acp_thread);
2748    }
2749
2750    #[gpui::test]
2751    async fn test_save_load_thread(cx: &mut TestAppContext) {
2752        init_test(cx);
2753        let fs = FakeFs::new(cx.executor());
2754        fs.insert_tree(
2755            "/",
2756            json!({
2757                "a": {
2758                    "b.md": "Lorem"
2759                }
2760            }),
2761        )
2762        .await;
2763        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2764        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2765        let agent = cx.update(|cx| {
2766            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2767        });
2768        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2769
2770        let acp_thread = cx
2771            .update(|cx| {
2772                connection
2773                    .clone()
2774                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2775            })
2776            .await
2777            .unwrap();
2778        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2779        let thread = agent.read_with(cx, |agent, _| {
2780            agent.sessions.get(&session_id).unwrap().thread.clone()
2781        });
2782
2783        // Ensure empty threads are not saved, even if they get mutated.
2784        let model = Arc::new(FakeLanguageModel::default());
2785        let summary_model = Arc::new(FakeLanguageModel::default());
2786        thread.update(cx, |thread, cx| {
2787            thread.set_model(model.clone(), cx);
2788            thread.set_summarization_model(Some(summary_model.clone()), cx);
2789        });
2790        cx.run_until_parked();
2791        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2792
2793        let send = acp_thread.update(cx, |thread, cx| {
2794            thread.send(
2795                vec![
2796                    "What does ".into(),
2797                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2798                        "b.md",
2799                        MentionUri::File {
2800                            abs_path: path!("/a/b.md").into(),
2801                        }
2802                        .to_uri()
2803                        .to_string(),
2804                    )),
2805                    " mean?".into(),
2806                ],
2807                cx,
2808            )
2809        });
2810        let send = cx.foreground_executor().spawn(send);
2811        cx.run_until_parked();
2812
2813        model.send_last_completion_stream_text_chunk("Lorem.");
2814        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2815            language_model::TokenUsage {
2816                input_tokens: 150,
2817                output_tokens: 75,
2818                ..Default::default()
2819            },
2820        ));
2821        model.end_last_completion_stream();
2822        cx.run_until_parked();
2823        summary_model
2824            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2825        summary_model.end_last_completion_stream();
2826
2827        send.await.unwrap();
2828        let uri = MentionUri::File {
2829            abs_path: path!("/a/b.md").into(),
2830        }
2831        .to_uri();
2832        acp_thread.read_with(cx, |thread, cx| {
2833            assert_eq!(
2834                thread.to_markdown(cx),
2835                formatdoc! {"
2836                    ## User
2837
2838                    What does [@b.md]({uri}) mean?
2839
2840                    ## Assistant
2841
2842                    Lorem.
2843
2844                "}
2845            )
2846        });
2847
2848        cx.run_until_parked();
2849
2850        // Set a draft prompt with rich content blocks and scroll position
2851        // AFTER run_until_parked, so the only save that captures these
2852        // changes is the one performed by close_session itself.
2853        let draft_blocks = vec![
2854            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2855            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2856            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2857        ];
2858        acp_thread.update(cx, |thread, _cx| {
2859            thread.set_draft_prompt(Some(draft_blocks.clone()));
2860        });
2861        thread.update(cx, |thread, _cx| {
2862            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2863                item_ix: 5,
2864                offset_in_item: gpui::px(12.5),
2865            }));
2866        });
2867
2868        // Close the session so it can be reloaded from disk.
2869        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2870            .await
2871            .unwrap();
2872        drop(thread);
2873        drop(acp_thread);
2874        agent.read_with(cx, |agent, _| {
2875            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2876        });
2877
2878        // Ensure the thread can be reloaded from disk.
2879        assert_eq!(
2880            thread_entries(&thread_store, cx),
2881            vec![(
2882                session_id.clone(),
2883                format!("Explaining {}", path!("/a/b.md"))
2884            )]
2885        );
2886        let acp_thread = agent
2887            .update(cx, |agent, cx| {
2888                agent.open_thread(session_id.clone(), project.clone(), cx)
2889            })
2890            .await
2891            .unwrap();
2892        acp_thread.read_with(cx, |thread, cx| {
2893            assert_eq!(
2894                thread.to_markdown(cx),
2895                formatdoc! {"
2896                    ## User
2897
2898                    What does [@b.md]({uri}) mean?
2899
2900                    ## Assistant
2901
2902                    Lorem.
2903
2904                "}
2905            )
2906        });
2907
2908        // Ensure the draft prompt with rich content blocks survived the round-trip.
2909        acp_thread.read_with(cx, |thread, _| {
2910            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2911        });
2912
2913        // Ensure token usage survived the round-trip.
2914        acp_thread.read_with(cx, |thread, _| {
2915            let usage = thread
2916                .token_usage()
2917                .expect("token usage should be restored after reload");
2918            assert_eq!(usage.input_tokens, 150);
2919            assert_eq!(usage.output_tokens, 75);
2920        });
2921
2922        // Ensure scroll position survived the round-trip.
2923        acp_thread.read_with(cx, |thread, _| {
2924            let scroll = thread
2925                .ui_scroll_position()
2926                .expect("scroll position should be restored after reload");
2927            assert_eq!(scroll.item_ix, 5);
2928            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2929        });
2930    }
2931
2932    #[gpui::test]
2933    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2934        init_test(cx);
2935        let fs = FakeFs::new(cx.executor());
2936        fs.insert_tree(
2937            "/",
2938            json!({
2939                "a": {
2940                    "file.txt": "hello"
2941                }
2942            }),
2943        )
2944        .await;
2945        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2946        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2947        let agent = cx.update(|cx| {
2948            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2949        });
2950        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2951
2952        let acp_thread = cx
2953            .update(|cx| {
2954                connection
2955                    .clone()
2956                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2957            })
2958            .await
2959            .unwrap();
2960        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2961        let thread = agent.read_with(cx, |agent, _| {
2962            agent.sessions.get(&session_id).unwrap().thread.clone()
2963        });
2964
2965        let model = Arc::new(FakeLanguageModel::default());
2966        thread.update(cx, |thread, cx| {
2967            thread.set_model(model.clone(), cx);
2968        });
2969
2970        // Send a message so the thread is non-empty (empty threads aren't saved).
2971        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2972        let send = cx.foreground_executor().spawn(send);
2973        cx.run_until_parked();
2974
2975        model.send_last_completion_stream_text_chunk("world");
2976        model.end_last_completion_stream();
2977        send.await.unwrap();
2978        cx.run_until_parked();
2979
2980        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2981        // This means no observe-triggered save has run for this change.
2982        // The only way this data gets persisted is if close_session
2983        // itself performs the save.
2984        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2985            "unsaved draft",
2986        ))];
2987        acp_thread.update(cx, |thread, _cx| {
2988            thread.set_draft_prompt(Some(draft_blocks.clone()));
2989        });
2990
2991        // Close the session immediately — no run_until_parked in between.
2992        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2993            .await
2994            .unwrap();
2995        cx.run_until_parked();
2996
2997        // Reopen and verify the draft prompt was saved.
2998        let reloaded = agent
2999            .update(cx, |agent, cx| {
3000                agent.open_thread(session_id.clone(), project.clone(), cx)
3001            })
3002            .await
3003            .unwrap();
3004        reloaded.read_with(cx, |thread, _| {
3005            assert_eq!(
3006                thread.draft_prompt(),
3007                Some(draft_blocks.as_slice()),
3008                "close_session must save the thread; draft prompt was lost"
3009            );
3010        });
3011    }
3012
3013    #[gpui::test]
3014    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3015        // Regression test: rapid title changes must not cause a propagation loop
3016        // between Thread and AcpThread via handle_thread_title_updated.
3017        init_test(cx);
3018        let fs = FakeFs::new(cx.executor());
3019        fs.insert_tree("/", json!({ "a": {} })).await;
3020        let project = Project::test(fs.clone(), [], cx).await;
3021        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3022        let agent = cx.update(|cx| {
3023            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3024        });
3025        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3026
3027        let acp_thread = cx
3028            .update(|cx| {
3029                connection
3030                    .clone()
3031                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3032            })
3033            .await
3034            .unwrap();
3035
3036        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3037        let thread = agent.read_with(cx, |agent, _| {
3038            agent.sessions.get(&session_id).unwrap().thread.clone()
3039        });
3040
3041        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3042        cx.update(|cx| {
3043            let count = title_updated_count.clone();
3044            cx.subscribe(
3045                &thread,
3046                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3047                    let new_count = {
3048                        let mut count = count.borrow_mut();
3049                        *count += 1;
3050                        *count
3051                    };
3052                    assert!(
3053                        new_count <= 2,
3054                        "TitleUpdated fired {new_count} times; \
3055                         title updates are looping"
3056                    );
3057                },
3058            )
3059            .detach();
3060        });
3061
3062        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3063        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3064
3065        cx.run_until_parked();
3066
3067        thread.read_with(cx, |thread, _| {
3068            assert_eq!(thread.title(), Some("second".into()));
3069        });
3070        acp_thread.read_with(cx, |acp_thread, _| {
3071            assert_eq!(acp_thread.title(), Some("second".into()));
3072        });
3073
3074        assert_eq!(*title_updated_count.borrow(), 2);
3075    }
3076
3077    fn thread_entries(
3078        thread_store: &Entity<ThreadStore>,
3079        cx: &mut TestAppContext,
3080    ) -> Vec<(acp::SessionId, String)> {
3081        thread_store.read_with(cx, |store, _| {
3082            store
3083                .entries()
3084                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3085                .collect::<Vec<_>>()
3086        })
3087    }
3088
3089    fn init_test(cx: &mut TestAppContext) {
3090        env_logger::try_init().ok();
3091        cx.update(|cx| {
3092            let settings_store = SettingsStore::test(cx);
3093            cx.set_global(settings_store);
3094
3095            LanguageModelRegistry::test(cx);
3096        });
3097    }
3098}
3099
3100fn mcp_message_content_to_acp_content_block(
3101    content: context_server::types::MessageContent,
3102) -> acp::ContentBlock {
3103    match content {
3104        context_server::types::MessageContent::Text {
3105            text,
3106            annotations: _,
3107        } => text.into(),
3108        context_server::types::MessageContent::Image {
3109            data,
3110            mime_type,
3111            annotations: _,
3112        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3113        context_server::types::MessageContent::Audio {
3114            data,
3115            mime_type,
3116            annotations: _,
3117        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3118        context_server::types::MessageContent::Resource {
3119            resource,
3120            annotations: _,
3121        } => {
3122            let mut link =
3123                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3124            if let Some(mime_type) = resource.mime_type {
3125                link = link.mime_type(mime_type);
3126            }
3127            acp::ContentBlock::ResourceLink(link)
3128        }
3129    }
3130}