agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod skills;
   8mod templates;
   9#[cfg(test)]
  10mod tests;
  11mod thread;
  12mod thread_store;
  13mod tool_permissions;
  14mod tools;
  15
  16use context_server::ContextServerId;
  17pub use db::*;
  18use itertools::Itertools;
  19pub use native_agent_server::NativeAgentServer;
  20pub use pattern_extraction::*;
  21pub use shell_command_parser::extract_commands;
  22pub use skills::*;
  23pub use templates::*;
  24pub use thread::*;
  25pub use thread_store::*;
  26pub use tool_permissions::*;
  27pub use tools::*;
  28
  29use acp_thread::{
  30    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  31    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  32};
  33use agent_client_protocol as acp;
  34use anyhow::{Context as _, Result, anyhow};
  35use chrono::{DateTime, Utc};
  36use collections::{HashMap, HashSet, IndexMap};
  37use fs::Fs;
  38use futures::channel::{mpsc, oneshot};
  39use futures::future::Shared;
  40use futures::{FutureExt as _, StreamExt as _, future};
  41use gpui::{
  42    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  43    WeakEntity,
  44};
  45use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  46use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  47use prompt_store::{
  48    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  49    WorktreeContext,
  50};
  51use serde::{Deserialize, Serialize};
  52use settings::{LanguageModelSelection, update_settings_file};
  53use std::any::Any;
  54use std::path::PathBuf;
  55use std::rc::Rc;
  56use std::sync::{Arc, LazyLock};
  57use util::ResultExt;
  58use util::path_list::PathList;
  59use util::rel_path::RelPath;
  60
  61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  62pub struct ProjectSnapshot {
  63    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  64    pub timestamp: DateTime<Utc>,
  65}
  66
  67pub struct RulesLoadingError {
  68    pub message: SharedString,
  69}
  70
  71struct ProjectState {
  72    project: Entity<Project>,
  73    project_context: Entity<ProjectContext>,
  74    project_context_needs_refresh: watch::Sender<()>,
  75    _maintain_project_context: Task<Result<()>>,
  76    context_server_registry: Entity<ContextServerRegistry>,
  77    _subscriptions: Vec<Subscription>,
  78}
  79
  80/// Holds both the internal Thread and the AcpThread for a session
  81struct Session {
  82    /// The internal thread that processes messages
  83    thread: Entity<Thread>,
  84    /// The ACP thread that handles protocol communication
  85    acp_thread: Entity<acp_thread::AcpThread>,
  86    project_id: EntityId,
  87    pending_save: Task<Result<()>>,
  88    _subscriptions: Vec<Subscription>,
  89}
  90
  91pub struct LanguageModels {
  92    /// Access language model by ID
  93    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  94    /// Cached list for returning language model information
  95    model_list: acp_thread::AgentModelList,
  96    refresh_models_rx: watch::Receiver<()>,
  97    refresh_models_tx: watch::Sender<()>,
  98    _authenticate_all_providers_task: Task<()>,
  99}
 100
 101impl LanguageModels {
 102    fn new(cx: &mut App) -> Self {
 103        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 104
 105        let mut this = Self {
 106            models: HashMap::default(),
 107            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 108            refresh_models_rx,
 109            refresh_models_tx,
 110            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 111        };
 112        this.refresh_list(cx);
 113        this
 114    }
 115
 116    fn refresh_list(&mut self, cx: &App) {
 117        let providers = LanguageModelRegistry::global(cx)
 118            .read(cx)
 119            .visible_providers()
 120            .into_iter()
 121            .filter(|provider| provider.is_authenticated(cx))
 122            .collect::<Vec<_>>();
 123
 124        let mut language_model_list = IndexMap::default();
 125        let mut recommended_models = HashSet::default();
 126
 127        let mut recommended = Vec::new();
 128        for provider in &providers {
 129            for model in provider.recommended_models(cx) {
 130                recommended_models.insert((model.provider_id(), model.id()));
 131                recommended.push(Self::map_language_model_to_info(&model, provider));
 132            }
 133        }
 134        if !recommended.is_empty() {
 135            language_model_list.insert(
 136                acp_thread::AgentModelGroupName("Recommended".into()),
 137                recommended,
 138            );
 139        }
 140
 141        let mut models = HashMap::default();
 142        for provider in providers {
 143            let mut provider_models = Vec::new();
 144            for model in provider.provided_models(cx) {
 145                let model_info = Self::map_language_model_to_info(&model, &provider);
 146                let model_id = model_info.id.clone();
 147                provider_models.push(model_info);
 148                models.insert(model_id, model);
 149            }
 150            if !provider_models.is_empty() {
 151                language_model_list.insert(
 152                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 153                    provider_models,
 154                );
 155            }
 156        }
 157
 158        self.models = models;
 159        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 160        self.refresh_models_tx.send(()).ok();
 161    }
 162
 163    fn watch(&self) -> watch::Receiver<()> {
 164        self.refresh_models_rx.clone()
 165    }
 166
 167    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 168        self.models.get(model_id).cloned()
 169    }
 170
 171    fn map_language_model_to_info(
 172        model: &Arc<dyn LanguageModel>,
 173        provider: &Arc<dyn LanguageModelProvider>,
 174    ) -> acp_thread::AgentModelInfo {
 175        acp_thread::AgentModelInfo {
 176            id: Self::model_id(model),
 177            name: model.name().0,
 178            description: None,
 179            icon: Some(match provider.icon() {
 180                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 181                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 182            }),
 183            is_latest: model.is_latest(),
 184            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 185        }
 186    }
 187
 188    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 189        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 190    }
 191
 192    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 193        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 194            .read(cx)
 195            .visible_providers()
 196            .iter()
 197            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 198            .collect::<Vec<_>>();
 199
 200        cx.background_spawn(async move {
 201            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 202                if let Err(err) = authenticate_task.await {
 203                    match err {
 204                        language_model::AuthenticateError::CredentialsNotFound => {
 205                            // Since we're authenticating these providers in the
 206                            // background for the purposes of populating the
 207                            // language selector, we don't care about providers
 208                            // where the credentials are not found.
 209                        }
 210                        language_model::AuthenticateError::ConnectionRefused => {
 211                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 212                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 213                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 214                        }
 215                        _ => {
 216                            // Some providers have noisy failure states that we
 217                            // don't want to spam the logs with every time the
 218                            // language model selector is initialized.
 219                            //
 220                            // Ideally these should have more clear failure modes
 221                            // that we know are safe to ignore here, like what we do
 222                            // with `CredentialsNotFound` above.
 223                            match provider_id.0.as_ref() {
 224                                "lmstudio" | "ollama" => {
 225                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 226                                    //
 227                                    // These fail noisily, so we don't log them.
 228                                }
 229                                "copilot_chat" => {
 230                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 231                                }
 232                                _ => {
 233                                    log::error!(
 234                                        "Failed to authenticate provider: {}: {err:#}",
 235                                        provider_name.0
 236                                    );
 237                                }
 238                            }
 239                        }
 240                    }
 241                }
 242            }
 243        })
 244    }
 245}
 246
 247pub struct NativeAgent {
 248    /// Session ID -> Session mapping
 249    sessions: HashMap<acp::SessionId, Session>,
 250    thread_store: Entity<ThreadStore>,
 251    /// Project-specific state keyed by project EntityId
 252    projects: HashMap<EntityId, ProjectState>,
 253    /// Shared templates for all threads
 254    templates: Arc<Templates>,
 255    /// Cached model information
 256    models: LanguageModels,
 257    prompt_store: Option<Entity<PromptStore>>,
 258    fs: Arc<dyn Fs>,
 259    _subscriptions: Vec<Subscription>,
 260}
 261
 262impl NativeAgent {
 263    pub fn new(
 264        thread_store: Entity<ThreadStore>,
 265        templates: Arc<Templates>,
 266        prompt_store: Option<Entity<PromptStore>>,
 267        fs: Arc<dyn Fs>,
 268        cx: &mut App,
 269    ) -> Entity<NativeAgent> {
 270        log::debug!("Creating new NativeAgent");
 271
 272        cx.new(|cx| {
 273            let mut subscriptions = vec![cx.subscribe(
 274                &LanguageModelRegistry::global(cx),
 275                Self::handle_models_updated_event,
 276            )];
 277            if let Some(prompt_store) = prompt_store.as_ref() {
 278                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 279            }
 280
 281            Self {
 282                sessions: HashMap::default(),
 283                thread_store,
 284                projects: HashMap::default(),
 285                templates,
 286                models: LanguageModels::new(cx),
 287                prompt_store,
 288                fs,
 289                _subscriptions: subscriptions,
 290            }
 291        })
 292    }
 293
 294    fn new_session(
 295        &mut self,
 296        project: Entity<Project>,
 297        cx: &mut Context<Self>,
 298    ) -> Entity<AcpThread> {
 299        let project_id = self.get_or_create_project_state(&project, cx);
 300        let project_state = &self.projects[&project_id];
 301
 302        let registry = LanguageModelRegistry::read_global(cx);
 303        let available_count = registry.available_models(cx).count();
 304        log::debug!("Total available models: {}", available_count);
 305
 306        let default_model = registry.default_model().and_then(|default_model| {
 307            self.models
 308                .model_from_id(&LanguageModels::model_id(&default_model.model))
 309        });
 310        let thread = cx.new(|cx| {
 311            Thread::new(
 312                project,
 313                project_state.project_context.clone(),
 314                project_state.context_server_registry.clone(),
 315                self.templates.clone(),
 316                default_model,
 317                cx,
 318            )
 319        });
 320
 321        self.register_session(thread, project_id, cx)
 322    }
 323
 324    fn register_session(
 325        &mut self,
 326        thread_handle: Entity<Thread>,
 327        project_id: EntityId,
 328        cx: &mut Context<Self>,
 329    ) -> Entity<AcpThread> {
 330        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 331
 332        let thread = thread_handle.read(cx);
 333        let session_id = thread.id().clone();
 334        let parent_session_id = thread.parent_thread_id();
 335        let title = thread.title();
 336        let draft_prompt = thread.draft_prompt().map(Vec::from);
 337        let scroll_position = thread.ui_scroll_position();
 338        let token_usage = thread.latest_token_usage();
 339        let project = thread.project.clone();
 340        let action_log = thread.action_log.clone();
 341        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 342        let acp_thread = cx.new(|cx| {
 343            let mut acp_thread = acp_thread::AcpThread::new(
 344                parent_session_id,
 345                title,
 346                None,
 347                connection,
 348                project.clone(),
 349                action_log.clone(),
 350                session_id.clone(),
 351                prompt_capabilities_rx,
 352                cx,
 353            );
 354            acp_thread.set_draft_prompt(draft_prompt);
 355            acp_thread.set_ui_scroll_position(scroll_position);
 356            acp_thread.update_token_usage(token_usage, cx);
 357            acp_thread
 358        });
 359
 360        let registry = LanguageModelRegistry::read_global(cx);
 361        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 362
 363        let weak = cx.weak_entity();
 364        let weak_thread = thread_handle.downgrade();
 365        thread_handle.update(cx, |thread, cx| {
 366            thread.set_summarization_model(summarization_model, cx);
 367            thread.add_default_tools(
 368                Rc::new(NativeThreadEnvironment {
 369                    acp_thread: acp_thread.downgrade(),
 370                    thread: weak_thread,
 371                    agent: weak,
 372                }) as _,
 373                cx,
 374            )
 375        });
 376
 377        let subscriptions = vec![
 378            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 379            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 380            cx.observe(&thread_handle, move |this, thread, cx| {
 381                this.save_thread(thread, cx)
 382            }),
 383        ];
 384
 385        self.sessions.insert(
 386            session_id,
 387            Session {
 388                thread: thread_handle,
 389                acp_thread: acp_thread.clone(),
 390                project_id,
 391                _subscriptions: subscriptions,
 392                pending_save: Task::ready(Ok(())),
 393            },
 394        );
 395
 396        self.update_available_commands_for_project(project_id, cx);
 397
 398        acp_thread
 399    }
 400
 401    pub fn models(&self) -> &LanguageModels {
 402        &self.models
 403    }
 404
 405    fn get_or_create_project_state(
 406        &mut self,
 407        project: &Entity<Project>,
 408        cx: &mut Context<Self>,
 409    ) -> EntityId {
 410        let project_id = project.entity_id();
 411        if self.projects.contains_key(&project_id) {
 412            return project_id;
 413        }
 414
 415        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 416        self.register_project_with_initial_context(project.clone(), project_context, cx);
 417        if let Some(state) = self.projects.get_mut(&project_id) {
 418            state.project_context_needs_refresh.send(()).ok();
 419        }
 420        project_id
 421    }
 422
 423    fn register_project_with_initial_context(
 424        &mut self,
 425        project: Entity<Project>,
 426        project_context: Entity<ProjectContext>,
 427        cx: &mut Context<Self>,
 428    ) {
 429        let project_id = project.entity_id();
 430
 431        let context_server_store = project.read(cx).context_server_store();
 432        let context_server_registry =
 433            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 434
 435        let subscriptions = vec![
 436            cx.subscribe(&project, Self::handle_project_event),
 437            cx.subscribe(
 438                &context_server_store,
 439                Self::handle_context_server_store_updated,
 440            ),
 441            cx.subscribe(
 442                &context_server_registry,
 443                Self::handle_context_server_registry_event,
 444            ),
 445        ];
 446
 447        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 448            watch::channel(());
 449
 450        self.projects.insert(
 451            project_id,
 452            ProjectState {
 453                project,
 454                project_context,
 455                project_context_needs_refresh: project_context_needs_refresh_tx,
 456                _maintain_project_context: cx.spawn(async move |this, cx| {
 457                    Self::maintain_project_context(
 458                        this,
 459                        project_id,
 460                        project_context_needs_refresh_rx,
 461                        cx,
 462                    )
 463                    .await
 464                }),
 465                context_server_registry,
 466                _subscriptions: subscriptions,
 467            },
 468        );
 469    }
 470
 471    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 472        self.sessions
 473            .get(session_id)
 474            .and_then(|session| self.projects.get(&session.project_id))
 475    }
 476
 477    async fn maintain_project_context(
 478        this: WeakEntity<Self>,
 479        project_id: EntityId,
 480        mut needs_refresh: watch::Receiver<()>,
 481        cx: &mut AsyncApp,
 482    ) -> Result<()> {
 483        while needs_refresh.changed().await.is_ok() {
 484            let project_context = this
 485                .update(cx, |this, cx| {
 486                    let state = this
 487                        .projects
 488                        .get(&project_id)
 489                        .context("project state not found")?;
 490                    anyhow::Ok(Self::build_project_context(
 491                        &state.project,
 492                        this.prompt_store.as_ref(),
 493                        cx,
 494                    ))
 495                })??
 496                .await;
 497            this.update(cx, |this, cx| {
 498                if let Some(state) = this.projects.get(&project_id) {
 499                    state
 500                        .project_context
 501                        .update(cx, |current_project_context, _cx| {
 502                            *current_project_context = project_context;
 503                        });
 504                }
 505            })?;
 506        }
 507
 508        Ok(())
 509    }
 510
 511    fn build_project_context(
 512        project: &Entity<Project>,
 513        prompt_store: Option<&Entity<PromptStore>>,
 514        cx: &mut App,
 515    ) -> Task<ProjectContext> {
 516        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 517        let worktree_tasks = worktrees
 518            .into_iter()
 519            .map(|worktree| {
 520                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 521            })
 522            .collect::<Vec<_>>();
 523        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 524            prompt_store.read_with(cx, |prompt_store, cx| {
 525                let prompts = prompt_store.default_prompt_metadata();
 526                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 527                    let contents = prompt_store.load(prompt_metadata.id, cx);
 528                    async move { (contents.await, prompt_metadata) }
 529                });
 530                cx.background_spawn(future::join_all(load_tasks))
 531            })
 532        } else {
 533            Task::ready(vec![])
 534        };
 535
 536        cx.spawn(async move |_cx| {
 537            let (worktrees, default_user_rules) =
 538                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 539
 540            let worktrees = worktrees
 541                .into_iter()
 542                .map(|(worktree, _rules_error)| {
 543                    // TODO: show error message
 544                    // if let Some(rules_error) = rules_error {
 545                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 546                    // }
 547                    worktree
 548                })
 549                .collect::<Vec<_>>();
 550
 551            let default_user_rules = default_user_rules
 552                .into_iter()
 553                .flat_map(|(contents, prompt_metadata)| match contents {
 554                    Ok(contents) => Some(UserRulesContext {
 555                        uuid: prompt_metadata.id.as_user()?,
 556                        title: prompt_metadata.title.map(|title| title.to_string()),
 557                        contents,
 558                    }),
 559                    Err(_err) => {
 560                        // TODO: show error message
 561                        // this.update(cx, |_, cx| {
 562                        //     cx.emit(RulesLoadingError {
 563                        //         message: format!("{err:?}").into(),
 564                        //     });
 565                        // })
 566                        // .ok();
 567                        None
 568                    }
 569                })
 570                .collect::<Vec<_>>();
 571
 572            ProjectContext::new(worktrees, default_user_rules)
 573        })
 574    }
 575
 576    fn load_worktree_info_for_system_prompt(
 577        worktree: Entity<Worktree>,
 578        project: Entity<Project>,
 579        cx: &mut App,
 580    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 581        let tree = worktree.read(cx);
 582        let root_name = tree.root_name_str().into();
 583        let abs_path = tree.abs_path();
 584
 585        let mut context = WorktreeContext {
 586            root_name,
 587            abs_path,
 588            rules_file: None,
 589        };
 590
 591        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 592        let Some(rules_task) = rules_task else {
 593            return Task::ready((context, None));
 594        };
 595
 596        cx.spawn(async move |_| {
 597            let (rules_file, rules_file_error) = match rules_task.await {
 598                Ok(rules_file) => (Some(rules_file), None),
 599                Err(err) => (
 600                    None,
 601                    Some(RulesLoadingError {
 602                        message: format!("{err}").into(),
 603                    }),
 604                ),
 605            };
 606            context.rules_file = rules_file;
 607            (context, rules_file_error)
 608        })
 609    }
 610
 611    fn load_worktree_rules_file(
 612        worktree: Entity<Worktree>,
 613        project: Entity<Project>,
 614        cx: &mut App,
 615    ) -> Option<Task<Result<RulesFileContext>>> {
 616        let worktree = worktree.read(cx);
 617        let worktree_id = worktree.id();
 618        let selected_rules_file = RULES_FILE_NAMES
 619            .into_iter()
 620            .filter_map(|name| {
 621                worktree
 622                    .entry_for_path(RelPath::unix(name).unwrap())
 623                    .filter(|entry| entry.is_file())
 624                    .map(|entry| entry.path.clone())
 625            })
 626            .next();
 627
 628        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 629        // supported. This doesn't seem to occur often in GitHub repositories.
 630        selected_rules_file.map(|path_in_worktree| {
 631            let project_path = ProjectPath {
 632                worktree_id,
 633                path: path_in_worktree.clone(),
 634            };
 635            let buffer_task =
 636                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 637            let rope_task = cx.spawn(async move |cx| {
 638                let buffer = buffer_task.await?;
 639                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 640                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 641                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 642                })?;
 643                anyhow::Ok((project_entry_id, rope))
 644            });
 645            // Build a string from the rope on a background thread.
 646            cx.background_spawn(async move {
 647                let (project_entry_id, rope) = rope_task.await?;
 648                anyhow::Ok(RulesFileContext {
 649                    path_in_worktree,
 650                    text: rope.to_string().trim().to_string(),
 651                    project_entry_id: project_entry_id.to_usize(),
 652                })
 653            })
 654        })
 655    }
 656
 657    fn handle_thread_title_updated(
 658        &mut self,
 659        thread: Entity<Thread>,
 660        _: &TitleUpdated,
 661        cx: &mut Context<Self>,
 662    ) {
 663        let session_id = thread.read(cx).id();
 664        let Some(session) = self.sessions.get(session_id) else {
 665            return;
 666        };
 667
 668        let thread = thread.downgrade();
 669        let acp_thread = session.acp_thread.downgrade();
 670        cx.spawn(async move |_, cx| {
 671            let title = thread.read_with(cx, |thread, _| thread.title())?;
 672            if let Some(title) = title {
 673                let task =
 674                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 675                task.await?;
 676            }
 677            anyhow::Ok(())
 678        })
 679        .detach_and_log_err(cx);
 680    }
 681
 682    fn handle_thread_token_usage_updated(
 683        &mut self,
 684        thread: Entity<Thread>,
 685        usage: &TokenUsageUpdated,
 686        cx: &mut Context<Self>,
 687    ) {
 688        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 689            return;
 690        };
 691        session.acp_thread.update(cx, |acp_thread, cx| {
 692            acp_thread.update_token_usage(usage.0.clone(), cx);
 693        });
 694    }
 695
 696    fn handle_project_event(
 697        &mut self,
 698        project: Entity<Project>,
 699        event: &project::Event,
 700        _cx: &mut Context<Self>,
 701    ) {
 702        let project_id = project.entity_id();
 703        let Some(state) = self.projects.get_mut(&project_id) else {
 704            return;
 705        };
 706        match event {
 707            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 708                state.project_context_needs_refresh.send(()).ok();
 709            }
 710            project::Event::WorktreeUpdatedEntries(_, items) => {
 711                if items.iter().any(|(path, _, _)| {
 712                    RULES_FILE_NAMES
 713                        .iter()
 714                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 715                }) {
 716                    state.project_context_needs_refresh.send(()).ok();
 717                }
 718            }
 719            _ => {}
 720        }
 721    }
 722
 723    fn handle_prompts_updated_event(
 724        &mut self,
 725        _prompt_store: Entity<PromptStore>,
 726        _event: &prompt_store::PromptsUpdatedEvent,
 727        _cx: &mut Context<Self>,
 728    ) {
 729        for state in self.projects.values_mut() {
 730            state.project_context_needs_refresh.send(()).ok();
 731        }
 732    }
 733
 734    fn handle_models_updated_event(
 735        &mut self,
 736        _registry: Entity<LanguageModelRegistry>,
 737        event: &language_model::Event,
 738        cx: &mut Context<Self>,
 739    ) {
 740        self.models.refresh_list(cx);
 741
 742        let registry = LanguageModelRegistry::read_global(cx);
 743        let default_model = registry.default_model().map(|m| m.model);
 744        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 745
 746        for session in self.sessions.values_mut() {
 747            session.thread.update(cx, |thread, cx| {
 748                if thread.model().is_none()
 749                    && let Some(model) = default_model.clone()
 750                {
 751                    thread.set_model(model, cx);
 752                    cx.notify();
 753                }
 754                if let Some(model) = summarization_model.clone() {
 755                    if thread.summarization_model().is_none()
 756                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 757                    {
 758                        thread.set_summarization_model(Some(model), cx);
 759                    }
 760                }
 761            });
 762        }
 763    }
 764
 765    fn handle_context_server_store_updated(
 766        &mut self,
 767        store: Entity<project::context_server_store::ContextServerStore>,
 768        _event: &project::context_server_store::ServerStatusChangedEvent,
 769        cx: &mut Context<Self>,
 770    ) {
 771        let project_id = self.projects.iter().find_map(|(id, state)| {
 772            if *state.context_server_registry.read(cx).server_store() == store {
 773                Some(*id)
 774            } else {
 775                None
 776            }
 777        });
 778        if let Some(project_id) = project_id {
 779            self.update_available_commands_for_project(project_id, cx);
 780        }
 781    }
 782
 783    fn handle_context_server_registry_event(
 784        &mut self,
 785        registry: Entity<ContextServerRegistry>,
 786        event: &ContextServerRegistryEvent,
 787        cx: &mut Context<Self>,
 788    ) {
 789        match event {
 790            ContextServerRegistryEvent::ToolsChanged => {}
 791            ContextServerRegistryEvent::PromptsChanged => {
 792                let project_id = self.projects.iter().find_map(|(id, state)| {
 793                    if state.context_server_registry == registry {
 794                        Some(*id)
 795                    } else {
 796                        None
 797                    }
 798                });
 799                if let Some(project_id) = project_id {
 800                    self.update_available_commands_for_project(project_id, cx);
 801                }
 802            }
 803        }
 804    }
 805
 806    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 807        let available_commands =
 808            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 809        for session in self.sessions.values() {
 810            if session.project_id != project_id {
 811                continue;
 812            }
 813            session.acp_thread.update(cx, |thread, cx| {
 814                thread
 815                    .handle_session_update(
 816                        acp::SessionUpdate::AvailableCommandsUpdate(
 817                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 818                        ),
 819                        cx,
 820                    )
 821                    .log_err();
 822            });
 823        }
 824    }
 825
 826    fn build_available_commands_for_project(
 827        project_state: Option<&ProjectState>,
 828        cx: &App,
 829    ) -> Vec<acp::AvailableCommand> {
 830        let Some(state) = project_state else {
 831            return vec![];
 832        };
 833        let registry = state.context_server_registry.read(cx);
 834
 835        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 836        for context_server_prompt in registry.prompts() {
 837            *prompt_name_counts
 838                .entry(context_server_prompt.prompt.name.as_str())
 839                .or_insert(0) += 1;
 840        }
 841
 842        registry
 843            .prompts()
 844            .flat_map(|context_server_prompt| {
 845                let prompt = &context_server_prompt.prompt;
 846
 847                let should_prefix = prompt_name_counts
 848                    .get(prompt.name.as_str())
 849                    .copied()
 850                    .unwrap_or(0)
 851                    > 1;
 852
 853                let name = if should_prefix {
 854                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 855                } else {
 856                    prompt.name.clone()
 857                };
 858
 859                let mut command = acp::AvailableCommand::new(
 860                    name,
 861                    prompt.description.clone().unwrap_or_default(),
 862                );
 863
 864                match prompt.arguments.as_deref() {
 865                    Some([arg]) => {
 866                        let hint = format!("<{}>", arg.name);
 867
 868                        command = command.input(acp::AvailableCommandInput::Unstructured(
 869                            acp::UnstructuredCommandInput::new(hint),
 870                        ));
 871                    }
 872                    Some([]) | None => {}
 873                    Some(_) => {
 874                        // skip >1 argument commands since we don't support them yet
 875                        return None;
 876                    }
 877                }
 878
 879                Some(command)
 880            })
 881            .collect()
 882    }
 883
 884    pub fn load_thread(
 885        &mut self,
 886        id: acp::SessionId,
 887        project: Entity<Project>,
 888        cx: &mut Context<Self>,
 889    ) -> Task<Result<Entity<Thread>>> {
 890        let database_future = ThreadsDatabase::connect(cx);
 891        cx.spawn(async move |this, cx| {
 892            let database = database_future.await.map_err(|err| anyhow!(err))?;
 893            let db_thread = database
 894                .load_thread(id.clone())
 895                .await?
 896                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 897
 898            this.update(cx, |this, cx| {
 899                let project_id = this.get_or_create_project_state(&project, cx);
 900                let project_state = this
 901                    .projects
 902                    .get(&project_id)
 903                    .context("project state not found")?;
 904                let summarization_model = LanguageModelRegistry::read_global(cx)
 905                    .thread_summary_model()
 906                    .map(|c| c.model);
 907
 908                Ok(cx.new(|cx| {
 909                    let mut thread = Thread::from_db(
 910                        id.clone(),
 911                        db_thread,
 912                        project_state.project.clone(),
 913                        project_state.project_context.clone(),
 914                        project_state.context_server_registry.clone(),
 915                        this.templates.clone(),
 916                        cx,
 917                    );
 918                    thread.set_summarization_model(summarization_model, cx);
 919                    thread
 920                }))
 921            })?
 922        })
 923    }
 924
 925    pub fn open_thread(
 926        &mut self,
 927        id: acp::SessionId,
 928        project: Entity<Project>,
 929        cx: &mut Context<Self>,
 930    ) -> Task<Result<Entity<AcpThread>>> {
 931        if let Some(session) = self.sessions.get(&id) {
 932            return Task::ready(Ok(session.acp_thread.clone()));
 933        }
 934
 935        let task = self.load_thread(id, project.clone(), cx);
 936        cx.spawn(async move |this, cx| {
 937            let thread = task.await?;
 938            let acp_thread = this.update(cx, |this, cx| {
 939                let project_id = this.get_or_create_project_state(&project, cx);
 940                this.register_session(thread.clone(), project_id, cx)
 941            })?;
 942            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 943            cx.update(|cx| {
 944                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 945            })
 946            .await?;
 947            acp_thread.update(cx, |thread, cx| {
 948                thread.snapshot_completed_plan(cx);
 949            });
 950            Ok(acp_thread)
 951        })
 952    }
 953
 954    pub fn thread_summary(
 955        &mut self,
 956        id: acp::SessionId,
 957        project: Entity<Project>,
 958        cx: &mut Context<Self>,
 959    ) -> Task<Result<SharedString>> {
 960        let thread = self.open_thread(id.clone(), project, cx);
 961        cx.spawn(async move |this, cx| {
 962            let acp_thread = thread.await?;
 963            let result = this
 964                .update(cx, |this, cx| {
 965                    this.sessions
 966                        .get(&id)
 967                        .unwrap()
 968                        .thread
 969                        .update(cx, |thread, cx| thread.summary(cx))
 970                })?
 971                .await
 972                .context("Failed to generate summary")?;
 973            drop(acp_thread);
 974            Ok(result)
 975        })
 976    }
 977
 978    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 979        if thread.read(cx).is_empty() {
 980            return;
 981        }
 982
 983        let id = thread.read(cx).id().clone();
 984        let Some(session) = self.sessions.get_mut(&id) else {
 985            return;
 986        };
 987
 988        let project_id = session.project_id;
 989        let Some(state) = self.projects.get(&project_id) else {
 990            return;
 991        };
 992
 993        let folder_paths = PathList::new(
 994            &state
 995                .project
 996                .read(cx)
 997                .visible_worktrees(cx)
 998                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 999                .collect::<Vec<_>>(),
1000        );
1001
1002        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1003        let database_future = ThreadsDatabase::connect(cx);
1004        let db_thread = thread.update(cx, |thread, cx| {
1005            thread.set_draft_prompt(draft_prompt);
1006            thread.to_db(cx)
1007        });
1008        let thread_store = self.thread_store.clone();
1009        session.pending_save = cx.spawn(async move |_, cx| {
1010            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1011                return Ok(());
1012            };
1013            let db_thread = db_thread.await;
1014            database
1015                .save_thread(id, db_thread, folder_paths)
1016                .await
1017                .log_err();
1018            thread_store.update(cx, |store, cx| store.reload(cx));
1019            Ok(())
1020        });
1021    }
1022
1023    fn send_mcp_prompt(
1024        &self,
1025        message_id: UserMessageId,
1026        session_id: acp::SessionId,
1027        prompt_name: String,
1028        server_id: ContextServerId,
1029        arguments: HashMap<String, String>,
1030        original_content: Vec<acp::ContentBlock>,
1031        cx: &mut Context<Self>,
1032    ) -> Task<Result<acp::PromptResponse>> {
1033        let Some(state) = self.session_project_state(&session_id) else {
1034            return Task::ready(Err(anyhow!("Project state not found for session")));
1035        };
1036        let server_store = state
1037            .context_server_registry
1038            .read(cx)
1039            .server_store()
1040            .clone();
1041        let path_style = state.project.read(cx).path_style(cx);
1042
1043        cx.spawn(async move |this, cx| {
1044            let prompt =
1045                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1046
1047            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1048                let session = this
1049                    .sessions
1050                    .get(&session_id)
1051                    .context("Failed to get session")?;
1052                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1053            })??;
1054
1055            let mut last_is_user = true;
1056
1057            thread.update(cx, |thread, cx| {
1058                thread.push_acp_user_block(
1059                    message_id,
1060                    original_content.into_iter().skip(1),
1061                    path_style,
1062                    cx,
1063                );
1064            });
1065
1066            for message in prompt.messages {
1067                let context_server::types::PromptMessage { role, content } = message;
1068                let block = mcp_message_content_to_acp_content_block(content);
1069
1070                match role {
1071                    context_server::types::Role::User => {
1072                        let id = acp_thread::UserMessageId::new();
1073
1074                        acp_thread.update(cx, |acp_thread, cx| {
1075                            acp_thread.push_user_content_block_with_indent(
1076                                Some(id.clone()),
1077                                block.clone(),
1078                                true,
1079                                cx,
1080                            );
1081                        });
1082
1083                        thread.update(cx, |thread, cx| {
1084                            thread.push_acp_user_block(id, [block], path_style, cx);
1085                        });
1086                    }
1087                    context_server::types::Role::Assistant => {
1088                        acp_thread.update(cx, |acp_thread, cx| {
1089                            acp_thread.push_assistant_content_block_with_indent(
1090                                block.clone(),
1091                                false,
1092                                true,
1093                                cx,
1094                            );
1095                        });
1096
1097                        thread.update(cx, |thread, cx| {
1098                            thread.push_acp_agent_block(block, cx);
1099                        });
1100                    }
1101                }
1102
1103                last_is_user = role == context_server::types::Role::User;
1104            }
1105
1106            let response_stream = thread.update(cx, |thread, cx| {
1107                if last_is_user {
1108                    thread.send_existing(cx)
1109                } else {
1110                    // Resume if MCP prompt did not end with a user message
1111                    thread.resume(cx)
1112                }
1113            })?;
1114
1115            cx.update(|cx| {
1116                NativeAgentConnection::handle_thread_events(
1117                    response_stream,
1118                    acp_thread.downgrade(),
1119                    cx,
1120                )
1121            })
1122            .await
1123        })
1124    }
1125}
1126
1127/// Wrapper struct that implements the AgentConnection trait
1128#[derive(Clone)]
1129pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1130
1131impl NativeAgentConnection {
1132    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1133        self.0
1134            .read(cx)
1135            .sessions
1136            .get(session_id)
1137            .map(|session| session.thread.clone())
1138    }
1139
1140    pub fn load_thread(
1141        &self,
1142        id: acp::SessionId,
1143        project: Entity<Project>,
1144        cx: &mut App,
1145    ) -> Task<Result<Entity<Thread>>> {
1146        self.0
1147            .update(cx, |this, cx| this.load_thread(id, project, cx))
1148    }
1149
1150    fn run_turn(
1151        &self,
1152        session_id: acp::SessionId,
1153        cx: &mut App,
1154        f: impl 'static
1155        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1156    ) -> Task<Result<acp::PromptResponse>> {
1157        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1158            agent
1159                .sessions
1160                .get_mut(&session_id)
1161                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1162        }) else {
1163            return Task::ready(Err(anyhow!("Session not found")));
1164        };
1165        log::debug!("Found session for: {}", session_id);
1166
1167        let response_stream = match f(thread, cx) {
1168            Ok(stream) => stream,
1169            Err(err) => return Task::ready(Err(err)),
1170        };
1171        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1172    }
1173
1174    fn handle_thread_events(
1175        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1176        acp_thread: WeakEntity<AcpThread>,
1177        cx: &App,
1178    ) -> Task<Result<acp::PromptResponse>> {
1179        cx.spawn(async move |cx| {
1180            // Handle response stream and forward to session.acp_thread
1181            while let Some(result) = events.next().await {
1182                match result {
1183                    Ok(event) => {
1184                        log::trace!("Received completion event: {:?}", event);
1185
1186                        match event {
1187                            ThreadEvent::UserMessage(message) => {
1188                                acp_thread.update(cx, |thread, cx| {
1189                                    for content in message.content {
1190                                        thread.push_user_content_block(
1191                                            Some(message.id.clone()),
1192                                            content.into(),
1193                                            cx,
1194                                        );
1195                                    }
1196                                })?;
1197                            }
1198                            ThreadEvent::AgentText(text) => {
1199                                acp_thread.update(cx, |thread, cx| {
1200                                    thread.push_assistant_content_block(text.into(), false, cx)
1201                                })?;
1202                            }
1203                            ThreadEvent::AgentThinking(text) => {
1204                                acp_thread.update(cx, |thread, cx| {
1205                                    thread.push_assistant_content_block(text.into(), true, cx)
1206                                })?;
1207                            }
1208                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1209                                tool_call,
1210                                options,
1211                                response,
1212                                context: _,
1213                            }) => {
1214                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1215                                    thread.request_tool_call_authorization(tool_call, options, cx)
1216                                })??;
1217                                cx.background_spawn(async move {
1218                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1219                                        outcome_task.await
1220                                    {
1221                                        response
1222                                            .send(outcome)
1223                                            .map(|_| anyhow!("authorization receiver was dropped"))
1224                                            .log_err();
1225                                    }
1226                                })
1227                                .detach();
1228                            }
1229                            ThreadEvent::ToolCall(tool_call) => {
1230                                acp_thread.update(cx, |thread, cx| {
1231                                    thread.upsert_tool_call(tool_call, cx)
1232                                })??;
1233                            }
1234                            ThreadEvent::ToolCallUpdate(update) => {
1235                                acp_thread.update(cx, |thread, cx| {
1236                                    thread.update_tool_call(update, cx)
1237                                })??;
1238                            }
1239                            ThreadEvent::Plan(plan) => {
1240                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1241                            }
1242                            ThreadEvent::SubagentSpawned(session_id) => {
1243                                acp_thread.update(cx, |thread, cx| {
1244                                    thread.subagent_spawned(session_id, cx);
1245                                })?;
1246                            }
1247                            ThreadEvent::Retry(status) => {
1248                                acp_thread.update(cx, |thread, cx| {
1249                                    thread.update_retry_status(status, cx)
1250                                })?;
1251                            }
1252                            ThreadEvent::Stop(stop_reason) => {
1253                                log::debug!("Assistant message complete: {:?}", stop_reason);
1254                                return Ok(acp::PromptResponse::new(stop_reason));
1255                            }
1256                        }
1257                    }
1258                    Err(e) => {
1259                        log::error!("Error in model response stream: {:?}", e);
1260                        return Err(e);
1261                    }
1262                }
1263            }
1264
1265            log::debug!("Response stream completed");
1266            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1267        })
1268    }
1269}
1270
1271struct Command<'a> {
1272    prompt_name: &'a str,
1273    arg_value: &'a str,
1274    explicit_server_id: Option<&'a str>,
1275}
1276
1277impl<'a> Command<'a> {
1278    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1279        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1280            return None;
1281        };
1282        let text = text_content.text.trim();
1283        let command = text.strip_prefix('/')?;
1284        let (command, arg_value) = command
1285            .split_once(char::is_whitespace)
1286            .unwrap_or((command, ""));
1287
1288        if let Some((server_id, prompt_name)) = command.split_once('.') {
1289            Some(Self {
1290                prompt_name,
1291                arg_value,
1292                explicit_server_id: Some(server_id),
1293            })
1294        } else {
1295            Some(Self {
1296                prompt_name: command,
1297                arg_value,
1298                explicit_server_id: None,
1299            })
1300        }
1301    }
1302}
1303
1304struct NativeAgentModelSelector {
1305    session_id: acp::SessionId,
1306    connection: NativeAgentConnection,
1307}
1308
1309impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1310    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1311        log::debug!("NativeAgentConnection::list_models called");
1312        let list = self.connection.0.read(cx).models.model_list.clone();
1313        Task::ready(if list.is_empty() {
1314            Err(anyhow::anyhow!("No models available"))
1315        } else {
1316            Ok(list)
1317        })
1318    }
1319
1320    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1321        log::debug!(
1322            "Setting model for session {}: {}",
1323            self.session_id,
1324            model_id
1325        );
1326        let Some(thread) = self
1327            .connection
1328            .0
1329            .read(cx)
1330            .sessions
1331            .get(&self.session_id)
1332            .map(|session| session.thread.clone())
1333        else {
1334            return Task::ready(Err(anyhow!("Session not found")));
1335        };
1336
1337        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1338            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1339        };
1340
1341        // We want to reset the effort level when switching models, as the currently-selected effort level may
1342        // not be compatible.
1343        let effort = model
1344            .default_effort_level()
1345            .map(|effort_level| effort_level.value.to_string());
1346
1347        thread.update(cx, |thread, cx| {
1348            thread.set_model(model.clone(), cx);
1349            thread.set_thinking_effort(effort.clone(), cx);
1350            thread.set_thinking_enabled(model.supports_thinking(), cx);
1351        });
1352
1353        update_settings_file(
1354            self.connection.0.read(cx).fs.clone(),
1355            cx,
1356            move |settings, cx| {
1357                let provider = model.provider_id().0.to_string();
1358                let model = model.id().0.to_string();
1359                let enable_thinking = thread.read(cx).thinking_enabled();
1360                settings
1361                    .agent
1362                    .get_or_insert_default()
1363                    .set_model(LanguageModelSelection {
1364                        provider: provider.into(),
1365                        model,
1366                        enable_thinking,
1367                        effort,
1368                    });
1369            },
1370        );
1371
1372        Task::ready(Ok(()))
1373    }
1374
1375    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1376        let Some(thread) = self
1377            .connection
1378            .0
1379            .read(cx)
1380            .sessions
1381            .get(&self.session_id)
1382            .map(|session| session.thread.clone())
1383        else {
1384            return Task::ready(Err(anyhow!("Session not found")));
1385        };
1386        let Some(model) = thread.read(cx).model() else {
1387            return Task::ready(Err(anyhow!("Model not found")));
1388        };
1389        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1390        else {
1391            return Task::ready(Err(anyhow!("Provider not found")));
1392        };
1393        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1394            model, &provider,
1395        )))
1396    }
1397
1398    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1399        Some(self.connection.0.read(cx).models.watch())
1400    }
1401
1402    fn should_render_footer(&self) -> bool {
1403        true
1404    }
1405}
1406
1407pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1408
1409impl acp_thread::AgentConnection for NativeAgentConnection {
1410    fn agent_id(&self) -> AgentId {
1411        ZED_AGENT_ID.clone()
1412    }
1413
1414    fn telemetry_id(&self) -> SharedString {
1415        "zed".into()
1416    }
1417
1418    fn new_session(
1419        self: Rc<Self>,
1420        project: Entity<Project>,
1421        work_dirs: PathList,
1422        cx: &mut App,
1423    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1424        log::debug!("Creating new thread for project at: {work_dirs:?}");
1425        Task::ready(Ok(self
1426            .0
1427            .update(cx, |agent, cx| agent.new_session(project, cx))))
1428    }
1429
1430    fn supports_load_session(&self) -> bool {
1431        true
1432    }
1433
1434    fn load_session(
1435        self: Rc<Self>,
1436        session_id: acp::SessionId,
1437        project: Entity<Project>,
1438        _work_dirs: PathList,
1439        _title: Option<SharedString>,
1440        cx: &mut App,
1441    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1442        self.0
1443            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1444    }
1445
1446    fn supports_close_session(&self) -> bool {
1447        true
1448    }
1449
1450    fn close_session(
1451        self: Rc<Self>,
1452        session_id: &acp::SessionId,
1453        cx: &mut App,
1454    ) -> Task<Result<()>> {
1455        self.0.update(cx, |agent, cx| {
1456            let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1457            if let Some(thread) = thread {
1458                agent.save_thread(thread, cx);
1459            }
1460
1461            let Some(session) = agent.sessions.remove(session_id) else {
1462                return Task::ready(Ok(()));
1463            };
1464            let project_id = session.project_id;
1465
1466            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1467            if !has_remaining {
1468                agent.projects.remove(&project_id);
1469            }
1470
1471            session.pending_save
1472        })
1473    }
1474
1475    fn auth_methods(&self) -> &[acp::AuthMethod] {
1476        &[] // No auth for in-process
1477    }
1478
1479    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1480        Task::ready(Ok(()))
1481    }
1482
1483    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1484        Some(Rc::new(NativeAgentModelSelector {
1485            session_id: session_id.clone(),
1486            connection: self.clone(),
1487        }) as Rc<dyn AgentModelSelector>)
1488    }
1489
1490    fn prompt(
1491        &self,
1492        id: Option<acp_thread::UserMessageId>,
1493        params: acp::PromptRequest,
1494        cx: &mut App,
1495    ) -> Task<Result<acp::PromptResponse>> {
1496        let id = id.expect("UserMessageId is required");
1497        let session_id = params.session_id.clone();
1498        log::info!("Received prompt request for session: {}", session_id);
1499        log::debug!("Prompt blocks count: {}", params.prompt.len());
1500
1501        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1502            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1503        };
1504
1505        if let Some(parsed_command) = Command::parse(&params.prompt) {
1506            let registry = project_state.context_server_registry.read(cx);
1507
1508            let explicit_server_id = parsed_command
1509                .explicit_server_id
1510                .map(|server_id| ContextServerId(server_id.into()));
1511
1512            if let Some(prompt) =
1513                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1514            {
1515                let arguments = if !parsed_command.arg_value.is_empty()
1516                    && let Some(arg_name) = prompt
1517                        .prompt
1518                        .arguments
1519                        .as_ref()
1520                        .and_then(|args| args.first())
1521                        .map(|arg| arg.name.clone())
1522                {
1523                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1524                } else {
1525                    Default::default()
1526                };
1527
1528                let prompt_name = prompt.prompt.name.clone();
1529                let server_id = prompt.server_id.clone();
1530
1531                return self.0.update(cx, |agent, cx| {
1532                    agent.send_mcp_prompt(
1533                        id,
1534                        session_id.clone(),
1535                        prompt_name,
1536                        server_id,
1537                        arguments,
1538                        params.prompt,
1539                        cx,
1540                    )
1541                });
1542            }
1543        };
1544
1545        let path_style = project_state.project.read(cx).path_style(cx);
1546
1547        self.run_turn(session_id, cx, move |thread, cx| {
1548            let content: Vec<UserMessageContent> = params
1549                .prompt
1550                .into_iter()
1551                .map(|block| UserMessageContent::from_content_block(block, path_style))
1552                .collect::<Vec<_>>();
1553            log::debug!("Converted prompt to message: {} chars", content.len());
1554            log::debug!("Message id: {:?}", id);
1555            log::debug!("Message content: {:?}", content);
1556
1557            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1558        })
1559    }
1560
1561    fn retry(
1562        &self,
1563        session_id: &acp::SessionId,
1564        _cx: &App,
1565    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1566        Some(Rc::new(NativeAgentSessionRetry {
1567            connection: self.clone(),
1568            session_id: session_id.clone(),
1569        }) as _)
1570    }
1571
1572    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1573        log::info!("Cancelling on session: {}", session_id);
1574        self.0.update(cx, |agent, cx| {
1575            if let Some(session) = agent.sessions.get(session_id) {
1576                session
1577                    .thread
1578                    .update(cx, |thread, cx| thread.cancel(cx))
1579                    .detach();
1580            }
1581        });
1582    }
1583
1584    fn truncate(
1585        &self,
1586        session_id: &acp::SessionId,
1587        cx: &App,
1588    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1589        self.0.read_with(cx, |agent, _cx| {
1590            agent.sessions.get(session_id).map(|session| {
1591                Rc::new(NativeAgentSessionTruncate {
1592                    thread: session.thread.clone(),
1593                    acp_thread: session.acp_thread.downgrade(),
1594                }) as _
1595            })
1596        })
1597    }
1598
1599    fn set_title(
1600        &self,
1601        session_id: &acp::SessionId,
1602        cx: &App,
1603    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1604        self.0.read_with(cx, |agent, _cx| {
1605            agent
1606                .sessions
1607                .get(session_id)
1608                .filter(|s| !s.thread.read(cx).is_subagent())
1609                .map(|session| {
1610                    Rc::new(NativeAgentSessionSetTitle {
1611                        thread: session.thread.clone(),
1612                    }) as _
1613                })
1614        })
1615    }
1616
1617    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1618        let thread_store = self.0.read(cx).thread_store.clone();
1619        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1620    }
1621
1622    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1623        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1624    }
1625
1626    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1627        self
1628    }
1629}
1630
1631impl acp_thread::AgentTelemetry for NativeAgentConnection {
1632    fn thread_data(
1633        &self,
1634        session_id: &acp::SessionId,
1635        cx: &mut App,
1636    ) -> Task<Result<serde_json::Value>> {
1637        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1638            return Task::ready(Err(anyhow!("Session not found")));
1639        };
1640
1641        let task = session.thread.read(cx).to_db(cx);
1642        cx.background_spawn(async move {
1643            serde_json::to_value(task.await).context("Failed to serialize thread")
1644        })
1645    }
1646}
1647
1648pub struct NativeAgentSessionList {
1649    thread_store: Entity<ThreadStore>,
1650    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1651    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1652    _subscription: Subscription,
1653}
1654
1655impl NativeAgentSessionList {
1656    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1657        let (tx, rx) = smol::channel::unbounded();
1658        let this_tx = tx.clone();
1659        let subscription = cx.observe(&thread_store, move |_, _| {
1660            this_tx
1661                .try_send(acp_thread::SessionListUpdate::Refresh)
1662                .ok();
1663        });
1664        Self {
1665            thread_store,
1666            updates_tx: tx,
1667            updates_rx: rx,
1668            _subscription: subscription,
1669        }
1670    }
1671
1672    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1673        &self.thread_store
1674    }
1675}
1676
1677impl AgentSessionList for NativeAgentSessionList {
1678    fn list_sessions(
1679        &self,
1680        _request: AgentSessionListRequest,
1681        cx: &mut App,
1682    ) -> Task<Result<AgentSessionListResponse>> {
1683        let sessions = self
1684            .thread_store
1685            .read(cx)
1686            .entries()
1687            .map(|entry| AgentSessionInfo::from(&entry))
1688            .collect();
1689        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1690    }
1691
1692    fn supports_delete(&self) -> bool {
1693        true
1694    }
1695
1696    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1697        self.thread_store
1698            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1699    }
1700
1701    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1702        self.thread_store
1703            .update(cx, |store, cx| store.delete_threads(cx))
1704    }
1705
1706    fn watch(
1707        &self,
1708        _cx: &mut App,
1709    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1710        Some(self.updates_rx.clone())
1711    }
1712
1713    fn notify_refresh(&self) {
1714        self.updates_tx
1715            .try_send(acp_thread::SessionListUpdate::Refresh)
1716            .ok();
1717    }
1718
1719    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1720        self
1721    }
1722}
1723
1724struct NativeAgentSessionTruncate {
1725    thread: Entity<Thread>,
1726    acp_thread: WeakEntity<AcpThread>,
1727}
1728
1729impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1730    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1731        match self.thread.update(cx, |thread, cx| {
1732            thread.truncate(message_id.clone(), cx)?;
1733            Ok(thread.latest_token_usage())
1734        }) {
1735            Ok(usage) => {
1736                self.acp_thread
1737                    .update(cx, |thread, cx| {
1738                        thread.update_token_usage(usage, cx);
1739                    })
1740                    .ok();
1741                Task::ready(Ok(()))
1742            }
1743            Err(error) => Task::ready(Err(error)),
1744        }
1745    }
1746}
1747
1748struct NativeAgentSessionRetry {
1749    connection: NativeAgentConnection,
1750    session_id: acp::SessionId,
1751}
1752
1753impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1754    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1755        self.connection
1756            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1757                thread.update(cx, |thread, cx| thread.resume(cx))
1758            })
1759    }
1760}
1761
1762struct NativeAgentSessionSetTitle {
1763    thread: Entity<Thread>,
1764}
1765
1766impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1767    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1768        self.thread
1769            .update(cx, |thread, cx| thread.set_title(title, cx));
1770        Task::ready(Ok(()))
1771    }
1772}
1773
1774pub struct NativeThreadEnvironment {
1775    agent: WeakEntity<NativeAgent>,
1776    thread: WeakEntity<Thread>,
1777    acp_thread: WeakEntity<AcpThread>,
1778}
1779
1780impl NativeThreadEnvironment {
1781    pub(crate) fn create_subagent_thread(
1782        &self,
1783        label: String,
1784        cx: &mut App,
1785    ) -> Result<Rc<dyn SubagentHandle>> {
1786        let Some(parent_thread_entity) = self.thread.upgrade() else {
1787            anyhow::bail!("Parent thread no longer exists".to_string());
1788        };
1789        let parent_thread = parent_thread_entity.read(cx);
1790        let current_depth = parent_thread.depth();
1791        let parent_session_id = parent_thread.id().clone();
1792
1793        if current_depth >= MAX_SUBAGENT_DEPTH {
1794            return Err(anyhow!(
1795                "Maximum subagent depth ({}) reached",
1796                MAX_SUBAGENT_DEPTH
1797            ));
1798        }
1799
1800        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1801            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1802            thread.set_title(label.into(), cx);
1803            thread
1804        });
1805
1806        let session_id = subagent_thread.read(cx).id().clone();
1807
1808        let acp_thread = self
1809            .agent
1810            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1811                let project_id = agent
1812                    .sessions
1813                    .get(&parent_session_id)
1814                    .map(|s| s.project_id)
1815                    .context("parent session not found")?;
1816                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1817            })??;
1818
1819        let depth = current_depth + 1;
1820
1821        telemetry::event!(
1822            "Subagent Started",
1823            session = parent_thread_entity.read(cx).id().to_string(),
1824            subagent_session = session_id.to_string(),
1825            depth,
1826            is_resumed = false,
1827        );
1828
1829        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1830    }
1831
1832    pub(crate) fn resume_subagent_thread(
1833        &self,
1834        session_id: acp::SessionId,
1835        cx: &mut App,
1836    ) -> Result<Rc<dyn SubagentHandle>> {
1837        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1838            let session = agent
1839                .sessions
1840                .get(&session_id)
1841                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1842            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1843        })??;
1844
1845        let depth = subagent_thread.read(cx).depth();
1846
1847        if let Some(parent_thread_entity) = self.thread.upgrade() {
1848            telemetry::event!(
1849                "Subagent Started",
1850                session = parent_thread_entity.read(cx).id().to_string(),
1851                subagent_session = session_id.to_string(),
1852                depth,
1853                is_resumed = true,
1854            );
1855        }
1856
1857        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1858    }
1859
1860    fn prompt_subagent(
1861        &self,
1862        session_id: acp::SessionId,
1863        subagent_thread: Entity<Thread>,
1864        acp_thread: Entity<acp_thread::AcpThread>,
1865    ) -> Result<Rc<dyn SubagentHandle>> {
1866        let Some(parent_thread_entity) = self.thread.upgrade() else {
1867            anyhow::bail!("Parent thread no longer exists".to_string());
1868        };
1869        Ok(Rc::new(NativeSubagentHandle::new(
1870            session_id,
1871            subagent_thread,
1872            acp_thread,
1873            parent_thread_entity,
1874        )) as _)
1875    }
1876}
1877
1878impl ThreadEnvironment for NativeThreadEnvironment {
1879    fn create_terminal(
1880        &self,
1881        command: String,
1882        cwd: Option<PathBuf>,
1883        output_byte_limit: Option<u64>,
1884        cx: &mut AsyncApp,
1885    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1886        let task = self.acp_thread.update(cx, |thread, cx| {
1887            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1888        });
1889
1890        let acp_thread = self.acp_thread.clone();
1891        cx.spawn(async move |cx| {
1892            let terminal = task?.await?;
1893
1894            let (drop_tx, drop_rx) = oneshot::channel();
1895            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1896
1897            cx.spawn(async move |cx| {
1898                drop_rx.await.ok();
1899                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1900            })
1901            .detach();
1902
1903            let handle = AcpTerminalHandle {
1904                terminal,
1905                _drop_tx: Some(drop_tx),
1906            };
1907
1908            Ok(Rc::new(handle) as _)
1909        })
1910    }
1911
1912    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1913        self.create_subagent_thread(label, cx)
1914    }
1915
1916    fn resume_subagent(
1917        &self,
1918        session_id: acp::SessionId,
1919        cx: &mut App,
1920    ) -> Result<Rc<dyn SubagentHandle>> {
1921        self.resume_subagent_thread(session_id, cx)
1922    }
1923}
1924
1925#[derive(Debug, Clone)]
1926enum SubagentPromptResult {
1927    Completed,
1928    Cancelled,
1929    ContextWindowWarning,
1930    Error(String),
1931}
1932
1933pub struct NativeSubagentHandle {
1934    session_id: acp::SessionId,
1935    parent_thread: WeakEntity<Thread>,
1936    subagent_thread: Entity<Thread>,
1937    acp_thread: Entity<acp_thread::AcpThread>,
1938}
1939
1940impl NativeSubagentHandle {
1941    fn new(
1942        session_id: acp::SessionId,
1943        subagent_thread: Entity<Thread>,
1944        acp_thread: Entity<acp_thread::AcpThread>,
1945        parent_thread_entity: Entity<Thread>,
1946    ) -> Self {
1947        NativeSubagentHandle {
1948            session_id,
1949            subagent_thread,
1950            parent_thread: parent_thread_entity.downgrade(),
1951            acp_thread,
1952        }
1953    }
1954}
1955
1956impl SubagentHandle for NativeSubagentHandle {
1957    fn id(&self) -> acp::SessionId {
1958        self.session_id.clone()
1959    }
1960
1961    fn num_entries(&self, cx: &App) -> usize {
1962        self.acp_thread.read(cx).entries().len()
1963    }
1964
1965    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1966        let thread = self.subagent_thread.clone();
1967        let acp_thread = self.acp_thread.clone();
1968        let subagent_session_id = self.session_id.clone();
1969        let parent_thread = self.parent_thread.clone();
1970
1971        cx.spawn(async move |cx| {
1972            let (task, _subscription) = cx.update(|cx| {
1973                let ratio_before_prompt = thread
1974                    .read(cx)
1975                    .latest_token_usage()
1976                    .map(|usage| usage.ratio());
1977
1978                parent_thread
1979                    .update(cx, |parent_thread, _cx| {
1980                        parent_thread.register_running_subagent(thread.downgrade())
1981                    })
1982                    .ok();
1983
1984                let task = acp_thread.update(cx, |acp_thread, cx| {
1985                    acp_thread.send(vec![message.into()], cx)
1986                });
1987
1988                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1989                let mut token_limit_tx = Some(token_limit_tx);
1990
1991                let subscription = cx.subscribe(
1992                    &thread,
1993                    move |_thread, event: &TokenUsageUpdated, _cx| {
1994                        if let Some(usage) = &event.0 {
1995                            let old_ratio = ratio_before_prompt
1996                                .clone()
1997                                .unwrap_or(TokenUsageRatio::Normal);
1998                            let new_ratio = usage.ratio();
1999                            if old_ratio == TokenUsageRatio::Normal
2000                                && new_ratio == TokenUsageRatio::Warning
2001                            {
2002                                if let Some(tx) = token_limit_tx.take() {
2003                                    tx.send(()).ok();
2004                                }
2005                            }
2006                        }
2007                    },
2008                );
2009
2010                let wait_for_prompt = cx
2011                    .background_spawn(async move {
2012                        futures::select! {
2013                            response = task.fuse() => match response {
2014                                Ok(Some(response)) => {
2015                                    match response.stop_reason {
2016                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2017                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2018                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2019                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2020                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2021                                    }
2022                                }
2023                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2024                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2025                            },
2026                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2027                        }
2028                    });
2029
2030                (wait_for_prompt, subscription)
2031            });
2032
2033            let result = match task.await {
2034                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2035                    thread
2036                        .last_message()
2037                        .and_then(|message| {
2038                            let content = message.as_agent_message()?
2039                                .content
2040                                .iter()
2041                                .filter_map(|c| match c {
2042                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2043                                    _ => None,
2044                                })
2045                                .join("\n\n");
2046                            if content.is_empty() {
2047                                None
2048                            } else {
2049                                Some( content)
2050                            }
2051                        })
2052                        .context("No response from subagent")
2053                }),
2054                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2055                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2056                SubagentPromptResult::ContextWindowWarning => {
2057                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2058                    Err(anyhow!(
2059                        "The agent is nearing the end of its context window and has been \
2060                         stopped. You can prompt the thread again to have the agent wrap up \
2061                         or hand off its work."
2062                    ))
2063                }
2064            };
2065
2066            parent_thread
2067                .update(cx, |parent_thread, cx| {
2068                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2069                })
2070                .ok();
2071
2072            result
2073        })
2074    }
2075}
2076
2077pub struct AcpTerminalHandle {
2078    terminal: Entity<acp_thread::Terminal>,
2079    _drop_tx: Option<oneshot::Sender<()>>,
2080}
2081
2082impl TerminalHandle for AcpTerminalHandle {
2083    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2084        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2085    }
2086
2087    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2088        Ok(self
2089            .terminal
2090            .read_with(cx, |term, _cx| term.wait_for_exit()))
2091    }
2092
2093    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2094        Ok(self
2095            .terminal
2096            .read_with(cx, |term, cx| term.current_output(cx)))
2097    }
2098
2099    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2100        cx.update(|cx| {
2101            self.terminal.update(cx, |terminal, cx| {
2102                terminal.kill(cx);
2103            });
2104        });
2105        Ok(())
2106    }
2107
2108    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2109        Ok(self
2110            .terminal
2111            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2112    }
2113}
2114
2115#[cfg(test)]
2116mod internal_tests {
2117    use std::path::Path;
2118
2119    use super::*;
2120    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2121    use fs::FakeFs;
2122    use gpui::TestAppContext;
2123    use indoc::formatdoc;
2124    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2125    use language_model::{
2126        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2127    };
2128    use serde_json::json;
2129    use settings::SettingsStore;
2130    use util::{path, rel_path::rel_path};
2131
2132    #[gpui::test]
2133    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2134        init_test(cx);
2135        let fs = FakeFs::new(cx.executor());
2136        fs.insert_tree(
2137            "/",
2138            json!({
2139                "a": {}
2140            }),
2141        )
2142        .await;
2143        let project = Project::test(fs.clone(), [], cx).await;
2144        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2145        let agent =
2146            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2147
2148        // Creating a session registers the project and triggers context building.
2149        let connection = NativeAgentConnection(agent.clone());
2150        let _acp_thread = cx
2151            .update(|cx| {
2152                Rc::new(connection).new_session(
2153                    project.clone(),
2154                    PathList::new(&[Path::new("/")]),
2155                    cx,
2156                )
2157            })
2158            .await
2159            .unwrap();
2160        cx.run_until_parked();
2161
2162        let thread = agent.read_with(cx, |agent, _cx| {
2163            agent.sessions.values().next().unwrap().thread.clone()
2164        });
2165
2166        agent.read_with(cx, |agent, cx| {
2167            let project_id = project.entity_id();
2168            let state = agent.projects.get(&project_id).unwrap();
2169            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2170            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2171        });
2172
2173        let worktree = project
2174            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2175            .await
2176            .unwrap();
2177        cx.run_until_parked();
2178        agent.read_with(cx, |agent, cx| {
2179            let project_id = project.entity_id();
2180            let state = agent.projects.get(&project_id).unwrap();
2181            let expected_worktrees = vec![WorktreeContext {
2182                root_name: "a".into(),
2183                abs_path: Path::new("/a").into(),
2184                rules_file: None,
2185            }];
2186            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2187            assert_eq!(
2188                thread.read(cx).project_context().read(cx).worktrees,
2189                expected_worktrees
2190            );
2191        });
2192
2193        // Creating `/a/.rules` updates the project context.
2194        fs.insert_file("/a/.rules", Vec::new()).await;
2195        cx.run_until_parked();
2196        agent.read_with(cx, |agent, cx| {
2197            let project_id = project.entity_id();
2198            let state = agent.projects.get(&project_id).unwrap();
2199            let rules_entry = worktree
2200                .read(cx)
2201                .entry_for_path(rel_path(".rules"))
2202                .unwrap();
2203            let expected_worktrees = vec![WorktreeContext {
2204                root_name: "a".into(),
2205                abs_path: Path::new("/a").into(),
2206                rules_file: Some(RulesFileContext {
2207                    path_in_worktree: rel_path(".rules").into(),
2208                    text: "".into(),
2209                    project_entry_id: rules_entry.id.to_usize(),
2210                }),
2211            }];
2212            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2213            assert_eq!(
2214                thread.read(cx).project_context().read(cx).worktrees,
2215                expected_worktrees
2216            );
2217        });
2218    }
2219
2220    #[gpui::test]
2221    async fn test_listing_models(cx: &mut TestAppContext) {
2222        init_test(cx);
2223        let fs = FakeFs::new(cx.executor());
2224        fs.insert_tree("/", json!({ "a": {}  })).await;
2225        let project = Project::test(fs.clone(), [], cx).await;
2226        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2227        let connection =
2228            NativeAgentConnection(cx.update(|cx| {
2229                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2230            }));
2231
2232        // Create a thread/session
2233        let acp_thread = cx
2234            .update(|cx| {
2235                Rc::new(connection.clone()).new_session(
2236                    project.clone(),
2237                    PathList::new(&[Path::new("/a")]),
2238                    cx,
2239                )
2240            })
2241            .await
2242            .unwrap();
2243
2244        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2245
2246        let models = cx
2247            .update(|cx| {
2248                connection
2249                    .model_selector(&session_id)
2250                    .unwrap()
2251                    .list_models(cx)
2252            })
2253            .await
2254            .unwrap();
2255
2256        let acp_thread::AgentModelList::Grouped(models) = models else {
2257            panic!("Unexpected model group");
2258        };
2259        assert_eq!(
2260            models,
2261            IndexMap::from_iter([(
2262                AgentModelGroupName("Fake".into()),
2263                vec![AgentModelInfo {
2264                    id: acp::ModelId::new("fake/fake"),
2265                    name: "Fake".into(),
2266                    description: None,
2267                    icon: Some(acp_thread::AgentModelIcon::Named(
2268                        ui::IconName::ZedAssistant
2269                    )),
2270                    is_latest: false,
2271                    cost: None,
2272                }]
2273            )])
2274        );
2275    }
2276
2277    #[gpui::test]
2278    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2279        init_test(cx);
2280        let fs = FakeFs::new(cx.executor());
2281        fs.create_dir(paths::settings_file().parent().unwrap())
2282            .await
2283            .unwrap();
2284        fs.insert_file(
2285            paths::settings_file(),
2286            json!({
2287                "agent": {
2288                    "default_model": {
2289                        "provider": "foo",
2290                        "model": "bar"
2291                    }
2292                }
2293            })
2294            .to_string()
2295            .into_bytes(),
2296        )
2297        .await;
2298        let project = Project::test(fs.clone(), [], cx).await;
2299
2300        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2301
2302        // Create the agent and connection
2303        let agent =
2304            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2305        let connection = NativeAgentConnection(agent.clone());
2306
2307        // Create a thread/session
2308        let acp_thread = cx
2309            .update(|cx| {
2310                Rc::new(connection.clone()).new_session(
2311                    project.clone(),
2312                    PathList::new(&[Path::new("/a")]),
2313                    cx,
2314                )
2315            })
2316            .await
2317            .unwrap();
2318
2319        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2320
2321        // Select a model
2322        let selector = connection.model_selector(&session_id).unwrap();
2323        let model_id = acp::ModelId::new("fake/fake");
2324        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2325            .await
2326            .unwrap();
2327
2328        // Verify the thread has the selected model
2329        agent.read_with(cx, |agent, _| {
2330            let session = agent.sessions.get(&session_id).unwrap();
2331            session.thread.read_with(cx, |thread, _| {
2332                assert_eq!(thread.model().unwrap().id().0, "fake");
2333            });
2334        });
2335
2336        cx.run_until_parked();
2337
2338        // Verify settings file was updated
2339        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2340        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2341
2342        // Check that the agent settings contain the selected model
2343        assert_eq!(
2344            settings_json["agent"]["default_model"]["model"],
2345            json!("fake")
2346        );
2347        assert_eq!(
2348            settings_json["agent"]["default_model"]["provider"],
2349            json!("fake")
2350        );
2351
2352        // Register a thinking model and select it.
2353        cx.update(|cx| {
2354            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2355                "fake-corp",
2356                "fake-thinking",
2357                "Fake Thinking",
2358                true,
2359            ));
2360            let thinking_provider = Arc::new(
2361                FakeLanguageModelProvider::new(
2362                    LanguageModelProviderId::from("fake-corp".to_string()),
2363                    LanguageModelProviderName::from("Fake Corp".to_string()),
2364                )
2365                .with_models(vec![thinking_model]),
2366            );
2367            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2368                registry.register_provider(thinking_provider, cx);
2369            });
2370        });
2371        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2372
2373        let selector = connection.model_selector(&session_id).unwrap();
2374        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2375            .await
2376            .unwrap();
2377        cx.run_until_parked();
2378
2379        // Verify enable_thinking was written to settings as true.
2380        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2381        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2382        assert_eq!(
2383            settings_json["agent"]["default_model"]["enable_thinking"],
2384            json!(true),
2385            "selecting a thinking model should persist enable_thinking: true to settings"
2386        );
2387    }
2388
2389    #[gpui::test]
2390    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2391        init_test(cx);
2392        let fs = FakeFs::new(cx.executor());
2393        fs.create_dir(paths::settings_file().parent().unwrap())
2394            .await
2395            .unwrap();
2396        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2397        let project = Project::test(fs.clone(), [], cx).await;
2398
2399        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2400        let agent =
2401            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2402        let connection = NativeAgentConnection(agent.clone());
2403
2404        let acp_thread = cx
2405            .update(|cx| {
2406                Rc::new(connection.clone()).new_session(
2407                    project.clone(),
2408                    PathList::new(&[Path::new("/a")]),
2409                    cx,
2410                )
2411            })
2412            .await
2413            .unwrap();
2414        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2415
2416        // Register a second provider with a thinking model.
2417        cx.update(|cx| {
2418            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2419                "fake-corp",
2420                "fake-thinking",
2421                "Fake Thinking",
2422                true,
2423            ));
2424            let thinking_provider = Arc::new(
2425                FakeLanguageModelProvider::new(
2426                    LanguageModelProviderId::from("fake-corp".to_string()),
2427                    LanguageModelProviderName::from("Fake Corp".to_string()),
2428                )
2429                .with_models(vec![thinking_model]),
2430            );
2431            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2432                registry.register_provider(thinking_provider, cx);
2433            });
2434        });
2435        // Refresh the agent's model list so it picks up the new provider.
2436        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2437
2438        // Thread starts with thinking_enabled = false (the default).
2439        agent.read_with(cx, |agent, _| {
2440            let session = agent.sessions.get(&session_id).unwrap();
2441            session.thread.read_with(cx, |thread, _| {
2442                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2443            });
2444        });
2445
2446        // Select the thinking model via select_model.
2447        let selector = connection.model_selector(&session_id).unwrap();
2448        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2449            .await
2450            .unwrap();
2451
2452        // select_model should have enabled thinking based on the model's supports_thinking().
2453        agent.read_with(cx, |agent, _| {
2454            let session = agent.sessions.get(&session_id).unwrap();
2455            session.thread.read_with(cx, |thread, _| {
2456                assert!(
2457                    thread.thinking_enabled(),
2458                    "select_model should enable thinking when model supports it"
2459                );
2460            });
2461        });
2462
2463        // Switch back to the non-thinking model.
2464        let selector = connection.model_selector(&session_id).unwrap();
2465        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2466            .await
2467            .unwrap();
2468
2469        // select_model should have disabled thinking.
2470        agent.read_with(cx, |agent, _| {
2471            let session = agent.sessions.get(&session_id).unwrap();
2472            session.thread.read_with(cx, |thread, _| {
2473                assert!(
2474                    !thread.thinking_enabled(),
2475                    "select_model should disable thinking when model does not support it"
2476                );
2477            });
2478        });
2479    }
2480
2481    #[gpui::test]
2482    async fn test_summarization_model_survives_transient_registry_clearing(
2483        cx: &mut TestAppContext,
2484    ) {
2485        init_test(cx);
2486        let fs = FakeFs::new(cx.executor());
2487        fs.insert_tree("/", json!({ "a": {} })).await;
2488        let project = Project::test(fs.clone(), [], cx).await;
2489
2490        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2491        let agent =
2492            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2493        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2494
2495        let acp_thread = cx
2496            .update(|cx| {
2497                connection.clone().new_session(
2498                    project.clone(),
2499                    PathList::new(&[Path::new("/a")]),
2500                    cx,
2501                )
2502            })
2503            .await
2504            .unwrap();
2505        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2506
2507        let thread = agent.read_with(cx, |agent, _| {
2508            agent.sessions.get(&session_id).unwrap().thread.clone()
2509        });
2510
2511        thread.read_with(cx, |thread, _| {
2512            assert!(
2513                thread.summarization_model().is_some(),
2514                "session should have a summarization model from the test registry"
2515            );
2516        });
2517
2518        // Simulate what happens during a provider blip:
2519        // update_active_language_model_from_settings calls set_default_model(None)
2520        // when it can't resolve the model, clearing all fallbacks.
2521        cx.update(|cx| {
2522            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2523                registry.set_default_model(None, cx);
2524            });
2525        });
2526        cx.run_until_parked();
2527
2528        thread.read_with(cx, |thread, _| {
2529            assert!(
2530                thread.summarization_model().is_some(),
2531                "summarization model should survive a transient default model clearing"
2532            );
2533        });
2534    }
2535
2536    #[gpui::test]
2537    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2538        init_test(cx);
2539        let fs = FakeFs::new(cx.executor());
2540        fs.insert_tree("/", json!({ "a": {} })).await;
2541        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2542        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2543        let agent = cx.update(|cx| {
2544            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2545        });
2546        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2547
2548        // Register a thinking model.
2549        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2550            "fake-corp",
2551            "fake-thinking",
2552            "Fake Thinking",
2553            true,
2554        ));
2555        let thinking_provider = Arc::new(
2556            FakeLanguageModelProvider::new(
2557                LanguageModelProviderId::from("fake-corp".to_string()),
2558                LanguageModelProviderName::from("Fake Corp".to_string()),
2559            )
2560            .with_models(vec![thinking_model.clone()]),
2561        );
2562        cx.update(|cx| {
2563            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2564                registry.register_provider(thinking_provider, cx);
2565            });
2566        });
2567        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2568
2569        // Create a thread and select the thinking model.
2570        let acp_thread = cx
2571            .update(|cx| {
2572                connection.clone().new_session(
2573                    project.clone(),
2574                    PathList::new(&[Path::new("/a")]),
2575                    cx,
2576                )
2577            })
2578            .await
2579            .unwrap();
2580        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2581
2582        let selector = connection.model_selector(&session_id).unwrap();
2583        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2584            .await
2585            .unwrap();
2586
2587        // Verify thinking is enabled after selecting the thinking model.
2588        let thread = agent.read_with(cx, |agent, _| {
2589            agent.sessions.get(&session_id).unwrap().thread.clone()
2590        });
2591        thread.read_with(cx, |thread, _| {
2592            assert!(
2593                thread.thinking_enabled(),
2594                "thinking should be enabled after selecting thinking model"
2595            );
2596        });
2597
2598        // Send a message so the thread gets persisted.
2599        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2600        let send = cx.foreground_executor().spawn(send);
2601        cx.run_until_parked();
2602
2603        thinking_model.send_last_completion_stream_text_chunk("Response.");
2604        thinking_model.end_last_completion_stream();
2605
2606        send.await.unwrap();
2607        cx.run_until_parked();
2608
2609        // Close the session so it can be reloaded from disk.
2610        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2611            .await
2612            .unwrap();
2613        drop(thread);
2614        drop(acp_thread);
2615        agent.read_with(cx, |agent, _| {
2616            assert!(agent.sessions.is_empty());
2617        });
2618
2619        // Reload the thread and verify thinking_enabled is still true.
2620        let reloaded_acp_thread = agent
2621            .update(cx, |agent, cx| {
2622                agent.open_thread(session_id.clone(), project.clone(), cx)
2623            })
2624            .await
2625            .unwrap();
2626        let reloaded_thread = agent.read_with(cx, |agent, _| {
2627            agent.sessions.get(&session_id).unwrap().thread.clone()
2628        });
2629        reloaded_thread.read_with(cx, |thread, _| {
2630            assert!(
2631                thread.thinking_enabled(),
2632                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2633            );
2634        });
2635
2636        drop(reloaded_acp_thread);
2637    }
2638
2639    #[gpui::test]
2640    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2641        init_test(cx);
2642        let fs = FakeFs::new(cx.executor());
2643        fs.insert_tree("/", json!({ "a": {} })).await;
2644        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2645        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2646        let agent = cx.update(|cx| {
2647            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2648        });
2649        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2650
2651        // Register a model where id() != name(), like real Anthropic models
2652        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2653        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2654            "fake-corp",
2655            "custom-model-id",
2656            "Custom Model Display Name",
2657            false,
2658        ));
2659        let provider = Arc::new(
2660            FakeLanguageModelProvider::new(
2661                LanguageModelProviderId::from("fake-corp".to_string()),
2662                LanguageModelProviderName::from("Fake Corp".to_string()),
2663            )
2664            .with_models(vec![model.clone()]),
2665        );
2666        cx.update(|cx| {
2667            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2668                registry.register_provider(provider, cx);
2669            });
2670        });
2671        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2672
2673        // Create a thread and select the model.
2674        let acp_thread = cx
2675            .update(|cx| {
2676                connection.clone().new_session(
2677                    project.clone(),
2678                    PathList::new(&[Path::new("/a")]),
2679                    cx,
2680                )
2681            })
2682            .await
2683            .unwrap();
2684        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2685
2686        let selector = connection.model_selector(&session_id).unwrap();
2687        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2688            .await
2689            .unwrap();
2690
2691        let thread = agent.read_with(cx, |agent, _| {
2692            agent.sessions.get(&session_id).unwrap().thread.clone()
2693        });
2694        thread.read_with(cx, |thread, _| {
2695            assert_eq!(
2696                thread.model().unwrap().id().0.as_ref(),
2697                "custom-model-id",
2698                "model should be set before persisting"
2699            );
2700        });
2701
2702        // Send a message so the thread gets persisted.
2703        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2704        let send = cx.foreground_executor().spawn(send);
2705        cx.run_until_parked();
2706
2707        model.send_last_completion_stream_text_chunk("Response.");
2708        model.end_last_completion_stream();
2709
2710        send.await.unwrap();
2711        cx.run_until_parked();
2712
2713        // Close the session so it can be reloaded from disk.
2714        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2715            .await
2716            .unwrap();
2717        drop(thread);
2718        drop(acp_thread);
2719        agent.read_with(cx, |agent, _| {
2720            assert!(agent.sessions.is_empty());
2721        });
2722
2723        // Reload the thread and verify the model was preserved.
2724        let reloaded_acp_thread = agent
2725            .update(cx, |agent, cx| {
2726                agent.open_thread(session_id.clone(), project.clone(), cx)
2727            })
2728            .await
2729            .unwrap();
2730        let reloaded_thread = agent.read_with(cx, |agent, _| {
2731            agent.sessions.get(&session_id).unwrap().thread.clone()
2732        });
2733        reloaded_thread.read_with(cx, |thread, _| {
2734            let reloaded_model = thread
2735                .model()
2736                .expect("model should be present after reload");
2737            assert_eq!(
2738                reloaded_model.id().0.as_ref(),
2739                "custom-model-id",
2740                "reloaded thread should have the same model, not fall back to the default"
2741            );
2742        });
2743
2744        drop(reloaded_acp_thread);
2745    }
2746
2747    #[gpui::test]
2748    async fn test_save_load_thread(cx: &mut TestAppContext) {
2749        init_test(cx);
2750        let fs = FakeFs::new(cx.executor());
2751        fs.insert_tree(
2752            "/",
2753            json!({
2754                "a": {
2755                    "b.md": "Lorem"
2756                }
2757            }),
2758        )
2759        .await;
2760        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2761        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2762        let agent = cx.update(|cx| {
2763            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2764        });
2765        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2766
2767        let acp_thread = cx
2768            .update(|cx| {
2769                connection
2770                    .clone()
2771                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2772            })
2773            .await
2774            .unwrap();
2775        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2776        let thread = agent.read_with(cx, |agent, _| {
2777            agent.sessions.get(&session_id).unwrap().thread.clone()
2778        });
2779
2780        // Ensure empty threads are not saved, even if they get mutated.
2781        let model = Arc::new(FakeLanguageModel::default());
2782        let summary_model = Arc::new(FakeLanguageModel::default());
2783        thread.update(cx, |thread, cx| {
2784            thread.set_model(model.clone(), cx);
2785            thread.set_summarization_model(Some(summary_model.clone()), cx);
2786        });
2787        cx.run_until_parked();
2788        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2789
2790        let send = acp_thread.update(cx, |thread, cx| {
2791            thread.send(
2792                vec![
2793                    "What does ".into(),
2794                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2795                        "b.md",
2796                        MentionUri::File {
2797                            abs_path: path!("/a/b.md").into(),
2798                        }
2799                        .to_uri()
2800                        .to_string(),
2801                    )),
2802                    " mean?".into(),
2803                ],
2804                cx,
2805            )
2806        });
2807        let send = cx.foreground_executor().spawn(send);
2808        cx.run_until_parked();
2809
2810        model.send_last_completion_stream_text_chunk("Lorem.");
2811        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2812            language_model::TokenUsage {
2813                input_tokens: 150,
2814                output_tokens: 75,
2815                ..Default::default()
2816            },
2817        ));
2818        model.end_last_completion_stream();
2819        cx.run_until_parked();
2820        summary_model
2821            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2822        summary_model.end_last_completion_stream();
2823
2824        send.await.unwrap();
2825        let uri = MentionUri::File {
2826            abs_path: path!("/a/b.md").into(),
2827        }
2828        .to_uri();
2829        acp_thread.read_with(cx, |thread, cx| {
2830            assert_eq!(
2831                thread.to_markdown(cx),
2832                formatdoc! {"
2833                    ## User
2834
2835                    What does [@b.md]({uri}) mean?
2836
2837                    ## Assistant
2838
2839                    Lorem.
2840
2841                "}
2842            )
2843        });
2844
2845        cx.run_until_parked();
2846
2847        // Set a draft prompt with rich content blocks and scroll position
2848        // AFTER run_until_parked, so the only save that captures these
2849        // changes is the one performed by close_session itself.
2850        let draft_blocks = vec![
2851            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2852            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2853            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2854        ];
2855        acp_thread.update(cx, |thread, _cx| {
2856            thread.set_draft_prompt(Some(draft_blocks.clone()));
2857        });
2858        thread.update(cx, |thread, _cx| {
2859            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2860                item_ix: 5,
2861                offset_in_item: gpui::px(12.5),
2862            }));
2863        });
2864
2865        // Close the session so it can be reloaded from disk.
2866        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2867            .await
2868            .unwrap();
2869        drop(thread);
2870        drop(acp_thread);
2871        agent.read_with(cx, |agent, _| {
2872            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2873        });
2874
2875        // Ensure the thread can be reloaded from disk.
2876        assert_eq!(
2877            thread_entries(&thread_store, cx),
2878            vec![(
2879                session_id.clone(),
2880                format!("Explaining {}", path!("/a/b.md"))
2881            )]
2882        );
2883        let acp_thread = agent
2884            .update(cx, |agent, cx| {
2885                agent.open_thread(session_id.clone(), project.clone(), cx)
2886            })
2887            .await
2888            .unwrap();
2889        acp_thread.read_with(cx, |thread, cx| {
2890            assert_eq!(
2891                thread.to_markdown(cx),
2892                formatdoc! {"
2893                    ## User
2894
2895                    What does [@b.md]({uri}) mean?
2896
2897                    ## Assistant
2898
2899                    Lorem.
2900
2901                "}
2902            )
2903        });
2904
2905        // Ensure the draft prompt with rich content blocks survived the round-trip.
2906        acp_thread.read_with(cx, |thread, _| {
2907            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2908        });
2909
2910        // Ensure token usage survived the round-trip.
2911        acp_thread.read_with(cx, |thread, _| {
2912            let usage = thread
2913                .token_usage()
2914                .expect("token usage should be restored after reload");
2915            assert_eq!(usage.input_tokens, 150);
2916            assert_eq!(usage.output_tokens, 75);
2917        });
2918
2919        // Ensure scroll position survived the round-trip.
2920        acp_thread.read_with(cx, |thread, _| {
2921            let scroll = thread
2922                .ui_scroll_position()
2923                .expect("scroll position should be restored after reload");
2924            assert_eq!(scroll.item_ix, 5);
2925            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2926        });
2927    }
2928
2929    #[gpui::test]
2930    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2931        init_test(cx);
2932        let fs = FakeFs::new(cx.executor());
2933        fs.insert_tree(
2934            "/",
2935            json!({
2936                "a": {
2937                    "file.txt": "hello"
2938                }
2939            }),
2940        )
2941        .await;
2942        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2943        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2944        let agent = cx.update(|cx| {
2945            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2946        });
2947        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2948
2949        let acp_thread = cx
2950            .update(|cx| {
2951                connection
2952                    .clone()
2953                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2954            })
2955            .await
2956            .unwrap();
2957        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2958        let thread = agent.read_with(cx, |agent, _| {
2959            agent.sessions.get(&session_id).unwrap().thread.clone()
2960        });
2961
2962        let model = Arc::new(FakeLanguageModel::default());
2963        thread.update(cx, |thread, cx| {
2964            thread.set_model(model.clone(), cx);
2965        });
2966
2967        // Send a message so the thread is non-empty (empty threads aren't saved).
2968        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2969        let send = cx.foreground_executor().spawn(send);
2970        cx.run_until_parked();
2971
2972        model.send_last_completion_stream_text_chunk("world");
2973        model.end_last_completion_stream();
2974        send.await.unwrap();
2975        cx.run_until_parked();
2976
2977        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2978        // This means no observe-triggered save has run for this change.
2979        // The only way this data gets persisted is if close_session
2980        // itself performs the save.
2981        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2982            "unsaved draft",
2983        ))];
2984        acp_thread.update(cx, |thread, _cx| {
2985            thread.set_draft_prompt(Some(draft_blocks.clone()));
2986        });
2987
2988        // Close the session immediately — no run_until_parked in between.
2989        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2990            .await
2991            .unwrap();
2992        cx.run_until_parked();
2993
2994        // Reopen and verify the draft prompt was saved.
2995        let reloaded = agent
2996            .update(cx, |agent, cx| {
2997                agent.open_thread(session_id.clone(), project.clone(), cx)
2998            })
2999            .await
3000            .unwrap();
3001        reloaded.read_with(cx, |thread, _| {
3002            assert_eq!(
3003                thread.draft_prompt(),
3004                Some(draft_blocks.as_slice()),
3005                "close_session must save the thread; draft prompt was lost"
3006            );
3007        });
3008    }
3009
3010    #[gpui::test]
3011    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3012        // Regression test: rapid title changes must not cause a propagation loop
3013        // between Thread and AcpThread via handle_thread_title_updated.
3014        init_test(cx);
3015        let fs = FakeFs::new(cx.executor());
3016        fs.insert_tree("/", json!({ "a": {} })).await;
3017        let project = Project::test(fs.clone(), [], cx).await;
3018        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3019        let agent = cx.update(|cx| {
3020            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3021        });
3022        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3023
3024        let acp_thread = cx
3025            .update(|cx| {
3026                connection
3027                    .clone()
3028                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3029            })
3030            .await
3031            .unwrap();
3032
3033        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3034        let thread = agent.read_with(cx, |agent, _| {
3035            agent.sessions.get(&session_id).unwrap().thread.clone()
3036        });
3037
3038        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3039        cx.update(|cx| {
3040            let count = title_updated_count.clone();
3041            cx.subscribe(
3042                &thread,
3043                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3044                    let new_count = {
3045                        let mut count = count.borrow_mut();
3046                        *count += 1;
3047                        *count
3048                    };
3049                    assert!(
3050                        new_count <= 2,
3051                        "TitleUpdated fired {new_count} times; \
3052                         title updates are looping"
3053                    );
3054                },
3055            )
3056            .detach();
3057        });
3058
3059        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3060        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3061
3062        cx.run_until_parked();
3063
3064        thread.read_with(cx, |thread, _| {
3065            assert_eq!(thread.title(), Some("second".into()));
3066        });
3067        acp_thread.read_with(cx, |acp_thread, _| {
3068            assert_eq!(acp_thread.title(), Some("second".into()));
3069        });
3070
3071        assert_eq!(*title_updated_count.borrow(), 2);
3072    }
3073
3074    fn thread_entries(
3075        thread_store: &Entity<ThreadStore>,
3076        cx: &mut TestAppContext,
3077    ) -> Vec<(acp::SessionId, String)> {
3078        thread_store.read_with(cx, |store, _| {
3079            store
3080                .entries()
3081                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3082                .collect::<Vec<_>>()
3083        })
3084    }
3085
3086    fn init_test(cx: &mut TestAppContext) {
3087        env_logger::try_init().ok();
3088        cx.update(|cx| {
3089            let settings_store = SettingsStore::test(cx);
3090            cx.set_global(settings_store);
3091
3092            LanguageModelRegistry::test(cx);
3093        });
3094    }
3095}
3096
3097fn mcp_message_content_to_acp_content_block(
3098    content: context_server::types::MessageContent,
3099) -> acp::ContentBlock {
3100    match content {
3101        context_server::types::MessageContent::Text {
3102            text,
3103            annotations: _,
3104        } => text.into(),
3105        context_server::types::MessageContent::Image {
3106            data,
3107            mime_type,
3108            annotations: _,
3109        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3110        context_server::types::MessageContent::Audio {
3111            data,
3112            mime_type,
3113            annotations: _,
3114        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3115        context_server::types::MessageContent::Resource {
3116            resource,
3117            annotations: _,
3118        } => {
3119            let mut link =
3120                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3121            if let Some(mime_type) = resource.mime_type {
3122                link = link.mime_type(mime_type);
3123            }
3124            acp::ContentBlock::ResourceLink(link)
3125        }
3126    }
3127}