agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, ThreadId, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let thread_id = thread.id();
 332        let session_id = thread.session_id().clone();
 333        let parent_thread_id = thread.parent_thread_id();
 334        let title = thread.title();
 335        let draft_prompt = thread.draft_prompt().map(Vec::from);
 336        let scroll_position = thread.ui_scroll_position();
 337        let token_usage = thread.latest_token_usage();
 338        let project = thread.project.clone();
 339        let action_log = thread.action_log.clone();
 340        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 341        let acp_thread = cx.new(|cx| {
 342            let mut acp_thread = acp_thread::AcpThread::new(
 343                parent_thread_id,
 344                title,
 345                None,
 346                connection,
 347                project.clone(),
 348                action_log.clone(),
 349                thread_id,
 350                session_id.clone(),
 351                prompt_capabilities_rx,
 352                cx,
 353            );
 354            acp_thread.set_draft_prompt(draft_prompt);
 355            acp_thread.set_ui_scroll_position(scroll_position);
 356            acp_thread.update_token_usage(token_usage, cx);
 357            acp_thread
 358        });
 359
 360        let registry = LanguageModelRegistry::read_global(cx);
 361        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 362
 363        let weak = cx.weak_entity();
 364        let weak_thread = thread_handle.downgrade();
 365        thread_handle.update(cx, |thread, cx| {
 366            thread.set_summarization_model(summarization_model, cx);
 367            thread.add_default_tools(
 368                Rc::new(NativeThreadEnvironment {
 369                    acp_thread: acp_thread.downgrade(),
 370                    thread: weak_thread,
 371                    agent: weak,
 372                }) as _,
 373                cx,
 374            )
 375        });
 376
 377        let subscriptions = vec![
 378            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 379            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 380            cx.observe(&thread_handle, move |this, thread, cx| {
 381                this.save_thread(thread, cx)
 382            }),
 383        ];
 384
 385        self.sessions.insert(
 386            session_id,
 387            Session {
 388                thread: thread_handle,
 389                acp_thread: acp_thread.clone(),
 390                project_id,
 391                _subscriptions: subscriptions,
 392                pending_save: Task::ready(Ok(())),
 393            },
 394        );
 395
 396        self.update_available_commands_for_project(project_id, cx);
 397
 398        acp_thread
 399    }
 400
 401    pub fn models(&self) -> &LanguageModels {
 402        &self.models
 403    }
 404
 405    fn get_or_create_project_state(
 406        &mut self,
 407        project: &Entity<Project>,
 408        cx: &mut Context<Self>,
 409    ) -> EntityId {
 410        let project_id = project.entity_id();
 411        if self.projects.contains_key(&project_id) {
 412            return project_id;
 413        }
 414
 415        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 416        self.register_project_with_initial_context(project.clone(), project_context, cx);
 417        if let Some(state) = self.projects.get_mut(&project_id) {
 418            state.project_context_needs_refresh.send(()).ok();
 419        }
 420        project_id
 421    }
 422
 423    fn register_project_with_initial_context(
 424        &mut self,
 425        project: Entity<Project>,
 426        project_context: Entity<ProjectContext>,
 427        cx: &mut Context<Self>,
 428    ) {
 429        let project_id = project.entity_id();
 430
 431        let context_server_store = project.read(cx).context_server_store();
 432        let context_server_registry =
 433            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 434
 435        let subscriptions = vec![
 436            cx.subscribe(&project, Self::handle_project_event),
 437            cx.subscribe(
 438                &context_server_store,
 439                Self::handle_context_server_store_updated,
 440            ),
 441            cx.subscribe(
 442                &context_server_registry,
 443                Self::handle_context_server_registry_event,
 444            ),
 445        ];
 446
 447        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 448            watch::channel(());
 449
 450        self.projects.insert(
 451            project_id,
 452            ProjectState {
 453                project,
 454                project_context,
 455                project_context_needs_refresh: project_context_needs_refresh_tx,
 456                _maintain_project_context: cx.spawn(async move |this, cx| {
 457                    Self::maintain_project_context(
 458                        this,
 459                        project_id,
 460                        project_context_needs_refresh_rx,
 461                        cx,
 462                    )
 463                    .await
 464                }),
 465                context_server_registry,
 466                _subscriptions: subscriptions,
 467            },
 468        );
 469    }
 470
 471    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 472        self.sessions
 473            .get(session_id)
 474            .and_then(|session| self.projects.get(&session.project_id))
 475    }
 476
 477    async fn maintain_project_context(
 478        this: WeakEntity<Self>,
 479        project_id: EntityId,
 480        mut needs_refresh: watch::Receiver<()>,
 481        cx: &mut AsyncApp,
 482    ) -> Result<()> {
 483        while needs_refresh.changed().await.is_ok() {
 484            let project_context = this
 485                .update(cx, |this, cx| {
 486                    let state = this
 487                        .projects
 488                        .get(&project_id)
 489                        .context("project state not found")?;
 490                    anyhow::Ok(Self::build_project_context(
 491                        &state.project,
 492                        this.prompt_store.as_ref(),
 493                        cx,
 494                    ))
 495                })??
 496                .await;
 497            this.update(cx, |this, cx| {
 498                if let Some(state) = this.projects.get(&project_id) {
 499                    state
 500                        .project_context
 501                        .update(cx, |current_project_context, _cx| {
 502                            *current_project_context = project_context;
 503                        });
 504                }
 505            })?;
 506        }
 507
 508        Ok(())
 509    }
 510
 511    fn build_project_context(
 512        project: &Entity<Project>,
 513        prompt_store: Option<&Entity<PromptStore>>,
 514        cx: &mut App,
 515    ) -> Task<ProjectContext> {
 516        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 517        let worktree_tasks = worktrees
 518            .into_iter()
 519            .map(|worktree| {
 520                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 521            })
 522            .collect::<Vec<_>>();
 523        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 524            prompt_store.read_with(cx, |prompt_store, cx| {
 525                let prompts = prompt_store.default_prompt_metadata();
 526                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 527                    let contents = prompt_store.load(prompt_metadata.id, cx);
 528                    async move { (contents.await, prompt_metadata) }
 529                });
 530                cx.background_spawn(future::join_all(load_tasks))
 531            })
 532        } else {
 533            Task::ready(vec![])
 534        };
 535
 536        cx.spawn(async move |_cx| {
 537            let (worktrees, default_user_rules) =
 538                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 539
 540            let worktrees = worktrees
 541                .into_iter()
 542                .map(|(worktree, _rules_error)| {
 543                    // TODO: show error message
 544                    // if let Some(rules_error) = rules_error {
 545                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 546                    // }
 547                    worktree
 548                })
 549                .collect::<Vec<_>>();
 550
 551            let default_user_rules = default_user_rules
 552                .into_iter()
 553                .flat_map(|(contents, prompt_metadata)| match contents {
 554                    Ok(contents) => Some(UserRulesContext {
 555                        uuid: prompt_metadata.id.as_user()?,
 556                        title: prompt_metadata.title.map(|title| title.to_string()),
 557                        contents,
 558                    }),
 559                    Err(_err) => {
 560                        // TODO: show error message
 561                        // this.update(cx, |_, cx| {
 562                        //     cx.emit(RulesLoadingError {
 563                        //         message: format!("{err:?}").into(),
 564                        //     });
 565                        // })
 566                        // .ok();
 567                        None
 568                    }
 569                })
 570                .collect::<Vec<_>>();
 571
 572            ProjectContext::new(worktrees, default_user_rules)
 573        })
 574    }
 575
 576    fn load_worktree_info_for_system_prompt(
 577        worktree: Entity<Worktree>,
 578        project: Entity<Project>,
 579        cx: &mut App,
 580    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 581        let tree = worktree.read(cx);
 582        let root_name = tree.root_name_str().into();
 583        let abs_path = tree.abs_path();
 584
 585        let mut context = WorktreeContext {
 586            root_name,
 587            abs_path,
 588            rules_file: None,
 589        };
 590
 591        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 592        let Some(rules_task) = rules_task else {
 593            return Task::ready((context, None));
 594        };
 595
 596        cx.spawn(async move |_| {
 597            let (rules_file, rules_file_error) = match rules_task.await {
 598                Ok(rules_file) => (Some(rules_file), None),
 599                Err(err) => (
 600                    None,
 601                    Some(RulesLoadingError {
 602                        message: format!("{err}").into(),
 603                    }),
 604                ),
 605            };
 606            context.rules_file = rules_file;
 607            (context, rules_file_error)
 608        })
 609    }
 610
 611    fn load_worktree_rules_file(
 612        worktree: Entity<Worktree>,
 613        project: Entity<Project>,
 614        cx: &mut App,
 615    ) -> Option<Task<Result<RulesFileContext>>> {
 616        let worktree = worktree.read(cx);
 617        let worktree_id = worktree.id();
 618        let selected_rules_file = RULES_FILE_NAMES
 619            .into_iter()
 620            .filter_map(|name| {
 621                worktree
 622                    .entry_for_path(RelPath::unix(name).unwrap())
 623                    .filter(|entry| entry.is_file())
 624                    .map(|entry| entry.path.clone())
 625            })
 626            .next();
 627
 628        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 629        // supported. This doesn't seem to occur often in GitHub repositories.
 630        selected_rules_file.map(|path_in_worktree| {
 631            let project_path = ProjectPath {
 632                worktree_id,
 633                path: path_in_worktree.clone(),
 634            };
 635            let buffer_task =
 636                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 637            let rope_task = cx.spawn(async move |cx| {
 638                let buffer = buffer_task.await?;
 639                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 640                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 641                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 642                })?;
 643                anyhow::Ok((project_entry_id, rope))
 644            });
 645            // Build a string from the rope on a background thread.
 646            cx.background_spawn(async move {
 647                let (project_entry_id, rope) = rope_task.await?;
 648                anyhow::Ok(RulesFileContext {
 649                    path_in_worktree,
 650                    text: rope.to_string().trim().to_string(),
 651                    project_entry_id: project_entry_id.to_usize(),
 652                })
 653            })
 654        })
 655    }
 656
 657    fn handle_thread_title_updated(
 658        &mut self,
 659        thread: Entity<Thread>,
 660        _: &TitleUpdated,
 661        cx: &mut Context<Self>,
 662    ) {
 663        let session_id = thread.read(cx).session_id();
 664        let Some(session) = self.sessions.get(session_id) else {
 665            return;
 666        };
 667
 668        let thread = thread.downgrade();
 669        let acp_thread = session.acp_thread.downgrade();
 670        cx.spawn(async move |_, cx| {
 671            let title = thread.read_with(cx, |thread, _| thread.title())?;
 672            if let Some(title) = title {
 673                let task =
 674                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 675                task.await?;
 676            }
 677            anyhow::Ok(())
 678        })
 679        .detach_and_log_err(cx);
 680    }
 681
 682    fn handle_thread_token_usage_updated(
 683        &mut self,
 684        thread: Entity<Thread>,
 685        usage: &TokenUsageUpdated,
 686        cx: &mut Context<Self>,
 687    ) {
 688        let Some(session) = self.sessions.get(thread.read(cx).session_id()) else {
 689            return;
 690        };
 691        session.acp_thread.update(cx, |acp_thread, cx| {
 692            acp_thread.update_token_usage(usage.0.clone(), cx);
 693        });
 694    }
 695
 696    fn handle_project_event(
 697        &mut self,
 698        project: Entity<Project>,
 699        event: &project::Event,
 700        _cx: &mut Context<Self>,
 701    ) {
 702        let project_id = project.entity_id();
 703        let Some(state) = self.projects.get_mut(&project_id) else {
 704            return;
 705        };
 706        match event {
 707            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 708                state.project_context_needs_refresh.send(()).ok();
 709            }
 710            project::Event::WorktreeUpdatedEntries(_, items) => {
 711                if items.iter().any(|(path, _, _)| {
 712                    RULES_FILE_NAMES
 713                        .iter()
 714                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 715                }) {
 716                    state.project_context_needs_refresh.send(()).ok();
 717                }
 718            }
 719            _ => {}
 720        }
 721    }
 722
 723    fn handle_prompts_updated_event(
 724        &mut self,
 725        _prompt_store: Entity<PromptStore>,
 726        _event: &prompt_store::PromptsUpdatedEvent,
 727        _cx: &mut Context<Self>,
 728    ) {
 729        for state in self.projects.values_mut() {
 730            state.project_context_needs_refresh.send(()).ok();
 731        }
 732    }
 733
 734    fn handle_models_updated_event(
 735        &mut self,
 736        _registry: Entity<LanguageModelRegistry>,
 737        event: &language_model::Event,
 738        cx: &mut Context<Self>,
 739    ) {
 740        self.models.refresh_list(cx);
 741
 742        let registry = LanguageModelRegistry::read_global(cx);
 743        let default_model = registry.default_model().map(|m| m.model);
 744        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 745
 746        for session in self.sessions.values_mut() {
 747            session.thread.update(cx, |thread, cx| {
 748                if thread.model().is_none()
 749                    && let Some(model) = default_model.clone()
 750                {
 751                    thread.set_model(model, cx);
 752                    cx.notify();
 753                }
 754                if let Some(model) = summarization_model.clone() {
 755                    if thread.summarization_model().is_none()
 756                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 757                    {
 758                        thread.set_summarization_model(Some(model), cx);
 759                    }
 760                }
 761            });
 762        }
 763    }
 764
 765    fn handle_context_server_store_updated(
 766        &mut self,
 767        store: Entity<project::context_server_store::ContextServerStore>,
 768        _event: &project::context_server_store::ServerStatusChangedEvent,
 769        cx: &mut Context<Self>,
 770    ) {
 771        let project_id = self.projects.iter().find_map(|(id, state)| {
 772            if *state.context_server_registry.read(cx).server_store() == store {
 773                Some(*id)
 774            } else {
 775                None
 776            }
 777        });
 778        if let Some(project_id) = project_id {
 779            self.update_available_commands_for_project(project_id, cx);
 780        }
 781    }
 782
 783    fn handle_context_server_registry_event(
 784        &mut self,
 785        registry: Entity<ContextServerRegistry>,
 786        event: &ContextServerRegistryEvent,
 787        cx: &mut Context<Self>,
 788    ) {
 789        match event {
 790            ContextServerRegistryEvent::ToolsChanged => {}
 791            ContextServerRegistryEvent::PromptsChanged => {
 792                let project_id = self.projects.iter().find_map(|(id, state)| {
 793                    if state.context_server_registry == registry {
 794                        Some(*id)
 795                    } else {
 796                        None
 797                    }
 798                });
 799                if let Some(project_id) = project_id {
 800                    self.update_available_commands_for_project(project_id, cx);
 801                }
 802            }
 803        }
 804    }
 805
 806    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 807        let available_commands =
 808            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 809        for session in self.sessions.values() {
 810            if session.project_id != project_id {
 811                continue;
 812            }
 813            session.acp_thread.update(cx, |thread, cx| {
 814                thread
 815                    .handle_session_update(
 816                        acp::SessionUpdate::AvailableCommandsUpdate(
 817                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 818                        ),
 819                        cx,
 820                    )
 821                    .log_err();
 822            });
 823        }
 824    }
 825
 826    fn build_available_commands_for_project(
 827        project_state: Option<&ProjectState>,
 828        cx: &App,
 829    ) -> Vec<acp::AvailableCommand> {
 830        let Some(state) = project_state else {
 831            return vec![];
 832        };
 833        let registry = state.context_server_registry.read(cx);
 834
 835        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 836        for context_server_prompt in registry.prompts() {
 837            *prompt_name_counts
 838                .entry(context_server_prompt.prompt.name.as_str())
 839                .or_insert(0) += 1;
 840        }
 841
 842        registry
 843            .prompts()
 844            .flat_map(|context_server_prompt| {
 845                let prompt = &context_server_prompt.prompt;
 846
 847                let should_prefix = prompt_name_counts
 848                    .get(prompt.name.as_str())
 849                    .copied()
 850                    .unwrap_or(0)
 851                    > 1;
 852
 853                let name = if should_prefix {
 854                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 855                } else {
 856                    prompt.name.clone()
 857                };
 858
 859                let mut command = acp::AvailableCommand::new(
 860                    name,
 861                    prompt.description.clone().unwrap_or_default(),
 862                );
 863
 864                match prompt.arguments.as_deref() {
 865                    Some([arg]) => {
 866                        let hint = format!("<{}>", arg.name);
 867
 868                        command = command.input(acp::AvailableCommandInput::Unstructured(
 869                            acp::UnstructuredCommandInput::new(hint),
 870                        ));
 871                    }
 872                    Some([]) | None => {}
 873                    Some(_) => {
 874                        // skip >1 argument commands since we don't support them yet
 875                        return None;
 876                    }
 877                }
 878
 879                Some(command)
 880            })
 881            .collect()
 882    }
 883
 884    pub fn load_thread(
 885        &mut self,
 886        id: acp::SessionId,
 887        project: Entity<Project>,
 888        cx: &mut Context<Self>,
 889    ) -> Task<Result<Entity<Thread>>> {
 890        let database_future = ThreadsDatabase::connect(cx);
 891        cx.spawn(async move |this, cx| {
 892            let database = database_future.await.map_err(|err| anyhow!(err))?;
 893            let db_thread = database
 894                .load_thread(id.clone())
 895                .await?
 896                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 897
 898            this.update(cx, |this, cx| {
 899                let project_id = this.get_or_create_project_state(&project, cx);
 900                let project_state = this
 901                    .projects
 902                    .get(&project_id)
 903                    .context("project state not found")?;
 904                let summarization_model = LanguageModelRegistry::read_global(cx)
 905                    .thread_summary_model()
 906                    .map(|c| c.model);
 907
 908                Ok(cx.new(|cx| {
 909                    let mut thread = Thread::from_db(
 910                        ThreadId::new(),
 911                        id.clone(),
 912                        db_thread,
 913                        project_state.project.clone(),
 914                        project_state.project_context.clone(),
 915                        project_state.context_server_registry.clone(),
 916                        this.templates.clone(),
 917                        cx,
 918                    );
 919                    thread.set_summarization_model(summarization_model, cx);
 920                    thread
 921                }))
 922            })?
 923        })
 924    }
 925
 926    pub fn open_thread(
 927        &mut self,
 928        id: acp::SessionId,
 929        project: Entity<Project>,
 930        cx: &mut Context<Self>,
 931    ) -> Task<Result<Entity<AcpThread>>> {
 932        if let Some(session) = self.sessions.get(&id) {
 933            return Task::ready(Ok(session.acp_thread.clone()));
 934        }
 935
 936        let task = self.load_thread(id, project.clone(), cx);
 937        cx.spawn(async move |this, cx| {
 938            let thread = task.await?;
 939            let acp_thread = this.update(cx, |this, cx| {
 940                let project_id = this.get_or_create_project_state(&project, cx);
 941                this.register_session(thread.clone(), project_id, cx)
 942            })?;
 943            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 944            cx.update(|cx| {
 945                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 946            })
 947            .await?;
 948            acp_thread.update(cx, |thread, cx| {
 949                thread.snapshot_completed_plan(cx);
 950            });
 951            Ok(acp_thread)
 952        })
 953    }
 954
 955    pub fn thread_summary(
 956        &mut self,
 957        id: acp::SessionId,
 958        project: Entity<Project>,
 959        cx: &mut Context<Self>,
 960    ) -> Task<Result<SharedString>> {
 961        let thread = self.open_thread(id.clone(), project, cx);
 962        cx.spawn(async move |this, cx| {
 963            let acp_thread = thread.await?;
 964            let summary_task = this.update(cx, |this, cx| {
 965                let session = this.sessions.get(&id).context("session not found")?;
 966                anyhow::Ok(session.thread.update(cx, |thread, cx| thread.summary(cx)))
 967            })??;
 968            let result = summary_task.await.context("Failed to generate summary")?;
 969            drop(acp_thread);
 970            Ok(result)
 971        })
 972    }
 973
 974    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 975        if thread.read(cx).is_empty() {
 976            return;
 977        }
 978
 979        let session_id = thread.read(cx).session_id().clone();
 980        let Some(session) = self.sessions.get_mut(&session_id) else {
 981            return;
 982        };
 983
 984        let project_id = session.project_id;
 985        let Some(state) = self.projects.get(&project_id) else {
 986            return;
 987        };
 988
 989        let folder_paths = PathList::new(
 990            &state
 991                .project
 992                .read(cx)
 993                .visible_worktrees(cx)
 994                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 995                .collect::<Vec<_>>(),
 996        );
 997
 998        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 999        let database_future = ThreadsDatabase::connect(cx);
