agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let session_id = thread.id().clone();
 332        let parent_session_id = thread.parent_thread_id();
 333        let title = thread.title();
 334        let draft_prompt = thread.draft_prompt().map(Vec::from);
 335        let scroll_position = thread.ui_scroll_position();
 336        let token_usage = thread.latest_token_usage();
 337        let project = thread.project.clone();
 338        let action_log = thread.action_log.clone();
 339        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 340        let acp_thread = cx.new(|cx| {
 341            let mut acp_thread = acp_thread::AcpThread::new(
 342                parent_session_id,
 343                title,
 344                None,
 345                connection,
 346                project.clone(),
 347                action_log.clone(),
 348                session_id.clone(),
 349                prompt_capabilities_rx,
 350                cx,
 351            );
 352            acp_thread.set_draft_prompt(draft_prompt);
 353            acp_thread.set_ui_scroll_position(scroll_position);
 354            acp_thread.update_token_usage(token_usage, cx);
 355            acp_thread
 356        });
 357
 358        let registry = LanguageModelRegistry::read_global(cx);
 359        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 360
 361        let weak = cx.weak_entity();
 362        let weak_thread = thread_handle.downgrade();
 363        thread_handle.update(cx, |thread, cx| {
 364            thread.set_summarization_model(summarization_model, cx);
 365            thread.add_default_tools(
 366                Rc::new(NativeThreadEnvironment {
 367                    acp_thread: acp_thread.downgrade(),
 368                    thread: weak_thread,
 369                    agent: weak,
 370                }) as _,
 371                cx,
 372            )
 373        });
 374
 375        let subscriptions = vec![
 376            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 377            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 378            cx.observe(&thread_handle, move |this, thread, cx| {
 379                this.save_thread(thread, cx)
 380            }),
 381        ];
 382
 383        self.sessions.insert(
 384            session_id,
 385            Session {
 386                thread: thread_handle,
 387                acp_thread: acp_thread.clone(),
 388                project_id,
 389                _subscriptions: subscriptions,
 390                pending_save: Task::ready(Ok(())),
 391            },
 392        );
 393
 394        self.update_available_commands_for_project(project_id, cx);
 395
 396        acp_thread
 397    }
 398
 399    pub fn models(&self) -> &LanguageModels {
 400        &self.models
 401    }
 402
 403    fn get_or_create_project_state(
 404        &mut self,
 405        project: &Entity<Project>,
 406        cx: &mut Context<Self>,
 407    ) -> EntityId {
 408        let project_id = project.entity_id();
 409        if self.projects.contains_key(&project_id) {
 410            return project_id;
 411        }
 412
 413        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 414        self.register_project_with_initial_context(project.clone(), project_context, cx);
 415        if let Some(state) = self.projects.get_mut(&project_id) {
 416            state.project_context_needs_refresh.send(()).ok();
 417        }
 418        project_id
 419    }
 420
 421    fn register_project_with_initial_context(
 422        &mut self,
 423        project: Entity<Project>,
 424        project_context: Entity<ProjectContext>,
 425        cx: &mut Context<Self>,
 426    ) {
 427        let project_id = project.entity_id();
 428
 429        let context_server_store = project.read(cx).context_server_store();
 430        let context_server_registry =
 431            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 432
 433        let subscriptions = vec![
 434            cx.subscribe(&project, Self::handle_project_event),
 435            cx.subscribe(
 436                &context_server_store,
 437                Self::handle_context_server_store_updated,
 438            ),
 439            cx.subscribe(
 440                &context_server_registry,
 441                Self::handle_context_server_registry_event,
 442            ),
 443        ];
 444
 445        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 446            watch::channel(());
 447
 448        self.projects.insert(
 449            project_id,
 450            ProjectState {
 451                project,
 452                project_context,
 453                project_context_needs_refresh: project_context_needs_refresh_tx,
 454                _maintain_project_context: cx.spawn(async move |this, cx| {
 455                    Self::maintain_project_context(
 456                        this,
 457                        project_id,
 458                        project_context_needs_refresh_rx,
 459                        cx,
 460                    )
 461                    .await
 462                }),
 463                context_server_registry,
 464                _subscriptions: subscriptions,
 465            },
 466        );
 467    }
 468
 469    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 470        self.sessions
 471            .get(session_id)
 472            .and_then(|session| self.projects.get(&session.project_id))
 473    }
 474
 475    async fn maintain_project_context(
 476        this: WeakEntity<Self>,
 477        project_id: EntityId,
 478        mut needs_refresh: watch::Receiver<()>,
 479        cx: &mut AsyncApp,
 480    ) -> Result<()> {
 481        while needs_refresh.changed().await.is_ok() {
 482            let project_context = this
 483                .update(cx, |this, cx| {
 484                    let state = this
 485                        .projects
 486                        .get(&project_id)
 487                        .context("project state not found")?;
 488                    anyhow::Ok(Self::build_project_context(
 489                        &state.project,
 490                        this.prompt_store.as_ref(),
 491                        cx,
 492                    ))
 493                })??
 494                .await;
 495            this.update(cx, |this, cx| {
 496                if let Some(state) = this.projects.get(&project_id) {
 497                    state
 498                        .project_context
 499                        .update(cx, |current_project_context, _cx| {
 500                            *current_project_context = project_context;
 501                        });
 502                }
 503            })?;
 504        }
 505
 506        Ok(())
 507    }
 508
 509    fn build_project_context(
 510        project: &Entity<Project>,
 511        prompt_store: Option<&Entity<PromptStore>>,
 512        cx: &mut App,
 513    ) -> Task<ProjectContext> {
 514        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 515        let worktree_tasks = worktrees
 516            .into_iter()
 517            .map(|worktree| {
 518                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 519            })
 520            .collect::<Vec<_>>();
 521        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 522            prompt_store.read_with(cx, |prompt_store, cx| {
 523                let prompts = prompt_store.default_prompt_metadata();
 524                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 525                    let contents = prompt_store.load(prompt_metadata.id, cx);
 526                    async move { (contents.await, prompt_metadata) }
 527                });
 528                cx.background_spawn(future::join_all(load_tasks))
 529            })
 530        } else {
 531            Task::ready(vec![])
 532        };
 533
 534        cx.spawn(async move |_cx| {
 535            let (worktrees, default_user_rules) =
 536                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 537
 538            let worktrees = worktrees
 539                .into_iter()
 540                .map(|(worktree, _rules_error)| {
 541                    // TODO: show error message
 542                    // if let Some(rules_error) = rules_error {
 543                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 544                    // }
 545                    worktree
 546                })
 547                .collect::<Vec<_>>();
 548
 549            let default_user_rules = default_user_rules
 550                .into_iter()
 551                .flat_map(|(contents, prompt_metadata)| match contents {
 552                    Ok(contents) => Some(UserRulesContext {
 553                        uuid: prompt_metadata.id.as_user()?,
 554                        title: prompt_metadata.title.map(|title| title.to_string()),
 555                        contents,
 556                    }),
 557                    Err(_err) => {
 558                        // TODO: show error message
 559                        // this.update(cx, |_, cx| {
 560                        //     cx.emit(RulesLoadingError {
 561                        //         message: format!("{err:?}").into(),
 562                        //     });
 563                        // })
 564                        // .ok();
 565                        None
 566                    }
 567                })
 568                .collect::<Vec<_>>();
 569
 570            ProjectContext::new(worktrees, default_user_rules)
 571        })
 572    }
 573
 574    fn load_worktree_info_for_system_prompt(
 575        worktree: Entity<Worktree>,
 576        project: Entity<Project>,
 577        cx: &mut App,
 578    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 579        let tree = worktree.read(cx);
 580        let root_name = tree.root_name_str().into();
 581        let abs_path = tree.abs_path();
 582
 583        let mut context = WorktreeContext {
 584            root_name,
 585            abs_path,
 586            rules_file: None,
 587        };
 588
 589        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 590        let Some(rules_task) = rules_task else {
 591            return Task::ready((context, None));
 592        };
 593
 594        cx.spawn(async move |_| {
 595            let (rules_file, rules_file_error) = match rules_task.await {
 596                Ok(rules_file) => (Some(rules_file), None),
 597                Err(err) => (
 598                    None,
 599                    Some(RulesLoadingError {
 600                        message: format!("{err}").into(),
 601                    }),
 602                ),
 603            };
 604            context.rules_file = rules_file;
 605            (context, rules_file_error)
 606        })
 607    }
 608
 609    fn load_worktree_rules_file(
 610        worktree: Entity<Worktree>,
 611        project: Entity<Project>,
 612        cx: &mut App,
 613    ) -> Option<Task<Result<RulesFileContext>>> {
 614        let worktree = worktree.read(cx);
 615        let worktree_id = worktree.id();
 616        let selected_rules_file = RULES_FILE_NAMES
 617            .into_iter()
 618            .filter_map(|name| {
 619                worktree
 620                    .entry_for_path(RelPath::unix(name).unwrap())
 621                    .filter(|entry| entry.is_file())
 622                    .map(|entry| entry.path.clone())
 623            })
 624            .next();
 625
 626        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 627        // supported. This doesn't seem to occur often in GitHub repositories.
 628        selected_rules_file.map(|path_in_worktree| {
 629            let project_path = ProjectPath {
 630                worktree_id,
 631                path: path_in_worktree.clone(),
 632            };
 633            let buffer_task =
 634                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 635            let rope_task = cx.spawn(async move |cx| {
 636                let buffer = buffer_task.await?;
 637                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 638                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 639                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 640                })?;
 641                anyhow::Ok((project_entry_id, rope))
 642            });
 643            // Build a string from the rope on a background thread.
 644            cx.background_spawn(async move {
 645                let (project_entry_id, rope) = rope_task.await?;
 646                anyhow::Ok(RulesFileContext {
 647                    path_in_worktree,
 648                    text: rope.to_string().trim().to_string(),
 649                    project_entry_id: project_entry_id.to_usize(),
 650                })
 651            })
 652        })
 653    }
 654
 655    fn handle_thread_title_updated(
 656        &mut self,
 657        thread: Entity<Thread>,
 658        _: &TitleUpdated,
 659        cx: &mut Context<Self>,
 660    ) {
 661        let session_id = thread.read(cx).id();
 662        let Some(session) = self.sessions.get(session_id) else {
 663            return;
 664        };
 665
 666        if let Some(title) = thread.read(cx).title() {
 667            let acp_thread = session.acp_thread.downgrade();
 668            cx.spawn(async move |_, cx| {
 669                let task =
 670                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 671                task.await
 672            })
 673            .detach_and_log_err(cx);
 674        }
 675    }
 676
 677    fn handle_thread_token_usage_updated(
 678        &mut self,
 679        thread: Entity<Thread>,
 680        usage: &TokenUsageUpdated,
 681        cx: &mut Context<Self>,
 682    ) {
 683        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 684            return;
 685        };
 686        session.acp_thread.update(cx, |acp_thread, cx| {
 687            acp_thread.update_token_usage(usage.0.clone(), cx);
 688        });
 689    }
 690
 691    fn handle_project_event(
 692        &mut self,
 693        project: Entity<Project>,
 694        event: &project::Event,
 695        _cx: &mut Context<Self>,
 696    ) {
 697        let project_id = project.entity_id();
 698        let Some(state) = self.projects.get_mut(&project_id) else {
 699            return;
 700        };
 701        match event {
 702            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 703                state.project_context_needs_refresh.send(()).ok();
 704            }
 705            project::Event::WorktreeUpdatedEntries(_, items) => {
 706                if items.iter().any(|(path, _, _)| {
 707                    RULES_FILE_NAMES
 708                        .iter()
 709                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 710                }) {
 711                    state.project_context_needs_refresh.send(()).ok();
 712                }
 713            }
 714            _ => {}
 715        }
 716    }
 717
 718    fn handle_prompts_updated_event(
 719        &mut self,
 720        _prompt_store: Entity<PromptStore>,
 721        _event: &prompt_store::PromptsUpdatedEvent,
 722        _cx: &mut Context<Self>,
 723    ) {
 724        for state in self.projects.values_mut() {
 725            state.project_context_needs_refresh.send(()).ok();
 726        }
 727    }
 728
 729    fn handle_models_updated_event(
 730        &mut self,
 731        _registry: Entity<LanguageModelRegistry>,
 732        event: &language_model::Event,
 733        cx: &mut Context<Self>,
 734    ) {
 735        self.models.refresh_list(cx);
 736
 737        let registry = LanguageModelRegistry::read_global(cx);
 738        let default_model = registry.default_model().map(|m| m.model);
 739        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 740
 741        for session in self.sessions.values_mut() {
 742            session.thread.update(cx, |thread, cx| {
 743                if thread.model().is_none()
 744                    && let Some(model) = default_model.clone()
 745                {
 746                    thread.set_model(model, cx);
 747                    cx.notify();
 748                }
 749                if let Some(model) = summarization_model.clone() {
 750                    if thread.summarization_model().is_none()
 751                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 752                    {
 753                        thread.set_summarization_model(Some(model), cx);
 754                    }
 755                }
 756            });
 757        }
 758    }
 759
 760    fn handle_context_server_store_updated(
 761        &mut self,
 762        store: Entity<project::context_server_store::ContextServerStore>,
 763        _event: &project::context_server_store::ServerStatusChangedEvent,
 764        cx: &mut Context<Self>,
 765    ) {
 766        let project_id = self.projects.iter().find_map(|(id, state)| {
 767            if *state.context_server_registry.read(cx).server_store() == store {
 768                Some(*id)
 769            } else {
 770                None
 771            }
 772        });
 773        if let Some(project_id) = project_id {
 774            self.update_available_commands_for_project(project_id, cx);
 775        }
 776    }
 777
 778    fn handle_context_server_registry_event(
 779        &mut self,
 780        registry: Entity<ContextServerRegistry>,
 781        event: &ContextServerRegistryEvent,
 782        cx: &mut Context<Self>,
 783    ) {
 784        match event {
 785            ContextServerRegistryEvent::ToolsChanged => {}
 786            ContextServerRegistryEvent::PromptsChanged => {
 787                let project_id = self.projects.iter().find_map(|(id, state)| {
 788                    if state.context_server_registry == registry {
 789                        Some(*id)
 790                    } else {
 791                        None
 792                    }
 793                });
 794                if let Some(project_id) = project_id {
 795                    self.update_available_commands_for_project(project_id, cx);
 796                }
 797            }
 798        }
 799    }
 800
 801    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 802        let available_commands =
 803            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 804        for session in self.sessions.values() {
 805            if session.project_id != project_id {
 806                continue;
 807            }
 808            session.acp_thread.update(cx, |thread, cx| {
 809                thread
 810                    .handle_session_update(
 811                        acp::SessionUpdate::AvailableCommandsUpdate(
 812                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 813                        ),
 814                        cx,
 815                    )
 816                    .log_err();
 817            });
 818        }
 819    }
 820
 821    fn build_available_commands_for_project(
 822        project_state: Option<&ProjectState>,
 823        cx: &App,
 824    ) -> Vec<acp::AvailableCommand> {
 825        let Some(state) = project_state else {
 826            return vec![];
 827        };
 828        let registry = state.context_server_registry.read(cx);
 829
 830        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 831        for context_server_prompt in registry.prompts() {
 832            *prompt_name_counts
 833                .entry(context_server_prompt.prompt.name.as_str())
 834                .or_insert(0) += 1;
 835        }
 836
 837        registry
 838            .prompts()
 839            .flat_map(|context_server_prompt| {
 840                let prompt = &context_server_prompt.prompt;
 841
 842                let should_prefix = prompt_name_counts
 843                    .get(prompt.name.as_str())
 844                    .copied()
 845                    .unwrap_or(0)
 846                    > 1;
 847
 848                let name = if should_prefix {
 849                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 850                } else {
 851                    prompt.name.clone()
 852                };
 853
 854                let mut command = acp::AvailableCommand::new(
 855                    name,
 856                    prompt.description.clone().unwrap_or_default(),
 857                );
 858
 859                match prompt.arguments.as_deref() {
 860                    Some([arg]) => {
 861                        let hint = format!("<{}>", arg.name);
 862
 863                        command = command.input(acp::AvailableCommandInput::Unstructured(
 864                            acp::UnstructuredCommandInput::new(hint),
 865                        ));
 866                    }
 867                    Some([]) | None => {}
 868                    Some(_) => {
 869                        // skip >1 argument commands since we don't support them yet
 870                        return None;
 871                    }
 872                }
 873
 874                Some(command)
 875            })
 876            .collect()
 877    }
 878
 879    pub fn load_thread(
 880        &mut self,
 881        id: acp::SessionId,
 882        project: Entity<Project>,
 883        cx: &mut Context<Self>,
 884    ) -> Task<Result<Entity<Thread>>> {
 885        let database_future = ThreadsDatabase::connect(cx);
 886        cx.spawn(async move |this, cx| {
 887            let database = database_future.await.map_err(|err| anyhow!(err))?;
 888            let db_thread = database
 889                .load_thread(id.clone())
 890                .await?
 891                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 892
 893            this.update(cx, |this, cx| {
 894                let project_id = this.get_or_create_project_state(&project, cx);
 895                let project_state = this
 896                    .projects
 897                    .get(&project_id)
 898                    .context("project state not found")?;
 899                let summarization_model = LanguageModelRegistry::read_global(cx)
 900                    .thread_summary_model()
 901                    .map(|c| c.model);
 902
 903                Ok(cx.new(|cx| {
 904                    let mut thread = Thread::from_db(
 905                        id.clone(),
 906                        db_thread,
 907                        project_state.project.clone(),
 908                        project_state.project_context.clone(),
 909                        project_state.context_server_registry.clone(),
 910                        this.templates.clone(),
 911                        cx,
 912                    );
 913                    thread.set_summarization_model(summarization_model, cx);
 914                    thread
 915                }))
 916            })?
 917        })
 918    }
 919
 920    pub fn open_thread(
 921        &mut self,
 922        id: acp::SessionId,
 923        project: Entity<Project>,
 924        cx: &mut Context<Self>,
 925    ) -> Task<Result<Entity<AcpThread>>> {
 926        if let Some(session) = self.sessions.get(&id) {
 927            return Task::ready(Ok(session.acp_thread.clone()));
 928        }
 929
 930        let task = self.load_thread(id, project.clone(), cx);
 931        cx.spawn(async move |this, cx| {
 932            let thread = task.await?;
 933            let acp_thread = this.update(cx, |this, cx| {
 934                let project_id = this.get_or_create_project_state(&project, cx);
 935                this.register_session(thread.clone(), project_id, cx)
 936            })?;
 937            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 938            cx.update(|cx| {
 939                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 940            })
 941            .await?;
 942            Ok(acp_thread)
 943        })
 944    }
 945
 946    pub fn thread_summary(
 947        &mut self,
 948        id: acp::SessionId,
 949        project: Entity<Project>,
 950        cx: &mut Context<Self>,
 951    ) -> Task<Result<SharedString>> {
 952        let thread = self.open_thread(id.clone(), project, cx);
 953        cx.spawn(async move |this, cx| {
 954            let acp_thread = thread.await?;
 955            let result = this
 956                .update(cx, |this, cx| {
 957                    this.sessions
 958                        .get(&id)
 959                        .unwrap()
 960                        .thread
 961                        .update(cx, |thread, cx| thread.summary(cx))
 962                })?
 963                .await
 964                .context("Failed to generate summary")?;
 965            drop(acp_thread);
 966            Ok(result)
 967        })
 968    }
 969
 970    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 971        if thread.read(cx).is_empty() {
 972            return;
 973        }
 974
 975        let id = thread.read(cx).id().clone();
 976        let Some(session) = self.sessions.get_mut(&id) else {
 977            return;
 978        };
 979
 980        let project_id = session.project_id;
 981        let Some(state) = self.projects.get(&project_id) else {
 982            return;
 983        };
 984
 985        let folder_paths = PathList::new(
 986            &state
 987                .project
 988                .read(cx)
 989                .visible_worktrees(cx)
 990                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 991                .collect::<Vec<_>>(),
 992        );
 993
 994        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 995        let database_future = ThreadsDatabase::connect(cx);
 996        let db_thread = thread.update(cx, |thread, cx| {
 997            thread.set_draft_prompt(draft_prompt);
 998            thread.to_db(cx)
 999        });
