agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87    ref_count: usize,
  88}
  89
  90struct PendingSession {
  91    task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
  92    ref_count: usize,
  93}
  94
  95pub struct LanguageModels {
  96    /// Access language model by ID
  97    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  98    /// Cached list for returning language model information
  99    model_list: acp_thread::AgentModelList,
 100    refresh_models_rx: watch::Receiver<()>,
 101    refresh_models_tx: watch::Sender<()>,
 102    _authenticate_all_providers_task: Task<()>,
 103}
 104
 105impl LanguageModels {
 106    fn new(cx: &mut App) -> Self {
 107        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 108
 109        let mut this = Self {
 110            models: HashMap::default(),
 111            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 112            refresh_models_rx,
 113            refresh_models_tx,
 114            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 115        };
 116        this.refresh_list(cx);
 117        this
 118    }
 119
 120    fn refresh_list(&mut self, cx: &App) {
 121        let providers = LanguageModelRegistry::global(cx)
 122            .read(cx)
 123            .visible_providers()
 124            .into_iter()
 125            .filter(|provider| provider.is_authenticated(cx))
 126            .collect::<Vec<_>>();
 127
 128        let mut language_model_list = IndexMap::default();
 129        let mut recommended_models = HashSet::default();
 130
 131        let mut recommended = Vec::new();
 132        for provider in &providers {
 133            for model in provider.recommended_models(cx) {
 134                recommended_models.insert((model.provider_id(), model.id()));
 135                recommended.push(Self::map_language_model_to_info(&model, provider));
 136            }
 137        }
 138        if !recommended.is_empty() {
 139            language_model_list.insert(
 140                acp_thread::AgentModelGroupName("Recommended".into()),
 141                recommended,
 142            );
 143        }
 144
 145        let mut models = HashMap::default();
 146        for provider in providers {
 147            let mut provider_models = Vec::new();
 148            for model in provider.provided_models(cx) {
 149                let model_info = Self::map_language_model_to_info(&model, &provider);
 150                let model_id = model_info.id.clone();
 151                provider_models.push(model_info);
 152                models.insert(model_id, model);
 153            }
 154            if !provider_models.is_empty() {
 155                language_model_list.insert(
 156                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 157                    provider_models,
 158                );
 159            }
 160        }
 161
 162        self.models = models;
 163        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 164        self.refresh_models_tx.send(()).ok();
 165    }
 166
 167    fn watch(&self) -> watch::Receiver<()> {
 168        self.refresh_models_rx.clone()
 169    }
 170
 171    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 172        self.models.get(model_id).cloned()
 173    }
 174
 175    fn map_language_model_to_info(
 176        model: &Arc<dyn LanguageModel>,
 177        provider: &Arc<dyn LanguageModelProvider>,
 178    ) -> acp_thread::AgentModelInfo {
 179        acp_thread::AgentModelInfo {
 180            id: Self::model_id(model),
 181            name: model.name().0,
 182            description: None,
 183            icon: Some(match provider.icon() {
 184                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 185                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 186            }),
 187            is_latest: model.is_latest(),
 188            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 189        }
 190    }
 191
 192    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 193        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 194    }
 195
 196    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 197        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 198            .read(cx)
 199            .visible_providers()
 200            .iter()
 201            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 202            .collect::<Vec<_>>();
 203
 204        cx.background_spawn(async move {
 205            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 206                if let Err(err) = authenticate_task.await {
 207                    match err {
 208                        language_model::AuthenticateError::CredentialsNotFound => {
 209                            // Since we're authenticating these providers in the
 210                            // background for the purposes of populating the
 211                            // language selector, we don't care about providers
 212                            // where the credentials are not found.
 213                        }
 214                        language_model::AuthenticateError::ConnectionRefused => {
 215                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 216                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 217                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 218                        }
 219                        _ => {
 220                            // Some providers have noisy failure states that we
 221                            // don't want to spam the logs with every time the
 222                            // language model selector is initialized.
 223                            //
 224                            // Ideally these should have more clear failure modes
 225                            // that we know are safe to ignore here, like what we do
 226                            // with `CredentialsNotFound` above.
 227                            match provider_id.0.as_ref() {
 228                                "lmstudio" | "ollama" => {
 229                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 230                                    //
 231                                    // These fail noisily, so we don't log them.
 232                                }
 233                                "copilot_chat" => {
 234                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 235                                }
 236                                _ => {
 237                                    log::error!(
 238                                        "Failed to authenticate provider: {}: {err:#}",
 239                                        provider_name.0
 240                                    );
 241                                }
 242                            }
 243                        }
 244                    }
 245                }
 246            }
 247        })
 248    }
 249}
 250
 251pub struct NativeAgent {
 252    /// Session ID -> Session mapping
 253    sessions: HashMap<acp::SessionId, Session>,
 254    pending_sessions: HashMap<acp::SessionId, PendingSession>,
 255    thread_store: Entity<ThreadStore>,
 256    /// Project-specific state keyed by project EntityId
 257    projects: HashMap<EntityId, ProjectState>,
 258    /// Shared templates for all threads
 259    templates: Arc<Templates>,
 260    /// Cached model information
 261    models: LanguageModels,
 262    prompt_store: Option<Entity<PromptStore>>,
 263    fs: Arc<dyn Fs>,
 264    _subscriptions: Vec<Subscription>,
 265}
 266
 267impl NativeAgent {
 268    pub fn new(
 269        thread_store: Entity<ThreadStore>,
 270        templates: Arc<Templates>,
 271        prompt_store: Option<Entity<PromptStore>>,
 272        fs: Arc<dyn Fs>,
 273        cx: &mut App,
 274    ) -> Entity<NativeAgent> {
 275        log::debug!("Creating new NativeAgent");
 276
 277        cx.new(|cx| {
 278            let mut subscriptions = vec![cx.subscribe(
 279                &LanguageModelRegistry::global(cx),
 280                Self::handle_models_updated_event,
 281            )];
 282            if let Some(prompt_store) = prompt_store.as_ref() {
 283                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 284            }
 285
 286            Self {
 287                sessions: HashMap::default(),
 288                pending_sessions: HashMap::default(),
 289                thread_store,
 290                projects: HashMap::default(),
 291                templates,
 292                models: LanguageModels::new(cx),
 293                prompt_store,
 294                fs,
 295                _subscriptions: subscriptions,
 296            }
 297        })
 298    }
 299
 300    fn new_session(
 301        &mut self,
 302        project: Entity<Project>,
 303        cx: &mut Context<Self>,
 304    ) -> Entity<AcpThread> {
 305        let project_id = self.get_or_create_project_state(&project, cx);
 306        let project_state = &self.projects[&project_id];
 307
 308        let registry = LanguageModelRegistry::read_global(cx);
 309        let available_count = registry.available_models(cx).count();
 310        log::debug!("Total available models: {}", available_count);
 311
 312        let default_model = registry.default_model().and_then(|default_model| {
 313            self.models
 314                .model_from_id(&LanguageModels::model_id(&default_model.model))
 315        });
 316        let thread = cx.new(|cx| {
 317            Thread::new(
 318                project,
 319                project_state.project_context.clone(),
 320                project_state.context_server_registry.clone(),
 321                self.templates.clone(),
 322                default_model,
 323                cx,
 324            )
 325        });
 326
 327        self.register_session(thread, project_id, 1, cx)
 328    }
 329
 330    fn register_session(
 331        &mut self,
 332        thread_handle: Entity<Thread>,
 333        project_id: EntityId,
 334        ref_count: usize,
 335        cx: &mut Context<Self>,
 336    ) -> Entity<AcpThread> {
 337        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 338
 339        let thread = thread_handle.read(cx);
 340        let session_id = thread.id().clone();
 341        let parent_session_id = thread.parent_thread_id();
 342        let title = thread.title();
 343        let draft_prompt = thread.draft_prompt().map(Vec::from);
 344        let scroll_position = thread.ui_scroll_position();
 345        let token_usage = thread.latest_token_usage();
 346        let project = thread.project.clone();
 347        let action_log = thread.action_log.clone();
 348        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 349        let acp_thread = cx.new(|cx| {
 350            let mut acp_thread = acp_thread::AcpThread::new(
 351                parent_session_id,
 352                title,
 353                None,
 354                connection,
 355                project.clone(),
 356                action_log.clone(),
 357                session_id.clone(),
 358                prompt_capabilities_rx,
 359                cx,
 360            );
 361            acp_thread.set_draft_prompt(draft_prompt, cx);
 362            acp_thread.set_ui_scroll_position(scroll_position);
 363            acp_thread.update_token_usage(token_usage, cx);
 364            acp_thread
 365        });
 366
 367        let registry = LanguageModelRegistry::read_global(cx);
 368        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 369
 370        let weak = cx.weak_entity();
 371        let weak_thread = thread_handle.downgrade();
 372        thread_handle.update(cx, |thread, cx| {
 373            thread.set_summarization_model(summarization_model, cx);
 374            thread.add_default_tools(
 375                Rc::new(NativeThreadEnvironment {
 376                    acp_thread: acp_thread.downgrade(),
 377                    thread: weak_thread,
 378                    agent: weak,
 379                }) as _,
 380                cx,
 381            )
 382        });
 383
 384        let subscriptions = vec![
 385            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 386            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 387            cx.observe(&thread_handle, move |this, thread, cx| {
 388                this.save_thread(thread, cx)
 389            }),
 390        ];
 391
 392        self.sessions.insert(
 393            session_id,
 394            Session {
 395                thread: thread_handle,
 396                acp_thread: acp_thread.clone(),
 397                project_id,
 398                _subscriptions: subscriptions,
 399                pending_save: Task::ready(Ok(())),
 400                ref_count,
 401            },
 402        );
 403
 404        self.update_available_commands_for_project(project_id, cx);
 405
 406        acp_thread
 407    }
 408
 409    pub fn models(&self) -> &LanguageModels {
 410        &self.models
 411    }
 412
 413    fn get_or_create_project_state(
 414        &mut self,
 415        project: &Entity<Project>,
 416        cx: &mut Context<Self>,
 417    ) -> EntityId {
 418        let project_id = project.entity_id();
 419        if self.projects.contains_key(&project_id) {
 420            return project_id;
 421        }
 422
 423        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 424        self.register_project_with_initial_context(project.clone(), project_context, cx);
 425        if let Some(state) = self.projects.get_mut(&project_id) {
 426            state.project_context_needs_refresh.send(()).ok();
 427        }
 428        project_id
 429    }
 430
 431    fn register_project_with_initial_context(
 432        &mut self,
 433        project: Entity<Project>,
 434        project_context: Entity<ProjectContext>,
 435        cx: &mut Context<Self>,
 436    ) {
 437        let project_id = project.entity_id();
 438
 439        let context_server_store = project.read(cx).context_server_store();
 440        let context_server_registry =
 441            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 442
 443        let subscriptions = vec![
 444            cx.subscribe(&project, Self::handle_project_event),
 445            cx.subscribe(
 446                &context_server_store,
 447                Self::handle_context_server_store_updated,
 448            ),
 449            cx.subscribe(
 450                &context_server_registry,
 451                Self::handle_context_server_registry_event,
 452            ),
 453        ];
 454
 455        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 456            watch::channel(());
 457
 458        self.projects.insert(
 459            project_id,
 460            ProjectState {
 461                project,
 462                project_context,
 463                project_context_needs_refresh: project_context_needs_refresh_tx,
 464                _maintain_project_context: cx.spawn(async move |this, cx| {
 465                    Self::maintain_project_context(
 466                        this,
 467                        project_id,
 468                        project_context_needs_refresh_rx,
 469                        cx,
 470                    )
 471                    .await
 472                }),
 473                context_server_registry,
 474                _subscriptions: subscriptions,
 475            },
 476        );
 477    }
 478
 479    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 480        self.sessions
 481            .get(session_id)
 482            .and_then(|session| self.projects.get(&session.project_id))
 483    }
 484
 485    async fn maintain_project_context(
 486        this: WeakEntity<Self>,
 487        project_id: EntityId,
 488        mut needs_refresh: watch::Receiver<()>,
 489        cx: &mut AsyncApp,
 490    ) -> Result<()> {
 491        while needs_refresh.changed().await.is_ok() {
 492            let project_context = this
 493                .update(cx, |this, cx| {
 494                    let state = this
 495                        .projects
 496                        .get(&project_id)
 497                        .context("project state not found")?;
 498                    anyhow::Ok(Self::build_project_context(
 499                        &state.project,
 500                        this.prompt_store.as_ref(),
 501                        cx,
 502                    ))
 503                })??
 504                .await;
 505            this.update(cx, |this, cx| {
 506                if let Some(state) = this.projects.get(&project_id) {
 507                    state
 508                        .project_context
 509                        .update(cx, |current_project_context, _cx| {
 510                            *current_project_context = project_context;
 511                        });
 512                }
 513            })?;
 514        }
 515
 516        Ok(())
 517    }
 518
 519    fn build_project_context(
 520        project: &Entity<Project>,
 521        prompt_store: Option<&Entity<PromptStore>>,
 522        cx: &mut App,
 523    ) -> Task<ProjectContext> {
 524        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 525        let worktree_tasks = worktrees
 526            .into_iter()
 527            .map(|worktree| {
 528                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 529            })
 530            .collect::<Vec<_>>();
 531        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 532            prompt_store.read_with(cx, |prompt_store, cx| {
 533                let prompts = prompt_store.default_prompt_metadata();
 534                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 535                    let contents = prompt_store.load(prompt_metadata.id, cx);
 536                    async move { (contents.await, prompt_metadata) }
 537                });
 538                cx.background_spawn(future::join_all(load_tasks))
 539            })
 540        } else {
 541            Task::ready(vec![])
 542        };
 543
 544        cx.spawn(async move |_cx| {
 545            let (worktrees, default_user_rules) =
 546                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 547
 548            let worktrees = worktrees
 549                .into_iter()
 550                .map(|(worktree, _rules_error)| {
 551                    // TODO: show error message
 552                    // if let Some(rules_error) = rules_error {
 553                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 554                    // }
 555                    worktree
 556                })
 557                .collect::<Vec<_>>();
 558
 559            let default_user_rules = default_user_rules
 560                .into_iter()
 561                .flat_map(|(contents, prompt_metadata)| match contents {
 562                    Ok(contents) => Some(UserRulesContext {
 563                        uuid: prompt_metadata.id.as_user()?,
 564                        title: prompt_metadata.title.map(|title| title.to_string()),
 565                        contents,
 566                    }),
 567                    Err(_err) => {
 568                        // TODO: show error message
 569                        // this.update(cx, |_, cx| {
 570                        //     cx.emit(RulesLoadingError {
 571                        //         message: format!("{err:?}").into(),
 572                        //     });
 573                        // })
 574                        // .ok();
 575                        None
 576                    }
 577                })
 578                .collect::<Vec<_>>();
 579
 580            ProjectContext::new(worktrees, default_user_rules)
 581        })
 582    }
 583
 584    fn load_worktree_info_for_system_prompt(
 585        worktree: Entity<Worktree>,
 586        project: Entity<Project>,
 587        cx: &mut App,
 588    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 589        let tree = worktree.read(cx);
 590        let root_name = tree.root_name_str().into();
 591        let abs_path = tree.abs_path();
 592
 593        let mut context = WorktreeContext {
 594            root_name,
 595            abs_path,
 596            rules_file: None,
 597        };
 598
 599        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 600        let Some(rules_task) = rules_task else {
 601            return Task::ready((context, None));
 602        };
 603
 604        cx.spawn(async move |_| {
 605            let (rules_file, rules_file_error) = match rules_task.await {
 606                Ok(rules_file) => (Some(rules_file), None),
 607                Err(err) => (
 608                    None,
 609                    Some(RulesLoadingError {
 610                        message: format!("{err}").into(),
 611                    }),
 612                ),
 613            };
 614            context.rules_file = rules_file;
 615            (context, rules_file_error)
 616        })
 617    }
 618
 619    fn load_worktree_rules_file(
 620        worktree: Entity<Worktree>,
 621        project: Entity<Project>,
 622        cx: &mut App,
 623    ) -> Option<Task<Result<RulesFileContext>>> {
 624        let worktree = worktree.read(cx);
 625        let worktree_id = worktree.id();
 626        let selected_rules_file = RULES_FILE_NAMES
 627            .into_iter()
 628            .filter_map(|name| {
 629                worktree
 630                    .entry_for_path(RelPath::unix(name).unwrap())
 631                    .filter(|entry| entry.is_file())
 632                    .map(|entry| entry.path.clone())
 633            })
 634            .next();
 635
 636        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 637        // supported. This doesn't seem to occur often in GitHub repositories.
 638        selected_rules_file.map(|path_in_worktree| {
 639            let project_path = ProjectPath {
 640                worktree_id,
 641                path: path_in_worktree.clone(),
 642            };
 643            let buffer_task =
 644                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 645            let rope_task = cx.spawn(async move |cx| {
 646                let buffer = buffer_task.await?;
 647                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 648                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 649                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 650                })?;
 651                anyhow::Ok((project_entry_id, rope))
 652            });
 653            // Build a string from the rope on a background thread.
 654            cx.background_spawn(async move {
 655                let (project_entry_id, rope) = rope_task.await?;
 656                anyhow::Ok(RulesFileContext {
 657                    path_in_worktree,
 658                    text: rope.to_string().trim().to_string(),
 659                    project_entry_id: project_entry_id.to_usize(),
 660                })
 661            })
 662        })
 663    }
 664
 665    fn handle_thread_title_updated(
 666        &mut self,
 667        thread: Entity<Thread>,
 668        _: &TitleUpdated,
 669        cx: &mut Context<Self>,
 670    ) {
 671        let session_id = thread.read(cx).id();
 672        let Some(session) = self.sessions.get(session_id) else {
 673            return;
 674        };
 675
 676        let thread = thread.downgrade();
 677        let acp_thread = session.acp_thread.downgrade();
 678        cx.spawn(async move |_, cx| {
 679            let title = thread.read_with(cx, |thread, _| thread.title())?;
 680            if let Some(title) = title {
 681                let task =
 682                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 683                task.await?;
 684            }
 685            anyhow::Ok(())
 686        })
 687        .detach_and_log_err(cx);
 688    }
 689
 690    fn handle_thread_token_usage_updated(
 691        &mut self,
 692        thread: Entity<Thread>,
 693        usage: &TokenUsageUpdated,
 694        cx: &mut Context<Self>,
 695    ) {
 696        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 697            return;
 698        };
 699        session.acp_thread.update(cx, |acp_thread, cx| {
 700            acp_thread.update_token_usage(usage.0.clone(), cx);
 701        });
 702    }
 703
 704    fn handle_project_event(
 705        &mut self,
 706        project: Entity<Project>,
 707        event: &project::Event,
 708        _cx: &mut Context<Self>,
 709    ) {
 710        let project_id = project.entity_id();
 711        let Some(state) = self.projects.get_mut(&project_id) else {
 712            return;
 713        };
 714        match event {
 715            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 716                state.project_context_needs_refresh.send(()).ok();
 717            }
 718            project::Event::WorktreeUpdatedEntries(_, items) => {
 719                if items.iter().any(|(path, _, _)| {
 720                    RULES_FILE_NAMES
 721                        .iter()
 722                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 723                }) {
 724                    state.project_context_needs_refresh.send(()).ok();
 725                }
 726            }
 727            _ => {}
 728        }
 729    }
 730
 731    fn handle_prompts_updated_event(
 732        &mut self,
 733        _prompt_store: Entity<PromptStore>,
 734        _event: &prompt_store::PromptsUpdatedEvent,
 735        _cx: &mut Context<Self>,
 736    ) {
 737        for state in self.projects.values_mut() {
 738            state.project_context_needs_refresh.send(()).ok();
 739        }
 740    }
 741
 742    fn handle_models_updated_event(
 743        &mut self,
 744        _registry: Entity<LanguageModelRegistry>,
 745        event: &language_model::Event,
 746        cx: &mut Context<Self>,
 747    ) {
 748        self.models.refresh_list(cx);
 749
 750        let registry = LanguageModelRegistry::read_global(cx);
 751        let default_model = registry.default_model().map(|m| m.model);
 752        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 753
 754        for session in self.sessions.values_mut() {
 755            session.thread.update(cx, |thread, cx| {
 756                if thread.model().is_none()
 757                    && let Some(model) = default_model.clone()
 758                {
 759                    thread.set_model(model, cx);
 760                    cx.notify();
 761                }
 762                if let Some(model) = summarization_model.clone() {
 763                    if thread.summarization_model().is_none()
 764                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 765                    {
 766                        thread.set_summarization_model(Some(model), cx);
 767                    }
 768                }
 769            });
 770        }
 771    }
 772
 773    fn handle_context_server_store_updated(
 774        &mut self,
 775        store: Entity<project::context_server_store::ContextServerStore>,
 776        _event: &project::context_server_store::ServerStatusChangedEvent,
 777        cx: &mut Context<Self>,
 778    ) {
 779        let project_id = self.projects.iter().find_map(|(id, state)| {
 780            if *state.context_server_registry.read(cx).server_store() == store {
 781                Some(*id)
 782            } else {
 783                None
 784            }
 785        });
 786        if let Some(project_id) = project_id {
 787            self.update_available_commands_for_project(project_id, cx);
 788        }
 789    }
 790
 791    fn handle_context_server_registry_event(
 792        &mut self,
 793        registry: Entity<ContextServerRegistry>,
 794        event: &ContextServerRegistryEvent,
 795        cx: &mut Context<Self>,
 796    ) {
 797        match event {
 798            ContextServerRegistryEvent::ToolsChanged => {}
 799            ContextServerRegistryEvent::PromptsChanged => {
 800                let project_id = self.projects.iter().find_map(|(id, state)| {
 801                    if state.context_server_registry == registry {
 802                        Some(*id)
 803                    } else {
 804                        None
 805                    }
 806                });
 807                if let Some(project_id) = project_id {
 808                    self.update_available_commands_for_project(project_id, cx);
 809                }
 810            }
 811        }
 812    }
 813
 814    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 815        let available_commands =
 816            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 817        for session in self.sessions.values() {
 818            if session.project_id != project_id {
 819                continue;
 820            }
 821            session.acp_thread.update(cx, |thread, cx| {
 822                thread
 823                    .handle_session_update(
 824                        acp::SessionUpdate::AvailableCommandsUpdate(
 825                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 826                        ),
 827                        cx,
 828                    )
 829                    .log_err();
 830            });
 831        }
 832    }
 833
 834    fn build_available_commands_for_project(
 835        project_state: Option<&ProjectState>,
 836        cx: &App,
 837    ) -> Vec<acp::AvailableCommand> {
 838        let Some(state) = project_state else {
 839            return vec![];
 840        };
 841        let registry = state.context_server_registry.read(cx);
 842
 843        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 844        for context_server_prompt in registry.prompts() {
 845            *prompt_name_counts
 846                .entry(context_server_prompt.prompt.name.as_str())
 847                .or_insert(0) += 1;
 848        }
 849
 850        registry
 851            .prompts()
 852            .flat_map(|context_server_prompt| {
 853                let prompt = &context_server_prompt.prompt;
 854
 855                let should_prefix = prompt_name_counts
 856                    .get(prompt.name.as_str())
 857                    .copied()
 858                    .unwrap_or(0)
 859                    > 1;
 860
 861                let name = if should_prefix {
 862                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 863                } else {
 864                    prompt.name.clone()
 865                };
 866
 867                let mut command = acp::AvailableCommand::new(
 868                    name,
 869                    prompt.description.clone().unwrap_or_default(),
 870                );
 871
 872                match prompt.arguments.as_deref() {
 873                    Some([arg]) => {
 874                        let hint = format!("<{}>", arg.name);
 875
 876                        command = command.input(acp::AvailableCommandInput::Unstructured(
 877                            acp::UnstructuredCommandInput::new(hint),
 878                        ));
 879                    }
 880                    Some([]) | None => {}
 881                    Some(_) => {
 882                        // skip >1 argument commands since we don't support them yet
 883                        return None;
 884                    }
 885                }
 886
 887                Some(command)
 888            })
 889            .collect()
 890    }
 891
 892    pub fn load_thread(
 893        &mut self,
 894        id: acp::SessionId,
 895        project: Entity<Project>,
 896        cx: &mut Context<Self>,
 897    ) -> Task<Result<Entity<Thread>>> {
 898        let database_future = ThreadsDatabase::connect(cx);
 899        cx.spawn(async move |this, cx| {
 900            let database = database_future.await.map_err(|err| anyhow!(err))?;
 901            let db_thread = database
 902                .load_thread(id.clone())
 903                .await?
 904                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 905
 906            this.update(cx, |this, cx| {
 907                let project_id = this.get_or_create_project_state(&project, cx);
 908                let project_state = this
 909                    .projects
 910                    .get(&project_id)
 911                    .context("project state not found")?;
 912                let summarization_model = LanguageModelRegistry::read_global(cx)
 913                    .thread_summary_model()
 914                    .map(|c| c.model);
 915
 916                Ok(cx.new(|cx| {
 917                    let mut thread = Thread::from_db(
 918                        id.clone(),
 919                        db_thread,
 920                        project_state.project.clone(),
 921                        project_state.project_context.clone(),
 922                        project_state.context_server_registry.clone(),
 923                        this.templates.clone(),
 924                        cx,
 925                    );
 926                    thread.set_summarization_model(summarization_model, cx);
 927                    thread
 928                }))
 929            })?
 930        })
 931    }
 932
 933    pub fn open_thread(
 934        &mut self,
 935        id: acp::SessionId,
 936        project: Entity<Project>,
 937        cx: &mut Context<Self>,
 938    ) -> Task<Result<Entity<AcpThread>>> {
 939        if let Some(session) = self.sessions.get_mut(&id) {
 940            session.ref_count += 1;
 941            return Task::ready(Ok(session.acp_thread.clone()));
 942        }
 943
 944        if let Some(pending) = self.pending_sessions.get_mut(&id) {
 945            pending.ref_count += 1;
 946            let task = pending.task.clone();
 947            return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
 948        }
 949
 950        let task = self.load_thread(id.clone(), project.clone(), cx);
 951        let shared_task = cx
 952            .spawn({
 953                let id = id.clone();
 954                async move |this, cx| {
 955                    let thread = match task.await {
 956                        Ok(thread) => thread,
 957                        Err(err) => {
 958                            this.update(cx, |this, _cx| {
 959                                this.pending_sessions.remove(&id);
 960                            })
 961                            .ok();
 962                            return Err(Arc::new(err));
 963                        }
 964                    };
 965                    let acp_thread = this
 966                        .update(cx, |this, cx| {
 967                            let project_id = this.get_or_create_project_state(&project, cx);
 968                            let ref_count = this
 969                                .pending_sessions
 970                                .remove(&id)
 971                                .map_or(1, |pending| pending.ref_count);
 972                            this.register_session(thread.clone(), project_id, ref_count, cx)
 973                        })
 974                        .map_err(Arc::new)?;
 975                    let events = thread.update(cx, |thread, cx| thread.replay(cx));
 976                    cx.update(|cx| {
 977                        NativeAgentConnection::handle_thread_events(
 978                            events,
 979                            acp_thread.downgrade(),
 980                            cx,
 981                        )
 982                    })
 983                    .await
 984                    .map_err(Arc::new)?;
 985                    acp_thread.update(cx, |thread, cx| {
 986                        thread.snapshot_completed_plan(cx);
 987                    });
 988                    Ok(acp_thread)
 989                }
 990            })
 991            .shared();
 992        self.pending_sessions.insert(
 993            id,
 994            PendingSession {
 995                task: shared_task.clone(),
 996                ref_count: 1,
 997            },
 998        );
 999
1000        cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1001    }
1002
1003    pub fn thread_summary(
1004        &mut self,
1005        id: acp::SessionId,
1006        project: Entity<Project>,
1007        cx: &mut Context<Self>,
1008    ) -> Task<Result<SharedString>> {
1009        let thread = self.open_thread(id.clone(), project, cx);
1010        cx.spawn(async move |this, cx| {
1011            let acp_thread = thread.await?;
1012            let result = this
1013                .update(cx, |this, cx| {
1014                    this.sessions
1015                        .get(&id)
1016                        .unwrap()
1017                        .thread
1018                        .update(cx, |thread, cx| thread.summary(cx))
1019                })?
1020                .await
1021                .context("Failed to generate summary")?;
1022
1023            this.update(cx, |this, cx| this.close_session(&id, cx))?
1024                .await?;
1025            drop(acp_thread);
1026            Ok(result)
1027        })
1028    }
1029
1030    fn close_session(
1031        &mut self,
1032        session_id: &acp::SessionId,
1033        cx: &mut Context<Self>,
1034    ) -> Task<Result<()>> {
1035        let Some(session) = self.sessions.get_mut(session_id) else {
1036            return Task::ready(Ok(()));
1037        };
1038
1039        session.ref_count -= 1;
1040        if session.ref_count > 0 {
1041            return Task::ready(Ok(()));
1042        }
1043
1044        let thread = session.thread.clone();
1045        self.save_thread(thread, cx);
1046        let Some(session) = self.sessions.remove(session_id) else {
1047            return Task::ready(Ok(()));
1048        };
1049        let project_id = session.project_id;
1050
1051        let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1052        if !has_remaining {
1053            self.projects.remove(&project_id);
1054        }
1055
1056        session.pending_save
1057    }
1058
1059    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1060        if thread.read(cx).is_empty() {
1061            return;
1062        }
1063
1064        let id = thread.read(cx).id().clone();
1065        let Some(session) = self.sessions.get_mut(&id) else {
1066            return;
1067        };
1068
1069        let project_id = session.project_id;
1070        let Some(state) = self.projects.get(&project_id) else {
1071            return;
1072        };
1073
1074        let folder_paths = PathList::new(
1075            &state
1076                .project
1077                .read(cx)
1078                .visible_worktrees(cx)
1079                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1080                .collect::<Vec<_>>(),
1081        );
1082
1083        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1084        let database_future = ThreadsDatabase::connect(cx);
1085        let db_thread = thread.update(cx, |thread, cx| {
1086            thread.set_draft_prompt(draft_prompt);
1087            thread.to_db(cx)
1088        });
1089        let thread_store = self.thread_store.clone();
1090        session.pending_save = cx.spawn(async move |_, cx| {
1091            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1092                return Ok(());
1093            };
1094            let db_thread = db_thread.await;
1095            database
1096                .save_thread(id, db_thread, folder_paths)
1097                .await
1098                .log_err();
1099            thread_store.update(cx, |store, cx| store.reload(cx));
1100            Ok(())
1101        });
1102    }
1103
1104    fn send_mcp_prompt(
1105        &self,
1106        message_id: UserMessageId,
1107        session_id: acp::SessionId,
1108        prompt_name: String,
1109        server_id: ContextServerId,
1110        arguments: HashMap<String, String>,
1111        original_content: Vec<acp::ContentBlock>,
1112        cx: &mut Context<Self>,
1113    ) -> Task<Result<acp::PromptResponse>> {
1114        let Some(state) = self.session_project_state(&session_id) else {
1115            return Task::ready(Err(anyhow!("Project state not found for session")));
1116        };
1117        let server_store = state
1118            .context_server_registry
1119            .read(cx)
1120            .server_store()
1121            .clone();
1122        let path_style = state.project.read(cx).path_style(cx);
1123
1124        cx.spawn(async move |this, cx| {
1125            let prompt =
1126                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1127
1128            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1129                let session = this
1130                    .sessions
1131                    .get(&session_id)
1132                    .context("Failed to get session")?;
1133                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1134            })??;
1135
1136            let mut last_is_user = true;
1137
1138            thread.update(cx, |thread, cx| {
1139                thread.push_acp_user_block(
1140                    message_id,
1141                    original_content.into_iter().skip(1),
1142                    path_style,
1143                    cx,
1144                );
1145            });
1146
1147            for message in prompt.messages {
1148                let context_server::types::PromptMessage { role, content } = message;
1149                let block = mcp_message_content_to_acp_content_block(content);
1150
1151                match role {
1152                    context_server::types::Role::User => {
1153                        let id = acp_thread::UserMessageId::new();
1154
1155                        acp_thread.update(cx, |acp_thread, cx| {
1156                            acp_thread.push_user_content_block_with_indent(
1157                                Some(id.clone()),
1158                                block.clone(),
1159                                true,
1160                                cx,
1161                            );
1162                        });
1163
1164                        thread.update(cx, |thread, cx| {
1165                            thread.push_acp_user_block(id, [block], path_style, cx);
1166                        });
1167                    }
1168                    context_server::types::Role::Assistant => {
1169                        acp_thread.update(cx, |acp_thread, cx| {
1170                            acp_thread.push_assistant_content_block_with_indent(
1171                                block.clone(),
1172                                false,
1173                                true,
1174                                cx,
1175                            );
1176                        });
1177
1178                        thread.update(cx, |thread, cx| {
1179                            thread.push_acp_agent_block(block, cx);
1180                        });
1181                    }
1182                }
1183
1184                last_is_user = role == context_server::types::Role::User;
1185            }
1186
1187            let response_stream = thread.update(cx, |thread, cx| {
1188                if last_is_user {
1189                    thread.send_existing(cx)
1190                } else {
1191                    // Resume if MCP prompt did not end with a user message
1192                    thread.resume(cx)
1193                }
1194            })?;
1195
1196            cx.update(|cx| {
1197                NativeAgentConnection::handle_thread_events(
1198                    response_stream,
1199                    acp_thread.downgrade(),
1200                    cx,
1201                )
1202            })
1203            .await
1204        })
1205    }
1206}
1207
1208/// Wrapper struct that implements the AgentConnection trait
1209#[derive(Clone)]
1210pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1211
1212impl NativeAgentConnection {
1213    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1214        self.0
1215            .read(cx)
1216            .sessions
1217            .get(session_id)
1218            .map(|session| session.thread.clone())
1219    }
1220
1221    pub fn load_thread(
1222        &self,
1223        id: acp::SessionId,
1224        project: Entity<Project>,
1225        cx: &mut App,
1226    ) -> Task<Result<Entity<Thread>>> {
1227        self.0
1228            .update(cx, |this, cx| this.load_thread(id, project, cx))
1229    }
1230
1231    fn run_turn(
1232        &self,
1233        session_id: acp::SessionId,
1234        cx: &mut App,
1235        f: impl 'static
1236        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1237    ) -> Task<Result<acp::PromptResponse>> {
1238        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1239            agent
1240                .sessions
1241                .get_mut(&session_id)
1242                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1243        }) else {
1244            log::error!("Session not found in run_turn: {}", session_id);
1245            return Task::ready(Err(anyhow!("Session not found")));
1246        };
1247        log::debug!("Found session for: {}", session_id);
1248
1249        let response_stream = match f(thread, cx) {
1250            Ok(stream) => stream,
1251            Err(err) => return Task::ready(Err(err)),
1252        };
1253        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1254    }
1255
1256    fn handle_thread_events(
1257        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1258        acp_thread: WeakEntity<AcpThread>,
1259        cx: &App,
1260    ) -> Task<Result<acp::PromptResponse>> {
1261        cx.spawn(async move |cx| {
1262            // Handle response stream and forward to session.acp_thread
1263            while let Some(result) = events.next().await {
1264                match result {
1265                    Ok(event) => {
1266                        log::trace!("Received completion event: {:?}", event);
1267
1268                        match event {
1269                            ThreadEvent::UserMessage(message) => {
1270                                acp_thread.update(cx, |thread, cx| {
1271                                    for content in message.content {
1272                                        thread.push_user_content_block(
1273                                            Some(message.id.clone()),
1274                                            content.into(),
1275                                            cx,
1276                                        );
1277                                    }
1278                                })?;
1279                            }
1280                            ThreadEvent::AgentText(text) => {
1281                                acp_thread.update(cx, |thread, cx| {
1282                                    thread.push_assistant_content_block(text.into(), false, cx)
1283                                })?;
1284                            }
1285                            ThreadEvent::AgentThinking(text) => {
1286                                acp_thread.update(cx, |thread, cx| {
1287                                    thread.push_assistant_content_block(text.into(), true, cx)
1288                                })?;
1289                            }
1290                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1291                                tool_call,
1292                                options,
1293                                response,
1294                                context: _,
1295                            }) => {
1296                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1297                                    thread.request_tool_call_authorization(tool_call, options, cx)
1298                                })??;
1299                                cx.background_spawn(async move {
1300                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1301                                        outcome_task.await
1302                                    {
1303                                        response
1304                                            .send(outcome)
1305                                            .map(|_| anyhow!("authorization receiver was dropped"))
1306                                            .log_err();
1307                                    }
1308                                })
1309                                .detach();
1310                            }
1311                            ThreadEvent::ToolCall(tool_call) => {
1312                                acp_thread.update(cx, |thread, cx| {
1313                                    thread.upsert_tool_call(tool_call, cx)
1314                                })??;
1315                            }
1316                            ThreadEvent::ToolCallUpdate(update) => {
1317                                acp_thread.update(cx, |thread, cx| {
1318                                    thread.update_tool_call(update, cx)
1319                                })??;
1320                            }
1321                            ThreadEvent::Plan(plan) => {
1322                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1323                            }
1324                            ThreadEvent::SubagentSpawned(session_id) => {
1325                                acp_thread.update(cx, |thread, cx| {
1326                                    thread.subagent_spawned(session_id, cx);
1327                                })?;
1328                            }
1329                            ThreadEvent::Retry(status) => {
1330                                acp_thread.update(cx, |thread, cx| {
1331                                    thread.update_retry_status(status, cx)
1332                                })?;
1333                            }
1334                            ThreadEvent::Stop(stop_reason) => {
1335                                log::debug!("Assistant message complete: {:?}", stop_reason);
1336                                return Ok(acp::PromptResponse::new(stop_reason));
1337                            }
1338                        }
1339                    }
1340                    Err(e) => {
1341                        log::error!("Error in model response stream: {:?}", e);
1342                        return Err(e);
1343                    }
1344                }
1345            }
1346
1347            log::debug!("Response stream completed");
1348            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1349        })
1350    }
1351}
1352
1353struct Command<'a> {
1354    prompt_name: &'a str,
1355    arg_value: &'a str,
1356    explicit_server_id: Option<&'a str>,
1357}
1358
1359impl<'a> Command<'a> {
1360    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1361        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1362            return None;
1363        };
1364        let text = text_content.text.trim();
1365        let command = text.strip_prefix('/')?;
1366        let (command, arg_value) = command
1367            .split_once(char::is_whitespace)
1368            .unwrap_or((command, ""));
1369
1370        if let Some((server_id, prompt_name)) = command.split_once('.') {
1371            Some(Self {
1372                prompt_name,
1373                arg_value,
1374                explicit_server_id: Some(server_id),
1375            })
1376        } else {
1377            Some(Self {
1378                prompt_name: command,
1379                arg_value,
1380                explicit_server_id: None,
1381            })
1382        }
1383    }
1384}
1385
1386struct NativeAgentModelSelector {
1387    session_id: acp::SessionId,
1388    connection: NativeAgentConnection,
1389}
1390
1391impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1392    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1393        log::debug!("NativeAgentConnection::list_models called");
1394        let list = self.connection.0.read(cx).models.model_list.clone();
1395        Task::ready(if list.is_empty() {
1396            Err(anyhow::anyhow!("No models available"))
1397        } else {
1398            Ok(list)
1399        })
1400    }
1401
1402    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1403        log::debug!(
1404            "Setting model for session {}: {}",
1405            self.session_id,
1406            model_id
1407        );
1408        let Some(thread) = self
1409            .connection
1410            .0
1411            .read(cx)
1412            .sessions
1413            .get(&self.session_id)
1414            .map(|session| session.thread.clone())
1415        else {
1416            return Task::ready(Err(anyhow!("Session not found")));
1417        };
1418
1419        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1420            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1421        };
1422
1423        // We want to reset the effort level when switching models, as the currently-selected effort level may
1424        // not be compatible.
1425        let effort = model
1426            .default_effort_level()
1427            .map(|effort_level| effort_level.value.to_string());
1428
1429        thread.update(cx, |thread, cx| {
1430            thread.set_model(model.clone(), cx);
1431            thread.set_thinking_effort(effort.clone(), cx);
1432            thread.set_thinking_enabled(model.supports_thinking(), cx);
1433        });
1434
1435        update_settings_file(
1436            self.connection.0.read(cx).fs.clone(),
1437            cx,
1438            move |settings, cx| {
1439                let provider = model.provider_id().0.to_string();
1440                let model = model.id().0.to_string();
1441                let enable_thinking = thread.read(cx).thinking_enabled();
1442                let speed = thread.read(cx).speed();
1443                settings
1444                    .agent
1445                    .get_or_insert_default()
1446                    .set_model(LanguageModelSelection {
1447                        provider: provider.into(),
1448                        model,
1449                        enable_thinking,
1450                        effort,
1451                        speed,
1452                    });
1453            },
1454        );
1455
1456        Task::ready(Ok(()))
1457    }
1458
1459    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1460        let Some(thread) = self
1461            .connection
1462            .0
1463            .read(cx)
1464            .sessions
1465            .get(&self.session_id)
1466            .map(|session| session.thread.clone())
1467        else {
1468            return Task::ready(Err(anyhow!("Session not found")));
1469        };
1470        let Some(model) = thread.read(cx).model() else {
1471            return Task::ready(Err(anyhow!("Model not found")));
1472        };
1473        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1474        else {
1475            return Task::ready(Err(anyhow!("Provider not found")));
1476        };
1477        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1478            model, &provider,
1479        )))
1480    }
1481
1482    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1483        Some(self.connection.0.read(cx).models.watch())
1484    }
1485
1486    fn should_render_footer(&self) -> bool {
1487        true
1488    }
1489}
1490
1491pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1492
1493impl acp_thread::AgentConnection for NativeAgentConnection {
1494    fn agent_id(&self) -> AgentId {
1495        ZED_AGENT_ID.clone()
1496    }
1497
1498    fn telemetry_id(&self) -> SharedString {
1499        "zed".into()
1500    }
1501
1502    fn new_session(
1503        self: Rc<Self>,
1504        project: Entity<Project>,
1505        work_dirs: PathList,
1506        cx: &mut App,
1507    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1508        log::debug!("Creating new thread for project at: {work_dirs:?}");
1509        Task::ready(Ok(self
1510            .0
1511            .update(cx, |agent, cx| agent.new_session(project, cx))))
1512    }
1513
1514    fn supports_load_session(&self) -> bool {
1515        true
1516    }
1517
1518    fn load_session(
1519        self: Rc<Self>,
1520        session_id: acp::SessionId,
1521        project: Entity<Project>,
1522        _work_dirs: PathList,
1523        _title: Option<SharedString>,
1524        cx: &mut App,
1525    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1526        self.0
1527            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1528    }
1529
1530    fn supports_close_session(&self) -> bool {
1531        true
1532    }
1533
1534    fn close_session(
1535        self: Rc<Self>,
1536        session_id: &acp::SessionId,
1537        cx: &mut App,
1538    ) -> Task<Result<()>> {
1539        self.0
1540            .update(cx, |agent, cx| agent.close_session(session_id, cx))
1541    }
1542
1543    fn auth_methods(&self) -> &[acp::AuthMethod] {
1544        &[] // No auth for in-process
1545    }
1546
1547    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1548        Task::ready(Ok(()))
1549    }
1550
1551    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1552        Some(Rc::new(NativeAgentModelSelector {
1553            session_id: session_id.clone(),
1554            connection: self.clone(),
1555        }) as Rc<dyn AgentModelSelector>)
1556    }
1557
1558    fn prompt(
1559        &self,
1560        id: acp_thread::UserMessageId,
1561        params: acp::PromptRequest,
1562        cx: &mut App,
1563    ) -> Task<Result<acp::PromptResponse>> {
1564        let session_id = params.session_id.clone();
1565        log::info!("Received prompt request for session: {}", session_id);
1566        log::debug!("Prompt blocks count: {}", params.prompt.len());
1567
1568        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1569            log::error!("Session not found in prompt: {}", session_id);
1570            if self.0.read(cx).sessions.contains_key(&session_id) {
1571                log::error!(
1572                    "Session found in sessions map, but not in project state: {}",
1573                    session_id
1574                );
1575            }
1576            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1577        };
1578
1579        if let Some(parsed_command) = Command::parse(&params.prompt) {
1580            let registry = project_state.context_server_registry.read(cx);
1581
1582            let explicit_server_id = parsed_command
1583                .explicit_server_id
1584                .map(|server_id| ContextServerId(server_id.into()));
1585
1586            if let Some(prompt) =
1587                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1588            {
1589                let arguments = if !parsed_command.arg_value.is_empty()
1590                    && let Some(arg_name) = prompt
1591                        .prompt
1592                        .arguments
1593                        .as_ref()
1594                        .and_then(|args| args.first())
1595                        .map(|arg| arg.name.clone())
1596                {
1597                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1598                } else {
1599                    Default::default()
1600                };
1601
1602                let prompt_name = prompt.prompt.name.clone();
1603                let server_id = prompt.server_id.clone();
1604
1605                return self.0.update(cx, |agent, cx| {
1606                    agent.send_mcp_prompt(
1607                        id,
1608                        session_id.clone(),
1609                        prompt_name,
1610                        server_id,
1611                        arguments,
1612                        params.prompt,
1613                        cx,
1614                    )
1615                });
1616            }
1617        };
1618
1619        let path_style = project_state.project.read(cx).path_style(cx);
1620
1621        self.run_turn(session_id, cx, move |thread, cx| {
1622            let content: Vec<UserMessageContent> = params
1623                .prompt
1624                .into_iter()
1625                .map(|block| UserMessageContent::from_content_block(block, path_style))
1626                .collect::<Vec<_>>();
1627            log::debug!("Converted prompt to message: {} chars", content.len());
1628            log::debug!("Message id: {:?}", id);
1629            log::debug!("Message content: {:?}", content);
1630
1631            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1632        })
1633    }
1634
1635    fn retry(
1636        &self,
1637        session_id: &acp::SessionId,
1638        _cx: &App,
1639    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1640        Some(Rc::new(NativeAgentSessionRetry {
1641            connection: self.clone(),
1642            session_id: session_id.clone(),
1643        }) as _)
1644    }
1645
1646    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1647        log::info!("Cancelling on session: {}", session_id);
1648        self.0.update(cx, |agent, cx| {
1649            if let Some(session) = agent.sessions.get(session_id) {
1650                session
1651                    .thread
1652                    .update(cx, |thread, cx| thread.cancel(cx))
1653                    .detach();
1654            }
1655        });
1656    }
1657
1658    fn truncate(
1659        &self,
1660        session_id: &acp::SessionId,
1661        cx: &App,
1662    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1663        self.0.read_with(cx, |agent, _cx| {
1664            agent.sessions.get(session_id).map(|session| {
1665                Rc::new(NativeAgentSessionTruncate {
1666                    thread: session.thread.clone(),
1667                    acp_thread: session.acp_thread.downgrade(),
1668                }) as _
1669            })
1670        })
1671    }
1672
1673    fn set_title(
1674        &self,
1675        session_id: &acp::SessionId,
1676        cx: &App,
1677    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1678        self.0.read_with(cx, |agent, _cx| {
1679            agent
1680                .sessions
1681                .get(session_id)
1682                .filter(|s| !s.thread.read(cx).is_subagent())
1683                .map(|session| {
1684                    Rc::new(NativeAgentSessionSetTitle {
1685                        thread: session.thread.clone(),
1686                    }) as _
1687                })
1688        })
1689    }
1690
1691    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1692        let thread_store = self.0.read(cx).thread_store.clone();
1693        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1694    }
1695
1696    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1697        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1698    }
1699
1700    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1701        self
1702    }
1703}
1704
1705impl acp_thread::AgentTelemetry for NativeAgentConnection {
1706    fn thread_data(
1707        &self,
1708        session_id: &acp::SessionId,
1709        cx: &mut App,
1710    ) -> Task<Result<serde_json::Value>> {
1711        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1712            return Task::ready(Err(anyhow!("Session not found")));
1713        };
1714
1715        let task = session.thread.read(cx).to_db(cx);
1716        cx.background_spawn(async move {
1717            serde_json::to_value(task.await).context("Failed to serialize thread")
1718        })
1719    }
1720}
1721
1722pub struct NativeAgentSessionList {
1723    thread_store: Entity<ThreadStore>,
1724    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1725    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1726    _subscription: Subscription,
1727}
1728
1729impl NativeAgentSessionList {
1730    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1731        let (tx, rx) = smol::channel::unbounded();
1732        let this_tx = tx.clone();
1733        let subscription = cx.observe(&thread_store, move |_, _| {
1734            this_tx
1735                .try_send(acp_thread::SessionListUpdate::Refresh)
1736                .ok();
1737        });
1738        Self {
1739            thread_store,
1740            updates_tx: tx,
1741            updates_rx: rx,
1742            _subscription: subscription,
1743        }
1744    }
1745
1746    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1747        &self.thread_store
1748    }
1749}
1750
1751impl AgentSessionList for NativeAgentSessionList {
1752    fn list_sessions(
1753        &self,
1754        _request: AgentSessionListRequest,
1755        cx: &mut App,
1756    ) -> Task<Result<AgentSessionListResponse>> {
1757        let sessions = self
1758            .thread_store
1759            .read(cx)
1760            .entries()
1761            .map(|entry| AgentSessionInfo::from(&entry))
1762            .collect();
1763        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1764    }
1765
1766    fn supports_delete(&self) -> bool {
1767        true
1768    }
1769
1770    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1771        self.thread_store
1772            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1773    }
1774
1775    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1776        self.thread_store
1777            .update(cx, |store, cx| store.delete_threads(cx))
1778    }
1779
1780    fn watch(
1781        &self,
1782        _cx: &mut App,
1783    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1784        Some(self.updates_rx.clone())
1785    }
1786
1787    fn notify_refresh(&self) {
1788        self.updates_tx
1789            .try_send(acp_thread::SessionListUpdate::Refresh)
1790            .ok();
1791    }
1792
1793    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1794        self
1795    }
1796}
1797
1798struct NativeAgentSessionTruncate {
1799    thread: Entity<Thread>,
1800    acp_thread: WeakEntity<AcpThread>,
1801}
1802
1803impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1804    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1805        match self.thread.update(cx, |thread, cx| {
1806            thread.truncate(message_id.clone(), cx)?;
1807            Ok(thread.latest_token_usage())
1808        }) {
1809            Ok(usage) => {
1810                self.acp_thread
1811                    .update(cx, |thread, cx| {
1812                        thread.update_token_usage(usage, cx);
1813                    })
1814                    .ok();
1815                Task::ready(Ok(()))
1816            }
1817            Err(error) => Task::ready(Err(error)),
1818        }
1819    }
1820}
1821
1822struct NativeAgentSessionRetry {
1823    connection: NativeAgentConnection,
1824    session_id: acp::SessionId,
1825}
1826
1827impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1828    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1829        self.connection
1830            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1831                thread.update(cx, |thread, cx| thread.resume(cx))
1832            })
1833    }
1834}
1835
1836struct NativeAgentSessionSetTitle {
1837    thread: Entity<Thread>,
1838}
1839
1840impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1841    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1842        self.thread
1843            .update(cx, |thread, cx| thread.set_title(title, cx));
1844        Task::ready(Ok(()))
1845    }
1846}
1847
1848pub struct NativeThreadEnvironment {
1849    agent: WeakEntity<NativeAgent>,
1850    thread: WeakEntity<Thread>,
1851    acp_thread: WeakEntity<AcpThread>,
1852}
1853
1854impl NativeThreadEnvironment {
1855    pub(crate) fn create_subagent_thread(
1856        &self,
1857        label: String,
1858        cx: &mut App,
1859    ) -> Result<Rc<dyn SubagentHandle>> {
1860        let Some(parent_thread_entity) = self.thread.upgrade() else {
1861            anyhow::bail!("Parent thread no longer exists".to_string());
1862        };
1863        let parent_thread = parent_thread_entity.read(cx);
1864        let current_depth = parent_thread.depth();
1865        let parent_session_id = parent_thread.id().clone();
1866
1867        if current_depth >= MAX_SUBAGENT_DEPTH {
1868            return Err(anyhow!(
1869                "Maximum subagent depth ({}) reached",
1870                MAX_SUBAGENT_DEPTH
1871            ));
1872        }
1873
1874        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1875            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1876            thread.set_title(label.into(), cx);
1877            thread
1878        });
1879
1880        let session_id = subagent_thread.read(cx).id().clone();
1881
1882        let acp_thread = self
1883            .agent
1884            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1885                let project_id = agent
1886                    .sessions
1887                    .get(&parent_session_id)
1888                    .map(|s| s.project_id)
1889                    .context("parent session not found")?;
1890                Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1891            })??;
1892
1893        let depth = current_depth + 1;
1894
1895        telemetry::event!(
1896            "Subagent Started",
1897            session = parent_thread_entity.read(cx).id().to_string(),
1898            subagent_session = session_id.to_string(),
1899            depth,
1900            is_resumed = false,
1901        );
1902
1903        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1904    }
1905
1906    pub(crate) fn resume_subagent_thread(
1907        &self,
1908        session_id: acp::SessionId,
1909        cx: &mut App,
1910    ) -> Result<Rc<dyn SubagentHandle>> {
1911        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1912            let session = agent
1913                .sessions
1914                .get(&session_id)
1915                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1916            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1917        })??;
1918
1919        let depth = subagent_thread.read(cx).depth();
1920
1921        if let Some(parent_thread_entity) = self.thread.upgrade() {
1922            telemetry::event!(
1923                "Subagent Started",
1924                session = parent_thread_entity.read(cx).id().to_string(),
1925                subagent_session = session_id.to_string(),
1926                depth,
1927                is_resumed = true,
1928            );
1929        }
1930
1931        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1932    }
1933
1934    fn prompt_subagent(
1935        &self,
1936        session_id: acp::SessionId,
1937        subagent_thread: Entity<Thread>,
1938        acp_thread: Entity<acp_thread::AcpThread>,
1939    ) -> Result<Rc<dyn SubagentHandle>> {
1940        let Some(parent_thread_entity) = self.thread.upgrade() else {
1941            anyhow::bail!("Parent thread no longer exists".to_string());
1942        };
1943        Ok(Rc::new(NativeSubagentHandle::new(
1944            session_id,
1945            subagent_thread,
1946            acp_thread,
1947            parent_thread_entity,
1948        )) as _)
1949    }
1950}
1951
1952impl ThreadEnvironment for NativeThreadEnvironment {
1953    fn create_terminal(
1954        &self,
1955        command: String,
1956        cwd: Option<PathBuf>,
1957        output_byte_limit: Option<u64>,
1958        cx: &mut AsyncApp,
1959    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1960        let task = self.acp_thread.update(cx, |thread, cx| {
1961            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1962        });
1963
1964        let acp_thread = self.acp_thread.clone();
1965        cx.spawn(async move |cx| {
1966            let terminal = task?.await?;
1967
1968            let (drop_tx, drop_rx) = oneshot::channel();
1969            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1970
1971            cx.spawn(async move |cx| {
1972                drop_rx.await.ok();
1973                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1974            })
1975            .detach();
1976
1977            let handle = AcpTerminalHandle {
1978                terminal,
1979                _drop_tx: Some(drop_tx),
1980            };
1981
1982            Ok(Rc::new(handle) as _)
1983        })
1984    }
1985
1986    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1987        self.create_subagent_thread(label, cx)
1988    }
1989
1990    fn resume_subagent(
1991        &self,
1992        session_id: acp::SessionId,
1993        cx: &mut App,
1994    ) -> Result<Rc<dyn SubagentHandle>> {
1995        self.resume_subagent_thread(session_id, cx)
1996    }
1997}
1998
1999#[derive(Debug, Clone)]
2000enum SubagentPromptResult {
2001    Completed,
2002    Cancelled,
2003    ContextWindowWarning,
2004    Error(String),
2005}
2006
2007pub struct NativeSubagentHandle {
2008    session_id: acp::SessionId,
2009    parent_thread: WeakEntity<Thread>,
2010    subagent_thread: Entity<Thread>,
2011    acp_thread: Entity<acp_thread::AcpThread>,
2012}
2013
2014impl NativeSubagentHandle {
2015    fn new(
2016        session_id: acp::SessionId,
2017        subagent_thread: Entity<Thread>,
2018        acp_thread: Entity<acp_thread::AcpThread>,
2019        parent_thread_entity: Entity<Thread>,
2020    ) -> Self {
2021        NativeSubagentHandle {
2022            session_id,
2023            subagent_thread,
2024            parent_thread: parent_thread_entity.downgrade(),
2025            acp_thread,
2026        }
2027    }
2028}
2029
2030impl SubagentHandle for NativeSubagentHandle {
2031    fn id(&self) -> acp::SessionId {
2032        self.session_id.clone()
2033    }
2034
2035    fn num_entries(&self, cx: &App) -> usize {
2036        self.acp_thread.read(cx).entries().len()
2037    }
2038
2039    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2040        let thread = self.subagent_thread.clone();
2041        let acp_thread = self.acp_thread.clone();
2042        let subagent_session_id = self.session_id.clone();
2043        let parent_thread = self.parent_thread.clone();
2044
2045        cx.spawn(async move |cx| {
2046            let (task, _subscription) = cx.update(|cx| {
2047                let ratio_before_prompt = thread
2048                    .read(cx)
2049                    .latest_token_usage()
2050                    .map(|usage| usage.ratio());
2051
2052                parent_thread
2053                    .update(cx, |parent_thread, _cx| {
2054                        parent_thread.register_running_subagent(thread.downgrade())
2055                    })
2056                    .ok();
2057
2058                let task = acp_thread.update(cx, |acp_thread, cx| {
2059                    acp_thread.send(vec![message.into()], cx)
2060                });
2061
2062                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2063                let mut token_limit_tx = Some(token_limit_tx);
2064
2065                let subscription = cx.subscribe(
2066                    &thread,
2067                    move |_thread, event: &TokenUsageUpdated, _cx| {
2068                        if let Some(usage) = &event.0 {
2069                            let old_ratio = ratio_before_prompt
2070                                .clone()
2071                                .unwrap_or(TokenUsageRatio::Normal);
2072                            let new_ratio = usage.ratio();
2073                            if old_ratio == TokenUsageRatio::Normal
2074                                && new_ratio == TokenUsageRatio::Warning
2075                            {
2076                                if let Some(tx) = token_limit_tx.take() {
2077                                    tx.send(()).ok();
2078                                }
2079                            }
2080                        }
2081                    },
2082                );
2083
2084                let wait_for_prompt = cx
2085                    .background_spawn(async move {
2086                        futures::select! {
2087                            response = task.fuse() => match response {
2088                                Ok(Some(response)) => {
2089                                    match response.stop_reason {
2090                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2091                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2092                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2093                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2094                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2095                                    }
2096                                }
2097                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2098                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2099                            },
2100                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2101                        }
2102                    });
2103
2104                (wait_for_prompt, subscription)
2105            });
2106
2107            let result = match task.await {
2108                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2109                    thread
2110                        .last_message()
2111                        .and_then(|message| {
2112                            let content = message.as_agent_message()?
2113                                .content
2114                                .iter()
2115                                .filter_map(|c| match c {
2116                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2117                                    _ => None,
2118                                })
2119                                .join("\n\n");
2120                            if content.is_empty() {
2121                                None
2122                            } else {
2123                                Some( content)
2124                            }
2125                        })
2126                        .context("No response from subagent")
2127                }),
2128                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2129                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2130                SubagentPromptResult::ContextWindowWarning => {
2131                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2132                    Err(anyhow!(
2133                        "The agent is nearing the end of its context window and has been \
2134                         stopped. You can prompt the thread again to have the agent wrap up \
2135                         or hand off its work."
2136                    ))
2137                }
2138            };
2139
2140            parent_thread
2141                .update(cx, |parent_thread, cx| {
2142                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2143                })
2144                .ok();
2145
2146            result
2147        })
2148    }
2149}
2150
2151pub struct AcpTerminalHandle {
2152    terminal: Entity<acp_thread::Terminal>,
2153    _drop_tx: Option<oneshot::Sender<()>>,
2154}
2155
2156impl TerminalHandle for AcpTerminalHandle {
2157    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2158        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2159    }
2160
2161    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2162        Ok(self
2163            .terminal
2164            .read_with(cx, |term, _cx| term.wait_for_exit()))
2165    }
2166
2167    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2168        Ok(self
2169            .terminal
2170            .read_with(cx, |term, cx| term.current_output(cx)))
2171    }
2172
2173    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2174        cx.update(|cx| {
2175            self.terminal.update(cx, |terminal, cx| {
2176                terminal.kill(cx);
2177            });
2178        });
2179        Ok(())
2180    }
2181
2182    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2183        Ok(self
2184            .terminal
2185            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2186    }
2187}
2188
2189#[cfg(test)]
2190mod internal_tests {
2191    use std::path::Path;
2192
2193    use super::*;
2194    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2195    use fs::FakeFs;
2196    use gpui::TestAppContext;
2197    use indoc::formatdoc;
2198    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2199    use language_model::{
2200        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2201    };
2202    use serde_json::json;
2203    use settings::SettingsStore;
2204    use util::{path, rel_path::rel_path};
2205
2206    #[gpui::test]
2207    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2208        init_test(cx);
2209        let fs = FakeFs::new(cx.executor());
2210        fs.insert_tree(
2211            "/",
2212            json!({
2213                "a": {}
2214            }),
2215        )
2216        .await;
2217        let project = Project::test(fs.clone(), [], cx).await;
2218        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2219        let agent =
2220            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2221
2222        // Creating a session registers the project and triggers context building.
2223        let connection = NativeAgentConnection(agent.clone());
2224        let _acp_thread = cx
2225            .update(|cx| {
2226                Rc::new(connection).new_session(
2227                    project.clone(),
2228                    PathList::new(&[Path::new("/")]),
2229                    cx,
2230                )
2231            })
2232            .await
2233            .unwrap();
2234        cx.run_until_parked();
2235
2236        let thread = agent.read_with(cx, |agent, _cx| {
2237            agent.sessions.values().next().unwrap().thread.clone()
2238        });
2239
2240        agent.read_with(cx, |agent, cx| {
2241            let project_id = project.entity_id();
2242            let state = agent.projects.get(&project_id).unwrap();
2243            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2244            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2245        });
2246
2247        let worktree = project
2248            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2249            .await
2250            .unwrap();
2251        cx.run_until_parked();
2252        agent.read_with(cx, |agent, cx| {
2253            let project_id = project.entity_id();
2254            let state = agent.projects.get(&project_id).unwrap();
2255            let expected_worktrees = vec![WorktreeContext {
2256                root_name: "a".into(),
2257                abs_path: Path::new("/a").into(),
2258                rules_file: None,
2259            }];
2260            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2261            assert_eq!(
2262                thread.read(cx).project_context().read(cx).worktrees,
2263                expected_worktrees
2264            );
2265        });
2266
2267        // Creating `/a/.rules` updates the project context.
2268        fs.insert_file("/a/.rules", Vec::new()).await;
2269        cx.run_until_parked();
2270        agent.read_with(cx, |agent, cx| {
2271            let project_id = project.entity_id();
2272            let state = agent.projects.get(&project_id).unwrap();
2273            let rules_entry = worktree
2274                .read(cx)
2275                .entry_for_path(rel_path(".rules"))
2276                .unwrap();
2277            let expected_worktrees = vec![WorktreeContext {
2278                root_name: "a".into(),
2279                abs_path: Path::new("/a").into(),
2280                rules_file: Some(RulesFileContext {
2281                    path_in_worktree: rel_path(".rules").into(),
2282                    text: "".into(),
2283                    project_entry_id: rules_entry.id.to_usize(),
2284                }),
2285            }];
2286            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2287            assert_eq!(
2288                thread.read(cx).project_context().read(cx).worktrees,
2289                expected_worktrees
2290            );
2291        });
2292    }
2293
2294    #[gpui::test]
2295    async fn test_listing_models(cx: &mut TestAppContext) {
2296        init_test(cx);
2297        let fs = FakeFs::new(cx.executor());
2298        fs.insert_tree("/", json!({ "a": {}  })).await;
2299        let project = Project::test(fs.clone(), [], cx).await;
2300        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2301        let connection =
2302            NativeAgentConnection(cx.update(|cx| {
2303                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2304            }));
2305
2306        // Create a thread/session
2307        let acp_thread = cx
2308            .update(|cx| {
2309                Rc::new(connection.clone()).new_session(
2310                    project.clone(),
2311                    PathList::new(&[Path::new("/a")]),
2312                    cx,
2313                )
2314            })
2315            .await
2316            .unwrap();
2317
2318        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2319
2320        let models = cx
2321            .update(|cx| {
2322                connection
2323                    .model_selector(&session_id)
2324                    .unwrap()
2325                    .list_models(cx)
2326            })
2327            .await
2328            .unwrap();
2329
2330        let acp_thread::AgentModelList::Grouped(models) = models else {
2331            panic!("Unexpected model group");
2332        };
2333        assert_eq!(
2334            models,
2335            IndexMap::from_iter([(
2336                AgentModelGroupName("Fake".into()),
2337                vec![AgentModelInfo {
2338                    id: acp::ModelId::new("fake/fake"),
2339                    name: "Fake".into(),
2340                    description: None,
2341                    icon: Some(acp_thread::AgentModelIcon::Named(
2342                        ui::IconName::ZedAssistant
2343                    )),
2344                    is_latest: false,
2345                    cost: None,
2346                }]
2347            )])
2348        );
2349    }
2350
2351    #[gpui::test]
2352    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2353        init_test(cx);
2354        let fs = FakeFs::new(cx.executor());
2355        fs.create_dir(paths::settings_file().parent().unwrap())
2356            .await
2357            .unwrap();
2358        fs.insert_file(
2359            paths::settings_file(),
2360            json!({
2361                "agent": {
2362                    "default_model": {
2363                        "provider": "foo",
2364                        "model": "bar"
2365                    }
2366                }
2367            })
2368            .to_string()
2369            .into_bytes(),
2370        )
2371        .await;
2372        let project = Project::test(fs.clone(), [], cx).await;
2373
2374        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2375
2376        // Create the agent and connection
2377        let agent =
2378            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2379        let connection = NativeAgentConnection(agent.clone());
2380
2381        // Create a thread/session
2382        let acp_thread = cx
2383            .update(|cx| {
2384                Rc::new(connection.clone()).new_session(
2385                    project.clone(),
2386                    PathList::new(&[Path::new("/a")]),
2387                    cx,
2388                )
2389            })
2390            .await
2391            .unwrap();
2392
2393        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2394
2395        // Select a model
2396        let selector = connection.model_selector(&session_id).unwrap();
2397        let model_id = acp::ModelId::new("fake/fake");
2398        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2399            .await
2400            .unwrap();
2401
2402        // Verify the thread has the selected model
2403        agent.read_with(cx, |agent, _| {
2404            let session = agent.sessions.get(&session_id).unwrap();
2405            session.thread.read_with(cx, |thread, _| {
2406                assert_eq!(thread.model().unwrap().id().0, "fake");
2407            });
2408        });
2409
2410        cx.run_until_parked();
2411
2412        // Verify settings file was updated
2413        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2414        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2415
2416        // Check that the agent settings contain the selected model
2417        assert_eq!(
2418            settings_json["agent"]["default_model"]["model"],
2419            json!("fake")
2420        );
2421        assert_eq!(
2422            settings_json["agent"]["default_model"]["provider"],
2423            json!("fake")
2424        );
2425
2426        // Register a thinking model and select it.
2427        cx.update(|cx| {
2428            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2429                "fake-corp",
2430                "fake-thinking",
2431                "Fake Thinking",
2432                true,
2433            ));
2434            let thinking_provider = Arc::new(
2435                FakeLanguageModelProvider::new(
2436                    LanguageModelProviderId::from("fake-corp".to_string()),
2437                    LanguageModelProviderName::from("Fake Corp".to_string()),
2438                )
2439                .with_models(vec![thinking_model]),
2440            );
2441            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2442                registry.register_provider(thinking_provider, cx);
2443            });
2444        });
2445        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2446
2447        let selector = connection.model_selector(&session_id).unwrap();
2448        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2449            .await
2450            .unwrap();
2451        cx.run_until_parked();
2452
2453        // Verify enable_thinking was written to settings as true.
2454        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2455        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2456        assert_eq!(
2457            settings_json["agent"]["default_model"]["enable_thinking"],
2458            json!(true),
2459            "selecting a thinking model should persist enable_thinking: true to settings"
2460        );
2461    }
2462
2463    #[gpui::test]
2464    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2465        init_test(cx);
2466        let fs = FakeFs::new(cx.executor());
2467        fs.create_dir(paths::settings_file().parent().unwrap())
2468            .await
2469            .unwrap();
2470        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2471        let project = Project::test(fs.clone(), [], cx).await;
2472
2473        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2474        let agent =
2475            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2476        let connection = NativeAgentConnection(agent.clone());
2477
2478        let acp_thread = cx
2479            .update(|cx| {
2480                Rc::new(connection.clone()).new_session(
2481                    project.clone(),
2482                    PathList::new(&[Path::new("/a")]),
2483                    cx,
2484                )
2485            })
2486            .await
2487            .unwrap();
2488        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2489
2490        // Register a second provider with a thinking model.
2491        cx.update(|cx| {
2492            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2493                "fake-corp",
2494                "fake-thinking",
2495                "Fake Thinking",
2496                true,
2497            ));
2498            let thinking_provider = Arc::new(
2499                FakeLanguageModelProvider::new(
2500                    LanguageModelProviderId::from("fake-corp".to_string()),
2501                    LanguageModelProviderName::from("Fake Corp".to_string()),
2502                )
2503                .with_models(vec![thinking_model]),
2504            );
2505            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2506                registry.register_provider(thinking_provider, cx);
2507            });
2508        });
2509        // Refresh the agent's model list so it picks up the new provider.
2510        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2511
2512        // Thread starts with thinking_enabled = false (the default).
2513        agent.read_with(cx, |agent, _| {
2514            let session = agent.sessions.get(&session_id).unwrap();
2515            session.thread.read_with(cx, |thread, _| {
2516                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2517            });
2518        });
2519
2520        // Select the thinking model via select_model.
2521        let selector = connection.model_selector(&session_id).unwrap();
2522        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2523            .await
2524            .unwrap();
2525
2526        // select_model should have enabled thinking based on the model's supports_thinking().
2527        agent.read_with(cx, |agent, _| {
2528            let session = agent.sessions.get(&session_id).unwrap();
2529            session.thread.read_with(cx, |thread, _| {
2530                assert!(
2531                    thread.thinking_enabled(),
2532                    "select_model should enable thinking when model supports it"
2533                );
2534            });
2535        });
2536
2537        // Switch back to the non-thinking model.
2538        let selector = connection.model_selector(&session_id).unwrap();
2539        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2540            .await
2541            .unwrap();
2542
2543        // select_model should have disabled thinking.
2544        agent.read_with(cx, |agent, _| {
2545            let session = agent.sessions.get(&session_id).unwrap();
2546            session.thread.read_with(cx, |thread, _| {
2547                assert!(
2548                    !thread.thinking_enabled(),
2549                    "select_model should disable thinking when model does not support it"
2550                );
2551            });
2552        });
2553    }
2554
2555    #[gpui::test]
2556    async fn test_summarization_model_survives_transient_registry_clearing(
2557        cx: &mut TestAppContext,
2558    ) {
2559        init_test(cx);
2560        let fs = FakeFs::new(cx.executor());
2561        fs.insert_tree("/", json!({ "a": {} })).await;
2562        let project = Project::test(fs.clone(), [], cx).await;
2563
2564        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2565        let agent =
2566            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2567        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2568
2569        let acp_thread = cx
2570            .update(|cx| {
2571                connection.clone().new_session(
2572                    project.clone(),
2573                    PathList::new(&[Path::new("/a")]),
2574                    cx,
2575                )
2576            })
2577            .await
2578            .unwrap();
2579        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2580
2581        let thread = agent.read_with(cx, |agent, _| {
2582            agent.sessions.get(&session_id).unwrap().thread.clone()
2583        });
2584
2585        thread.read_with(cx, |thread, _| {
2586            assert!(
2587                thread.summarization_model().is_some(),
2588                "session should have a summarization model from the test registry"
2589            );
2590        });
2591
2592        // Simulate what happens during a provider blip:
2593        // update_active_language_model_from_settings calls set_default_model(None)
2594        // when it can't resolve the model, clearing all fallbacks.
2595        cx.update(|cx| {
2596            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2597                registry.set_default_model(None, cx);
2598            });
2599        });
2600        cx.run_until_parked();
2601
2602        thread.read_with(cx, |thread, _| {
2603            assert!(
2604                thread.summarization_model().is_some(),
2605                "summarization model should survive a transient default model clearing"
2606            );
2607        });
2608    }
2609
2610    #[gpui::test]
2611    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2612        init_test(cx);
2613        let fs = FakeFs::new(cx.executor());
2614        fs.insert_tree("/", json!({ "a": {} })).await;
2615        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2616        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2617        let agent = cx.update(|cx| {
2618            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2619        });
2620        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2621
2622        // Register a thinking model.
2623        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2624            "fake-corp",
2625            "fake-thinking",
2626            "Fake Thinking",
2627            true,
2628        ));
2629        let thinking_provider = Arc::new(
2630            FakeLanguageModelProvider::new(
2631                LanguageModelProviderId::from("fake-corp".to_string()),
2632                LanguageModelProviderName::from("Fake Corp".to_string()),
2633            )
2634            .with_models(vec![thinking_model.clone()]),
2635        );
2636        cx.update(|cx| {
2637            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2638                registry.register_provider(thinking_provider, cx);
2639            });
2640        });
2641        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2642
2643        // Create a thread and select the thinking model.
2644        let acp_thread = cx
2645            .update(|cx| {
2646                connection.clone().new_session(
2647                    project.clone(),
2648                    PathList::new(&[Path::new("/a")]),
2649                    cx,
2650                )
2651            })
2652            .await
2653            .unwrap();
2654        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2655
2656        let selector = connection.model_selector(&session_id).unwrap();
2657        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2658            .await
2659            .unwrap();
2660
2661        // Verify thinking is enabled after selecting the thinking model.
2662        let thread = agent.read_with(cx, |agent, _| {
2663            agent.sessions.get(&session_id).unwrap().thread.clone()
2664        });
2665        thread.read_with(cx, |thread, _| {
2666            assert!(
2667                thread.thinking_enabled(),
2668                "thinking should be enabled after selecting thinking model"
2669            );
2670        });
2671
2672        // Send a message so the thread gets persisted.
2673        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2674        let send = cx.foreground_executor().spawn(send);
2675        cx.run_until_parked();
2676
2677        thinking_model.send_last_completion_stream_text_chunk("Response.");
2678        thinking_model.end_last_completion_stream();
2679
2680        send.await.unwrap();
2681        cx.run_until_parked();
2682
2683        // Close the session so it can be reloaded from disk.
2684        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2685            .await
2686            .unwrap();
2687        drop(thread);
2688        drop(acp_thread);
2689        agent.read_with(cx, |agent, _| {
2690            assert!(agent.sessions.is_empty());
2691        });
2692
2693        // Reload the thread and verify thinking_enabled is still true.
2694        let reloaded_acp_thread = agent
2695            .update(cx, |agent, cx| {
2696                agent.open_thread(session_id.clone(), project.clone(), cx)
2697            })
2698            .await
2699            .unwrap();
2700        let reloaded_thread = agent.read_with(cx, |agent, _| {
2701            agent.sessions.get(&session_id).unwrap().thread.clone()
2702        });
2703        reloaded_thread.read_with(cx, |thread, _| {
2704            assert!(
2705                thread.thinking_enabled(),
2706                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2707            );
2708        });
2709
2710        drop(reloaded_acp_thread);
2711    }
2712
2713    #[gpui::test]
2714    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2715        init_test(cx);
2716        let fs = FakeFs::new(cx.executor());
2717        fs.insert_tree("/", json!({ "a": {} })).await;
2718        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2719        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2720        let agent = cx.update(|cx| {
2721            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2722        });
2723        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2724
2725        // Register a model where id() != name(), like real Anthropic models
2726        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2727        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2728            "fake-corp",
2729            "custom-model-id",
2730            "Custom Model Display Name",
2731            false,
2732        ));
2733        let provider = Arc::new(
2734            FakeLanguageModelProvider::new(
2735                LanguageModelProviderId::from("fake-corp".to_string()),
2736                LanguageModelProviderName::from("Fake Corp".to_string()),
2737            )
2738            .with_models(vec![model.clone()]),
2739        );
2740        cx.update(|cx| {
2741            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2742                registry.register_provider(provider, cx);
2743            });
2744        });
2745        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2746
2747        // Create a thread and select the model.
2748        let acp_thread = cx
2749            .update(|cx| {
2750                connection.clone().new_session(
2751                    project.clone(),
2752                    PathList::new(&[Path::new("/a")]),
2753                    cx,
2754                )
2755            })
2756            .await
2757            .unwrap();
2758        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2759
2760        let selector = connection.model_selector(&session_id).unwrap();
2761        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2762            .await
2763            .unwrap();
2764
2765        let thread = agent.read_with(cx, |agent, _| {
2766            agent.sessions.get(&session_id).unwrap().thread.clone()
2767        });
2768        thread.read_with(cx, |thread, _| {
2769            assert_eq!(
2770                thread.model().unwrap().id().0.as_ref(),
2771                "custom-model-id",
2772                "model should be set before persisting"
2773            );
2774        });
2775
2776        // Send a message so the thread gets persisted.
2777        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2778        let send = cx.foreground_executor().spawn(send);
2779        cx.run_until_parked();
2780
2781        model.send_last_completion_stream_text_chunk("Response.");
2782        model.end_last_completion_stream();
2783
2784        send.await.unwrap();
2785        cx.run_until_parked();
2786
2787        // Close the session so it can be reloaded from disk.
2788        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2789            .await
2790            .unwrap();
2791        drop(thread);
2792        drop(acp_thread);
2793        agent.read_with(cx, |agent, _| {
2794            assert!(agent.sessions.is_empty());
2795        });
2796
2797        // Reload the thread and verify the model was preserved.
2798        let reloaded_acp_thread = agent
2799            .update(cx, |agent, cx| {
2800                agent.open_thread(session_id.clone(), project.clone(), cx)
2801            })
2802            .await
2803            .unwrap();
2804        let reloaded_thread = agent.read_with(cx, |agent, _| {
2805            agent.sessions.get(&session_id).unwrap().thread.clone()
2806        });
2807        reloaded_thread.read_with(cx, |thread, _| {
2808            let reloaded_model = thread
2809                .model()
2810                .expect("model should be present after reload");
2811            assert_eq!(
2812                reloaded_model.id().0.as_ref(),
2813                "custom-model-id",
2814                "reloaded thread should have the same model, not fall back to the default"
2815            );
2816        });
2817
2818        drop(reloaded_acp_thread);
2819    }
2820
2821    #[gpui::test]
2822    async fn test_save_load_thread(cx: &mut TestAppContext) {
2823        init_test(cx);
2824        let fs = FakeFs::new(cx.executor());
2825        fs.insert_tree(
2826            "/",
2827            json!({
2828                "a": {
2829                    "b.md": "Lorem"
2830                }
2831            }),
2832        )
2833        .await;
2834        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2835        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2836        let agent = cx.update(|cx| {
2837            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2838        });
2839        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2840
2841        let acp_thread = cx
2842            .update(|cx| {
2843                connection
2844                    .clone()
2845                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2846            })
2847            .await
2848            .unwrap();
2849        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2850        let thread = agent.read_with(cx, |agent, _| {
2851            agent.sessions.get(&session_id).unwrap().thread.clone()
2852        });
2853
2854        // Ensure empty threads are not saved, even if they get mutated.
2855        let model = Arc::new(FakeLanguageModel::default());
2856        let summary_model = Arc::new(FakeLanguageModel::default());
2857        thread.update(cx, |thread, cx| {
2858            thread.set_model(model.clone(), cx);
2859            thread.set_summarization_model(Some(summary_model.clone()), cx);
2860        });
2861        cx.run_until_parked();
2862        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2863
2864        let send = acp_thread.update(cx, |thread, cx| {
2865            thread.send(
2866                vec![
2867                    "What does ".into(),
2868                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2869                        "b.md",
2870                        MentionUri::File {
2871                            abs_path: path!("/a/b.md").into(),
2872                        }
2873                        .to_uri()
2874                        .to_string(),
2875                    )),
2876                    " mean?".into(),
2877                ],
2878                cx,
2879            )
2880        });
2881        let send = cx.foreground_executor().spawn(send);
2882        cx.run_until_parked();
2883
2884        model.send_last_completion_stream_text_chunk("Lorem.");
2885        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2886            language_model::TokenUsage {
2887                input_tokens: 150,
2888                output_tokens: 75,
2889                ..Default::default()
2890            },
2891        ));
2892        model.end_last_completion_stream();
2893        cx.run_until_parked();
2894        summary_model
2895            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2896        summary_model.end_last_completion_stream();
2897
2898        send.await.unwrap();
2899        let uri = MentionUri::File {
2900            abs_path: path!("/a/b.md").into(),
2901        }
2902        .to_uri();
2903        acp_thread.read_with(cx, |thread, cx| {
2904            assert_eq!(
2905                thread.to_markdown(cx),
2906                formatdoc! {"
2907                    ## User
2908
2909                    What does [@b.md]({uri}) mean?
2910
2911                    ## Assistant
2912
2913                    Lorem.
2914
2915                "}
2916            )
2917        });
2918
2919        cx.run_until_parked();
2920
2921        // Set a draft prompt with rich content blocks and scroll position
2922        // AFTER run_until_parked, so the only save that captures these
2923        // changes is the one performed by close_session itself.
2924        let draft_blocks = vec![
2925            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2926            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2927            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2928        ];
2929        acp_thread.update(cx, |thread, cx| {
2930            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2931        });
2932        thread.update(cx, |thread, _cx| {
2933            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2934                item_ix: 5,
2935                offset_in_item: gpui::px(12.5),
2936            }));
2937        });
2938
2939        // Close the session so it can be reloaded from disk.
2940        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2941            .await
2942            .unwrap();
2943        drop(thread);
2944        drop(acp_thread);
2945        agent.read_with(cx, |agent, _| {
2946            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2947        });
2948
2949        // Ensure the thread can be reloaded from disk.
2950        assert_eq!(
2951            thread_entries(&thread_store, cx),
2952            vec![(
2953                session_id.clone(),
2954                format!("Explaining {}", path!("/a/b.md"))
2955            )]
2956        );
2957        let acp_thread = agent
2958            .update(cx, |agent, cx| {
2959                agent.open_thread(session_id.clone(), project.clone(), cx)
2960            })
2961            .await
2962            .unwrap();
2963        acp_thread.read_with(cx, |thread, cx| {
2964            assert_eq!(
2965                thread.to_markdown(cx),
2966                formatdoc! {"
2967                    ## User
2968
2969                    What does [@b.md]({uri}) mean?
2970
2971                    ## Assistant
2972
2973                    Lorem.
2974
2975                "}
2976            )
2977        });
2978
2979        // Ensure the draft prompt with rich content blocks survived the round-trip.
2980        acp_thread.read_with(cx, |thread, _| {
2981            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2982        });
2983
2984        // Ensure token usage survived the round-trip.
2985        acp_thread.read_with(cx, |thread, _| {
2986            let usage = thread
2987                .token_usage()
2988                .expect("token usage should be restored after reload");
2989            assert_eq!(usage.input_tokens, 150);
2990            assert_eq!(usage.output_tokens, 75);
2991        });
2992
2993        // Ensure scroll position survived the round-trip.
2994        acp_thread.read_with(cx, |thread, _| {
2995            let scroll = thread
2996                .ui_scroll_position()
2997                .expect("scroll position should be restored after reload");
2998            assert_eq!(scroll.item_ix, 5);
2999            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3000        });
3001    }
3002
3003    #[gpui::test]
3004    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3005        init_test(cx);
3006        let fs = FakeFs::new(cx.executor());
3007        fs.insert_tree(
3008            "/",
3009            json!({
3010                "a": {
3011                    "file.txt": "hello"
3012                }
3013            }),
3014        )
3015        .await;
3016        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3017        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3018        let agent = cx.update(|cx| {
3019            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3020        });
3021        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3022
3023        let acp_thread = cx
3024            .update(|cx| {
3025                connection
3026                    .clone()
3027                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3028            })
3029            .await
3030            .unwrap();
3031        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3032        let thread = agent.read_with(cx, |agent, _| {
3033            agent.sessions.get(&session_id).unwrap().thread.clone()
3034        });
3035
3036        let model = Arc::new(FakeLanguageModel::default());
3037        thread.update(cx, |thread, cx| {
3038            thread.set_model(model.clone(), cx);
3039        });
3040
3041        // Send a message so the thread is non-empty (empty threads aren't saved).
3042        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3043        let send = cx.foreground_executor().spawn(send);
3044        cx.run_until_parked();
3045
3046        model.send_last_completion_stream_text_chunk("world");
3047        model.end_last_completion_stream();
3048        send.await.unwrap();
3049        cx.run_until_parked();
3050
3051        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3052        // This means no observe-triggered save has run for this change.
3053        // The only way this data gets persisted is if close_session
3054        // itself performs the save.
3055        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3056            "unsaved draft",
3057        ))];
3058        acp_thread.update(cx, |thread, cx| {
3059            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3060        });
3061
3062        // Close the session immediately — no run_until_parked in between.
3063        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3064            .await
3065            .unwrap();
3066        cx.run_until_parked();
3067
3068        // Reopen and verify the draft prompt was saved.
3069        let reloaded = agent
3070            .update(cx, |agent, cx| {
3071                agent.open_thread(session_id.clone(), project.clone(), cx)
3072            })
3073            .await
3074            .unwrap();
3075        reloaded.read_with(cx, |thread, _| {
3076            assert_eq!(
3077                thread.draft_prompt(),
3078                Some(draft_blocks.as_slice()),
3079                "close_session must save the thread; draft prompt was lost"
3080            );
3081        });
3082    }
3083
3084    #[gpui::test]
3085    async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3086        init_test(cx);
3087        let fs = FakeFs::new(cx.executor());
3088        fs.insert_tree(
3089            "/",
3090            json!({
3091                "a": {
3092                    "file.txt": "hello"
3093                }
3094            }),
3095        )
3096        .await;
3097        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3098        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3099        let agent = cx.update(|cx| {
3100            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3101        });
3102        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3103
3104        let acp_thread = cx
3105            .update(|cx| {
3106                connection
3107                    .clone()
3108                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3109            })
3110            .await
3111            .unwrap();
3112        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3113        let thread = agent.read_with(cx, |agent, _| {
3114            agent.sessions.get(&session_id).unwrap().thread.clone()
3115        });
3116
3117        let model = Arc::new(FakeLanguageModel::default());
3118        let summary_model = Arc::new(FakeLanguageModel::default());
3119        thread.update(cx, |thread, cx| {
3120            thread.set_model(model.clone(), cx);
3121            thread.set_summarization_model(Some(summary_model.clone()), cx);
3122        });
3123
3124        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3125        let send = cx.foreground_executor().spawn(send);
3126        cx.run_until_parked();
3127
3128        model.send_last_completion_stream_text_chunk("world");
3129        model.end_last_completion_stream();
3130        send.await.unwrap();
3131        cx.run_until_parked();
3132
3133        let summary = agent.update(cx, |agent, cx| {
3134            agent.thread_summary(session_id.clone(), project.clone(), cx)
3135        });
3136        cx.run_until_parked();
3137
3138        summary_model.send_last_completion_stream_text_chunk("summary");
3139        summary_model.end_last_completion_stream();
3140
3141        assert_eq!(summary.await.unwrap(), "summary");
3142        cx.run_until_parked();
3143
3144        agent.read_with(cx, |agent, _| {
3145            let session = agent
3146                .sessions
3147                .get(&session_id)
3148                .expect("thread_summary should not close the active session");
3149            assert_eq!(
3150                session.ref_count, 1,
3151                "thread_summary should release its temporary session reference"
3152            );
3153        });
3154
3155        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3156            .await
3157            .unwrap();
3158        cx.run_until_parked();
3159
3160        agent.read_with(cx, |agent, _| {
3161            assert!(
3162                agent.sessions.is_empty(),
3163                "closing the active session after thread_summary should unload it"
3164            );
3165        });
3166    }
3167
3168    #[gpui::test]
3169    async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3170        init_test(cx);
3171        let fs = FakeFs::new(cx.executor());
3172        fs.insert_tree(
3173            "/",
3174            json!({
3175                "a": {
3176                    "file.txt": "hello"
3177                }
3178            }),
3179        )
3180        .await;
3181        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3182        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3183        let agent = cx.update(|cx| {
3184            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3185        });
3186        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3187
3188        let acp_thread = cx
3189            .update(|cx| {
3190                connection
3191                    .clone()
3192                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3193            })
3194            .await
3195            .unwrap();
3196        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3197        let thread = agent.read_with(cx, |agent, _| {
3198            agent.sessions.get(&session_id).unwrap().thread.clone()
3199        });
3200
3201        let model = cx.update(|cx| {
3202            LanguageModelRegistry::read_global(cx)
3203                .default_model()
3204                .map(|default_model| default_model.model)
3205                .expect("default test model should be available")
3206        });
3207        let fake_model = model.as_fake();
3208        thread.update(cx, |thread, cx| {
3209            thread.set_model(model.clone(), cx);
3210        });
3211
3212        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3213        let send = cx.foreground_executor().spawn(send);
3214        cx.run_until_parked();
3215
3216        fake_model.send_last_completion_stream_text_chunk("world");
3217        fake_model.end_last_completion_stream();
3218        send.await.unwrap();
3219        cx.run_until_parked();
3220
3221        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3222            .await
3223            .unwrap();
3224        drop(thread);
3225        drop(acp_thread);
3226        agent.read_with(cx, |agent, _| {
3227            assert!(agent.sessions.is_empty());
3228        });
3229
3230        let first_loaded_thread = cx.update(|cx| {
3231            connection.clone().load_session(
3232                session_id.clone(),
3233                project.clone(),
3234                PathList::new(&[Path::new("")]),
3235                None,
3236                cx,
3237            )
3238        });
3239        let second_loaded_thread = cx.update(|cx| {
3240            connection.clone().load_session(
3241                session_id.clone(),
3242                project.clone(),
3243                PathList::new(&[Path::new("")]),
3244                None,
3245                cx,
3246            )
3247        });
3248
3249        let first_loaded_thread = first_loaded_thread.await.unwrap();
3250        let second_loaded_thread = second_loaded_thread.await.unwrap();
3251
3252        cx.run_until_parked();
3253
3254        assert_eq!(
3255            first_loaded_thread.entity_id(),
3256            second_loaded_thread.entity_id(),
3257            "concurrent loads for the same session should share one AcpThread"
3258        );
3259
3260        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3261            .await
3262            .unwrap();
3263
3264        agent.read_with(cx, |agent, _| {
3265            assert!(
3266                agent.sessions.contains_key(&session_id),
3267                "closing one loaded session should not drop shared session state"
3268            );
3269        });
3270
3271        let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3272            thread.send(vec!["still there?".into()], cx)
3273        });
3274        let follow_up = cx.foreground_executor().spawn(follow_up);
3275        cx.run_until_parked();
3276
3277        fake_model.send_last_completion_stream_text_chunk("yes");
3278        fake_model.end_last_completion_stream();
3279        follow_up.await.unwrap();
3280        cx.run_until_parked();
3281
3282        second_loaded_thread.read_with(cx, |thread, cx| {
3283            assert_eq!(
3284                thread.to_markdown(cx),
3285                formatdoc! {"
3286                    ## User
3287
3288                    hello
3289
3290                    ## Assistant
3291
3292                    world
3293
3294                    ## User
3295
3296                    still there?
3297
3298                    ## Assistant
3299
3300                    yes
3301
3302                "}
3303            );
3304        });
3305
3306        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3307            .await
3308            .unwrap();
3309
3310        cx.run_until_parked();
3311
3312        drop(first_loaded_thread);
3313        drop(second_loaded_thread);
3314        agent.read_with(cx, |agent, _| {
3315            assert!(agent.sessions.is_empty());
3316        });
3317    }
3318
3319    #[gpui::test]
3320    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3321        // Regression test: rapid title changes must not cause a propagation loop
3322        // between Thread and AcpThread via handle_thread_title_updated.
3323        init_test(cx);
3324        let fs = FakeFs::new(cx.executor());
3325        fs.insert_tree("/", json!({ "a": {} })).await;
3326        let project = Project::test(fs.clone(), [], cx).await;
3327        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3328        let agent = cx.update(|cx| {
3329            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3330        });
3331        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3332
3333        let acp_thread = cx
3334            .update(|cx| {
3335                connection
3336                    .clone()
3337                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3338            })
3339            .await
3340            .unwrap();
3341
3342        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3343        let thread = agent.read_with(cx, |agent, _| {
3344            agent.sessions.get(&session_id).unwrap().thread.clone()
3345        });
3346
3347        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3348        cx.update(|cx| {
3349            let count = title_updated_count.clone();
3350            cx.subscribe(
3351                &thread,
3352                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3353                    let new_count = {
3354                        let mut count = count.borrow_mut();
3355                        *count += 1;
3356                        *count
3357                    };
3358                    assert!(
3359                        new_count <= 2,
3360                        "TitleUpdated fired {new_count} times; \
3361                         title updates are looping"
3362                    );
3363                },
3364            )
3365            .detach();
3366        });
3367
3368        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3369        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3370
3371        cx.run_until_parked();
3372
3373        thread.read_with(cx, |thread, _| {
3374            assert_eq!(thread.title(), Some("second".into()));
3375        });
3376        acp_thread.read_with(cx, |acp_thread, _| {
3377            assert_eq!(acp_thread.title(), Some("second".into()));
3378        });
3379
3380        assert_eq!(*title_updated_count.borrow(), 2);
3381    }
3382
3383    fn thread_entries(
3384        thread_store: &Entity<ThreadStore>,
3385        cx: &mut TestAppContext,
3386    ) -> Vec<(acp::SessionId, String)> {
3387        thread_store.read_with(cx, |store, _| {
3388            store
3389                .entries()
3390                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3391                .collect::<Vec<_>>()
3392        })
3393    }
3394
3395    fn init_test(cx: &mut TestAppContext) {
3396        env_logger::try_init().ok();
3397        cx.update(|cx| {
3398            let settings_store = SettingsStore::test(cx);
3399            cx.set_global(settings_store);
3400
3401            LanguageModelRegistry::test(cx);
3402        });
3403    }
3404}
3405
3406fn mcp_message_content_to_acp_content_block(
3407    content: context_server::types::MessageContent,
3408) -> acp::ContentBlock {
3409    match content {
3410        context_server::types::MessageContent::Text {
3411            text,
3412            annotations: _,
3413        } => text.into(),
3414        context_server::types::MessageContent::Image {
3415            data,
3416            mime_type,
3417            annotations: _,
3418        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3419        context_server::types::MessageContent::Audio {
3420            data,
3421            mime_type,
3422            annotations: _,
3423        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3424        context_server::types::MessageContent::Resource {
3425            resource,
3426            annotations: _,
3427        } => {
3428            let mut link =
3429                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3430            if let Some(mime_type) = resource.mime_type {
3431                link = link.mime_type(mime_type);
3432            }
3433            acp::ContentBlock::ResourceLink(link)
3434        }
3435    }
3436}