1000        let db_thread = thread.update(cx, |thread, cx| {
1001            thread.set_draft_prompt(draft_prompt);
1002            thread.to_db(cx)
1003        });
1004        let thread_store = self.thread_store.clone();
1005        session.pending_save = cx.spawn(async move |_, cx| {
1006            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1007                return Ok(());
1008            };
1009            let db_thread = db_thread.await;
1010            database
1011                .save_thread(session_id, db_thread, folder_paths)
1012                .await
1013                .log_err();
1014            thread_store.update(cx, |store, cx| store.reload(cx));
1015            Ok(())
1016        });
1017    }
1018
1019    fn send_mcp_prompt(
1020        &self,
1021        message_id: UserMessageId,
1022        session_id: acp::SessionId,
1023        prompt_name: String,
1024        server_id: ContextServerId,
1025        arguments: HashMap<String, String>,
1026        original_content: Vec<acp::ContentBlock>,
1027        cx: &mut Context<Self>,
1028    ) -> Task<Result<acp::PromptResponse>> {
1029        let Some(state) = self.session_project_state(&session_id) else {
1030            return Task::ready(Err(anyhow!("Project state not found for session")));
1031        };
1032        let server_store = state
1033            .context_server_registry
1034            .read(cx)
1035            .server_store()
1036            .clone();
1037        let path_style = state.project.read(cx).path_style(cx);
1038
1039        cx.spawn(async move |this, cx| {
1040            let prompt =
1041                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1042
1043            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1044                let session = this
1045                    .sessions
1046                    .get(&session_id)
1047                    .context("Failed to get session")?;
1048                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1049            })??;
1050
1051            let mut last_is_user = true;
1052
1053            thread.update(cx, |thread, cx| {
1054                thread.push_acp_user_block(
1055                    message_id,
1056                    original_content.into_iter().skip(1),
1057                    path_style,
1058                    cx,
1059                );
1060            });
1061
1062            for message in prompt.messages {
1063                let context_server::types::PromptMessage { role, content } = message;
1064                let block = mcp_message_content_to_acp_content_block(content);
1065
1066                match role {
1067                    context_server::types::Role::User => {
1068                        let id = acp_thread::UserMessageId::new();
1069
1070                        acp_thread.update(cx, |acp_thread, cx| {
1071                            acp_thread.push_user_content_block_with_indent(
1072                                Some(id.clone()),
1073                                block.clone(),
1074                                true,
1075                                cx,
1076                            );
1077                        });
1078
1079                        thread.update(cx, |thread, cx| {
1080                            thread.push_acp_user_block(id, [block], path_style, cx);
1081                        });
1082                    }
1083                    context_server::types::Role::Assistant => {
1084                        acp_thread.update(cx, |acp_thread, cx| {
1085                            acp_thread.push_assistant_content_block_with_indent(
1086                                block.clone(),
1087                                false,
1088                                true,
1089                                cx,
1090                            );
1091                        });
1092
1093                        thread.update(cx, |thread, cx| {
1094                            thread.push_acp_agent_block(block, cx);
1095                        });
1096                    }
1097                }
1098
1099                last_is_user = role == context_server::types::Role::User;
1100            }
1101
1102            let response_stream = thread.update(cx, |thread, cx| {
1103                if last_is_user {
1104                    thread.send_existing(cx)
1105                } else {
1106                    // Resume if MCP prompt did not end with a user message
1107                    thread.resume(cx)
1108                }
1109            })?;
1110
1111            cx.update(|cx| {
1112                NativeAgentConnection::handle_thread_events(
1113                    response_stream,
1114                    acp_thread.downgrade(),
1115                    cx,
1116                )
1117            })
1118            .await
1119        })
1120    }
1121}
1122
1123/// Wrapper struct that implements the AgentConnection trait
1124#[derive(Clone)]
1125pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1126
1127impl NativeAgentConnection {
1128    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1129        self.0
1130            .read(cx)
1131            .sessions
1132            .get(session_id)
1133            .map(|session| session.thread.clone())
1134    }
1135
1136    pub fn load_thread(
1137        &self,
1138        id: acp::SessionId,
1139        project: Entity<Project>,
1140        cx: &mut App,
1141    ) -> Task<Result<Entity<Thread>>> {
1142        self.0
1143            .update(cx, |this, cx| this.load_thread(id, project, cx))
1144    }
1145
1146    fn run_turn(
1147        &self,
1148        session_id: acp::SessionId,
1149        cx: &mut App,
1150        f: impl 'static
1151        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1152    ) -> Task<Result<acp::PromptResponse>> {
1153        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1154            agent
1155                .sessions
1156                .get(&session_id)
1157                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1158        }) else {
1159            return Task::ready(Err(anyhow!("Session not found")));
1160        };
1161        log::debug!("Found session for: {}", session_id);
1162
1163        let response_stream = match f(thread, cx) {
1164            Ok(stream) => stream,
1165            Err(err) => return Task::ready(Err(err)),
1166        };
1167        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1168    }
1169
1170    fn handle_thread_events(
1171        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1172        acp_thread: WeakEntity<AcpThread>,
1173        cx: &App,
1174    ) -> Task<Result<acp::PromptResponse>> {
1175        cx.spawn(async move |cx| {
1176            // Handle response stream and forward to session.acp_thread
1177            while let Some(result) = events.next().await {
1178                match result {
1179                    Ok(event) => {
1180                        log::trace!("Received completion event: {:?}", event);
1181
1182                        match event {
1183                            ThreadEvent::UserMessage(message) => {
1184                                acp_thread.update(cx, |thread, cx| {
1185                                    for content in message.content {
1186                                        thread.push_user_content_block(
1187                                            Some(message.id.clone()),
1188                                            content.into(),
1189                                            cx,
1190                                        );
1191                                    }
1192                                })?;
1193                            }
1194                            ThreadEvent::AgentText(text) => {
1195                                acp_thread.update(cx, |thread, cx| {
1196                                    thread.push_assistant_content_block(text.into(), false, cx)
1197                                })?;
1198                            }
1199                            ThreadEvent::AgentThinking(text) => {
1200                                acp_thread.update(cx, |thread, cx| {
1201                                    thread.push_assistant_content_block(text.into(), true, cx)
1202                                })?;
1203                            }
1204                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1205                                tool_call,
1206                                options,
1207                                response,
1208                                context: _,
1209                            }) => {
1210                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1211                                    thread.request_tool_call_authorization(tool_call, options, cx)
1212                                })??;
1213                                cx.background_spawn(async move {
1214                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1215                                        outcome_task.await
1216                                    {
1217                                        response
1218                                            .send(outcome)
1219                                            .map(|_| anyhow!("authorization receiver was dropped"))
1220                                            .log_err();
1221                                    }
1222                                })
1223                                .detach();
1224                            }
1225                            ThreadEvent::ToolCall(tool_call) => {
1226                                acp_thread.update(cx, |thread, cx| {
1227                                    thread.upsert_tool_call(tool_call, cx)
1228                                })??;
1229                            }
1230                            ThreadEvent::ToolCallUpdate(update) => {
1231                                acp_thread.update(cx, |thread, cx| {
1232                                    thread.update_tool_call(update, cx)
1233                                })??;
1234                            }
1235                            ThreadEvent::Plan(plan) => {
1236                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1237                            }
1238                            ThreadEvent::SubagentSpawned(session_id) => {
1239                                acp_thread.update(cx, |thread, cx| {
1240                                    thread.subagent_spawned(session_id, cx);
1241                                })?;
1242                            }
1243                            ThreadEvent::Retry(status) => {
1244                                acp_thread.update(cx, |thread, cx| {
1245                                    thread.update_retry_status(status, cx)
1246                                })?;
1247                            }
1248                            ThreadEvent::Stop(stop_reason) => {
1249                                log::debug!("Assistant message complete: {:?}", stop_reason);
1250                                return Ok(acp::PromptResponse::new(stop_reason));
1251                            }
1252                        }
1253                    }
1254                    Err(e) => {
1255                        log::error!("Error in model response stream: {:?}", e);
1256                        return Err(e);
1257                    }
1258                }
1259            }
1260
1261            log::debug!("Response stream completed");
1262            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1263        })
1264    }
1265}
1266
1267struct Command<'a> {
1268    prompt_name: &'a str,
1269    arg_value: &'a str,
1270    explicit_server_id: Option<&'a str>,
1271}
1272
1273impl<'a> Command<'a> {
1274    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1275        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1276            return None;
1277        };
1278        let text = text_content.text.trim();
1279        let command = text.strip_prefix('/')?;
1280        let (command, arg_value) = command
1281            .split_once(char::is_whitespace)
1282            .unwrap_or((command, ""));
1283
1284        if let Some((server_id, prompt_name)) = command.split_once('.') {
1285            Some(Self {
1286                prompt_name,
1287                arg_value,
1288                explicit_server_id: Some(server_id),
1289            })
1290        } else {
1291            Some(Self {
1292                prompt_name: command,
1293                arg_value,
1294                explicit_server_id: None,
1295            })
1296        }
1297    }
1298}
1299
1300struct NativeAgentModelSelector {
1301    session_id: acp::SessionId,
1302    connection: NativeAgentConnection,
1303}
1304
1305impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1306    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1307        log::debug!("NativeAgentConnection::list_models called");
1308        let list = self.connection.0.read(cx).models.model_list.clone();
1309        Task::ready(if list.is_empty() {
1310            Err(anyhow::anyhow!("No models available"))
1311        } else {
1312            Ok(list)
1313        })
1314    }
1315
1316    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1317        log::debug!(
1318            "Setting model for session {}: {}",
1319            self.session_id,
1320            model_id
1321        );
1322        let Some(thread) = self
1323            .connection
1324            .0
1325            .read(cx)
1326            .sessions
1327            .get(&self.session_id)
1328            .map(|session| session.thread.clone())
1329        else {
1330            return Task::ready(Err(anyhow!("Session not found")));
1331        };
1332
1333        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1334            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1335        };
1336
1337        // We want to reset the effort level when switching models, as the currently-selected effort level may
1338        // not be compatible.
1339        let effort = model
1340            .default_effort_level()
1341            .map(|effort_level| effort_level.value.to_string());
1342
1343        thread.update(cx, |thread, cx| {
1344            thread.set_model(model.clone(), cx);
1345            thread.set_thinking_effort(effort.clone(), cx);
1346            thread.set_thinking_enabled(model.supports_thinking(), cx);
1347        });
1348
1349        update_settings_file(
1350            self.connection.0.read(cx).fs.clone(),
1351            cx,
1352            move |settings, cx| {
1353                let provider = model.provider_id().0.to_string();
1354                let model = model.id().0.to_string();
1355                let enable_thinking = thread.read(cx).thinking_enabled();
1356                let speed = thread.read(cx).speed();
1357                settings
1358                    .agent
1359                    .get_or_insert_default()
1360                    .set_model(LanguageModelSelection {
1361                        provider: provider.into(),
1362                        model,
1363                        enable_thinking,
1364                        effort,
1365                        speed,
1366                    });
1367            },
1368        );
1369
1370        Task::ready(Ok(()))
1371    }
1372
1373    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1374        let Some(thread) = self
1375            .connection
1376            .0
1377            .read(cx)
1378            .sessions
1379            .get(&self.session_id)
1380            .map(|session| session.thread.clone())
1381        else {
1382            return Task::ready(Err(anyhow!("Session not found")));
1383        };
1384        let Some(model) = thread.read(cx).model() else {
1385            return Task::ready(Err(anyhow!("Model not found")));
1386        };
1387        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1388        else {
1389            return Task::ready(Err(anyhow!("Provider not found")));
1390        };
1391        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1392            model, &provider,
1393        )))
1394    }
1395
1396    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1397        Some(self.connection.0.read(cx).models.watch())
1398    }
1399
1400    fn should_render_footer(&self) -> bool {
1401        true
1402    }
1403}
1404
1405pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1406
1407impl acp_thread::AgentConnection for NativeAgentConnection {
1408    fn agent_id(&self) -> AgentId {
1409        ZED_AGENT_ID.clone()
1410    }
1411
1412    fn telemetry_id(&self) -> SharedString {
1413        "zed".into()
1414    }
1415
1416    fn new_session(
1417        self: Rc<Self>,
1418        project: Entity<Project>,
1419        work_dirs: PathList,
1420        cx: &mut App,
1421    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1422        log::debug!("Creating new thread for project at: {work_dirs:?}");
1423        Task::ready(Ok(self
1424            .0
1425            .update(cx, |agent, cx| agent.new_session(project, cx))))
1426    }
1427
1428    fn supports_load_session(&self) -> bool {
1429        true
1430    }
1431
1432    fn load_session(
1433        self: Rc<Self>,
1434        session_id: acp::SessionId,
1435        project: Entity<Project>,
1436        _work_dirs: PathList,
1437        _title: Option<SharedString>,
1438        cx: &mut App,
1439    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1440        self.0
1441            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1442    }
1443
1444    fn supports_close_session(&self) -> bool {
1445        true
1446    }
1447
1448    fn close_session(
1449        self: Rc<Self>,
1450        session_id: &acp::SessionId,
1451        cx: &mut App,
1452    ) -> Task<Result<()>> {
1453        self.0.update(cx, |agent, cx| {
1454            let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1455            if let Some(thread) = thread {
1456                agent.save_thread(thread, cx);
1457            }
1458
1459            let Some(session) = agent.sessions.remove(session_id) else {
1460                return Task::ready(Ok(()));
1461            };
1462            let project_id = session.project_id;
1463
1464            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1465            if !has_remaining {
1466                agent.projects.remove(&project_id);
1467            }
1468
1469            session.pending_save
1470        })
1471    }
1472
1473    fn auth_methods(&self) -> &[acp::AuthMethod] {
1474        &[] // No auth for in-process
1475    }
1476
1477    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1478        Task::ready(Ok(()))
1479    }
1480
1481    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1482        Some(Rc::new(NativeAgentModelSelector {
1483            session_id: session_id.clone(),
1484            connection: self.clone(),
1485        }) as Rc<dyn AgentModelSelector>)
1486    }
1487
1488    fn prompt(
1489        &self,
1490        id: Option<acp_thread::UserMessageId>,
1491        params: acp::PromptRequest,
1492        cx: &mut App,
1493    ) -> Task<Result<acp::PromptResponse>> {
1494        let id = id.expect("UserMessageId is required");
1495        let session_id = params.session_id.clone();
1496        log::info!("Received prompt request for session: {}", session_id);
1497        log::debug!("Prompt blocks count: {}", params.prompt.len());
1498
1499        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1500            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1501        };
1502
1503        if let Some(parsed_command) = Command::parse(&params.prompt) {
1504            let registry = project_state.context_server_registry.read(cx);
1505
1506            let explicit_server_id = parsed_command
1507                .explicit_server_id
1508                .map(|server_id| ContextServerId(server_id.into()));
1509
1510            if let Some(prompt) =
1511                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1512            {
1513                let arguments = if !parsed_command.arg_value.is_empty()
1514                    && let Some(arg_name) = prompt
1515                        .prompt
1516                        .arguments
1517                        .as_ref()
1518                        .and_then(|args| args.first())
1519                        .map(|arg| arg.name.clone())
1520                {
1521                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1522                } else {
1523                    Default::default()
1524                };
1525
1526                let prompt_name = prompt.prompt.name.clone();
1527                let server_id = prompt.server_id.clone();
1528
1529                return self.0.update(cx, |agent, cx| {
1530                    agent.send_mcp_prompt(
1531                        id,
1532                        session_id.clone(),
1533                        prompt_name,
1534                        server_id,
1535                        arguments,
1536                        params.prompt,
1537                        cx,
1538                    )
1539                });
1540            }
1541        };
1542
1543        let path_style = project_state.project.read(cx).path_style(cx);
1544
1545        self.run_turn(session_id, cx, move |thread, cx| {
1546            let content: Vec<UserMessageContent> = params
1547                .prompt
1548                .into_iter()
1549                .map(|block| UserMessageContent::from_content_block(block, path_style))
1550                .collect::<Vec<_>>();
1551            log::debug!("Converted prompt to message: {} chars", content.len());
1552            log::debug!("Message id: {:?}", id);
1553            log::debug!("Message content: {:?}", content);
1554
1555            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1556        })
1557    }
1558
1559    fn retry(
1560        &self,
1561        session_id: &acp::SessionId,
1562        _cx: &App,
1563    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1564        Some(Rc::new(NativeAgentSessionRetry {
1565            connection: self.clone(),
1566            session_id: session_id.clone(),
1567        }) as _)
1568    }
1569
1570    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1571        log::info!("Cancelling on session: {}", session_id);
1572        self.0.update(cx, |agent, cx| {
1573            if let Some(session) = agent.sessions.get(session_id) {
1574                session
1575                    .thread
1576                    .update(cx, |thread, cx| thread.cancel(cx))
1577                    .detach();
1578            }
1579        });
1580    }
1581
1582    fn truncate(
1583        &self,
1584        session_id: &acp::SessionId,
1585        cx: &App,
1586    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1587        self.0.read_with(cx, |agent, _cx| {
1588            agent.sessions.get(session_id).map(|session| {
1589                Rc::new(NativeAgentSessionTruncate {
1590                    thread: session.thread.clone(),
1591                    acp_thread: session.acp_thread.downgrade(),
1592                }) as _
1593            })
1594        })
1595    }
1596
1597    fn set_title(
1598        &self,
1599        session_id: &acp::SessionId,
1600        cx: &App,
1601    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1602        self.0.read_with(cx, |agent, _cx| {
1603            agent
1604                .sessions
1605                .get(session_id)
1606                .filter(|s| !s.thread.read(cx).is_subagent())
1607                .map(|session| {
1608                    Rc::new(NativeAgentSessionSetTitle {
1609                        thread: session.thread.clone(),
1610                    }) as _
1611                })
1612        })
1613    }
1614
1615    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1616        let thread_store = self.0.read(cx).thread_store.clone();
1617        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1618    }
1619
1620    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1621        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1622    }
1623
1624    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1625        self
1626    }
1627}
1628
1629impl acp_thread::AgentTelemetry for NativeAgentConnection {
1630    fn thread_data(
1631        &self,
1632        session_id: &acp::SessionId,
1633        cx: &mut App,
1634    ) -> Task<Result<serde_json::Value>> {
1635        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1636            return Task::ready(Err(anyhow!("Session not found")));
1637        };
1638
1639        let task = session.thread.read(cx).to_db(cx);
1640        cx.background_spawn(async move {
1641            serde_json::to_value(task.await).context("Failed to serialize thread")
1642        })
1643    }
1644}
1645
1646pub struct NativeAgentSessionList {
1647    thread_store: Entity<ThreadStore>,
1648    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1649    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1650    _subscription: Subscription,
1651}
1652
1653impl NativeAgentSessionList {
1654    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1655        let (tx, rx) = smol::channel::unbounded();
1656        let this_tx = tx.clone();
1657        let subscription = cx.observe(&thread_store, move |_, _| {
1658            this_tx
1659                .try_send(acp_thread::SessionListUpdate::Refresh)
1660                .ok();
1661        });
1662        Self {
1663            thread_store,
1664            updates_tx: tx,
1665            updates_rx: rx,
1666            _subscription: subscription,
1667        }
1668    }
1669
1670    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1671        &self.thread_store
1672    }
1673}
1674
1675impl AgentSessionList for NativeAgentSessionList {
1676    fn list_sessions(
1677        &self,
1678        _request: AgentSessionListRequest,
1679        cx: &mut App,
1680    ) -> Task<Result<AgentSessionListResponse>> {
1681        let sessions = self
1682            .thread_store
1683            .read(cx)
1684            .entries()
1685            .map(|entry| AgentSessionInfo::from(&entry))
1686            .collect();
1687        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1688    }
1689
1690    fn supports_delete(&self) -> bool {
1691        true
1692    }
1693
1694    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1695        self.thread_store
1696            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1697    }
1698
1699    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1700        self.thread_store
1701            .update(cx, |store, cx| store.delete_threads(cx))
1702    }
1703
1704    fn watch(
1705        &self,
1706        _cx: &mut App,
1707    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1708        Some(self.updates_rx.clone())
1709    }
1710
1711    fn notify_refresh(&self) {
1712        self.updates_tx
1713            .try_send(acp_thread::SessionListUpdate::Refresh)
1714            .ok();
1715    }
1716
1717    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1718        self
1719    }
1720}
1721
1722struct NativeAgentSessionTruncate {
1723    thread: Entity<Thread>,
1724    acp_thread: WeakEntity<AcpThread>,
1725}
1726
1727impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1728    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1729        match self.thread.update(cx, |thread, cx| {
1730            thread.truncate(message_id.clone(), cx)?;
1731            Ok(thread.latest_token_usage())
1732        }) {
1733            Ok(usage) => {
1734                self.acp_thread
1735                    .update(cx, |thread, cx| {
1736                        thread.update_token_usage(usage, cx);
1737                    })
1738                    .ok();
1739                Task::ready(Ok(()))
1740            }
1741            Err(error) => Task::ready(Err(error)),
1742        }
1743    }
1744}
1745
1746struct NativeAgentSessionRetry {
1747    connection: NativeAgentConnection,
1748    session_id: acp::SessionId,
1749}
1750
1751impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1752    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1753        self.connection
1754            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1755                thread.update(cx, |thread, cx| thread.resume(cx))
1756            })
1757    }
1758}
1759
1760struct NativeAgentSessionSetTitle {
1761    thread: Entity<Thread>,
1762}
1763
1764impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1765    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1766        self.thread
1767            .update(cx, |thread, cx| thread.set_title(title, cx));
1768        Task::ready(Ok(()))
1769    }
1770}
1771
1772pub struct NativeThreadEnvironment {
1773    agent: WeakEntity<NativeAgent>,
1774    thread: WeakEntity<Thread>,
1775    acp_thread: WeakEntity<AcpThread>,
1776}
1777
1778impl NativeThreadEnvironment {
1779    pub(crate) fn create_subagent_thread(
1780        &self,
1781        label: String,
1782        cx: &mut App,
1783    ) -> Result<Rc<dyn SubagentHandle>> {
1784        let Some(parent_thread_entity) = self.thread.upgrade() else {
1785            anyhow::bail!("Parent thread no longer exists".to_string());
1786        };
1787        let parent_thread = parent_thread_entity.read(cx);
1788        let current_depth = parent_thread.depth();
1789        let parent_session_id = parent_thread.session_id().clone();
1790
1791        if current_depth >= MAX_SUBAGENT_DEPTH {
1792            return Err(anyhow!(
1793                "Maximum subagent depth ({}) reached",
1794                MAX_SUBAGENT_DEPTH
1795            ));
1796        }
1797
1798        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1799            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1800            thread.set_title(label.into(), cx);
1801            thread
1802        });
1803
1804        let subagent_thread_id = subagent_thread.read(cx).id();
1805
1806        let acp_thread = self
1807            .agent
1808            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1809                let project_id = agent
1810                    .sessions
1811                    .get(&parent_session_id)
1812                    .map(|s| s.project_id)
1813                    .context("parent session not found")?;
1814                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1815            })??;
1816
1817        let depth = current_depth + 1;
1818
1819        telemetry::event!(
1820            "Subagent Started",
1821            session = parent_thread_entity.read(cx).id().to_string(),
1822            subagent_session = subagent_thread_id.to_string(),
1823            depth,
1824            is_resumed = false,
1825        );
1826
1827        self.prompt_subagent(subagent_thread_id, subagent_thread, acp_thread)
1828    }
1829
1830    pub(crate) fn resume_subagent_thread(
1831        &self,
1832        session_id: acp::SessionId,
1833        cx: &mut App,
1834    ) -> Result<Rc<dyn SubagentHandle>> {
1835        let (subagent_thread, acp_thread, thread_id) = self.agent.update(cx, |agent, _cx| {
1836            let session = agent
1837                .sessions
1838                .get(&session_id)
1839                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1840            let thread_id = session.thread.read(_cx).id();
1841            anyhow::Ok((
1842                session.thread.clone(),
1843                session.acp_thread.clone(),
1844                thread_id,
1845            ))
1846        })??;
1847
1848        let depth = subagent_thread.read(cx).depth();
1849
1850        if let Some(parent_thread_entity) = self.thread.upgrade() {
1851            telemetry::event!(
1852                "Subagent Started",
1853                session = parent_thread_entity.read(cx).id().to_string(),
1854                subagent_session = session_id.to_string(),
1855                depth,
1856                is_resumed = true,
1857            );
1858        }
1859
1860        self.prompt_subagent(thread_id, subagent_thread, acp_thread)
1861    }
1862
1863    fn prompt_subagent(
1864        &self,
1865        thread_id: ThreadId,
1866        subagent_thread: Entity<Thread>,
1867        acp_thread: Entity<acp_thread::AcpThread>,
1868    ) -> Result<Rc<dyn SubagentHandle>> {
1869        let Some(parent_thread_entity) = self.thread.upgrade() else {
1870            anyhow::bail!("Parent thread no longer exists".to_string());
1871        };
1872        Ok(Rc::new(NativeSubagentHandle::new(
1873            thread_id,
1874            subagent_thread,
1875            acp_thread,
1876            parent_thread_entity,
1877        )) as _)
1878    }
1879}
1880
1881impl ThreadEnvironment for NativeThreadEnvironment {
1882    fn create_terminal(
1883        &self,
1884        command: String,
1885        cwd: Option<PathBuf>,
1886        output_byte_limit: Option<u64>,
1887        cx: &mut AsyncApp,
1888    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1889        let task = self.acp_thread.update(cx, |thread, cx| {
1890            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1891        });
1892
1893        let acp_thread = self.acp_thread.clone();
1894        cx.spawn(async move |cx| {
1895            let terminal = task?.await?;
1896
1897            let (drop_tx, drop_rx) = oneshot::channel();
1898            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1899
1900            cx.spawn(async move |cx| {
1901                drop_rx.await.ok();
1902                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1903            })
1904            .detach();
1905
1906            let handle = AcpTerminalHandle {
1907                terminal,
1908                _drop_tx: Some(drop_tx),
1909            };
1910
1911            Ok(Rc::new(handle) as _)
1912        })
1913    }
1914
1915    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1916        self.create_subagent_thread(label, cx)
1917    }
1918
1919    fn resume_subagent(
1920        &self,
1921        session_id: acp::SessionId,
1922        cx: &mut App,
1923    ) -> Result<Rc<dyn SubagentHandle>> {
1924        self.resume_subagent_thread(session_id, cx)
1925    }
1926}
1927
1928#[derive(Debug, Clone)]
1929enum SubagentPromptResult {
1930    Completed,
1931    Cancelled,
1932    ContextWindowWarning,
1933    Error(String),
1934}
1935
1936pub struct NativeSubagentHandle {
1937    thread_id: ThreadId,
1938    parent_thread: WeakEntity<Thread>,
1939    subagent_thread: Entity<Thread>,
1940    acp_thread: Entity<acp_thread::AcpThread>,
1941}
1942
1943impl NativeSubagentHandle {
1944    fn new(
1945        thread_id: ThreadId,
1946        subagent_thread: Entity<Thread>,
1947        acp_thread: Entity<acp_thread::AcpThread>,
1948        parent_thread_entity: Entity<Thread>,
1949    ) -> Self {
1950        NativeSubagentHandle {
1951            thread_id,
1952            subagent_thread,
1953            parent_thread: parent_thread_entity.downgrade(),
1954            acp_thread,
1955        }
1956    }
1957}
1958
1959impl SubagentHandle for NativeSubagentHandle {
1960    fn id(&self) -> ThreadId {
1961        self.thread_id
1962    }
1963
1964    fn session_id(&self, cx: &App) -> acp::SessionId {
1965        self.subagent_thread.read(cx).session_id().clone()
1966    }
1967
1968    fn num_entries(&self, cx: &App) -> usize {
1969        self.acp_thread.read(cx).entries().len()
1970    }
1971
1972    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1973        let thread = self.subagent_thread.clone();
1974        let acp_thread = self.acp_thread.clone();
1975        let subagent_thread_id = self.thread_id;
1976        let parent_thread = self.parent_thread.clone();
1977
1978        cx.spawn(async move |cx| {
1979            let (task, _subscription) = cx.update(|cx| {
1980                let ratio_before_prompt = thread
1981                    .read(cx)
1982                    .latest_token_usage()
1983                    .map(|usage| usage.ratio());
1984
1985                parent_thread
1986                    .update(cx, |parent_thread, _cx| {
1987                        parent_thread.register_running_subagent(thread.downgrade())
1988                    })
1989                    .ok();
1990
1991                let task = acp_thread.update(cx, |acp_thread, cx| {
1992                    acp_thread.send(vec![message.into()], cx)
1993                });
1994
1995                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1996                let mut token_limit_tx = Some(token_limit_tx);
1997
1998                let subscription = cx.subscribe(
1999                    &thread,
2000                    move |_thread, event: &TokenUsageUpdated, _cx| {
2001                        if let Some(usage) = &event.0 {
2002                            let old_ratio = ratio_before_prompt
2003                                .clone()
2004                                .unwrap_or(TokenUsageRatio::Normal);
2005                            let new_ratio = usage.ratio();
2006                            if old_ratio == TokenUsageRatio::Normal
2007                                && new_ratio == TokenUsageRatio::Warning
2008                            {
2009                                if let Some(tx) = token_limit_tx.take() {
2010                                    tx.send(()).ok();
2011                                }
2012                            }
2013                        }
2014                    },
2015                );
2016
2017                let wait_for_prompt = cx
2018                    .background_spawn(async move {
2019                        futures::select! {
2020                            response = task.fuse() => match response {
2021                                Ok(Some(response)) => {
2022                                    match response.stop_reason {
2023                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2024                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2025                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2026                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2027                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2028                                    }
2029                                }
2030                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2031                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2032                            },
2033                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2034                        }
2035                    });
2036
2037                (wait_for_prompt, subscription)
2038            });
2039
2040            let result = match task.await {
2041                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2042                    thread
2043                        .last_message()
2044                        .and_then(|message| {
2045                            let content = message.as_agent_message()?
2046                                .content
2047                                .iter()
2048                                .filter_map(|c| match c {
2049                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2050                                    _ => None,
2051                                })
2052                                .join("\n\n");
2053                            if content.is_empty() {
2054                                None
2055                            } else {
2056                                Some( content)
2057                            }
2058                        })
2059                        .context("No response from subagent")
2060                }),
2061                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2062                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2063                SubagentPromptResult::ContextWindowWarning => {
2064                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2065                    Err(anyhow!(
2066                        "The agent is nearing the end of its context window and has been \
2067                         stopped. You can prompt the thread again to have the agent wrap up \
2068                         or hand off its work."
2069                    ))
2070                }
2071            };
2072
2073            parent_thread
2074                .update(cx, |parent_thread, cx| {
2075                    parent_thread.unregister_running_subagent(subagent_thread_id, cx)
2076                })
2077                .ok();
2078
2079            result
2080        })
2081    }
2082}
2083
2084pub struct AcpTerminalHandle {
2085    terminal: Entity<acp_thread::Terminal>,
2086    _drop_tx: Option<oneshot::Sender<()>>,
2087}
2088
2089impl TerminalHandle for AcpTerminalHandle {
2090    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2091        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2092    }
2093
2094    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2095        Ok(self
2096            .terminal
2097            .read_with(cx, |term, _cx| term.wait_for_exit()))
2098    }
2099
2100    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2101        Ok(self
2102            .terminal
2103            .read_with(cx, |term, cx| term.current_output(cx)))
2104    }
2105
2106    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2107        cx.update(|cx| {
2108            self.terminal.update(cx, |terminal, cx| {
2109                terminal.kill(cx);
2110            });
2111        });
2112        Ok(())
2113    }
2114
2115    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2116        Ok(self
2117            .terminal
2118            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2119    }
2120}
2121
2122#[cfg(test)]
2123mod internal_tests {
2124    use std::path::Path;
2125
2126    use super::*;
2127    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2128    use fs::FakeFs;
2129    use gpui::TestAppContext;
2130    use indoc::formatdoc;
2131    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2132    use language_model::{
2133        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2134    };
2135    use serde_json::json;
2136    use settings::SettingsStore;
2137    use util::{path, rel_path::rel_path};
2138
2139    #[gpui::test]
2140    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2141        init_test(cx);
2142        let fs = FakeFs::new(cx.executor());
2143        fs.insert_tree(
2144            "/",
2145            json!({
2146                "a": {}
2147            }),
2148        )
2149        .await;
2150        let project = Project::test(fs.clone(), [], cx).await;
2151        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2152        let agent =
2153            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2154
2155        // Creating a session registers the project and triggers context building.
2156        let connection = NativeAgentConnection(agent.clone());
2157        let _acp_thread = cx
2158            .update(|cx| {
2159                Rc::new(connection).new_session(
2160                    project.clone(),
2161                    PathList::new(&[Path::new("/")]),
2162                    cx,
2163                )
2164            })
2165            .await
2166            .unwrap();
2167        cx.run_until_parked();
2168
2169        let thread = agent.read_with(cx, |agent, _cx| {
2170            agent.sessions.values().next().unwrap().thread.clone()
2171        });
2172
2173        agent.read_with(cx, |agent, cx| {
2174            let project_id = project.entity_id();
2175            let state = agent.projects.get(&project_id).unwrap();
2176            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2177            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2178        });
2179
2180        let worktree = project
2181            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2182            .await
2183            .unwrap();
2184        cx.run_until_parked();
2185        agent.read_with(cx, |agent, cx| {
2186            let project_id = project.entity_id();
2187            let state = agent.projects.get(&project_id).unwrap();
2188            let expected_worktrees = vec![WorktreeContext {
2189                root_name: "a".into(),
2190                abs_path: Path::new("/a").into(),
2191                rules_file: None,
2192            }];
2193            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2194            assert_eq!(
2195                thread.read(cx).project_context().read(cx).worktrees,
2196                expected_worktrees
2197            );
2198        });
2199
2200        // Creating `/a/.rules` updates the project context.
2201        fs.insert_file("/a/.rules", Vec::new()).await;
2202        cx.run_until_parked();
2203        agent.read_with(cx, |agent, cx| {
2204            let project_id = project.entity_id();
2205            let state = agent.projects.get(&project_id).unwrap();
2206            let rules_entry = worktree
2207                .read(cx)
2208                .entry_for_path(rel_path(".rules"))
2209                .unwrap();
2210            let expected_worktrees = vec![WorktreeContext {
2211                root_name: "a".into(),
2212                abs_path: Path::new("/a").into(),
2213                rules_file: Some(RulesFileContext {
2214                    path_in_worktree: rel_path(".rules").into(),
2215                    text: "".into(),
2216                    project_entry_id: rules_entry.id.to_usize(),
2217                }),
2218            }];
2219            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2220            assert_eq!(
2221                thread.read(cx).project_context().read(cx).worktrees,
2222                expected_worktrees
2223            );
2224        });
2225    }
2226
2227    #[gpui::test]
2228    async fn test_listing_models(cx: &mut TestAppContext) {
2229        init_test(cx);
2230        let fs = FakeFs::new(cx.executor());
2231        fs.insert_tree("/", json!({ "a": {}  })).await;
2232        let project = Project::test(fs.clone(), [], cx).await;
2233        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2234        let connection =
2235            NativeAgentConnection(cx.update(|cx| {
2236                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2237            }));
2238
2239        // Create a thread/session
2240        let acp_thread = cx
2241            .update(|cx| {
2242                Rc::new(connection.clone()).new_session(
2243                    project.clone(),
2244                    PathList::new(&[Path::new("/a")]),
2245                    cx,
2246                )
2247            })
2248            .await
2249            .unwrap();
2250
2251        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2252
2253        let models = cx
2254            .update(|cx| {
2255                connection
2256                    .model_selector(&session_id)
2257                    .unwrap()
2258                    .list_models(cx)
2259            })
2260            .await
2261            .unwrap();
2262
2263        let acp_thread::AgentModelList::Grouped(models) = models else {
2264            panic!("Unexpected model group");
2265        };
2266        assert_eq!(
2267            models,
2268            IndexMap::from_iter([(
2269                AgentModelGroupName("Fake".into()),
2270                vec![AgentModelInfo {
2271                    id: acp::ModelId::new("fake/fake"),
2272                    name: "Fake".into(),
2273                    description: None,
2274                    icon: Some(acp_thread::AgentModelIcon::Named(
2275                        ui::IconName::ZedAssistant
2276                    )),
2277                    is_latest: false,
2278                    cost: None,
2279                }]
2280            )])
2281        );
2282    }
2283
2284    #[gpui::test]
2285    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2286        init_test(cx);
2287        let fs = FakeFs::new(cx.executor());
2288        fs.create_dir(paths::settings_file().parent().unwrap())
2289            .await
2290            .unwrap();
2291        fs.insert_file(
2292            paths::settings_file(),
2293            json!({
2294                "agent": {
2295                    "default_model": {
2296                        "provider": "foo",
2297                        "model": "bar"
2298                    }
2299                }
2300            })
2301            .to_string()
2302            .into_bytes(),
2303        )
2304        .await;
2305        let project = Project::test(fs.clone(), [], cx).await;
2306
2307        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2308
2309        // Create the agent and connection
2310        let agent =
2311            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2312        let connection = NativeAgentConnection(agent.clone());
2313
2314        // Create a thread/session
2315        let acp_thread = cx
2316            .update(|cx| {
2317                Rc::new(connection.clone()).new_session(
2318                    project.clone(),
2319                    PathList::new(&[Path::new("/a")]),
2320                    cx,
2321                )
2322            })
2323            .await
2324            .unwrap();
2325
2326        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2327
2328        // Select a model
2329        let selector = connection.model_selector(&session_id).unwrap();
2330        let model_id = acp::ModelId::new("fake/fake");
2331        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2332            .await
2333            .unwrap();
2334
2335        // Verify the thread has the selected model
2336        agent.read_with(cx, |agent, _| {
2337            let session = agent.sessions.get(&session_id).unwrap();
2338            session.thread.read_with(cx, |thread, _| {
2339                assert_eq!(thread.model().unwrap().id().0, "fake");
2340            });
2341        });
2342
2343        cx.run_until_parked();
2344
2345        // Verify settings file was updated
2346        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2347        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2348
2349        // Check that the agent settings contain the selected model
2350        assert_eq!(
2351            settings_json["agent"]["default_model"]["model"],
2352            json!("fake")
2353        );
2354        assert_eq!(
2355            settings_json["agent"]["default_model"]["provider"],
2356            json!("fake")
2357        );
2358
2359        // Register a thinking model and select it.
2360        cx.update(|cx| {
2361            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2362                "fake-corp",
2363                "fake-thinking",
2364                "Fake Thinking",
2365                true,
2366            ));
2367            let thinking_provider = Arc::new(
2368                FakeLanguageModelProvider::new(
2369                    LanguageModelProviderId::from("fake-corp".to_string()),
2370                    LanguageModelProviderName::from("Fake Corp".to_string()),
2371                )
2372                .with_models(vec![thinking_model]),
2373            );
2374            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2375                registry.register_provider(thinking_provider, cx);
2376            });
2377        });
2378        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2379
2380        let selector = connection.model_selector(&session_id).unwrap();
2381        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2382            .await
2383            .unwrap();
2384        cx.run_until_parked();
2385
2386        // Verify enable_thinking was written to settings as true.
2387        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2388        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2389        assert_eq!(
2390            settings_json["agent"]["default_model"]["enable_thinking"],
2391            json!(true),
2392            "selecting a thinking model should persist enable_thinking: true to settings"
2393        );
2394    }
2395
2396    #[gpui::test]
2397    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2398        init_test(cx);
2399        let fs = FakeFs::new(cx.executor());
2400        fs.create_dir(paths::settings_file().parent().unwrap())
2401            .await
2402            .unwrap();
2403        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2404        let project = Project::test(fs.clone(), [], cx).await;
2405
2406        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2407        let agent =
2408            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2409        let connection = NativeAgentConnection(agent.clone());
2410
2411        let acp_thread = cx
2412            .update(|cx| {
2413                Rc::new(connection.clone()).new_session(
2414                    project.clone(),
2415                    PathList::new(&[Path::new("/a")]),
2416                    cx,
2417                )
2418            })
2419            .await
2420            .unwrap();
2421        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2422
2423        // Register a second provider with a thinking model.
2424        cx.update(|cx| {
2425            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2426                "fake-corp",
2427                "fake-thinking",
2428                "Fake Thinking",
2429                true,
2430            ));
2431            let thinking_provider = Arc::new(
2432                FakeLanguageModelProvider::new(
2433                    LanguageModelProviderId::from("fake-corp".to_string()),
2434                    LanguageModelProviderName::from("Fake Corp".to_string()),
2435                )
2436                .with_models(vec![thinking_model]),
2437            );
2438            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2439                registry.register_provider(thinking_provider, cx);
2440            });
2441        });
2442        // Refresh the agent's model list so it picks up the new provider.
2443        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2444
2445        // Thread starts with thinking_enabled = false (the default).
2446        agent.read_with(cx, |agent, _| {
2447            let session = agent.sessions.get(&session_id).unwrap();
2448            session.thread.read_with(cx, |thread, _| {
2449                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2450            });
2451        });
2452
2453        // Select the thinking model via select_model.
2454        let selector = connection.model_selector(&session_id).unwrap();
2455        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2456            .await
2457            .unwrap();
2458
2459        // select_model should have enabled thinking based on the model's supports_thinking().
2460        agent.read_with(cx, |agent, _| {
2461            let session = agent.sessions.get(&session_id).unwrap();
2462            session.thread.read_with(cx, |thread, _| {
2463                assert!(
2464                    thread.thinking_enabled(),
2465                    "select_model should enable thinking when model supports it"
2466                );
2467            });
2468        });
2469
2470        // Switch back to the non-thinking model.
2471        let selector = connection.model_selector(&session_id).unwrap();
2472        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2473            .await
2474            .unwrap();
2475
2476        // select_model should have disabled thinking.
2477        agent.read_with(cx, |agent, _| {
2478            let session = agent.sessions.get(&session_id).unwrap();
2479            session.thread.read_with(cx, |thread, _| {
2480                assert!(
2481                    !thread.thinking_enabled(),
2482                    "select_model should disable thinking when model does not support it"
2483                );
2484            });
2485        });
2486    }
2487
2488    #[gpui::test]
2489    async fn test_summarization_model_survives_transient_registry_clearing(
2490        cx: &mut TestAppContext,
2491    ) {
2492        init_test(cx);
2493        let fs = FakeFs::new(cx.executor());
2494        fs.insert_tree("/", json!({ "a": {} })).await;
2495        let project = Project::test(fs.clone(), [], cx).await;
2496
2497        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2498        let agent =
2499            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2500        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2501
2502        let acp_thread = cx
2503            .update(|cx| {
2504                connection.clone().new_session(
2505                    project.clone(),
2506                    PathList::new(&[Path::new("/a")]),
2507                    cx,
2508                )
2509            })
2510            .await
2511            .unwrap();
2512        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2513
2514        let thread = agent.read_with(cx, |agent, _| {
2515            agent.sessions.get(&session_id).unwrap().thread.clone()
2516        });
2517
2518        thread.read_with(cx, |thread, _| {
2519            assert!(
2520                thread.summarization_model().is_some(),
2521                "session should have a summarization model from the test registry"
2522            );
2523        });
2524
2525        // Simulate what happens during a provider blip:
2526        // update_active_language_model_from_settings calls set_default_model(None)
2527        // when it can't resolve the model, clearing all fallbacks.
2528        cx.update(|cx| {
2529            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2530                registry.set_default_model(None, cx);
2531            });
2532        });
2533        cx.run_until_parked();
2534
2535        thread.read_with(cx, |thread, _| {
2536            assert!(
2537                thread.summarization_model().is_some(),
2538                "summarization model should survive a transient default model clearing"
2539            );
2540        });
2541    }
2542
2543    #[gpui::test]
2544    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2545        init_test(cx);
2546        let fs = FakeFs::new(cx.executor());
2547        fs.insert_tree("/", json!({ "a": {} })).await;
2548        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2549        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2550        let agent = cx.update(|cx| {
2551            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2552        });
2553        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2554
2555        // Register a thinking model.
2556        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2557            "fake-corp",
2558            "fake-thinking",
2559            "Fake Thinking",
2560            true,
2561        ));
2562        let thinking_provider = Arc::new(
2563            FakeLanguageModelProvider::new(
2564                LanguageModelProviderId::from("fake-corp".to_string()),
2565                LanguageModelProviderName::from("Fake Corp".to_string()),
2566            )
2567            .with_models(vec![thinking_model.clone()]),
2568        );
2569        cx.update(|cx| {
2570            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2571                registry.register_provider(thinking_provider, cx);
2572            });
2573        });
2574        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2575
2576        // Create a thread and select the thinking model.
2577        let acp_thread = cx
2578            .update(|cx| {
2579                connection.clone().new_session(
2580                    project.clone(),
2581                    PathList::new(&[Path::new("/a")]),
2582                    cx,
2583                )
2584            })
2585            .await
2586            .unwrap();
2587        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2588
2589        let selector = connection.model_selector(&session_id).unwrap();
2590        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2591            .await
2592            .unwrap();
2593
2594        // Verify thinking is enabled after selecting the thinking model.
2595        let thread = agent.read_with(cx, |agent, _| {
2596            agent.sessions.get(&session_id).unwrap().thread.clone()
2597        });
2598        thread.read_with(cx, |thread, _| {
2599            assert!(
2600                thread.thinking_enabled(),
2601                "thinking should be enabled after selecting thinking model"
2602            );
2603        });
2604
2605        // Send a message so the thread gets persisted.
2606        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2607        let send = cx.foreground_executor().spawn(send);
2608        cx.run_until_parked();
2609
2610        thinking_model.send_last_completion_stream_text_chunk("Response.");
2611        thinking_model.end_last_completion_stream();
2612
2613        send.await.unwrap();
2614        cx.run_until_parked();
2615
2616        // Close the session so it can be reloaded from disk.
2617        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2618            .await
2619            .unwrap();
2620        drop(thread);
2621        drop(acp_thread);
2622        agent.read_with(cx, |agent, _| {
2623            assert!(agent.sessions.is_empty());
2624        });
2625
2626        // Reload the thread and verify thinking_enabled is still true.
2627        let reloaded_acp_thread = agent
2628            .update(cx, |agent, cx| {
2629                agent.open_thread(session_id.clone(), project.clone(), cx)
2630            })
2631            .await
2632            .unwrap();
2633        let reloaded_thread = agent.read_with(cx, |agent, _| {
2634            agent.sessions.get(&session_id).unwrap().thread.clone()
2635        });
2636        reloaded_thread.read_with(cx, |thread, _| {
2637            assert!(
2638                thread.thinking_enabled(),
2639                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2640            );
2641        });
2642
2643        drop(reloaded_acp_thread);
2644    }
2645
2646    #[gpui::test]
2647    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2648        init_test(cx);
2649        let fs = FakeFs::new(cx.executor());
2650        fs.insert_tree("/", json!({ "a": {} })).await;
2651        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2652        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2653        let agent = cx.update(|cx| {
2654            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2655        });
2656        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2657
2658        // Register a model where id() != name(), like real Anthropic models
2659        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2660        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2661            "fake-corp",
2662            "custom-model-id",
2663            "Custom Model Display Name",
2664            false,
2665        ));
2666        let provider = Arc::new(
2667            FakeLanguageModelProvider::new(
2668                LanguageModelProviderId::from("fake-corp".to_string()),
2669                LanguageModelProviderName::from("Fake Corp".to_string()),
2670            )
2671            .with_models(vec![model.clone()]),
2672        );
2673        cx.update(|cx| {
2674            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2675                registry.register_provider(provider, cx);
2676            });
2677        });
2678        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2679
2680        // Create a thread and select the model.
2681        let acp_thread = cx
2682            .update(|cx| {
2683                connection.clone().new_session(
2684                    project.clone(),
2685                    PathList::new(&[Path::new("/a")]),
2686                    cx,
2687                )
2688            })
2689            .await
2690            .unwrap();
2691        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2692
2693        let selector = connection.model_selector(&session_id).unwrap();
2694        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2695            .await
2696            .unwrap();
2697
2698        let thread = agent.read_with(cx, |agent, _| {
2699            agent.sessions.get(&session_id).unwrap().thread.clone()
2700        });
2701        thread.read_with(cx, |thread, _| {
2702            assert_eq!(
2703                thread.model().unwrap().id().0.as_ref(),
2704                "custom-model-id",
2705                "model should be set before persisting"
2706            );
2707        });
2708
2709        // Send a message so the thread gets persisted.
2710        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2711        let send = cx.foreground_executor().spawn(send);
2712        cx.run_until_parked();
2713
2714        model.send_last_completion_stream_text_chunk("Response.");
2715        model.end_last_completion_stream();
2716
2717        send.await.unwrap();
2718        cx.run_until_parked();
2719
2720        // Close the session so it can be reloaded from disk.
2721        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2722            .await
2723            .unwrap();
2724        drop(thread);
2725        drop(acp_thread);
2726        agent.read_with(cx, |agent, _| {
2727            assert!(agent.sessions.is_empty());
2728        });
2729
2730        // Reload the thread and verify the model was preserved.
2731        let reloaded_acp_thread = agent
2732            .update(cx, |agent, cx| {
2733                agent.open_thread(session_id.clone(), project.clone(), cx)
2734            })
2735            .await
2736            .unwrap();
2737        let reloaded_thread = agent.read_with(cx, |agent, _| {
2738            agent.sessions.get(&session_id).unwrap().thread.clone()
2739        });
2740        reloaded_thread.read_with(cx, |thread, _| {
2741            let reloaded_model = thread
2742                .model()
2743                .expect("model should be present after reload");
2744            assert_eq!(
2745                reloaded_model.id().0.as_ref(),
2746                "custom-model-id",
2747                "reloaded thread should have the same model, not fall back to the default"
2748            );
2749        });
2750
2751        drop(reloaded_acp_thread);
2752    }
2753
2754    #[gpui::test]
2755    async fn test_save_load_thread(cx: &mut TestAppContext) {
2756        init_test(cx);
2757        let fs = FakeFs::new(cx.executor());
2758        fs.insert_tree(
2759            "/",
2760            json!({
2761                "a": {
2762                    "b.md": "Lorem"
2763                }
2764            }),
2765        )
2766        .await;
2767        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2768        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2769        let agent = cx.update(|cx| {
2770            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2771        });
2772        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2773
2774        let acp_thread = cx
2775            .update(|cx| {
2776                connection
2777                    .clone()
2778                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2779            })
2780            .await
2781            .unwrap();
2782        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2783        let thread = agent.read_with(cx, |agent, _| {
2784            agent.sessions.get(&session_id).unwrap().thread.clone()
2785        });
2786
2787        // Ensure empty threads are not saved, even if they get mutated.
2788        let model = Arc::new(FakeLanguageModel::default());
2789        let summary_model = Arc::new(FakeLanguageModel::default());
2790        thread.update(cx, |thread, cx| {
2791            thread.set_model(model.clone(), cx);
2792            thread.set_summarization_model(Some(summary_model.clone()), cx);
2793        });
2794        cx.run_until_parked();
2795        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2796
2797        let send = acp_thread.update(cx, |thread, cx| {
2798            thread.send(
2799                vec![
2800                    "What does ".into(),
2801                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2802                        "b.md",
2803                        MentionUri::File {
2804                            abs_path: path!("/a/b.md").into(),
2805                        }
2806                        .to_uri()
2807                        .to_string(),
2808                    )),
2809                    " mean?".into(),
2810                ],
2811                cx,
2812            )
2813        });
2814        let send = cx.foreground_executor().spawn(send);
2815        cx.run_until_parked();
2816
2817        model.send_last_completion_stream_text_chunk("Lorem.");
2818        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2819            language_model::TokenUsage {
2820                input_tokens: 150,
2821                output_tokens: 75,
2822                ..Default::default()
2823            },
2824        ));
2825        model.end_last_completion_stream();
2826        cx.run_until_parked();
2827        summary_model
2828            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2829        summary_model.end_last_completion_stream();
2830
2831        send.await.unwrap();
2832        let uri = MentionUri::File {
2833            abs_path: path!("/a/b.md").into(),
2834        }
2835        .to_uri();
2836        acp_thread.read_with(cx, |thread, cx| {
2837            assert_eq!(
2838                thread.to_markdown(cx),
2839                formatdoc! {"
2840                    ## User
2841
2842                    What does [@b.md]({uri}) mean?
2843
2844                    ## Assistant
2845
2846                    Lorem.
2847
2848                "}
2849            )
2850        });
2851
2852        cx.run_until_parked();
2853
2854        // Set a draft prompt with rich content blocks and scroll position
2855        // AFTER run_until_parked, so the only save that captures these
2856        // changes is the one performed by close_session itself.
2857        let draft_blocks = vec![
2858            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2859            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2860            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2861        ];
2862        acp_thread.update(cx, |thread, _cx| {
2863            thread.set_draft_prompt(Some(draft_blocks.clone()));
2864        });
2865        thread.update(cx, |thread, _cx| {
2866            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2867                item_ix: 5,
2868                offset_in_item: gpui::px(12.5),
2869            }));
2870        });
2871
2872        // Close the session so it can be reloaded from disk.
2873        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2874            .await
2875            .unwrap();
2876        drop(thread);
2877        drop(acp_thread);
2878        agent.read_with(cx, |agent, _| {
2879            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2880        });
2881
2882        // Ensure the thread can be reloaded from disk.
2883        assert_eq!(
2884            thread_entries(&thread_store, cx),
2885            vec![(
2886                session_id.clone(),
2887                format!("Explaining {}", path!("/a/b.md"))
2888            )]
2889        );
2890        let acp_thread = agent
2891            .update(cx, |agent, cx| {
2892                agent.open_thread(session_id.clone(), project.clone(), cx)
2893            })
2894            .await
2895            .unwrap();
2896        acp_thread.read_with(cx, |thread, cx| {
2897            assert_eq!(
2898                thread.to_markdown(cx),
2899                formatdoc! {"
2900                    ## User
2901
2902                    What does [@b.md]({uri}) mean?
2903
2904                    ## Assistant
2905
2906                    Lorem.
2907
2908                "}
2909            )
2910        });
2911
2912        // Ensure the draft prompt with rich content blocks survived the round-trip.
2913        acp_thread.read_with(cx, |thread, _| {
2914            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2915        });
2916
2917        // Ensure token usage survived the round-trip.
2918        acp_thread.read_with(cx, |thread, _| {
2919            let usage = thread
2920                .token_usage()
2921                .expect("token usage should be restored after reload");
2922            assert_eq!(usage.input_tokens, 150);
2923            assert_eq!(usage.output_tokens, 75);
2924        });
2925
2926        // Ensure scroll position survived the round-trip.
2927        acp_thread.read_with(cx, |thread, _| {
2928            let scroll = thread
2929                .ui_scroll_position()
2930                .expect("scroll position should be restored after reload");
2931            assert_eq!(scroll.item_ix, 5);
2932            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2933        });
2934    }
2935
2936    #[gpui::test]
2937    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2938        init_test(cx);
2939        let fs = FakeFs::new(cx.executor());
2940        fs.insert_tree(
2941            "/",
2942            json!({
2943                "a": {
2944                    "file.txt": "hello"
2945                }
2946            }),
2947        )
2948        .await;
2949        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2950        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2951        let agent = cx.update(|cx| {
2952            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2953        });
2954        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2955
2956        let acp_thread = cx
2957            .update(|cx| {
2958                connection
2959                    .clone()
2960                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2961            })
2962            .await
2963            .unwrap();
2964        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2965        let thread = agent.read_with(cx, |agent, _| {
2966            agent.sessions.get(&session_id).unwrap().thread.clone()
2967        });
2968
2969        let model = Arc::new(FakeLanguageModel::default());
2970        thread.update(cx, |thread, cx| {
2971            thread.set_model(model.clone(), cx);
2972        });
2973
2974        // Send a message so the thread is non-empty (empty threads aren't saved).
2975        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2976        let send = cx.foreground_executor().spawn(send);
2977        cx.run_until_parked();
2978
2979        model.send_last_completion_stream_text_chunk("world");
2980        model.end_last_completion_stream();
2981        send.await.unwrap();
2982        cx.run_until_parked();
2983
2984        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2985        // This means no observe-triggered save has run for this change.
2986        // The only way this data gets persisted is if close_session
2987        // itself performs the save.
2988        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2989            "unsaved draft",
2990        ))];
2991        acp_thread.update(cx, |thread, _cx| {
2992            thread.set_draft_prompt(Some(draft_blocks.clone()));
2993        });
2994
2995        // Close the session immediately — no run_until_parked in between.
2996        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2997            .await
2998            .unwrap();
2999        cx.run_until_parked();
3000
3001        // Reopen and verify the draft prompt was saved.
3002        let reloaded = agent
3003            .update(cx, |agent, cx| {
3004                agent.open_thread(session_id.clone(), project.clone(), cx)
3005            })
3006            .await
3007            .unwrap();
3008        reloaded.read_with(cx, |thread, _| {
3009            assert_eq!(
3010                thread.draft_prompt(),
3011                Some(draft_blocks.as_slice()),
3012                "close_session must save the thread; draft prompt was lost"
3013            );
3014        });
3015    }
3016
3017    #[gpui::test]
3018    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3019        // Regression test: rapid title changes must not cause a propagation loop
3020        // between Thread and AcpThread via handle_thread_title_updated.
3021        init_test(cx);
3022        let fs = FakeFs::new(cx.executor());
3023        fs.insert_tree("/", json!({ "a": {} })).await;
3024        let project = Project::test(fs.clone(), [], cx).await;
3025        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3026        let agent = cx.update(|cx| {
3027            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3028        });
3029        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3030
3031        let acp_thread = cx
3032            .update(|cx| {
3033                connection
3034                    .clone()
3035                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3036            })
3037            .await
3038            .unwrap();
3039
3040        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3041        let thread = agent.read_with(cx, |agent, _| {
3042            agent.sessions.get(&session_id).unwrap().thread.clone()
3043        });
3044
3045        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3046        cx.update(|cx| {
3047            let count = title_updated_count.clone();
3048            cx.subscribe(
3049                &thread,
3050                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3051                    let new_count = {
3052                        let mut count = count.borrow_mut();
3053                        *count += 1;
3054                        *count
3055                    };
3056                    assert!(
3057                        new_count <= 2,
3058                        "TitleUpdated fired {new_count} times; \
3059                         title updates are looping"
3060                    );
3061                },
3062            )
3063            .detach();
3064        });
3065
3066        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3067        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3068
3069        cx.run_until_parked();
3070
3071        thread.read_with(cx, |thread, _| {
3072            assert_eq!(thread.title(), Some("second".into()));
3073        });
3074        acp_thread.read_with(cx, |acp_thread, _| {
3075            assert_eq!(acp_thread.title(), Some("second".into()));
3076        });
3077
3078        assert_eq!(*title_updated_count.borrow(), 2);
3079    }
3080
3081    fn thread_entries(
3082        thread_store: &Entity<ThreadStore>,
3083        cx: &mut TestAppContext,
3084    ) -> Vec<(acp::SessionId, String)> {
3085        thread_store.read_with(cx, |store, _| {
3086            store
3087                .entries()
3088                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3089                .collect::<Vec<_>>()
3090        })
3091    }
3092
3093    fn init_test(cx: &mut TestAppContext) {
3094        env_logger::try_init().ok();
3095        cx.update(|cx| {
3096            let settings_store = SettingsStore::test(cx);
3097            cx.set_global(settings_store);
3098
3099            LanguageModelRegistry::test(cx);
3100        });
3101    }
3102}
3103
3104fn mcp_message_content_to_acp_content_block(
3105    content: context_server::types::MessageContent,
3106) -> acp::ContentBlock {
3107    match content {
3108        context_server::types::MessageContent::Text {
3109            text,
3110            annotations: _,
3111        } => text.into(),
3112        context_server::types::MessageContent::Image {
3113            data,
3114            mime_type,
3115            annotations: _,
3116        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3117        context_server::types::MessageContent::Audio {
3118            data,
3119            mime_type,
3120            annotations: _,
3121        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3122        context_server::types::MessageContent::Resource {
3123            resource,
3124            annotations: _,
3125        } => {
3126            let mut link =
3127                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3128            if let Some(mime_type) = resource.mime_type {
3129                link = link.mime_type(mime_type);
3130            }
3131            acp::ContentBlock::ResourceLink(link)
3132        }
3133    }
3134}