1000        let thread_store = self.thread_store.clone();
1001        session.pending_save = cx.spawn(async move |_, cx| {
1002            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1003                return Ok(());
1004            };
1005            let db_thread = db_thread.await;
1006            database
1007                .save_thread(id, db_thread, folder_paths)
1008                .await
1009                .log_err();
1010            thread_store.update(cx, |store, cx| store.reload(cx));
1011            Ok(())
1012        });
1013    }
1014
1015    fn send_mcp_prompt(
1016        &self,
1017        message_id: UserMessageId,
1018        session_id: acp::SessionId,
1019        prompt_name: String,
1020        server_id: ContextServerId,
1021        arguments: HashMap<String, String>,
1022        original_content: Vec<acp::ContentBlock>,
1023        cx: &mut Context<Self>,
1024    ) -> Task<Result<acp::PromptResponse>> {
1025        let Some(state) = self.session_project_state(&session_id) else {
1026            return Task::ready(Err(anyhow!("Project state not found for session")));
1027        };
1028        let server_store = state
1029            .context_server_registry
1030            .read(cx)
1031            .server_store()
1032            .clone();
1033        let path_style = state.project.read(cx).path_style(cx);
1034
1035        cx.spawn(async move |this, cx| {
1036            let prompt =
1037                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1038
1039            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1040                let session = this
1041                    .sessions
1042                    .get(&session_id)
1043                    .context("Failed to get session")?;
1044                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1045            })??;
1046
1047            let mut last_is_user = true;
1048
1049            thread.update(cx, |thread, cx| {
1050                thread.push_acp_user_block(
1051                    message_id,
1052                    original_content.into_iter().skip(1),
1053                    path_style,
1054                    cx,
1055                );
1056            });
1057
1058            for message in prompt.messages {
1059                let context_server::types::PromptMessage { role, content } = message;
1060                let block = mcp_message_content_to_acp_content_block(content);
1061
1062                match role {
1063                    context_server::types::Role::User => {
1064                        let id = acp_thread::UserMessageId::new();
1065
1066                        acp_thread.update(cx, |acp_thread, cx| {
1067                            acp_thread.push_user_content_block_with_indent(
1068                                Some(id.clone()),
1069                                block.clone(),
1070                                true,
1071                                cx,
1072                            );
1073                        });
1074
1075                        thread.update(cx, |thread, cx| {
1076                            thread.push_acp_user_block(id, [block], path_style, cx);
1077                        });
1078                    }
1079                    context_server::types::Role::Assistant => {
1080                        acp_thread.update(cx, |acp_thread, cx| {
1081                            acp_thread.push_assistant_content_block_with_indent(
1082                                block.clone(),
1083                                false,
1084                                true,
1085                                cx,
1086                            );
1087                        });
1088
1089                        thread.update(cx, |thread, cx| {
1090                            thread.push_acp_agent_block(block, cx);
1091                        });
1092                    }
1093                }
1094
1095                last_is_user = role == context_server::types::Role::User;
1096            }
1097
1098            let response_stream = thread.update(cx, |thread, cx| {
1099                if last_is_user {
1100                    thread.send_existing(cx)
1101                } else {
1102                    // Resume if MCP prompt did not end with a user message
1103                    thread.resume(cx)
1104                }
1105            })?;
1106
1107            cx.update(|cx| {
1108                NativeAgentConnection::handle_thread_events(
1109                    response_stream,
1110                    acp_thread.downgrade(),
1111                    cx,
1112                )
1113            })
1114            .await
1115        })
1116    }
1117}
1118
1119/// Wrapper struct that implements the AgentConnection trait
1120#[derive(Clone)]
1121pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1122
1123impl NativeAgentConnection {
1124    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1125        self.0
1126            .read(cx)
1127            .sessions
1128            .get(session_id)
1129            .map(|session| session.thread.clone())
1130    }
1131
1132    pub fn load_thread(
1133        &self,
1134        id: acp::SessionId,
1135        project: Entity<Project>,
1136        cx: &mut App,
1137    ) -> Task<Result<Entity<Thread>>> {
1138        self.0
1139            .update(cx, |this, cx| this.load_thread(id, project, cx))
1140    }
1141
1142    fn run_turn(
1143        &self,
1144        session_id: acp::SessionId,
1145        cx: &mut App,
1146        f: impl 'static
1147        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1148    ) -> Task<Result<acp::PromptResponse>> {
1149        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1150            agent
1151                .sessions
1152                .get_mut(&session_id)
1153                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1154        }) else {
1155            return Task::ready(Err(anyhow!("Session not found")));
1156        };
1157        log::debug!("Found session for: {}", session_id);
1158
1159        let response_stream = match f(thread, cx) {
1160            Ok(stream) => stream,
1161            Err(err) => return Task::ready(Err(err)),
1162        };
1163        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1164    }
1165
1166    fn handle_thread_events(
1167        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1168        acp_thread: WeakEntity<AcpThread>,
1169        cx: &App,
1170    ) -> Task<Result<acp::PromptResponse>> {
1171        cx.spawn(async move |cx| {
1172            // Handle response stream and forward to session.acp_thread
1173            while let Some(result) = events.next().await {
1174                match result {
1175                    Ok(event) => {
1176                        log::trace!("Received completion event: {:?}", event);
1177
1178                        match event {
1179                            ThreadEvent::UserMessage(message) => {
1180                                acp_thread.update(cx, |thread, cx| {
1181                                    for content in message.content {
1182                                        thread.push_user_content_block(
1183                                            Some(message.id.clone()),
1184                                            content.into(),
1185                                            cx,
1186                                        );
1187                                    }
1188                                })?;
1189                            }
1190                            ThreadEvent::AgentText(text) => {
1191                                acp_thread.update(cx, |thread, cx| {
1192                                    thread.push_assistant_content_block(text.into(), false, cx)
1193                                })?;
1194                            }
1195                            ThreadEvent::AgentThinking(text) => {
1196                                acp_thread.update(cx, |thread, cx| {
1197                                    thread.push_assistant_content_block(text.into(), true, cx)
1198                                })?;
1199                            }
1200                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1201                                tool_call,
1202                                options,
1203                                response,
1204                                context: _,
1205                            }) => {
1206                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1207                                    thread.request_tool_call_authorization(tool_call, options, cx)
1208                                })??;
1209                                cx.background_spawn(async move {
1210                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1211                                        outcome_task.await
1212                                    {
1213                                        response
1214                                            .send(outcome)
1215                                            .map(|_| anyhow!("authorization receiver was dropped"))
1216                                            .log_err();
1217                                    }
1218                                })
1219                                .detach();
1220                            }
1221                            ThreadEvent::ToolCall(tool_call) => {
1222                                acp_thread.update(cx, |thread, cx| {
1223                                    thread.upsert_tool_call(tool_call, cx)
1224                                })??;
1225                            }
1226                            ThreadEvent::ToolCallUpdate(update) => {
1227                                acp_thread.update(cx, |thread, cx| {
1228                                    thread.update_tool_call(update, cx)
1229                                })??;
1230                            }
1231                            ThreadEvent::Plan(plan) => {
1232                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1233                            }
1234                            ThreadEvent::SubagentSpawned(session_id) => {
1235                                acp_thread.update(cx, |thread, cx| {
1236                                    thread.subagent_spawned(session_id, cx);
1237                                })?;
1238                            }
1239                            ThreadEvent::Retry(status) => {
1240                                acp_thread.update(cx, |thread, cx| {
1241                                    thread.update_retry_status(status, cx)
1242                                })?;
1243                            }
1244                            ThreadEvent::Stop(stop_reason) => {
1245                                log::debug!("Assistant message complete: {:?}", stop_reason);
1246                                return Ok(acp::PromptResponse::new(stop_reason));
1247                            }
1248                        }
1249                    }
1250                    Err(e) => {
1251                        log::error!("Error in model response stream: {:?}", e);
1252                        return Err(e);
1253                    }
1254                }
1255            }
1256
1257            log::debug!("Response stream completed");
1258            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1259        })
1260    }
1261}
1262
1263struct Command<'a> {
1264    prompt_name: &'a str,
1265    arg_value: &'a str,
1266    explicit_server_id: Option<&'a str>,
1267}
1268
1269impl<'a> Command<'a> {
1270    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1271        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1272            return None;
1273        };
1274        let text = text_content.text.trim();
1275        let command = text.strip_prefix('/')?;
1276        let (command, arg_value) = command
1277            .split_once(char::is_whitespace)
1278            .unwrap_or((command, ""));
1279
1280        if let Some((server_id, prompt_name)) = command.split_once('.') {
1281            Some(Self {
1282                prompt_name,
1283                arg_value,
1284                explicit_server_id: Some(server_id),
1285            })
1286        } else {
1287            Some(Self {
1288                prompt_name: command,
1289                arg_value,
1290                explicit_server_id: None,
1291            })
1292        }
1293    }
1294}
1295
1296struct NativeAgentModelSelector {
1297    session_id: acp::SessionId,
1298    connection: NativeAgentConnection,
1299}
1300
1301impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1302    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1303        log::debug!("NativeAgentConnection::list_models called");
1304        let list = self.connection.0.read(cx).models.model_list.clone();
1305        Task::ready(if list.is_empty() {
1306            Err(anyhow::anyhow!("No models available"))
1307        } else {
1308            Ok(list)
1309        })
1310    }
1311
1312    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1313        log::debug!(
1314            "Setting model for session {}: {}",
1315            self.session_id,
1316            model_id
1317        );
1318        let Some(thread) = self
1319            .connection
1320            .0
1321            .read(cx)
1322            .sessions
1323            .get(&self.session_id)
1324            .map(|session| session.thread.clone())
1325        else {
1326            return Task::ready(Err(anyhow!("Session not found")));
1327        };
1328
1329        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1330            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1331        };
1332
1333        // We want to reset the effort level when switching models, as the currently-selected effort level may
1334        // not be compatible.
1335        let effort = model
1336            .default_effort_level()
1337            .map(|effort_level| effort_level.value.to_string());
1338
1339        thread.update(cx, |thread, cx| {
1340            thread.set_model(model.clone(), cx);
1341            thread.set_thinking_effort(effort.clone(), cx);
1342            thread.set_thinking_enabled(model.supports_thinking(), cx);
1343        });
1344
1345        update_settings_file(
1346            self.connection.0.read(cx).fs.clone(),
1347            cx,
1348            move |settings, cx| {
1349                let provider = model.provider_id().0.to_string();
1350                let model = model.id().0.to_string();
1351                let enable_thinking = thread.read(cx).thinking_enabled();
1352                settings
1353                    .agent
1354                    .get_or_insert_default()
1355                    .set_model(LanguageModelSelection {
1356                        provider: provider.into(),
1357                        model,
1358                        enable_thinking,
1359                        effort,
1360                    });
1361            },
1362        );
1363
1364        Task::ready(Ok(()))
1365    }
1366
1367    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1368        let Some(thread) = self
1369            .connection
1370            .0
1371            .read(cx)
1372            .sessions
1373            .get(&self.session_id)
1374            .map(|session| session.thread.clone())
1375        else {
1376            return Task::ready(Err(anyhow!("Session not found")));
1377        };
1378        let Some(model) = thread.read(cx).model() else {
1379            return Task::ready(Err(anyhow!("Model not found")));
1380        };
1381        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1382        else {
1383            return Task::ready(Err(anyhow!("Provider not found")));
1384        };
1385        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1386            model, &provider,
1387        )))
1388    }
1389
1390    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1391        Some(self.connection.0.read(cx).models.watch())
1392    }
1393
1394    fn should_render_footer(&self) -> bool {
1395        true
1396    }
1397}
1398
1399pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1400
1401impl acp_thread::AgentConnection for NativeAgentConnection {
1402    fn agent_id(&self) -> AgentId {
1403        ZED_AGENT_ID.clone()
1404    }
1405
1406    fn telemetry_id(&self) -> SharedString {
1407        "zed".into()
1408    }
1409
1410    fn new_session(
1411        self: Rc<Self>,
1412        project: Entity<Project>,
1413        work_dirs: PathList,
1414        cx: &mut App,
1415    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1416        log::debug!("Creating new thread for project at: {work_dirs:?}");
1417        Task::ready(Ok(self
1418            .0
1419            .update(cx, |agent, cx| agent.new_session(project, cx))))
1420    }
1421
1422    fn supports_load_session(&self) -> bool {
1423        true
1424    }
1425
1426    fn load_session(
1427        self: Rc<Self>,
1428        session_id: acp::SessionId,
1429        project: Entity<Project>,
1430        _work_dirs: PathList,
1431        _title: Option<SharedString>,
1432        cx: &mut App,
1433    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1434        self.0
1435            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1436    }
1437
1438    fn supports_close_session(&self) -> bool {
1439        true
1440    }
1441
1442    fn close_session(
1443        self: Rc<Self>,
1444        session_id: &acp::SessionId,
1445        cx: &mut App,
1446    ) -> Task<Result<()>> {
1447        self.0.update(cx, |agent, cx| {
1448            let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
1449            if let Some(thread) = thread {
1450                agent.save_thread(thread, cx);
1451            }
1452
1453            let Some(session) = agent.sessions.remove(session_id) else {
1454                return Task::ready(Ok(()));
1455            };
1456            let project_id = session.project_id;
1457
1458            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1459            if !has_remaining {
1460                agent.projects.remove(&project_id);
1461            }
1462
1463            session.pending_save
1464        })
1465    }
1466
1467    fn auth_methods(&self) -> &[acp::AuthMethod] {
1468        &[] // No auth for in-process
1469    }
1470
1471    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1472        Task::ready(Ok(()))
1473    }
1474
1475    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1476        Some(Rc::new(NativeAgentModelSelector {
1477            session_id: session_id.clone(),
1478            connection: self.clone(),
1479        }) as Rc<dyn AgentModelSelector>)
1480    }
1481
1482    fn prompt(
1483        &self,
1484        id: Option<acp_thread::UserMessageId>,
1485        params: acp::PromptRequest,
1486        cx: &mut App,
1487    ) -> Task<Result<acp::PromptResponse>> {
1488        let id = id.expect("UserMessageId is required");
1489        let session_id = params.session_id.clone();
1490        log::info!("Received prompt request for session: {}", session_id);
1491        log::debug!("Prompt blocks count: {}", params.prompt.len());
1492
1493        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1494            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1495        };
1496
1497        if let Some(parsed_command) = Command::parse(&params.prompt) {
1498            let registry = project_state.context_server_registry.read(cx);
1499
1500            let explicit_server_id = parsed_command
1501                .explicit_server_id
1502                .map(|server_id| ContextServerId(server_id.into()));
1503
1504            if let Some(prompt) =
1505                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1506            {
1507                let arguments = if !parsed_command.arg_value.is_empty()
1508                    && let Some(arg_name) = prompt
1509                        .prompt
1510                        .arguments
1511                        .as_ref()
1512                        .and_then(|args| args.first())
1513                        .map(|arg| arg.name.clone())
1514                {
1515                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1516                } else {
1517                    Default::default()
1518                };
1519
1520                let prompt_name = prompt.prompt.name.clone();
1521                let server_id = prompt.server_id.clone();
1522
1523                return self.0.update(cx, |agent, cx| {
1524                    agent.send_mcp_prompt(
1525                        id,
1526                        session_id.clone(),
1527                        prompt_name,
1528                        server_id,
1529                        arguments,
1530                        params.prompt,
1531                        cx,
1532                    )
1533                });
1534            }
1535        };
1536
1537        let path_style = project_state.project.read(cx).path_style(cx);
1538
1539        self.run_turn(session_id, cx, move |thread, cx| {
1540            let content: Vec<UserMessageContent> = params
1541                .prompt
1542                .into_iter()
1543                .map(|block| UserMessageContent::from_content_block(block, path_style))
1544                .collect::<Vec<_>>();
1545            log::debug!("Converted prompt to message: {} chars", content.len());
1546            log::debug!("Message id: {:?}", id);
1547            log::debug!("Message content: {:?}", content);
1548
1549            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1550        })
1551    }
1552
1553    fn retry(
1554        &self,
1555        session_id: &acp::SessionId,
1556        _cx: &App,
1557    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1558        Some(Rc::new(NativeAgentSessionRetry {
1559            connection: self.clone(),
1560            session_id: session_id.clone(),
1561        }) as _)
1562    }
1563
1564    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1565        log::info!("Cancelling on session: {}", session_id);
1566        self.0.update(cx, |agent, cx| {
1567            if let Some(session) = agent.sessions.get(session_id) {
1568                session
1569                    .thread
1570                    .update(cx, |thread, cx| thread.cancel(cx))
1571                    .detach();
1572            }
1573        });
1574    }
1575
1576    fn truncate(
1577        &self,
1578        session_id: &acp::SessionId,
1579        cx: &App,
1580    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1581        self.0.read_with(cx, |agent, _cx| {
1582            agent.sessions.get(session_id).map(|session| {
1583                Rc::new(NativeAgentSessionTruncate {
1584                    thread: session.thread.clone(),
1585                    acp_thread: session.acp_thread.downgrade(),
1586                }) as _
1587            })
1588        })
1589    }
1590
1591    fn set_title(
1592        &self,
1593        session_id: &acp::SessionId,
1594        cx: &App,
1595    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1596        self.0.read_with(cx, |agent, _cx| {
1597            agent
1598                .sessions
1599                .get(session_id)
1600                .filter(|s| !s.thread.read(cx).is_subagent())
1601                .map(|session| {
1602                    Rc::new(NativeAgentSessionSetTitle {
1603                        thread: session.thread.clone(),
1604                    }) as _
1605                })
1606        })
1607    }
1608
1609    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1610        let thread_store = self.0.read(cx).thread_store.clone();
1611        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1612    }
1613
1614    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1615        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1616    }
1617
1618    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1619        self
1620    }
1621}
1622
1623impl acp_thread::AgentTelemetry for NativeAgentConnection {
1624    fn thread_data(
1625        &self,
1626        session_id: &acp::SessionId,
1627        cx: &mut App,
1628    ) -> Task<Result<serde_json::Value>> {
1629        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1630            return Task::ready(Err(anyhow!("Session not found")));
1631        };
1632
1633        let task = session.thread.read(cx).to_db(cx);
1634        cx.background_spawn(async move {
1635            serde_json::to_value(task.await).context("Failed to serialize thread")
1636        })
1637    }
1638}
1639
1640pub struct NativeAgentSessionList {
1641    thread_store: Entity<ThreadStore>,
1642    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1643    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1644    _subscription: Subscription,
1645}
1646
1647impl NativeAgentSessionList {
1648    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1649        let (tx, rx) = smol::channel::unbounded();
1650        let this_tx = tx.clone();
1651        let subscription = cx.observe(&thread_store, move |_, _| {
1652            this_tx
1653                .try_send(acp_thread::SessionListUpdate::Refresh)
1654                .ok();
1655        });
1656        Self {
1657            thread_store,
1658            updates_tx: tx,
1659            updates_rx: rx,
1660            _subscription: subscription,
1661        }
1662    }
1663
1664    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1665        &self.thread_store
1666    }
1667}
1668
1669impl AgentSessionList for NativeAgentSessionList {
1670    fn list_sessions(
1671        &self,
1672        _request: AgentSessionListRequest,
1673        cx: &mut App,
1674    ) -> Task<Result<AgentSessionListResponse>> {
1675        let sessions = self
1676            .thread_store
1677            .read(cx)
1678            .entries()
1679            .map(|entry| AgentSessionInfo::from(&entry))
1680            .collect();
1681        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1682    }
1683
1684    fn supports_delete(&self) -> bool {
1685        true
1686    }
1687
1688    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1689        self.thread_store
1690            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1691    }
1692
1693    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1694        self.thread_store
1695            .update(cx, |store, cx| store.delete_threads(cx))
1696    }
1697
1698    fn watch(
1699        &self,
1700        _cx: &mut App,
1701    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1702        Some(self.updates_rx.clone())
1703    }
1704
1705    fn notify_refresh(&self) {
1706        self.updates_tx
1707            .try_send(acp_thread::SessionListUpdate::Refresh)
1708            .ok();
1709    }
1710
1711    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1712        self
1713    }
1714}
1715
1716struct NativeAgentSessionTruncate {
1717    thread: Entity<Thread>,
1718    acp_thread: WeakEntity<AcpThread>,
1719}
1720
1721impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1722    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1723        match self.thread.update(cx, |thread, cx| {
1724            thread.truncate(message_id.clone(), cx)?;
1725            Ok(thread.latest_token_usage())
1726        }) {
1727            Ok(usage) => {
1728                self.acp_thread
1729                    .update(cx, |thread, cx| {
1730                        thread.update_token_usage(usage, cx);
1731                    })
1732                    .ok();
1733                Task::ready(Ok(()))
1734            }
1735            Err(error) => Task::ready(Err(error)),
1736        }
1737    }
1738}
1739
1740struct NativeAgentSessionRetry {
1741    connection: NativeAgentConnection,
1742    session_id: acp::SessionId,
1743}
1744
1745impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1746    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1747        self.connection
1748            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1749                thread.update(cx, |thread, cx| thread.resume(cx))
1750            })
1751    }
1752}
1753
1754struct NativeAgentSessionSetTitle {
1755    thread: Entity<Thread>,
1756}
1757
1758impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1759    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1760        self.thread
1761            .update(cx, |thread, cx| thread.set_title(title, cx));
1762        Task::ready(Ok(()))
1763    }
1764}
1765
1766pub struct NativeThreadEnvironment {
1767    agent: WeakEntity<NativeAgent>,
1768    thread: WeakEntity<Thread>,
1769    acp_thread: WeakEntity<AcpThread>,
1770}
1771
1772impl NativeThreadEnvironment {
1773    pub(crate) fn create_subagent_thread(
1774        &self,
1775        label: String,
1776        cx: &mut App,
1777    ) -> Result<Rc<dyn SubagentHandle>> {
1778        let Some(parent_thread_entity) = self.thread.upgrade() else {
1779            anyhow::bail!("Parent thread no longer exists".to_string());
1780        };
1781        let parent_thread = parent_thread_entity.read(cx);
1782        let current_depth = parent_thread.depth();
1783        let parent_session_id = parent_thread.id().clone();
1784
1785        if current_depth >= MAX_SUBAGENT_DEPTH {
1786            return Err(anyhow!(
1787                "Maximum subagent depth ({}) reached",
1788                MAX_SUBAGENT_DEPTH
1789            ));
1790        }
1791
1792        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1793            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1794            thread.set_title(label.into(), cx);
1795            thread
1796        });
1797
1798        let session_id = subagent_thread.read(cx).id().clone();
1799
1800        let acp_thread = self
1801            .agent
1802            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1803                let project_id = agent
1804                    .sessions
1805                    .get(&parent_session_id)
1806                    .map(|s| s.project_id)
1807                    .context("parent session not found")?;
1808                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1809            })??;
1810
1811        let depth = current_depth + 1;
1812
1813        telemetry::event!(
1814            "Subagent Started",
1815            session = parent_thread_entity.read(cx).id().to_string(),
1816            subagent_session = session_id.to_string(),
1817            depth,
1818            is_resumed = false,
1819        );
1820
1821        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1822    }
1823
1824    pub(crate) fn resume_subagent_thread(
1825        &self,
1826        session_id: acp::SessionId,
1827        cx: &mut App,
1828    ) -> Result<Rc<dyn SubagentHandle>> {
1829        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1830            let session = agent
1831                .sessions
1832                .get(&session_id)
1833                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1834            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1835        })??;
1836
1837        let depth = subagent_thread.read(cx).depth();
1838
1839        if let Some(parent_thread_entity) = self.thread.upgrade() {
1840            telemetry::event!(
1841                "Subagent Started",
1842                session = parent_thread_entity.read(cx).id().to_string(),
1843                subagent_session = session_id.to_string(),
1844                depth,
1845                is_resumed = true,
1846            );
1847        }
1848
1849        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1850    }
1851
1852    fn prompt_subagent(
1853        &self,
1854        session_id: acp::SessionId,
1855        subagent_thread: Entity<Thread>,
1856        acp_thread: Entity<acp_thread::AcpThread>,
1857    ) -> Result<Rc<dyn SubagentHandle>> {
1858        let Some(parent_thread_entity) = self.thread.upgrade() else {
1859            anyhow::bail!("Parent thread no longer exists".to_string());
1860        };
1861        Ok(Rc::new(NativeSubagentHandle::new(
1862            session_id,
1863            subagent_thread,
1864            acp_thread,
1865            parent_thread_entity,
1866        )) as _)
1867    }
1868}
1869
1870impl ThreadEnvironment for NativeThreadEnvironment {
1871    fn create_terminal(
1872        &self,
1873        command: String,
1874        cwd: Option<PathBuf>,
1875        output_byte_limit: Option<u64>,
1876        cx: &mut AsyncApp,
1877    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1878        let task = self.acp_thread.update(cx, |thread, cx| {
1879            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1880        });
1881
1882        let acp_thread = self.acp_thread.clone();
1883        cx.spawn(async move |cx| {
1884            let terminal = task?.await?;
1885
1886            let (drop_tx, drop_rx) = oneshot::channel();
1887            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1888
1889            cx.spawn(async move |cx| {
1890                drop_rx.await.ok();
1891                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1892            })
1893            .detach();
1894
1895            let handle = AcpTerminalHandle {
1896                terminal,
1897                _drop_tx: Some(drop_tx),
1898            };
1899
1900            Ok(Rc::new(handle) as _)
1901        })
1902    }
1903
1904    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1905        self.create_subagent_thread(label, cx)
1906    }
1907
1908    fn resume_subagent(
1909        &self,
1910        session_id: acp::SessionId,
1911        cx: &mut App,
1912    ) -> Result<Rc<dyn SubagentHandle>> {
1913        self.resume_subagent_thread(session_id, cx)
1914    }
1915}
1916
1917#[derive(Debug, Clone)]
1918enum SubagentPromptResult {
1919    Completed,
1920    Cancelled,
1921    ContextWindowWarning,
1922    Error(String),
1923}
1924
1925pub struct NativeSubagentHandle {
1926    session_id: acp::SessionId,
1927    parent_thread: WeakEntity<Thread>,
1928    subagent_thread: Entity<Thread>,
1929    acp_thread: Entity<acp_thread::AcpThread>,
1930}
1931
1932impl NativeSubagentHandle {
1933    fn new(
1934        session_id: acp::SessionId,
1935        subagent_thread: Entity<Thread>,
1936        acp_thread: Entity<acp_thread::AcpThread>,
1937        parent_thread_entity: Entity<Thread>,
1938    ) -> Self {
1939        NativeSubagentHandle {
1940            session_id,
1941            subagent_thread,
1942            parent_thread: parent_thread_entity.downgrade(),
1943            acp_thread,
1944        }
1945    }
1946}
1947
1948impl SubagentHandle for NativeSubagentHandle {
1949    fn id(&self) -> acp::SessionId {
1950        self.session_id.clone()
1951    }
1952
1953    fn num_entries(&self, cx: &App) -> usize {
1954        self.acp_thread.read(cx).entries().len()
1955    }
1956
1957    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1958        let thread = self.subagent_thread.clone();
1959        let acp_thread = self.acp_thread.clone();
1960        let subagent_session_id = self.session_id.clone();
1961        let parent_thread = self.parent_thread.clone();
1962
1963        cx.spawn(async move |cx| {
1964            let (task, _subscription) = cx.update(|cx| {
1965                let ratio_before_prompt = thread
1966                    .read(cx)
1967                    .latest_token_usage()
1968                    .map(|usage| usage.ratio());
1969
1970                parent_thread
1971                    .update(cx, |parent_thread, _cx| {
1972                        parent_thread.register_running_subagent(thread.downgrade())
1973                    })
1974                    .ok();
1975
1976                let task = acp_thread.update(cx, |acp_thread, cx| {
1977                    acp_thread.send(vec![message.into()], cx)
1978                });
1979
1980                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1981                let mut token_limit_tx = Some(token_limit_tx);
1982
1983                let subscription = cx.subscribe(
1984                    &thread,
1985                    move |_thread, event: &TokenUsageUpdated, _cx| {
1986                        if let Some(usage) = &event.0 {
1987                            let old_ratio = ratio_before_prompt
1988                                .clone()
1989                                .unwrap_or(TokenUsageRatio::Normal);
1990                            let new_ratio = usage.ratio();
1991                            if old_ratio == TokenUsageRatio::Normal
1992                                && new_ratio == TokenUsageRatio::Warning
1993                            {
1994                                if let Some(tx) = token_limit_tx.take() {
1995                                    tx.send(()).ok();
1996                                }
1997                            }
1998                        }
1999                    },
2000                );
2001
2002                let wait_for_prompt = cx
2003                    .background_spawn(async move {
2004                        futures::select! {
2005                            response = task.fuse() => match response {
2006                                Ok(Some(response)) => {
2007                                    match response.stop_reason {
2008                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2009                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2010                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2011                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2012                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2013                                    }
2014                                }
2015                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2016                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2017                            },
2018                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2019                        }
2020                    });
2021
2022                (wait_for_prompt, subscription)
2023            });
2024
2025            let result = match task.await {
2026                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2027                    thread
2028                        .last_message()
2029                        .and_then(|message| {
2030                            let content = message.as_agent_message()?
2031                                .content
2032                                .iter()
2033                                .filter_map(|c| match c {
2034                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2035                                    _ => None,
2036                                })
2037                                .join("\n\n");
2038                            if content.is_empty() {
2039                                None
2040                            } else {
2041                                Some( content)
2042                            }
2043                        })
2044                        .context("No response from subagent")
2045                }),
2046                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2047                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2048                SubagentPromptResult::ContextWindowWarning => {
2049                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2050                    Err(anyhow!(
2051                        "The agent is nearing the end of its context window and has been \
2052                         stopped. You can prompt the thread again to have the agent wrap up \
2053                         or hand off its work."
2054                    ))
2055                }
2056            };
2057
2058            parent_thread
2059                .update(cx, |parent_thread, cx| {
2060                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2061                })
2062                .ok();
2063
2064            result
2065        })
2066    }
2067}
2068
2069pub struct AcpTerminalHandle {
2070    terminal: Entity<acp_thread::Terminal>,
2071    _drop_tx: Option<oneshot::Sender<()>>,
2072}
2073
2074impl TerminalHandle for AcpTerminalHandle {
2075    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2076        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2077    }
2078
2079    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2080        Ok(self
2081            .terminal
2082            .read_with(cx, |term, _cx| term.wait_for_exit()))
2083    }
2084
2085    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2086        Ok(self
2087            .terminal
2088            .read_with(cx, |term, cx| term.current_output(cx)))
2089    }
2090
2091    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2092        cx.update(|cx| {
2093            self.terminal.update(cx, |terminal, cx| {
2094                terminal.kill(cx);
2095            });
2096        });
2097        Ok(())
2098    }
2099
2100    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2101        Ok(self
2102            .terminal
2103            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2104    }
2105}
2106
2107#[cfg(test)]
2108mod internal_tests {
2109    use std::path::Path;
2110
2111    use super::*;
2112    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2113    use fs::FakeFs;
2114    use gpui::TestAppContext;
2115    use indoc::formatdoc;
2116    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2117    use language_model::{
2118        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2119    };
2120    use serde_json::json;
2121    use settings::SettingsStore;
2122    use util::{path, rel_path::rel_path};
2123
2124    #[gpui::test]
2125    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2126        init_test(cx);
2127        let fs = FakeFs::new(cx.executor());
2128        fs.insert_tree(
2129            "/",
2130            json!({
2131                "a": {}
2132            }),
2133        )
2134        .await;
2135        let project = Project::test(fs.clone(), [], cx).await;
2136        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2137        let agent =
2138            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2139
2140        // Creating a session registers the project and triggers context building.
2141        let connection = NativeAgentConnection(agent.clone());
2142        let _acp_thread = cx
2143            .update(|cx| {
2144                Rc::new(connection).new_session(
2145                    project.clone(),
2146                    PathList::new(&[Path::new("/")]),
2147                    cx,
2148                )
2149            })
2150            .await
2151            .unwrap();
2152        cx.run_until_parked();
2153
2154        let thread = agent.read_with(cx, |agent, _cx| {
2155            agent.sessions.values().next().unwrap().thread.clone()
2156        });
2157
2158        agent.read_with(cx, |agent, cx| {
2159            let project_id = project.entity_id();
2160            let state = agent.projects.get(&project_id).unwrap();
2161            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2162            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2163        });
2164
2165        let worktree = project
2166            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2167            .await
2168            .unwrap();
2169        cx.run_until_parked();
2170        agent.read_with(cx, |agent, cx| {
2171            let project_id = project.entity_id();
2172            let state = agent.projects.get(&project_id).unwrap();
2173            let expected_worktrees = vec![WorktreeContext {
2174                root_name: "a".into(),
2175                abs_path: Path::new("/a").into(),
2176                rules_file: None,
2177            }];
2178            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2179            assert_eq!(
2180                thread.read(cx).project_context().read(cx).worktrees,
2181                expected_worktrees
2182            );
2183        });
2184
2185        // Creating `/a/.rules` updates the project context.
2186        fs.insert_file("/a/.rules", Vec::new()).await;
2187        cx.run_until_parked();
2188        agent.read_with(cx, |agent, cx| {
2189            let project_id = project.entity_id();
2190            let state = agent.projects.get(&project_id).unwrap();
2191            let rules_entry = worktree
2192                .read(cx)
2193                .entry_for_path(rel_path(".rules"))
2194                .unwrap();
2195            let expected_worktrees = vec![WorktreeContext {
2196                root_name: "a".into(),
2197                abs_path: Path::new("/a").into(),
2198                rules_file: Some(RulesFileContext {
2199                    path_in_worktree: rel_path(".rules").into(),
2200                    text: "".into(),
2201                    project_entry_id: rules_entry.id.to_usize(),
2202                }),
2203            }];
2204            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2205            assert_eq!(
2206                thread.read(cx).project_context().read(cx).worktrees,
2207                expected_worktrees
2208            );
2209        });
2210    }
2211
2212    #[gpui::test]
2213    async fn test_listing_models(cx: &mut TestAppContext) {
2214        init_test(cx);
2215        let fs = FakeFs::new(cx.executor());
2216        fs.insert_tree("/", json!({ "a": {}  })).await;
2217        let project = Project::test(fs.clone(), [], cx).await;
2218        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2219        let connection =
2220            NativeAgentConnection(cx.update(|cx| {
2221                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2222            }));
2223
2224        // Create a thread/session
2225        let acp_thread = cx
2226            .update(|cx| {
2227                Rc::new(connection.clone()).new_session(
2228                    project.clone(),
2229                    PathList::new(&[Path::new("/a")]),
2230                    cx,
2231                )
2232            })
2233            .await
2234            .unwrap();
2235
2236        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2237
2238        let models = cx
2239            .update(|cx| {
2240                connection
2241                    .model_selector(&session_id)
2242                    .unwrap()
2243                    .list_models(cx)
2244            })
2245            .await
2246            .unwrap();
2247
2248        let acp_thread::AgentModelList::Grouped(models) = models else {
2249            panic!("Unexpected model group");
2250        };
2251        assert_eq!(
2252            models,
2253            IndexMap::from_iter([(
2254                AgentModelGroupName("Fake".into()),
2255                vec![AgentModelInfo {
2256                    id: acp::ModelId::new("fake/fake"),
2257                    name: "Fake".into(),
2258                    description: None,
2259                    icon: Some(acp_thread::AgentModelIcon::Named(
2260                        ui::IconName::ZedAssistant
2261                    )),
2262                    is_latest: false,
2263                    cost: None,
2264                }]
2265            )])
2266        );
2267    }
2268
2269    #[gpui::test]
2270    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2271        init_test(cx);
2272        let fs = FakeFs::new(cx.executor());
2273        fs.create_dir(paths::settings_file().parent().unwrap())
2274            .await
2275            .unwrap();
2276        fs.insert_file(
2277            paths::settings_file(),
2278            json!({
2279                "agent": {
2280                    "default_model": {
2281                        "provider": "foo",
2282                        "model": "bar"
2283                    }
2284                }
2285            })
2286            .to_string()
2287            .into_bytes(),
2288        )
2289        .await;
2290        let project = Project::test(fs.clone(), [], cx).await;
2291
2292        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2293
2294        // Create the agent and connection
2295        let agent =
2296            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2297        let connection = NativeAgentConnection(agent.clone());
2298
2299        // Create a thread/session
2300        let acp_thread = cx
2301            .update(|cx| {
2302                Rc::new(connection.clone()).new_session(
2303                    project.clone(),
2304                    PathList::new(&[Path::new("/a")]),
2305                    cx,
2306                )
2307            })
2308            .await
2309            .unwrap();
2310
2311        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2312
2313        // Select a model
2314        let selector = connection.model_selector(&session_id).unwrap();
2315        let model_id = acp::ModelId::new("fake/fake");
2316        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2317            .await
2318            .unwrap();
2319
2320        // Verify the thread has the selected model
2321        agent.read_with(cx, |agent, _| {
2322            let session = agent.sessions.get(&session_id).unwrap();
2323            session.thread.read_with(cx, |thread, _| {
2324                assert_eq!(thread.model().unwrap().id().0, "fake");
2325            });
2326        });
2327
2328        cx.run_until_parked();
2329
2330        // Verify settings file was updated
2331        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2332        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2333
2334        // Check that the agent settings contain the selected model
2335        assert_eq!(
2336            settings_json["agent"]["default_model"]["model"],
2337            json!("fake")
2338        );
2339        assert_eq!(
2340            settings_json["agent"]["default_model"]["provider"],
2341            json!("fake")
2342        );
2343
2344        // Register a thinking model and select it.
2345        cx.update(|cx| {
2346            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2347                "fake-corp",
2348                "fake-thinking",
2349                "Fake Thinking",
2350                true,
2351            ));
2352            let thinking_provider = Arc::new(
2353                FakeLanguageModelProvider::new(
2354                    LanguageModelProviderId::from("fake-corp".to_string()),
2355                    LanguageModelProviderName::from("Fake Corp".to_string()),
2356                )
2357                .with_models(vec![thinking_model]),
2358            );
2359            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2360                registry.register_provider(thinking_provider, cx);
2361            });
2362        });
2363        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2364
2365        let selector = connection.model_selector(&session_id).unwrap();
2366        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2367            .await
2368            .unwrap();
2369        cx.run_until_parked();
2370
2371        // Verify enable_thinking was written to settings as true.
2372        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2373        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2374        assert_eq!(
2375            settings_json["agent"]["default_model"]["enable_thinking"],
2376            json!(true),
2377            "selecting a thinking model should persist enable_thinking: true to settings"
2378        );
2379    }
2380
2381    #[gpui::test]
2382    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2383        init_test(cx);
2384        let fs = FakeFs::new(cx.executor());
2385        fs.create_dir(paths::settings_file().parent().unwrap())
2386            .await
2387            .unwrap();
2388        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2389        let project = Project::test(fs.clone(), [], cx).await;
2390
2391        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2392        let agent =
2393            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2394        let connection = NativeAgentConnection(agent.clone());
2395
2396        let acp_thread = cx
2397            .update(|cx| {
2398                Rc::new(connection.clone()).new_session(
2399                    project.clone(),
2400                    PathList::new(&[Path::new("/a")]),
2401                    cx,
2402                )
2403            })
2404            .await
2405            .unwrap();
2406        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2407
2408        // Register a second provider with a thinking model.
2409        cx.update(|cx| {
2410            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2411                "fake-corp",
2412                "fake-thinking",
2413                "Fake Thinking",
2414                true,
2415            ));
2416            let thinking_provider = Arc::new(
2417                FakeLanguageModelProvider::new(
2418                    LanguageModelProviderId::from("fake-corp".to_string()),
2419                    LanguageModelProviderName::from("Fake Corp".to_string()),
2420                )
2421                .with_models(vec![thinking_model]),
2422            );
2423            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2424                registry.register_provider(thinking_provider, cx);
2425            });
2426        });
2427        // Refresh the agent's model list so it picks up the new provider.
2428        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2429
2430        // Thread starts with thinking_enabled = false (the default).
2431        agent.read_with(cx, |agent, _| {
2432            let session = agent.sessions.get(&session_id).unwrap();
2433            session.thread.read_with(cx, |thread, _| {
2434                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2435            });
2436        });
2437
2438        // Select the thinking model via select_model.
2439        let selector = connection.model_selector(&session_id).unwrap();
2440        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2441            .await
2442            .unwrap();
2443
2444        // select_model should have enabled thinking based on the model's supports_thinking().
2445        agent.read_with(cx, |agent, _| {
2446            let session = agent.sessions.get(&session_id).unwrap();
2447            session.thread.read_with(cx, |thread, _| {
2448                assert!(
2449                    thread.thinking_enabled(),
2450                    "select_model should enable thinking when model supports it"
2451                );
2452            });
2453        });
2454
2455        // Switch back to the non-thinking model.
2456        let selector = connection.model_selector(&session_id).unwrap();
2457        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2458            .await
2459            .unwrap();
2460
2461        // select_model should have disabled thinking.
2462        agent.read_with(cx, |agent, _| {
2463            let session = agent.sessions.get(&session_id).unwrap();
2464            session.thread.read_with(cx, |thread, _| {
2465                assert!(
2466                    !thread.thinking_enabled(),
2467                    "select_model should disable thinking when model does not support it"
2468                );
2469            });
2470        });
2471    }
2472
2473    #[gpui::test]
2474    async fn test_summarization_model_survives_transient_registry_clearing(
2475        cx: &mut TestAppContext,
2476    ) {
2477        init_test(cx);
2478        let fs = FakeFs::new(cx.executor());
2479        fs.insert_tree("/", json!({ "a": {} })).await;
2480        let project = Project::test(fs.clone(), [], cx).await;
2481
2482        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2483        let agent =
2484            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2485        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2486
2487        let acp_thread = cx
2488            .update(|cx| {
2489                connection.clone().new_session(
2490                    project.clone(),
2491                    PathList::new(&[Path::new("/a")]),
2492                    cx,
2493                )
2494            })
2495            .await
2496            .unwrap();
2497        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2498
2499        let thread = agent.read_with(cx, |agent, _| {
2500            agent.sessions.get(&session_id).unwrap().thread.clone()
2501        });
2502
2503        thread.read_with(cx, |thread, _| {
2504            assert!(
2505                thread.summarization_model().is_some(),
2506                "session should have a summarization model from the test registry"
2507            );
2508        });
2509
2510        // Simulate what happens during a provider blip:
2511        // update_active_language_model_from_settings calls set_default_model(None)
2512        // when it can't resolve the model, clearing all fallbacks.
2513        cx.update(|cx| {
2514            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2515                registry.set_default_model(None, cx);
2516            });
2517        });
2518        cx.run_until_parked();
2519
2520        thread.read_with(cx, |thread, _| {
2521            assert!(
2522                thread.summarization_model().is_some(),
2523                "summarization model should survive a transient default model clearing"
2524            );
2525        });
2526    }
2527
2528    #[gpui::test]
2529    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2530        init_test(cx);
2531        let fs = FakeFs::new(cx.executor());
2532        fs.insert_tree("/", json!({ "a": {} })).await;
2533        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2534        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2535        let agent = cx.update(|cx| {
2536            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2537        });
2538        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2539
2540        // Register a thinking model.
2541        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2542            "fake-corp",
2543            "fake-thinking",
2544            "Fake Thinking",
2545            true,
2546        ));
2547        let thinking_provider = Arc::new(
2548            FakeLanguageModelProvider::new(
2549                LanguageModelProviderId::from("fake-corp".to_string()),
2550                LanguageModelProviderName::from("Fake Corp".to_string()),
2551            )
2552            .with_models(vec![thinking_model.clone()]),
2553        );
2554        cx.update(|cx| {
2555            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2556                registry.register_provider(thinking_provider, cx);
2557            });
2558        });
2559        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2560
2561        // Create a thread and select the thinking model.
2562        let acp_thread = cx
2563            .update(|cx| {
2564                connection.clone().new_session(
2565                    project.clone(),
2566                    PathList::new(&[Path::new("/a")]),
2567                    cx,
2568                )
2569            })
2570            .await
2571            .unwrap();
2572        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2573
2574        let selector = connection.model_selector(&session_id).unwrap();
2575        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2576            .await
2577            .unwrap();
2578
2579        // Verify thinking is enabled after selecting the thinking model.
2580        let thread = agent.read_with(cx, |agent, _| {
2581            agent.sessions.get(&session_id).unwrap().thread.clone()
2582        });
2583        thread.read_with(cx, |thread, _| {
2584            assert!(
2585                thread.thinking_enabled(),
2586                "thinking should be enabled after selecting thinking model"
2587            );
2588        });
2589
2590        // Send a message so the thread gets persisted.
2591        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2592        let send = cx.foreground_executor().spawn(send);
2593        cx.run_until_parked();
2594
2595        thinking_model.send_last_completion_stream_text_chunk("Response.");
2596        thinking_model.end_last_completion_stream();
2597
2598        send.await.unwrap();
2599        cx.run_until_parked();
2600
2601        // Close the session so it can be reloaded from disk.
2602        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2603            .await
2604            .unwrap();
2605        drop(thread);
2606        drop(acp_thread);
2607        agent.read_with(cx, |agent, _| {
2608            assert!(agent.sessions.is_empty());
2609        });
2610
2611        // Reload the thread and verify thinking_enabled is still true.
2612        let reloaded_acp_thread = agent
2613            .update(cx, |agent, cx| {
2614                agent.open_thread(session_id.clone(), project.clone(), cx)
2615            })
2616            .await
2617            .unwrap();
2618        let reloaded_thread = agent.read_with(cx, |agent, _| {
2619            agent.sessions.get(&session_id).unwrap().thread.clone()
2620        });
2621        reloaded_thread.read_with(cx, |thread, _| {
2622            assert!(
2623                thread.thinking_enabled(),
2624                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2625            );
2626        });
2627
2628        drop(reloaded_acp_thread);
2629    }
2630
2631    #[gpui::test]
2632    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2633        init_test(cx);
2634        let fs = FakeFs::new(cx.executor());
2635        fs.insert_tree("/", json!({ "a": {} })).await;
2636        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2637        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2638        let agent = cx.update(|cx| {
2639            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2640        });
2641        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2642
2643        // Register a model where id() != name(), like real Anthropic models
2644        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2645        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2646            "fake-corp",
2647            "custom-model-id",
2648            "Custom Model Display Name",
2649            false,
2650        ));
2651        let provider = Arc::new(
2652            FakeLanguageModelProvider::new(
2653                LanguageModelProviderId::from("fake-corp".to_string()),
2654                LanguageModelProviderName::from("Fake Corp".to_string()),
2655            )
2656            .with_models(vec![model.clone()]),
2657        );
2658        cx.update(|cx| {
2659            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2660                registry.register_provider(provider, cx);
2661            });
2662        });
2663        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2664
2665        // Create a thread and select the model.
2666        let acp_thread = cx
2667            .update(|cx| {
2668                connection.clone().new_session(
2669                    project.clone(),
2670                    PathList::new(&[Path::new("/a")]),
2671                    cx,
2672                )
2673            })
2674            .await
2675            .unwrap();
2676        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2677
2678        let selector = connection.model_selector(&session_id).unwrap();
2679        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2680            .await
2681            .unwrap();
2682
2683        let thread = agent.read_with(cx, |agent, _| {
2684            agent.sessions.get(&session_id).unwrap().thread.clone()
2685        });
2686        thread.read_with(cx, |thread, _| {
2687            assert_eq!(
2688                thread.model().unwrap().id().0.as_ref(),
2689                "custom-model-id",
2690                "model should be set before persisting"
2691            );
2692        });
2693
2694        // Send a message so the thread gets persisted.
2695        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2696        let send = cx.foreground_executor().spawn(send);
2697        cx.run_until_parked();
2698
2699        model.send_last_completion_stream_text_chunk("Response.");
2700        model.end_last_completion_stream();
2701
2702        send.await.unwrap();
2703        cx.run_until_parked();
2704
2705        // Close the session so it can be reloaded from disk.
2706        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2707            .await
2708            .unwrap();
2709        drop(thread);
2710        drop(acp_thread);
2711        agent.read_with(cx, |agent, _| {
2712            assert!(agent.sessions.is_empty());
2713        });
2714
2715        // Reload the thread and verify the model was preserved.
2716        let reloaded_acp_thread = agent
2717            .update(cx, |agent, cx| {
2718                agent.open_thread(session_id.clone(), project.clone(), cx)
2719            })
2720            .await
2721            .unwrap();
2722        let reloaded_thread = agent.read_with(cx, |agent, _| {
2723            agent.sessions.get(&session_id).unwrap().thread.clone()
2724        });
2725        reloaded_thread.read_with(cx, |thread, _| {
2726            let reloaded_model = thread
2727                .model()
2728                .expect("model should be present after reload");
2729            assert_eq!(
2730                reloaded_model.id().0.as_ref(),
2731                "custom-model-id",
2732                "reloaded thread should have the same model, not fall back to the default"
2733            );
2734        });
2735
2736        drop(reloaded_acp_thread);
2737    }
2738
2739    #[gpui::test]
2740    async fn test_save_load_thread(cx: &mut TestAppContext) {
2741        init_test(cx);
2742        let fs = FakeFs::new(cx.executor());
2743        fs.insert_tree(
2744            "/",
2745            json!({
2746                "a": {
2747                    "b.md": "Lorem"
2748                }
2749            }),
2750        )
2751        .await;
2752        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2753        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2754        let agent = cx.update(|cx| {
2755            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2756        });
2757        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2758
2759        let acp_thread = cx
2760            .update(|cx| {
2761                connection
2762                    .clone()
2763                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2764            })
2765            .await
2766            .unwrap();
2767        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2768        let thread = agent.read_with(cx, |agent, _| {
2769            agent.sessions.get(&session_id).unwrap().thread.clone()
2770        });
2771
2772        // Ensure empty threads are not saved, even if they get mutated.
2773        let model = Arc::new(FakeLanguageModel::default());
2774        let summary_model = Arc::new(FakeLanguageModel::default());
2775        thread.update(cx, |thread, cx| {
2776            thread.set_model(model.clone(), cx);
2777            thread.set_summarization_model(Some(summary_model.clone()), cx);
2778        });
2779        cx.run_until_parked();
2780        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2781
2782        let send = acp_thread.update(cx, |thread, cx| {
2783            thread.send(
2784                vec![
2785                    "What does ".into(),
2786                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2787                        "b.md",
2788                        MentionUri::File {
2789                            abs_path: path!("/a/b.md").into(),
2790                        }
2791                        .to_uri()
2792                        .to_string(),
2793                    )),
2794                    " mean?".into(),
2795                ],
2796                cx,
2797            )
2798        });
2799        let send = cx.foreground_executor().spawn(send);
2800        cx.run_until_parked();
2801
2802        model.send_last_completion_stream_text_chunk("Lorem.");
2803        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2804            language_model::TokenUsage {
2805                input_tokens: 150,
2806                output_tokens: 75,
2807                ..Default::default()
2808            },
2809        ));
2810        model.end_last_completion_stream();
2811        cx.run_until_parked();
2812        summary_model
2813            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2814        summary_model.end_last_completion_stream();
2815
2816        send.await.unwrap();
2817        let uri = MentionUri::File {
2818            abs_path: path!("/a/b.md").into(),
2819        }
2820        .to_uri();
2821        acp_thread.read_with(cx, |thread, cx| {
2822            assert_eq!(
2823                thread.to_markdown(cx),
2824                formatdoc! {"
2825                    ## User
2826
2827                    What does [@b.md]({uri}) mean?
2828
2829                    ## Assistant
2830
2831                    Lorem.
2832
2833                "}
2834            )
2835        });
2836
2837        cx.run_until_parked();
2838
2839        // Set a draft prompt with rich content blocks and scroll position
2840        // AFTER run_until_parked, so the only save that captures these
2841        // changes is the one performed by close_session itself.
2842        let draft_blocks = vec![
2843            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2844            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2845            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2846        ];
2847        acp_thread.update(cx, |thread, _cx| {
2848            thread.set_draft_prompt(Some(draft_blocks.clone()));
2849        });
2850        thread.update(cx, |thread, _cx| {
2851            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2852                item_ix: 5,
2853                offset_in_item: gpui::px(12.5),
2854            }));
2855        });
2856
2857        // Close the session so it can be reloaded from disk.
2858        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2859            .await
2860            .unwrap();
2861        drop(thread);
2862        drop(acp_thread);
2863        agent.read_with(cx, |agent, _| {
2864            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2865        });
2866
2867        // Ensure the thread can be reloaded from disk.
2868        assert_eq!(
2869            thread_entries(&thread_store, cx),
2870            vec![(
2871                session_id.clone(),
2872                format!("Explaining {}", path!("/a/b.md"))
2873            )]
2874        );
2875        let acp_thread = agent
2876            .update(cx, |agent, cx| {
2877                agent.open_thread(session_id.clone(), project.clone(), cx)
2878            })
2879            .await
2880            .unwrap();
2881        acp_thread.read_with(cx, |thread, cx| {
2882            assert_eq!(
2883                thread.to_markdown(cx),
2884                formatdoc! {"
2885                    ## User
2886
2887                    What does [@b.md]({uri}) mean?
2888
2889                    ## Assistant
2890
2891                    Lorem.
2892
2893                "}
2894            )
2895        });
2896
2897        // Ensure the draft prompt with rich content blocks survived the round-trip.
2898        acp_thread.read_with(cx, |thread, _| {
2899            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2900        });
2901
2902        // Ensure token usage survived the round-trip.
2903        acp_thread.read_with(cx, |thread, _| {
2904            let usage = thread
2905                .token_usage()
2906                .expect("token usage should be restored after reload");
2907            assert_eq!(usage.input_tokens, 150);
2908            assert_eq!(usage.output_tokens, 75);
2909        });
2910
2911        // Ensure scroll position survived the round-trip.
2912        acp_thread.read_with(cx, |thread, _| {
2913            let scroll = thread
2914                .ui_scroll_position()
2915                .expect("scroll position should be restored after reload");
2916            assert_eq!(scroll.item_ix, 5);
2917            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2918        });
2919    }
2920
2921    #[gpui::test]
2922    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
2923        init_test(cx);
2924        let fs = FakeFs::new(cx.executor());
2925        fs.insert_tree(
2926            "/",
2927            json!({
2928                "a": {
2929                    "file.txt": "hello"
2930                }
2931            }),
2932        )
2933        .await;
2934        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2935        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2936        let agent = cx.update(|cx| {
2937            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2938        });
2939        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2940
2941        let acp_thread = cx
2942            .update(|cx| {
2943                connection
2944                    .clone()
2945                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2946            })
2947            .await
2948            .unwrap();
2949        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2950        let thread = agent.read_with(cx, |agent, _| {
2951            agent.sessions.get(&session_id).unwrap().thread.clone()
2952        });
2953
2954        let model = Arc::new(FakeLanguageModel::default());
2955        thread.update(cx, |thread, cx| {
2956            thread.set_model(model.clone(), cx);
2957        });
2958
2959        // Send a message so the thread is non-empty (empty threads aren't saved).
2960        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
2961        let send = cx.foreground_executor().spawn(send);
2962        cx.run_until_parked();
2963
2964        model.send_last_completion_stream_text_chunk("world");
2965        model.end_last_completion_stream();
2966        send.await.unwrap();
2967        cx.run_until_parked();
2968
2969        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
2970        // This means no observe-triggered save has run for this change.
2971        // The only way this data gets persisted is if close_session
2972        // itself performs the save.
2973        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
2974            "unsaved draft",
2975        ))];
2976        acp_thread.update(cx, |thread, _cx| {
2977            thread.set_draft_prompt(Some(draft_blocks.clone()));
2978        });
2979
2980        // Close the session immediately — no run_until_parked in between.
2981        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2982            .await
2983            .unwrap();
2984        cx.run_until_parked();
2985
2986        // Reopen and verify the draft prompt was saved.
2987        let reloaded = agent
2988            .update(cx, |agent, cx| {
2989                agent.open_thread(session_id.clone(), project.clone(), cx)
2990            })
2991            .await
2992            .unwrap();
2993        reloaded.read_with(cx, |thread, _| {
2994            assert_eq!(
2995                thread.draft_prompt(),
2996                Some(draft_blocks.as_slice()),
2997                "close_session must save the thread; draft prompt was lost"
2998            );
2999        });
3000    }
3001
3002    fn thread_entries(
3003        thread_store: &Entity<ThreadStore>,
3004        cx: &mut TestAppContext,
3005    ) -> Vec<(acp::SessionId, String)> {
3006        thread_store.read_with(cx, |store, _| {
3007            store
3008                .entries()
3009                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3010                .collect::<Vec<_>>()
3011        })
3012    }
3013
3014    fn init_test(cx: &mut TestAppContext) {
3015        env_logger::try_init().ok();
3016        cx.update(|cx| {
3017            let settings_store = SettingsStore::test(cx);
3018            cx.set_global(settings_store);
3019
3020            LanguageModelRegistry::test(cx);
3021        });
3022    }
3023}
3024
3025fn mcp_message_content_to_acp_content_block(
3026    content: context_server::types::MessageContent,
3027) -> acp::ContentBlock {
3028    match content {
3029        context_server::types::MessageContent::Text {
3030            text,
3031            annotations: _,
3032        } => text.into(),
3033        context_server::types::MessageContent::Image {
3034            data,
3035            mime_type,
3036            annotations: _,
3037        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3038        context_server::types::MessageContent::Audio {
3039            data,
3040            mime_type,
3041            annotations: _,
3042        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3043        context_server::types::MessageContent::Resource {
3044            resource,
3045            annotations: _,
3046        } => {
3047            let mut link =
3048                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3049            if let Some(mime_type) = resource.mime_type {
3050                link = link.mime_type(mime_type);
3051            }
3052            acp::ContentBlock::ResourceLink(link)
3053        }
3054    }
3055}