agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, Settings as _, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87    ref_count: usize,
  88}
  89
  90struct PendingSession {
  91    task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
  92    ref_count: usize,
  93}
  94
  95pub struct LanguageModels {
  96    /// Access language model by ID
  97    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  98    /// Cached list for returning language model information
  99    model_list: acp_thread::AgentModelList,
 100    refresh_models_rx: watch::Receiver<()>,
 101    refresh_models_tx: watch::Sender<()>,
 102    _authenticate_all_providers_task: Task<()>,
 103}
 104
 105impl LanguageModels {
 106    fn new(cx: &mut App) -> Self {
 107        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 108
 109        let mut this = Self {
 110            models: HashMap::default(),
 111            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 112            refresh_models_rx,
 113            refresh_models_tx,
 114            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 115        };
 116        this.refresh_list(cx);
 117        this
 118    }
 119
 120    fn refresh_list(&mut self, cx: &App) {
 121        let providers = LanguageModelRegistry::global(cx)
 122            .read(cx)
 123            .visible_providers()
 124            .into_iter()
 125            .filter(|provider| provider.is_authenticated(cx))
 126            .collect::<Vec<_>>();
 127
 128        let mut language_model_list = IndexMap::default();
 129        let mut recommended_models = HashSet::default();
 130
 131        let mut recommended = Vec::new();
 132        for provider in &providers {
 133            for model in provider.recommended_models(cx) {
 134                recommended_models.insert((model.provider_id(), model.id()));
 135                recommended.push(Self::map_language_model_to_info(&model, provider));
 136            }
 137        }
 138        if !recommended.is_empty() {
 139            language_model_list.insert(
 140                acp_thread::AgentModelGroupName("Recommended".into()),
 141                recommended,
 142            );
 143        }
 144
 145        let mut models = HashMap::default();
 146        for provider in providers {
 147            let mut provider_models = Vec::new();
 148            for model in provider.provided_models(cx) {
 149                let model_info = Self::map_language_model_to_info(&model, &provider);
 150                let model_id = model_info.id.clone();
 151                provider_models.push(model_info);
 152                models.insert(model_id, model);
 153            }
 154            if !provider_models.is_empty() {
 155                language_model_list.insert(
 156                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 157                    provider_models,
 158                );
 159            }
 160        }
 161
 162        self.models = models;
 163        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 164        self.refresh_models_tx.send(()).ok();
 165    }
 166
 167    fn watch(&self) -> watch::Receiver<()> {
 168        self.refresh_models_rx.clone()
 169    }
 170
 171    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 172        self.models.get(model_id).cloned()
 173    }
 174
 175    fn map_language_model_to_info(
 176        model: &Arc<dyn LanguageModel>,
 177        provider: &Arc<dyn LanguageModelProvider>,
 178    ) -> acp_thread::AgentModelInfo {
 179        acp_thread::AgentModelInfo {
 180            id: Self::model_id(model),
 181            name: model.name().0,
 182            description: None,
 183            icon: Some(match provider.icon() {
 184                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 185                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 186            }),
 187            is_latest: model.is_latest(),
 188            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 189        }
 190    }
 191
 192    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 193        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 194    }
 195
 196    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 197        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 198            .read(cx)
 199            .visible_providers()
 200            .iter()
 201            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 202            .collect::<Vec<_>>();
 203
 204        cx.spawn(async move |cx| {
 205            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 206                if let Err(err) = authenticate_task.await {
 207                    match err {
 208                        language_model::AuthenticateError::CredentialsNotFound => {
 209                            // Since we're authenticating these providers in the
 210                            // background for the purposes of populating the
 211                            // language selector, we don't care about providers
 212                            // where the credentials are not found.
 213                        }
 214                        language_model::AuthenticateError::ConnectionRefused => {
 215                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 216                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 217                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 218                        }
 219                        _ => {
 220                            // Some providers have noisy failure states that we
 221                            // don't want to spam the logs with every time the
 222                            // language model selector is initialized.
 223                            //
 224                            // Ideally these should have more clear failure modes
 225                            // that we know are safe to ignore here, like what we do
 226                            // with `CredentialsNotFound` above.
 227                            match provider_id.0.as_ref() {
 228                                "lmstudio" | "ollama" => {
 229                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 230                                    //
 231                                    // These fail noisily, so we don't log them.
 232                                }
 233                                "copilot_chat" => {
 234                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 235                                }
 236                                _ => {
 237                                    log::error!(
 238                                        "Failed to authenticate provider: {}: {err:#}",
 239                                        provider_name.0
 240                                    );
 241                                }
 242                            }
 243                        }
 244                    }
 245                }
 246            }
 247
 248            cx.update(language_models::update_environment_fallback_model);
 249        })
 250    }
 251}
 252
 253pub struct NativeAgent {
 254    /// Session ID -> Session mapping
 255    sessions: HashMap<acp::SessionId, Session>,
 256    pending_sessions: HashMap<acp::SessionId, PendingSession>,
 257    thread_store: Entity<ThreadStore>,
 258    /// Project-specific state keyed by project EntityId
 259    projects: HashMap<EntityId, ProjectState>,
 260    /// Shared templates for all threads
 261    templates: Arc<Templates>,
 262    /// Cached model information
 263    models: LanguageModels,
 264    prompt_store: Option<Entity<PromptStore>>,
 265    fs: Arc<dyn Fs>,
 266    _subscriptions: Vec<Subscription>,
 267}
 268
 269impl NativeAgent {
 270    pub fn new(
 271        thread_store: Entity<ThreadStore>,
 272        templates: Arc<Templates>,
 273        prompt_store: Option<Entity<PromptStore>>,
 274        fs: Arc<dyn Fs>,
 275        cx: &mut App,
 276    ) -> Entity<NativeAgent> {
 277        log::debug!("Creating new NativeAgent");
 278
 279        cx.new(|cx| {
 280            let mut subscriptions = vec![cx.subscribe(
 281                &LanguageModelRegistry::global(cx),
 282                Self::handle_models_updated_event,
 283            )];
 284            if let Some(prompt_store) = prompt_store.as_ref() {
 285                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 286            }
 287
 288            Self {
 289                sessions: HashMap::default(),
 290                pending_sessions: HashMap::default(),
 291                thread_store,
 292                projects: HashMap::default(),
 293                templates,
 294                models: LanguageModels::new(cx),
 295                prompt_store,
 296                fs,
 297                _subscriptions: subscriptions,
 298            }
 299        })
 300    }
 301
 302    fn new_session(
 303        &mut self,
 304        project: Entity<Project>,
 305        cx: &mut Context<Self>,
 306    ) -> Entity<AcpThread> {
 307        let project_id = self.get_or_create_project_state(&project, cx);
 308        let project_state = &self.projects[&project_id];
 309
 310        let registry = LanguageModelRegistry::read_global(cx);
 311        let available_count = registry.available_models(cx).count();
 312        log::debug!("Total available models: {}", available_count);
 313
 314        let default_model = registry.default_model().and_then(|default_model| {
 315            self.models
 316                .model_from_id(&LanguageModels::model_id(&default_model.model))
 317        });
 318        let thread = cx.new(|cx| {
 319            Thread::new(
 320                project,
 321                project_state.project_context.clone(),
 322                project_state.context_server_registry.clone(),
 323                self.templates.clone(),
 324                default_model,
 325                cx,
 326            )
 327        });
 328
 329        self.register_session(thread, project_id, 1, cx)
 330    }
 331
 332    fn register_session(
 333        &mut self,
 334        thread_handle: Entity<Thread>,
 335        project_id: EntityId,
 336        ref_count: usize,
 337        cx: &mut Context<Self>,
 338    ) -> Entity<AcpThread> {
 339        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 340
 341        let thread = thread_handle.read(cx);
 342        let session_id = thread.id().clone();
 343        let parent_session_id = thread.parent_thread_id();
 344        let title = thread.title();
 345        let draft_prompt = thread.draft_prompt().map(Vec::from);
 346        let scroll_position = thread.ui_scroll_position();
 347        let token_usage = thread.latest_token_usage();
 348        let project = thread.project.clone();
 349        let action_log = thread.action_log.clone();
 350        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 351        let acp_thread = cx.new(|cx| {
 352            let mut acp_thread = acp_thread::AcpThread::new(
 353                parent_session_id,
 354                title,
 355                None,
 356                connection,
 357                project.clone(),
 358                action_log.clone(),
 359                session_id.clone(),
 360                prompt_capabilities_rx,
 361                cx,
 362            );
 363            acp_thread.set_draft_prompt(draft_prompt, cx);
 364            acp_thread.set_ui_scroll_position(scroll_position);
 365            acp_thread.update_token_usage(token_usage, cx);
 366            acp_thread
 367        });
 368
 369        let registry = LanguageModelRegistry::read_global(cx);
 370        let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
 371
 372        let weak = cx.weak_entity();
 373        let weak_thread = thread_handle.downgrade();
 374        thread_handle.update(cx, |thread, cx| {
 375            thread.set_summarization_model(summarization_model, cx);
 376            thread.add_default_tools(
 377                Rc::new(NativeThreadEnvironment {
 378                    acp_thread: acp_thread.downgrade(),
 379                    thread: weak_thread,
 380                    agent: weak,
 381                }) as _,
 382                cx,
 383            )
 384        });
 385
 386        let subscriptions = vec![
 387            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 388            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 389            cx.observe(&thread_handle, move |this, thread, cx| {
 390                this.save_thread(thread, cx)
 391            }),
 392        ];
 393
 394        self.sessions.insert(
 395            session_id,
 396            Session {
 397                thread: thread_handle,
 398                acp_thread: acp_thread.clone(),
 399                project_id,
 400                _subscriptions: subscriptions,
 401                pending_save: Task::ready(Ok(())),
 402                ref_count,
 403            },
 404        );
 405
 406        self.update_available_commands_for_project(project_id, cx);
 407
 408        acp_thread
 409    }
 410
 411    pub fn models(&self) -> &LanguageModels {
 412        &self.models
 413    }
 414
 415    fn get_or_create_project_state(
 416        &mut self,
 417        project: &Entity<Project>,
 418        cx: &mut Context<Self>,
 419    ) -> EntityId {
 420        let project_id = project.entity_id();
 421        if self.projects.contains_key(&project_id) {
 422            return project_id;
 423        }
 424
 425        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 426        self.register_project_with_initial_context(project.clone(), project_context, cx);
 427        if let Some(state) = self.projects.get_mut(&project_id) {
 428            state.project_context_needs_refresh.send(()).ok();
 429        }
 430        project_id
 431    }
 432
 433    fn register_project_with_initial_context(
 434        &mut self,
 435        project: Entity<Project>,
 436        project_context: Entity<ProjectContext>,
 437        cx: &mut Context<Self>,
 438    ) {
 439        let project_id = project.entity_id();
 440
 441        let context_server_store = project.read(cx).context_server_store();
 442        let context_server_registry =
 443            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 444
 445        let subscriptions = vec![
 446            cx.subscribe(&project, Self::handle_project_event),
 447            cx.subscribe(
 448                &context_server_store,
 449                Self::handle_context_server_store_updated,
 450            ),
 451            cx.subscribe(
 452                &context_server_registry,
 453                Self::handle_context_server_registry_event,
 454            ),
 455        ];
 456
 457        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 458            watch::channel(());
 459
 460        self.projects.insert(
 461            project_id,
 462            ProjectState {
 463                project,
 464                project_context,
 465                project_context_needs_refresh: project_context_needs_refresh_tx,
 466                _maintain_project_context: cx.spawn(async move |this, cx| {
 467                    Self::maintain_project_context(
 468                        this,
 469                        project_id,
 470                        project_context_needs_refresh_rx,
 471                        cx,
 472                    )
 473                    .await
 474                }),
 475                context_server_registry,
 476                _subscriptions: subscriptions,
 477            },
 478        );
 479    }
 480
 481    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 482        self.sessions
 483            .get(session_id)
 484            .and_then(|session| self.projects.get(&session.project_id))
 485    }
 486
 487    async fn maintain_project_context(
 488        this: WeakEntity<Self>,
 489        project_id: EntityId,
 490        mut needs_refresh: watch::Receiver<()>,
 491        cx: &mut AsyncApp,
 492    ) -> Result<()> {
 493        while needs_refresh.changed().await.is_ok() {
 494            let project_context = this
 495                .update(cx, |this, cx| {
 496                    let state = this
 497                        .projects
 498                        .get(&project_id)
 499                        .context("project state not found")?;
 500                    anyhow::Ok(Self::build_project_context(
 501                        &state.project,
 502                        this.prompt_store.as_ref(),
 503                        cx,
 504                    ))
 505                })??
 506                .await;
 507            this.update(cx, |this, cx| {
 508                if let Some(state) = this.projects.get(&project_id) {
 509                    state
 510                        .project_context
 511                        .update(cx, |current_project_context, _cx| {
 512                            *current_project_context = project_context;
 513                        });
 514                }
 515            })?;
 516        }
 517
 518        Ok(())
 519    }
 520
 521    fn build_project_context(
 522        project: &Entity<Project>,
 523        prompt_store: Option<&Entity<PromptStore>>,
 524        cx: &mut App,
 525    ) -> Task<ProjectContext> {
 526        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 527        let worktree_tasks = worktrees
 528            .into_iter()
 529            .map(|worktree| {
 530                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 531            })
 532            .collect::<Vec<_>>();
 533        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 534            prompt_store.read_with(cx, |prompt_store, cx| {
 535                let prompts = prompt_store.default_prompt_metadata();
 536                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 537                    let contents = prompt_store.load(prompt_metadata.id, cx);
 538                    async move { (contents.await, prompt_metadata) }
 539                });
 540                cx.background_spawn(future::join_all(load_tasks))
 541            })
 542        } else {
 543            Task::ready(vec![])
 544        };
 545
 546        cx.spawn(async move |_cx| {
 547            let (worktrees, default_user_rules) =
 548                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 549
 550            let worktrees = worktrees
 551                .into_iter()
 552                .map(|(worktree, _rules_error)| {
 553                    // TODO: show error message
 554                    // if let Some(rules_error) = rules_error {
 555                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 556                    // }
 557                    worktree
 558                })
 559                .collect::<Vec<_>>();
 560
 561            let default_user_rules = default_user_rules
 562                .into_iter()
 563                .flat_map(|(contents, prompt_metadata)| match contents {
 564                    Ok(contents) => Some(UserRulesContext {
 565                        uuid: prompt_metadata.id.as_user()?,
 566                        title: prompt_metadata.title.map(|title| title.to_string()),
 567                        contents,
 568                    }),
 569                    Err(_err) => {
 570                        // TODO: show error message
 571                        // this.update(cx, |_, cx| {
 572                        //     cx.emit(RulesLoadingError {
 573                        //         message: format!("{err:?}").into(),
 574                        //     });
 575                        // })
 576                        // .ok();
 577                        None
 578                    }
 579                })
 580                .collect::<Vec<_>>();
 581
 582            ProjectContext::new(worktrees, default_user_rules)
 583        })
 584    }
 585
 586    fn load_worktree_info_for_system_prompt(
 587        worktree: Entity<Worktree>,
 588        project: Entity<Project>,
 589        cx: &mut App,
 590    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 591        let tree = worktree.read(cx);
 592        let root_name = tree.root_name_str().into();
 593        let abs_path = tree.abs_path();
 594
 595        let mut context = WorktreeContext {
 596            root_name,
 597            abs_path,
 598            rules_file: None,
 599        };
 600
 601        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 602        let Some(rules_task) = rules_task else {
 603            return Task::ready((context, None));
 604        };
 605
 606        cx.spawn(async move |_| {
 607            let (rules_file, rules_file_error) = match rules_task.await {
 608                Ok(rules_file) => (Some(rules_file), None),
 609                Err(err) => (
 610                    None,
 611                    Some(RulesLoadingError {
 612                        message: format!("{err}").into(),
 613                    }),
 614                ),
 615            };
 616            context.rules_file = rules_file;
 617            (context, rules_file_error)
 618        })
 619    }
 620
 621    fn load_worktree_rules_file(
 622        worktree: Entity<Worktree>,
 623        project: Entity<Project>,
 624        cx: &mut App,
 625    ) -> Option<Task<Result<RulesFileContext>>> {
 626        let worktree = worktree.read(cx);
 627        let worktree_id = worktree.id();
 628        let selected_rules_file = RULES_FILE_NAMES
 629            .into_iter()
 630            .filter_map(|name| {
 631                worktree
 632                    .entry_for_path(RelPath::unix(name).unwrap())
 633                    .filter(|entry| entry.is_file())
 634                    .map(|entry| entry.path.clone())
 635            })
 636            .next();
 637
 638        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 639        // supported. This doesn't seem to occur often in GitHub repositories.
 640        selected_rules_file.map(|path_in_worktree| {
 641            let project_path = ProjectPath {
 642                worktree_id,
 643                path: path_in_worktree.clone(),
 644            };
 645            let buffer_task =
 646                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 647            let rope_task = cx.spawn(async move |cx| {
 648                let buffer = buffer_task.await?;
 649                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 650                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 651                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 652                })?;
 653                anyhow::Ok((project_entry_id, rope))
 654            });
 655            // Build a string from the rope on a background thread.
 656            cx.background_spawn(async move {
 657                let (project_entry_id, rope) = rope_task.await?;
 658                anyhow::Ok(RulesFileContext {
 659                    path_in_worktree,
 660                    text: rope.to_string().trim().to_string(),
 661                    project_entry_id: project_entry_id.to_usize(),
 662                })
 663            })
 664        })
 665    }
 666
 667    fn handle_thread_title_updated(
 668        &mut self,
 669        thread: Entity<Thread>,
 670        _: &TitleUpdated,
 671        cx: &mut Context<Self>,
 672    ) {
 673        let session_id = thread.read(cx).id();
 674        let Some(session) = self.sessions.get(session_id) else {
 675            return;
 676        };
 677
 678        let thread = thread.downgrade();
 679        let acp_thread = session.acp_thread.downgrade();
 680        cx.spawn(async move |_, cx| {
 681            let title = thread.read_with(cx, |thread, _| thread.title())?;
 682            if let Some(title) = title {
 683                let task =
 684                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 685                task.await?;
 686            }
 687            anyhow::Ok(())
 688        })
 689        .detach_and_log_err(cx);
 690    }
 691
 692    fn handle_thread_token_usage_updated(
 693        &mut self,
 694        thread: Entity<Thread>,
 695        usage: &TokenUsageUpdated,
 696        cx: &mut Context<Self>,
 697    ) {
 698        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 699            return;
 700        };
 701        session.acp_thread.update(cx, |acp_thread, cx| {
 702            acp_thread.update_token_usage(usage.0.clone(), cx);
 703        });
 704    }
 705
 706    fn handle_project_event(
 707        &mut self,
 708        project: Entity<Project>,
 709        event: &project::Event,
 710        _cx: &mut Context<Self>,
 711    ) {
 712        let project_id = project.entity_id();
 713        let Some(state) = self.projects.get_mut(&project_id) else {
 714            return;
 715        };
 716        match event {
 717            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 718                state.project_context_needs_refresh.send(()).ok();
 719            }
 720            project::Event::WorktreeUpdatedEntries(_, items) => {
 721                if items.iter().any(|(path, _, _)| {
 722                    RULES_FILE_NAMES
 723                        .iter()
 724                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 725                }) {
 726                    state.project_context_needs_refresh.send(()).ok();
 727                }
 728            }
 729            _ => {}
 730        }
 731    }
 732
 733    fn handle_prompts_updated_event(
 734        &mut self,
 735        _prompt_store: Entity<PromptStore>,
 736        _event: &prompt_store::PromptsUpdatedEvent,
 737        _cx: &mut Context<Self>,
 738    ) {
 739        for state in self.projects.values_mut() {
 740            state.project_context_needs_refresh.send(()).ok();
 741        }
 742    }
 743
 744    fn handle_models_updated_event(
 745        &mut self,
 746        _registry: Entity<LanguageModelRegistry>,
 747        event: &language_model::Event,
 748        cx: &mut Context<Self>,
 749    ) {
 750        self.models.refresh_list(cx);
 751
 752        let registry = LanguageModelRegistry::read_global(cx);
 753        let default_model = registry.default_model().map(|m| m.model);
 754        let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
 755
 756        for session in self.sessions.values_mut() {
 757            session.thread.update(cx, |thread, cx| {
 758                let should_update_model = thread.model().is_none()
 759                    || (thread.is_empty()
 760                        && matches!(event, language_model::Event::DefaultModelChanged));
 761                if should_update_model && let Some(model) = default_model.clone() {
 762                    thread.set_model(model, cx);
 763                    cx.notify();
 764                }
 765                if let Some(model) = summarization_model.clone() {
 766                    if thread.summarization_model().is_none()
 767                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 768                    {
 769                        thread.set_summarization_model(Some(model), cx);
 770                    }
 771                }
 772            });
 773        }
 774    }
 775
 776    fn handle_context_server_store_updated(
 777        &mut self,
 778        store: Entity<project::context_server_store::ContextServerStore>,
 779        _event: &project::context_server_store::ServerStatusChangedEvent,
 780        cx: &mut Context<Self>,
 781    ) {
 782        let project_id = self.projects.iter().find_map(|(id, state)| {
 783            if *state.context_server_registry.read(cx).server_store() == store {
 784                Some(*id)
 785            } else {
 786                None
 787            }
 788        });
 789        if let Some(project_id) = project_id {
 790            self.update_available_commands_for_project(project_id, cx);
 791        }
 792    }
 793
 794    fn handle_context_server_registry_event(
 795        &mut self,
 796        registry: Entity<ContextServerRegistry>,
 797        event: &ContextServerRegistryEvent,
 798        cx: &mut Context<Self>,
 799    ) {
 800        match event {
 801            ContextServerRegistryEvent::ToolsChanged => {}
 802            ContextServerRegistryEvent::PromptsChanged => {
 803                let project_id = self.projects.iter().find_map(|(id, state)| {
 804                    if state.context_server_registry == registry {
 805                        Some(*id)
 806                    } else {
 807                        None
 808                    }
 809                });
 810                if let Some(project_id) = project_id {
 811                    self.update_available_commands_for_project(project_id, cx);
 812                }
 813            }
 814        }
 815    }
 816
 817    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 818        let available_commands =
 819            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 820        for session in self.sessions.values() {
 821            if session.project_id != project_id {
 822                continue;
 823            }
 824            session.acp_thread.update(cx, |thread, cx| {
 825                thread
 826                    .handle_session_update(
 827                        acp::SessionUpdate::AvailableCommandsUpdate(
 828                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 829                        ),
 830                        cx,
 831                    )
 832                    .log_err();
 833            });
 834        }
 835    }
 836
 837    fn build_available_commands_for_project(
 838        project_state: Option<&ProjectState>,
 839        cx: &App,
 840    ) -> Vec<acp::AvailableCommand> {
 841        let Some(state) = project_state else {
 842            return vec![];
 843        };
 844        let registry = state.context_server_registry.read(cx);
 845
 846        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 847        for context_server_prompt in registry.prompts() {
 848            *prompt_name_counts
 849                .entry(context_server_prompt.prompt.name.as_str())
 850                .or_insert(0) += 1;
 851        }
 852
 853        registry
 854            .prompts()
 855            .flat_map(|context_server_prompt| {
 856                let prompt = &context_server_prompt.prompt;
 857
 858                let should_prefix = prompt_name_counts
 859                    .get(prompt.name.as_str())
 860                    .copied()
 861                    .unwrap_or(0)
 862                    > 1;
 863
 864                let name = if should_prefix {
 865                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 866                } else {
 867                    prompt.name.clone()
 868                };
 869
 870                let mut command = acp::AvailableCommand::new(
 871                    name,
 872                    prompt.description.clone().unwrap_or_default(),
 873                );
 874
 875                match prompt.arguments.as_deref() {
 876                    Some([arg]) => {
 877                        let hint = format!("<{}>", arg.name);
 878
 879                        command = command.input(acp::AvailableCommandInput::Unstructured(
 880                            acp::UnstructuredCommandInput::new(hint),
 881                        ));
 882                    }
 883                    Some([]) | None => {}
 884                    Some(_) => {
 885                        // skip >1 argument commands since we don't support them yet
 886                        return None;
 887                    }
 888                }
 889
 890                Some(command)
 891            })
 892            .collect()
 893    }
 894
 895    pub fn load_thread(
 896        &mut self,
 897        id: acp::SessionId,
 898        project: Entity<Project>,
 899        cx: &mut Context<Self>,
 900    ) -> Task<Result<Entity<Thread>>> {
 901        let database_future = ThreadsDatabase::connect(cx);
 902        cx.spawn(async move |this, cx| {
 903            let database = database_future.await.map_err(|err| anyhow!(err))?;
 904            let db_thread = database
 905                .load_thread(id.clone())
 906                .await?
 907                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 908
 909            this.update(cx, |this, cx| {
 910                let project_id = this.get_or_create_project_state(&project, cx);
 911                let project_state = this
 912                    .projects
 913                    .get(&project_id)
 914                    .context("project state not found")?;
 915                let summarization_model = LanguageModelRegistry::read_global(cx)
 916                    .thread_summary_model(cx)
 917                    .map(|c| c.model);
 918
 919                Ok(cx.new(|cx| {
 920                    let mut thread = Thread::from_db(
 921                        id.clone(),
 922                        db_thread,
 923                        project_state.project.clone(),
 924                        project_state.project_context.clone(),
 925                        project_state.context_server_registry.clone(),
 926                        this.templates.clone(),
 927                        cx,
 928                    );
 929                    thread.set_summarization_model(summarization_model, cx);
 930                    thread
 931                }))
 932            })?
 933        })
 934    }
 935
 936    pub fn open_thread(
 937        &mut self,
 938        id: acp::SessionId,
 939        project: Entity<Project>,
 940        cx: &mut Context<Self>,
 941    ) -> Task<Result<Entity<AcpThread>>> {
 942        if let Some(session) = self.sessions.get_mut(&id) {
 943            session.ref_count += 1;
 944            return Task::ready(Ok(session.acp_thread.clone()));
 945        }
 946
 947        if let Some(pending) = self.pending_sessions.get_mut(&id) {
 948            pending.ref_count += 1;
 949            let task = pending.task.clone();
 950            return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
 951        }
 952
 953        let task = self.load_thread(id.clone(), project.clone(), cx);
 954        let shared_task = cx
 955            .spawn({
 956                let id = id.clone();
 957                async move |this, cx| {
 958                    let thread = match task.await {
 959                        Ok(thread) => thread,
 960                        Err(err) => {
 961                            this.update(cx, |this, _cx| {
 962                                this.pending_sessions.remove(&id);
 963                            })
 964                            .ok();
 965                            return Err(Arc::new(err));
 966                        }
 967                    };
 968                    let acp_thread = this
 969                        .update(cx, |this, cx| {
 970                            let project_id = this.get_or_create_project_state(&project, cx);
 971                            let ref_count = this
 972                                .pending_sessions
 973                                .remove(&id)
 974                                .map_or(1, |pending| pending.ref_count);
 975                            this.register_session(thread.clone(), project_id, ref_count, cx)
 976                        })
 977                        .map_err(Arc::new)?;
 978                    let events = thread.update(cx, |thread, cx| thread.replay(cx));
 979                    cx.update(|cx| {
 980                        NativeAgentConnection::handle_thread_events(
 981                            events,
 982                            acp_thread.downgrade(),
 983                            cx,
 984                        )
 985                    })
 986                    .await
 987                    .map_err(Arc::new)?;
 988                    acp_thread.update(cx, |thread, cx| {
 989                        thread.snapshot_completed_plan(cx);
 990                    });
 991                    Ok(acp_thread)
 992                }
 993            })
 994            .shared();
 995        self.pending_sessions.insert(
 996            id,
 997            PendingSession {
 998                task: shared_task.clone(),
 999                ref_count: 1,
1000            },
1001        );
1002
1003        cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1004    }
1005
1006    pub fn thread_summary(
1007        &mut self,
1008        id: acp::SessionId,
1009        project: Entity<Project>,
1010        cx: &mut Context<Self>,
1011    ) -> Task<Result<SharedString>> {
1012        let thread = self.open_thread(id.clone(), project, cx);
1013        cx.spawn(async move |this, cx| {
1014            let acp_thread = thread.await?;
1015            let result = this
1016                .update(cx, |this, cx| {
1017                    this.sessions
1018                        .get(&id)
1019                        .unwrap()
1020                        .thread
1021                        .update(cx, |thread, cx| thread.summary(cx))
1022                })?
1023                .await
1024                .context("Failed to generate summary")?;
1025
1026            this.update(cx, |this, cx| this.close_session(&id, cx))?
1027                .await?;
1028            drop(acp_thread);
1029            Ok(result)
1030        })
1031    }
1032
1033    fn close_session(
1034        &mut self,
1035        session_id: &acp::SessionId,
1036        cx: &mut Context<Self>,
1037    ) -> Task<Result<()>> {
1038        let Some(session) = self.sessions.get_mut(session_id) else {
1039            return Task::ready(Ok(()));
1040        };
1041
1042        session.ref_count -= 1;
1043        if session.ref_count > 0 {
1044            return Task::ready(Ok(()));
1045        }
1046
1047        let thread = session.thread.clone();
1048        self.save_thread(thread, cx);
1049        let Some(session) = self.sessions.remove(session_id) else {
1050            return Task::ready(Ok(()));
1051        };
1052        let project_id = session.project_id;
1053
1054        let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1055        if !has_remaining {
1056            self.projects.remove(&project_id);
1057        }
1058
1059        session.pending_save
1060    }
1061
1062    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1063        if thread.read(cx).is_empty() {
1064            return;
1065        }
1066
1067        let id = thread.read(cx).id().clone();
1068        let Some(session) = self.sessions.get_mut(&id) else {
1069            return;
1070        };
1071
1072        let project_id = session.project_id;
1073        let Some(state) = self.projects.get(&project_id) else {
1074            return;
1075        };
1076
1077        let folder_paths = PathList::new(
1078            &state
1079                .project
1080                .read(cx)
1081                .visible_worktrees(cx)
1082                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1083                .collect::<Vec<_>>(),
1084        );
1085
1086        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1087        let database_future = ThreadsDatabase::connect(cx);
1088        let db_thread = thread.update(cx, |thread, cx| {
1089            thread.set_draft_prompt(draft_prompt);
1090            thread.to_db(cx)
1091        });
1092        let thread_store = self.thread_store.clone();
1093        session.pending_save = cx.spawn(async move |_, cx| {
1094            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1095                return Ok(());
1096            };
1097            let db_thread = db_thread.await;
1098            database
1099                .save_thread(id, db_thread, folder_paths)
1100                .await
1101                .log_err();
1102            thread_store.update(cx, |store, cx| store.reload(cx));
1103            Ok(())
1104        });
1105    }
1106
1107    fn send_mcp_prompt(
1108        &self,
1109        message_id: UserMessageId,
1110        session_id: acp::SessionId,
1111        prompt_name: String,
1112        server_id: ContextServerId,
1113        arguments: HashMap<String, String>,
1114        original_content: Vec<acp::ContentBlock>,
1115        cx: &mut Context<Self>,
1116    ) -> Task<Result<acp::PromptResponse>> {
1117        let Some(state) = self.session_project_state(&session_id) else {
1118            return Task::ready(Err(anyhow!("Project state not found for session")));
1119        };
1120        let server_store = state
1121            .context_server_registry
1122            .read(cx)
1123            .server_store()
1124            .clone();
1125        let path_style = state.project.read(cx).path_style(cx);
1126
1127        cx.spawn(async move |this, cx| {
1128            let prompt =
1129                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1130
1131            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1132                let session = this
1133                    .sessions
1134                    .get(&session_id)
1135                    .context("Failed to get session")?;
1136                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1137            })??;
1138
1139            let mut last_is_user = true;
1140
1141            thread.update(cx, |thread, cx| {
1142                thread.push_acp_user_block(
1143                    message_id,
1144                    original_content.into_iter().skip(1),
1145                    path_style,
1146                    cx,
1147                );
1148            });
1149
1150            for message in prompt.messages {
1151                let context_server::types::PromptMessage { role, content } = message;
1152                let block = mcp_message_content_to_acp_content_block(content);
1153
1154                match role {
1155                    context_server::types::Role::User => {
1156                        let id = acp_thread::UserMessageId::new();
1157
1158                        acp_thread.update(cx, |acp_thread, cx| {
1159                            acp_thread.push_user_content_block_with_indent(
1160                                Some(id.clone()),
1161                                block.clone(),
1162                                true,
1163                                cx,
1164                            );
1165                        });
1166
1167                        thread.update(cx, |thread, cx| {
1168                            thread.push_acp_user_block(id, [block], path_style, cx);
1169                        });
1170                    }
1171                    context_server::types::Role::Assistant => {
1172                        acp_thread.update(cx, |acp_thread, cx| {
1173                            acp_thread.push_assistant_content_block_with_indent(
1174                                block.clone(),
1175                                false,
1176                                true,
1177                                cx,
1178                            );
1179                        });
1180
1181                        thread.update(cx, |thread, cx| {
1182                            thread.push_acp_agent_block(block, cx);
1183                        });
1184                    }
1185                }
1186
1187                last_is_user = role == context_server::types::Role::User;
1188            }
1189
1190            let response_stream = thread.update(cx, |thread, cx| {
1191                if last_is_user {
1192                    thread.send_existing(cx)
1193                } else {
1194                    // Resume if MCP prompt did not end with a user message
1195                    thread.resume(cx)
1196                }
1197            })?;
1198
1199            cx.update(|cx| {
1200                NativeAgentConnection::handle_thread_events(
1201                    response_stream,
1202                    acp_thread.downgrade(),
1203                    cx,
1204                )
1205            })
1206            .await
1207        })
1208    }
1209}
1210
1211/// Wrapper struct that implements the AgentConnection trait
1212#[derive(Clone)]
1213pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1214
1215impl NativeAgentConnection {
1216    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1217        self.0
1218            .read(cx)
1219            .sessions
1220            .get(session_id)
1221            .map(|session| session.thread.clone())
1222    }
1223
1224    pub fn load_thread(
1225        &self,
1226        id: acp::SessionId,
1227        project: Entity<Project>,
1228        cx: &mut App,
1229    ) -> Task<Result<Entity<Thread>>> {
1230        self.0
1231            .update(cx, |this, cx| this.load_thread(id, project, cx))
1232    }
1233
1234    fn run_turn(
1235        &self,
1236        session_id: acp::SessionId,
1237        cx: &mut App,
1238        f: impl 'static
1239        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1240    ) -> Task<Result<acp::PromptResponse>> {
1241        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1242            agent
1243                .sessions
1244                .get_mut(&session_id)
1245                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1246        }) else {
1247            log::error!("Session not found in run_turn: {}", session_id);
1248            return Task::ready(Err(anyhow!("Session not found")));
1249        };
1250        log::debug!("Found session for: {}", session_id);
1251
1252        let response_stream = match f(thread, cx) {
1253            Ok(stream) => stream,
1254            Err(err) => return Task::ready(Err(err)),
1255        };
1256        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1257    }
1258
1259    fn handle_thread_events(
1260        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1261        acp_thread: WeakEntity<AcpThread>,
1262        cx: &App,
1263    ) -> Task<Result<acp::PromptResponse>> {
1264        cx.spawn(async move |cx| {
1265            // Handle response stream and forward to session.acp_thread
1266            while let Some(result) = events.next().await {
1267                match result {
1268                    Ok(event) => {
1269                        log::trace!("Received completion event: {:?}", event);
1270
1271                        match event {
1272                            ThreadEvent::UserMessage(message) => {
1273                                acp_thread.update(cx, |thread, cx| {
1274                                    for content in message.content {
1275                                        thread.push_user_content_block(
1276                                            Some(message.id.clone()),
1277                                            content.into(),
1278                                            cx,
1279                                        );
1280                                    }
1281                                })?;
1282                            }
1283                            ThreadEvent::AgentText(text) => {
1284                                acp_thread.update(cx, |thread, cx| {
1285                                    thread.push_assistant_content_block(text.into(), false, cx)
1286                                })?;
1287                            }
1288                            ThreadEvent::AgentThinking(text) => {
1289                                acp_thread.update(cx, |thread, cx| {
1290                                    thread.push_assistant_content_block(text.into(), true, cx)
1291                                })?;
1292                            }
1293                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1294                                tool_call,
1295                                options,
1296                                response,
1297                                context: _,
1298                            }) => {
1299                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1300                                    thread.request_tool_call_authorization(tool_call, options, cx)
1301                                })??;
1302                                cx.background_spawn(async move {
1303                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1304                                        outcome_task.await
1305                                    {
1306                                        response
1307                                            .send(outcome)
1308                                            .map(|_| anyhow!("authorization receiver was dropped"))
1309                                            .log_err();
1310                                    }
1311                                })
1312                                .detach();
1313                            }
1314                            ThreadEvent::ToolCall(tool_call) => {
1315                                acp_thread.update(cx, |thread, cx| {
1316                                    thread.upsert_tool_call(tool_call, cx)
1317                                })??;
1318                            }
1319                            ThreadEvent::ToolCallUpdate(update) => {
1320                                acp_thread.update(cx, |thread, cx| {
1321                                    thread.update_tool_call(update, cx)
1322                                })??;
1323                            }
1324                            ThreadEvent::Plan(plan) => {
1325                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1326                            }
1327                            ThreadEvent::SubagentSpawned(session_id) => {
1328                                acp_thread.update(cx, |thread, cx| {
1329                                    thread.subagent_spawned(session_id, cx);
1330                                })?;
1331                            }
1332                            ThreadEvent::Retry(status) => {
1333                                acp_thread.update(cx, |thread, cx| {
1334                                    thread.update_retry_status(status, cx)
1335                                })?;
1336                            }
1337                            ThreadEvent::Stop(stop_reason) => {
1338                                log::debug!("Assistant message complete: {:?}", stop_reason);
1339                                return Ok(acp::PromptResponse::new(stop_reason));
1340                            }
1341                        }
1342                    }
1343                    Err(e) => {
1344                        log::error!("Error in model response stream: {:?}", e);
1345                        return Err(e);
1346                    }
1347                }
1348            }
1349
1350            log::debug!("Response stream completed");
1351            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1352        })
1353    }
1354}
1355
1356struct Command<'a> {
1357    prompt_name: &'a str,
1358    arg_value: &'a str,
1359    explicit_server_id: Option<&'a str>,
1360}
1361
1362impl<'a> Command<'a> {
1363    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1364        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1365            return None;
1366        };
1367        let text = text_content.text.trim();
1368        let command = text.strip_prefix('/')?;
1369        let (command, arg_value) = command
1370            .split_once(char::is_whitespace)
1371            .unwrap_or((command, ""));
1372
1373        if let Some((server_id, prompt_name)) = command.split_once('.') {
1374            Some(Self {
1375                prompt_name,
1376                arg_value,
1377                explicit_server_id: Some(server_id),
1378            })
1379        } else {
1380            Some(Self {
1381                prompt_name: command,
1382                arg_value,
1383                explicit_server_id: None,
1384            })
1385        }
1386    }
1387}
1388
1389struct NativeAgentModelSelector {
1390    session_id: acp::SessionId,
1391    connection: NativeAgentConnection,
1392}
1393
1394impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1395    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1396        log::debug!("NativeAgentConnection::list_models called");
1397        let list = self.connection.0.read(cx).models.model_list.clone();
1398        Task::ready(if list.is_empty() {
1399            Err(anyhow::anyhow!("No models available"))
1400        } else {
1401            Ok(list)
1402        })
1403    }
1404
1405    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1406        log::debug!(
1407            "Setting model for session {}: {}",
1408            self.session_id,
1409            model_id
1410        );
1411        let Some(thread) = self
1412            .connection
1413            .0
1414            .read(cx)
1415            .sessions
1416            .get(&self.session_id)
1417            .map(|session| session.thread.clone())
1418        else {
1419            return Task::ready(Err(anyhow!("Session not found")));
1420        };
1421
1422        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1423            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1424        };
1425
1426        let favorite = agent_settings::AgentSettings::get_global(cx)
1427            .favorite_models
1428            .iter()
1429            .find(|favorite| {
1430                favorite.provider.0 == model.provider_id().0.as_ref()
1431                    && favorite.model == model.id().0.as_ref()
1432            })
1433            .cloned();
1434
1435        let LanguageModelSelection {
1436            enable_thinking,
1437            effort,
1438            speed,
1439            ..
1440        } = agent_settings::language_model_to_selection(&model, favorite.as_ref());
1441
1442        thread.update(cx, |thread, cx| {
1443            thread.set_model(model.clone(), cx);
1444            thread.set_thinking_effort(effort.clone(), cx);
1445            thread.set_thinking_enabled(enable_thinking, cx);
1446            if let Some(speed) = speed {
1447                thread.set_speed(speed, cx);
1448            }
1449        });
1450
1451        update_settings_file(
1452            self.connection.0.read(cx).fs.clone(),
1453            cx,
1454            move |settings, cx| {
1455                let provider = model.provider_id().0.to_string();
1456                let model = model.id().0.to_string();
1457                let enable_thinking = thread.read(cx).thinking_enabled();
1458                let speed = thread.read(cx).speed();
1459                settings
1460                    .agent
1461                    .get_or_insert_default()
1462                    .set_model(LanguageModelSelection {
1463                        provider: provider.into(),
1464                        model,
1465                        enable_thinking,
1466                        effort,
1467                        speed,
1468                    });
1469            },
1470        );
1471
1472        Task::ready(Ok(()))
1473    }
1474
1475    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1476        let Some(thread) = self
1477            .connection
1478            .0
1479            .read(cx)
1480            .sessions
1481            .get(&self.session_id)
1482            .map(|session| session.thread.clone())
1483        else {
1484            return Task::ready(Err(anyhow!("Session not found")));
1485        };
1486        let Some(model) = thread.read(cx).model() else {
1487            return Task::ready(Err(anyhow!("Model not found")));
1488        };
1489        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1490        else {
1491            return Task::ready(Err(anyhow!("Provider not found")));
1492        };
1493        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1494            model, &provider,
1495        )))
1496    }
1497
1498    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1499        Some(self.connection.0.read(cx).models.watch())
1500    }
1501
1502    fn should_render_footer(&self) -> bool {
1503        true
1504    }
1505}
1506
1507pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1508
1509impl acp_thread::AgentConnection for NativeAgentConnection {
1510    fn agent_id(&self) -> AgentId {
1511        ZED_AGENT_ID.clone()
1512    }
1513
1514    fn telemetry_id(&self) -> SharedString {
1515        "zed".into()
1516    }
1517
1518    fn new_session(
1519        self: Rc<Self>,
1520        project: Entity<Project>,
1521        work_dirs: PathList,
1522        cx: &mut App,
1523    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1524        log::debug!("Creating new thread for project at: {work_dirs:?}");
1525        Task::ready(Ok(self
1526            .0
1527            .update(cx, |agent, cx| agent.new_session(project, cx))))
1528    }
1529
1530    fn supports_load_session(&self) -> bool {
1531        true
1532    }
1533
1534    fn load_session(
1535        self: Rc<Self>,
1536        session_id: acp::SessionId,
1537        project: Entity<Project>,
1538        _work_dirs: PathList,
1539        _title: Option<SharedString>,
1540        cx: &mut App,
1541    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1542        self.0
1543            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1544    }
1545
1546    fn supports_close_session(&self) -> bool {
1547        true
1548    }
1549
1550    fn close_session(
1551        self: Rc<Self>,
1552        session_id: &acp::SessionId,
1553        cx: &mut App,
1554    ) -> Task<Result<()>> {
1555        self.0
1556            .update(cx, |agent, cx| agent.close_session(session_id, cx))
1557    }
1558
1559    fn auth_methods(&self) -> &[acp::AuthMethod] {
1560        &[] // No auth for in-process
1561    }
1562
1563    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1564        Task::ready(Ok(()))
1565    }
1566
1567    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1568        Some(Rc::new(NativeAgentModelSelector {
1569            session_id: session_id.clone(),
1570            connection: self.clone(),
1571        }) as Rc<dyn AgentModelSelector>)
1572    }
1573
1574    fn prompt(
1575        &self,
1576        id: acp_thread::UserMessageId,
1577        params: acp::PromptRequest,
1578        cx: &mut App,
1579    ) -> Task<Result<acp::PromptResponse>> {
1580        let session_id = params.session_id.clone();
1581        log::info!("Received prompt request for session: {}", session_id);
1582        log::debug!("Prompt blocks count: {}", params.prompt.len());
1583
1584        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1585            log::error!("Session not found in prompt: {}", session_id);
1586            if self.0.read(cx).sessions.contains_key(&session_id) {
1587                log::error!(
1588                    "Session found in sessions map, but not in project state: {}",
1589                    session_id
1590                );
1591            }
1592            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1593        };
1594
1595        if let Some(parsed_command) = Command::parse(&params.prompt) {
1596            let registry = project_state.context_server_registry.read(cx);
1597
1598            let explicit_server_id = parsed_command
1599                .explicit_server_id
1600                .map(|server_id| ContextServerId(server_id.into()));
1601
1602            if let Some(prompt) =
1603                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1604            {
1605                let arguments = if !parsed_command.arg_value.is_empty()
1606                    && let Some(arg_name) = prompt
1607                        .prompt
1608                        .arguments
1609                        .as_ref()
1610                        .and_then(|args| args.first())
1611                        .map(|arg| arg.name.clone())
1612                {
1613                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1614                } else {
1615                    Default::default()
1616                };
1617
1618                let prompt_name = prompt.prompt.name.clone();
1619                let server_id = prompt.server_id.clone();
1620
1621                return self.0.update(cx, |agent, cx| {
1622                    agent.send_mcp_prompt(
1623                        id,
1624                        session_id.clone(),
1625                        prompt_name,
1626                        server_id,
1627                        arguments,
1628                        params.prompt,
1629                        cx,
1630                    )
1631                });
1632            }
1633        };
1634
1635        let path_style = project_state.project.read(cx).path_style(cx);
1636
1637        self.run_turn(session_id, cx, move |thread, cx| {
1638            let content: Vec<UserMessageContent> = params
1639                .prompt
1640                .into_iter()
1641                .map(|block| UserMessageContent::from_content_block(block, path_style))
1642                .collect::<Vec<_>>();
1643            log::debug!("Converted prompt to message: {} chars", content.len());
1644            log::debug!("Message id: {:?}", id);
1645            log::debug!("Message content: {:?}", content);
1646
1647            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1648        })
1649    }
1650
1651    fn retry(
1652        &self,
1653        session_id: &acp::SessionId,
1654        _cx: &App,
1655    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1656        Some(Rc::new(NativeAgentSessionRetry {
1657            connection: self.clone(),
1658            session_id: session_id.clone(),
1659        }) as _)
1660    }
1661
1662    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1663        log::info!("Cancelling on session: {}", session_id);
1664        self.0.update(cx, |agent, cx| {
1665            if let Some(session) = agent.sessions.get(session_id) {
1666                session
1667                    .thread
1668                    .update(cx, |thread, cx| thread.cancel(cx))
1669                    .detach();
1670            }
1671        });
1672    }
1673
1674    fn truncate(
1675        &self,
1676        session_id: &acp::SessionId,
1677        cx: &App,
1678    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1679        self.0.read_with(cx, |agent, _cx| {
1680            agent.sessions.get(session_id).map(|session| {
1681                Rc::new(NativeAgentSessionTruncate {
1682                    thread: session.thread.clone(),
1683                    acp_thread: session.acp_thread.downgrade(),
1684                }) as _
1685            })
1686        })
1687    }
1688
1689    fn set_title(
1690        &self,
1691        session_id: &acp::SessionId,
1692        cx: &App,
1693    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1694        self.0.read_with(cx, |agent, _cx| {
1695            agent
1696                .sessions
1697                .get(session_id)
1698                .filter(|s| !s.thread.read(cx).is_subagent())
1699                .map(|session| {
1700                    Rc::new(NativeAgentSessionSetTitle {
1701                        thread: session.thread.clone(),
1702                    }) as _
1703                })
1704        })
1705    }
1706
1707    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1708        let thread_store = self.0.read(cx).thread_store.clone();
1709        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1710    }
1711
1712    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1713        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1714    }
1715
1716    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1717        self
1718    }
1719}
1720
1721impl acp_thread::AgentTelemetry for NativeAgentConnection {
1722    fn thread_data(
1723        &self,
1724        session_id: &acp::SessionId,
1725        cx: &mut App,
1726    ) -> Task<Result<serde_json::Value>> {
1727        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1728            return Task::ready(Err(anyhow!("Session not found")));
1729        };
1730
1731        let task = session.thread.read(cx).to_db(cx);
1732        cx.background_spawn(async move {
1733            serde_json::to_value(task.await).context("Failed to serialize thread")
1734        })
1735    }
1736}
1737
1738pub struct NativeAgentSessionList {
1739    thread_store: Entity<ThreadStore>,
1740    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1741    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1742    _subscription: Subscription,
1743}
1744
1745impl NativeAgentSessionList {
1746    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1747        let (tx, rx) = smol::channel::unbounded();
1748        let this_tx = tx.clone();
1749        let subscription = cx.observe(&thread_store, move |_, _| {
1750            this_tx
1751                .try_send(acp_thread::SessionListUpdate::Refresh)
1752                .ok();
1753        });
1754        Self {
1755            thread_store,
1756            updates_tx: tx,
1757            updates_rx: rx,
1758            _subscription: subscription,
1759        }
1760    }
1761
1762    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1763        &self.thread_store
1764    }
1765}
1766
1767impl AgentSessionList for NativeAgentSessionList {
1768    fn list_sessions(
1769        &self,
1770        _request: AgentSessionListRequest,
1771        cx: &mut App,
1772    ) -> Task<Result<AgentSessionListResponse>> {
1773        let sessions = self
1774            .thread_store
1775            .read(cx)
1776            .entries()
1777            .map(|entry| AgentSessionInfo::from(&entry))
1778            .collect();
1779        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1780    }
1781
1782    fn supports_delete(&self) -> bool {
1783        true
1784    }
1785
1786    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1787        self.thread_store
1788            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1789    }
1790
1791    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1792        self.thread_store
1793            .update(cx, |store, cx| store.delete_threads(cx))
1794    }
1795
1796    fn watch(
1797        &self,
1798        _cx: &mut App,
1799    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1800        Some(self.updates_rx.clone())
1801    }
1802
1803    fn notify_refresh(&self) {
1804        self.updates_tx
1805            .try_send(acp_thread::SessionListUpdate::Refresh)
1806            .ok();
1807    }
1808
1809    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1810        self
1811    }
1812}
1813
1814struct NativeAgentSessionTruncate {
1815    thread: Entity<Thread>,
1816    acp_thread: WeakEntity<AcpThread>,
1817}
1818
1819impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1820    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1821        match self.thread.update(cx, |thread, cx| {
1822            thread.truncate(message_id.clone(), cx)?;
1823            Ok(thread.latest_token_usage())
1824        }) {
1825            Ok(usage) => {
1826                self.acp_thread
1827                    .update(cx, |thread, cx| {
1828                        thread.update_token_usage(usage, cx);
1829                    })
1830                    .ok();
1831                Task::ready(Ok(()))
1832            }
1833            Err(error) => Task::ready(Err(error)),
1834        }
1835    }
1836}
1837
1838struct NativeAgentSessionRetry {
1839    connection: NativeAgentConnection,
1840    session_id: acp::SessionId,
1841}
1842
1843impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1844    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1845        self.connection
1846            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1847                thread.update(cx, |thread, cx| thread.resume(cx))
1848            })
1849    }
1850}
1851
1852struct NativeAgentSessionSetTitle {
1853    thread: Entity<Thread>,
1854}
1855
1856impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1857    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1858        self.thread
1859            .update(cx, |thread, cx| thread.set_title(title, cx));
1860        Task::ready(Ok(()))
1861    }
1862}
1863
1864pub struct NativeThreadEnvironment {
1865    agent: WeakEntity<NativeAgent>,
1866    thread: WeakEntity<Thread>,
1867    acp_thread: WeakEntity<AcpThread>,
1868}
1869
1870impl NativeThreadEnvironment {
1871    pub(crate) fn create_subagent_thread(
1872        &self,
1873        label: String,
1874        cx: &mut App,
1875    ) -> Result<Rc<dyn SubagentHandle>> {
1876        let Some(parent_thread_entity) = self.thread.upgrade() else {
1877            anyhow::bail!("Parent thread no longer exists".to_string());
1878        };
1879        let parent_thread = parent_thread_entity.read(cx);
1880        let current_depth = parent_thread.depth();
1881        let parent_session_id = parent_thread.id().clone();
1882
1883        if current_depth >= MAX_SUBAGENT_DEPTH {
1884            return Err(anyhow!(
1885                "Maximum subagent depth ({}) reached",
1886                MAX_SUBAGENT_DEPTH
1887            ));
1888        }
1889
1890        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1891            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1892            thread.set_title(label.into(), cx);
1893            thread
1894        });
1895
1896        let session_id = subagent_thread.read(cx).id().clone();
1897
1898        let acp_thread = self
1899            .agent
1900            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1901                let project_id = agent
1902                    .sessions
1903                    .get(&parent_session_id)
1904                    .map(|s| s.project_id)
1905                    .context("parent session not found")?;
1906                Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1907            })??;
1908
1909        let depth = current_depth + 1;
1910
1911        telemetry::event!(
1912            "Subagent Started",
1913            session = parent_thread_entity.read(cx).id().to_string(),
1914            subagent_session = session_id.to_string(),
1915            depth,
1916            is_resumed = false,
1917        );
1918
1919        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1920    }
1921
1922    pub(crate) fn resume_subagent_thread(
1923        &self,
1924        session_id: acp::SessionId,
1925        cx: &mut App,
1926    ) -> Result<Rc<dyn SubagentHandle>> {
1927        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1928            let session = agent
1929                .sessions
1930                .get(&session_id)
1931                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1932            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1933        })??;
1934
1935        let depth = subagent_thread.read(cx).depth();
1936
1937        if let Some(parent_thread_entity) = self.thread.upgrade() {
1938            telemetry::event!(
1939                "Subagent Started",
1940                session = parent_thread_entity.read(cx).id().to_string(),
1941                subagent_session = session_id.to_string(),
1942                depth,
1943                is_resumed = true,
1944            );
1945        }
1946
1947        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1948    }
1949
1950    fn prompt_subagent(
1951        &self,
1952        session_id: acp::SessionId,
1953        subagent_thread: Entity<Thread>,
1954        acp_thread: Entity<acp_thread::AcpThread>,
1955    ) -> Result<Rc<dyn SubagentHandle>> {
1956        let Some(parent_thread_entity) = self.thread.upgrade() else {
1957            anyhow::bail!("Parent thread no longer exists".to_string());
1958        };
1959        Ok(Rc::new(NativeSubagentHandle::new(
1960            session_id,
1961            subagent_thread,
1962            acp_thread,
1963            parent_thread_entity,
1964        )) as _)
1965    }
1966}
1967
1968impl ThreadEnvironment for NativeThreadEnvironment {
1969    fn create_terminal(
1970        &self,
1971        command: String,
1972        cwd: Option<PathBuf>,
1973        output_byte_limit: Option<u64>,
1974        cx: &mut AsyncApp,
1975    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1976        let task = self.acp_thread.update(cx, |thread, cx| {
1977            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1978        });
1979
1980        let acp_thread = self.acp_thread.clone();
1981        cx.spawn(async move |cx| {
1982            let terminal = task?.await?;
1983
1984            let (drop_tx, drop_rx) = oneshot::channel();
1985            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1986
1987            cx.spawn(async move |cx| {
1988                drop_rx.await.ok();
1989                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1990            })
1991            .detach();
1992
1993            let handle = AcpTerminalHandle {
1994                terminal,
1995                _drop_tx: Some(drop_tx),
1996            };
1997
1998            Ok(Rc::new(handle) as _)
1999        })
2000    }
2001
2002    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
2003        self.create_subagent_thread(label, cx)
2004    }
2005
2006    fn resume_subagent(
2007        &self,
2008        session_id: acp::SessionId,
2009        cx: &mut App,
2010    ) -> Result<Rc<dyn SubagentHandle>> {
2011        self.resume_subagent_thread(session_id, cx)
2012    }
2013}
2014
2015#[derive(Debug, Clone)]
2016enum SubagentPromptResult {
2017    Completed,
2018    Cancelled,
2019    ContextWindowWarning,
2020    Error(String),
2021}
2022
2023pub struct NativeSubagentHandle {
2024    session_id: acp::SessionId,
2025    parent_thread: WeakEntity<Thread>,
2026    subagent_thread: Entity<Thread>,
2027    acp_thread: Entity<acp_thread::AcpThread>,
2028}
2029
2030impl NativeSubagentHandle {
2031    fn new(
2032        session_id: acp::SessionId,
2033        subagent_thread: Entity<Thread>,
2034        acp_thread: Entity<acp_thread::AcpThread>,
2035        parent_thread_entity: Entity<Thread>,
2036    ) -> Self {
2037        NativeSubagentHandle {
2038            session_id,
2039            subagent_thread,
2040            parent_thread: parent_thread_entity.downgrade(),
2041            acp_thread,
2042        }
2043    }
2044}
2045
2046impl SubagentHandle for NativeSubagentHandle {
2047    fn id(&self) -> acp::SessionId {
2048        self.session_id.clone()
2049    }
2050
2051    fn num_entries(&self, cx: &App) -> usize {
2052        self.acp_thread.read(cx).entries().len()
2053    }
2054
2055    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2056        let thread = self.subagent_thread.clone();
2057        let acp_thread = self.acp_thread.clone();
2058        let subagent_session_id = self.session_id.clone();
2059        let parent_thread = self.parent_thread.clone();
2060
2061        cx.spawn(async move |cx| {
2062            let (task, _subscription) = cx.update(|cx| {
2063                let ratio_before_prompt = thread
2064                    .read(cx)
2065                    .latest_token_usage()
2066                    .map(|usage| usage.ratio());
2067
2068                parent_thread
2069                    .update(cx, |parent_thread, _cx| {
2070                        parent_thread.register_running_subagent(thread.downgrade())
2071                    })
2072                    .ok();
2073
2074                let task = acp_thread.update(cx, |acp_thread, cx| {
2075                    acp_thread.send(vec![message.into()], cx)
2076                });
2077
2078                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2079                let mut token_limit_tx = Some(token_limit_tx);
2080
2081                let subscription = cx.subscribe(
2082                    &thread,
2083                    move |_thread, event: &TokenUsageUpdated, _cx| {
2084                        if let Some(usage) = &event.0 {
2085                            let old_ratio = ratio_before_prompt
2086                                .clone()
2087                                .unwrap_or(TokenUsageRatio::Normal);
2088                            let new_ratio = usage.ratio();
2089                            if old_ratio == TokenUsageRatio::Normal
2090                                && new_ratio == TokenUsageRatio::Warning
2091                            {
2092                                if let Some(tx) = token_limit_tx.take() {
2093                                    tx.send(()).ok();
2094                                }
2095                            }
2096                        }
2097                    },
2098                );
2099
2100                let wait_for_prompt = cx
2101                    .background_spawn(async move {
2102                        futures::select! {
2103                            response = task.fuse() => match response {
2104                                Ok(Some(response)) => {
2105                                    match response.stop_reason {
2106                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2107                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2108                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2109                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2110                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2111                                    }
2112                                }
2113                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2114                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2115                            },
2116                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2117                        }
2118                    });
2119
2120                (wait_for_prompt, subscription)
2121            });
2122
2123            let result = match task.await {
2124                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2125                    thread
2126                        .last_message()
2127                        .and_then(|message| {
2128                            let content = message.as_agent_message()?
2129                                .content
2130                                .iter()
2131                                .filter_map(|c| match c {
2132                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2133                                    _ => None,
2134                                })
2135                                .join("\n\n");
2136                            if content.is_empty() {
2137                                None
2138                            } else {
2139                                Some( content)
2140                            }
2141                        })
2142                        .context("No response from subagent")
2143                }),
2144                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2145                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2146                SubagentPromptResult::ContextWindowWarning => {
2147                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2148                    Err(anyhow!(
2149                        "The agent is nearing the end of its context window and has been \
2150                         stopped. You can prompt the thread again to have the agent wrap up \
2151                         or hand off its work."
2152                    ))
2153                }
2154            };
2155
2156            parent_thread
2157                .update(cx, |parent_thread, cx| {
2158                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2159                })
2160                .ok();
2161
2162            result
2163        })
2164    }
2165}
2166
2167pub struct AcpTerminalHandle {
2168    terminal: Entity<acp_thread::Terminal>,
2169    _drop_tx: Option<oneshot::Sender<()>>,
2170}
2171
2172impl TerminalHandle for AcpTerminalHandle {
2173    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2174        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2175    }
2176
2177    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2178        Ok(self
2179            .terminal
2180            .read_with(cx, |term, _cx| term.wait_for_exit()))
2181    }
2182
2183    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2184        Ok(self
2185            .terminal
2186            .read_with(cx, |term, cx| term.current_output(cx)))
2187    }
2188
2189    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2190        cx.update(|cx| {
2191            self.terminal.update(cx, |terminal, cx| {
2192                terminal.kill(cx);
2193            });
2194        });
2195        Ok(())
2196    }
2197
2198    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2199        Ok(self
2200            .terminal
2201            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2202    }
2203}
2204
2205#[cfg(test)]
2206mod internal_tests {
2207    use std::path::Path;
2208
2209    use super::*;
2210    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2211    use fs::FakeFs;
2212    use gpui::TestAppContext;
2213    use indoc::formatdoc;
2214    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2215    use language_model::{
2216        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2217    };
2218    use serde_json::json;
2219    use settings::SettingsStore;
2220    use util::{path, rel_path::rel_path};
2221
2222    #[gpui::test]
2223    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2224        init_test(cx);
2225        let fs = FakeFs::new(cx.executor());
2226        fs.insert_tree(
2227            "/",
2228            json!({
2229                "a": {}
2230            }),
2231        )
2232        .await;
2233        let project = Project::test(fs.clone(), [], cx).await;
2234        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2235        let agent =
2236            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2237
2238        // Creating a session registers the project and triggers context building.
2239        let connection = NativeAgentConnection(agent.clone());
2240        let _acp_thread = cx
2241            .update(|cx| {
2242                Rc::new(connection).new_session(
2243                    project.clone(),
2244                    PathList::new(&[Path::new("/")]),
2245                    cx,
2246                )
2247            })
2248            .await
2249            .unwrap();
2250        cx.run_until_parked();
2251
2252        let thread = agent.read_with(cx, |agent, _cx| {
2253            agent.sessions.values().next().unwrap().thread.clone()
2254        });
2255
2256        agent.read_with(cx, |agent, cx| {
2257            let project_id = project.entity_id();
2258            let state = agent.projects.get(&project_id).unwrap();
2259            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2260            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2261        });
2262
2263        let worktree = project
2264            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2265            .await
2266            .unwrap();
2267        cx.run_until_parked();
2268        agent.read_with(cx, |agent, cx| {
2269            let project_id = project.entity_id();
2270            let state = agent.projects.get(&project_id).unwrap();
2271            let expected_worktrees = vec![WorktreeContext {
2272                root_name: "a".into(),
2273                abs_path: Path::new("/a").into(),
2274                rules_file: None,
2275            }];
2276            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2277            assert_eq!(
2278                thread.read(cx).project_context().read(cx).worktrees,
2279                expected_worktrees
2280            );
2281        });
2282
2283        // Creating `/a/.rules` updates the project context.
2284        fs.insert_file("/a/.rules", Vec::new()).await;
2285        cx.run_until_parked();
2286        agent.read_with(cx, |agent, cx| {
2287            let project_id = project.entity_id();
2288            let state = agent.projects.get(&project_id).unwrap();
2289            let rules_entry = worktree
2290                .read(cx)
2291                .entry_for_path(rel_path(".rules"))
2292                .unwrap();
2293            let expected_worktrees = vec![WorktreeContext {
2294                root_name: "a".into(),
2295                abs_path: Path::new("/a").into(),
2296                rules_file: Some(RulesFileContext {
2297                    path_in_worktree: rel_path(".rules").into(),
2298                    text: "".into(),
2299                    project_entry_id: rules_entry.id.to_usize(),
2300                }),
2301            }];
2302            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2303            assert_eq!(
2304                thread.read(cx).project_context().read(cx).worktrees,
2305                expected_worktrees
2306            );
2307        });
2308    }
2309
2310    #[gpui::test]
2311    async fn test_listing_models(cx: &mut TestAppContext) {
2312        init_test(cx);
2313        let fs = FakeFs::new(cx.executor());
2314        fs.insert_tree("/", json!({ "a": {}  })).await;
2315        let project = Project::test(fs.clone(), [], cx).await;
2316        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2317        let connection =
2318            NativeAgentConnection(cx.update(|cx| {
2319                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2320            }));
2321
2322        // Create a thread/session
2323        let acp_thread = cx
2324            .update(|cx| {
2325                Rc::new(connection.clone()).new_session(
2326                    project.clone(),
2327                    PathList::new(&[Path::new("/a")]),
2328                    cx,
2329                )
2330            })
2331            .await
2332            .unwrap();
2333
2334        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2335
2336        let models = cx
2337            .update(|cx| {
2338                connection
2339                    .model_selector(&session_id)
2340                    .unwrap()
2341                    .list_models(cx)
2342            })
2343            .await
2344            .unwrap();
2345
2346        let acp_thread::AgentModelList::Grouped(models) = models else {
2347            panic!("Unexpected model group");
2348        };
2349        assert_eq!(
2350            models,
2351            IndexMap::from_iter([(
2352                AgentModelGroupName("Fake".into()),
2353                vec![AgentModelInfo {
2354                    id: acp::ModelId::new("fake/fake"),
2355                    name: "Fake".into(),
2356                    description: None,
2357                    icon: Some(acp_thread::AgentModelIcon::Named(
2358                        ui::IconName::ZedAssistant
2359                    )),
2360                    is_latest: false,
2361                    cost: None,
2362                }]
2363            )])
2364        );
2365    }
2366
2367    #[gpui::test]
2368    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2369        init_test(cx);
2370        let fs = FakeFs::new(cx.executor());
2371        fs.create_dir(paths::settings_file().parent().unwrap())
2372            .await
2373            .unwrap();
2374        fs.insert_file(
2375            paths::settings_file(),
2376            json!({
2377                "agent": {
2378                    "default_model": {
2379                        "provider": "foo",
2380                        "model": "bar"
2381                    }
2382                }
2383            })
2384            .to_string()
2385            .into_bytes(),
2386        )
2387        .await;
2388        let project = Project::test(fs.clone(), [], cx).await;
2389
2390        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2391
2392        // Create the agent and connection
2393        let agent =
2394            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2395        let connection = NativeAgentConnection(agent.clone());
2396
2397        // Create a thread/session
2398        let acp_thread = cx
2399            .update(|cx| {
2400                Rc::new(connection.clone()).new_session(
2401                    project.clone(),
2402                    PathList::new(&[Path::new("/a")]),
2403                    cx,
2404                )
2405            })
2406            .await
2407            .unwrap();
2408
2409        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2410
2411        // Select a model
2412        let selector = connection.model_selector(&session_id).unwrap();
2413        let model_id = acp::ModelId::new("fake/fake");
2414        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2415            .await
2416            .unwrap();
2417
2418        // Verify the thread has the selected model
2419        agent.read_with(cx, |agent, _| {
2420            let session = agent.sessions.get(&session_id).unwrap();
2421            session.thread.read_with(cx, |thread, _| {
2422                assert_eq!(thread.model().unwrap().id().0, "fake");
2423            });
2424        });
2425
2426        cx.run_until_parked();
2427
2428        // Verify settings file was updated
2429        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2430        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2431
2432        // Check that the agent settings contain the selected model
2433        assert_eq!(
2434            settings_json["agent"]["default_model"]["model"],
2435            json!("fake")
2436        );
2437        assert_eq!(
2438            settings_json["agent"]["default_model"]["provider"],
2439            json!("fake")
2440        );
2441
2442        // Register a thinking model and select it.
2443        cx.update(|cx| {
2444            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2445                "fake-corp",
2446                "fake-thinking",
2447                "Fake Thinking",
2448                true,
2449            ));
2450            let thinking_provider = Arc::new(
2451                FakeLanguageModelProvider::new(
2452                    LanguageModelProviderId::from("fake-corp".to_string()),
2453                    LanguageModelProviderName::from("Fake Corp".to_string()),
2454                )
2455                .with_models(vec![thinking_model]),
2456            );
2457            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2458                registry.register_provider(thinking_provider, cx);
2459            });
2460        });
2461        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2462
2463        let selector = connection.model_selector(&session_id).unwrap();
2464        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2465            .await
2466            .unwrap();
2467        cx.run_until_parked();
2468
2469        // Verify enable_thinking was written to settings as true.
2470        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2471        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2472        assert_eq!(
2473            settings_json["agent"]["default_model"]["enable_thinking"],
2474            json!(true),
2475            "selecting a thinking model should persist enable_thinking: true to settings"
2476        );
2477    }
2478
2479    #[gpui::test]
2480    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2481        init_test(cx);
2482        let fs = FakeFs::new(cx.executor());
2483        fs.create_dir(paths::settings_file().parent().unwrap())
2484            .await
2485            .unwrap();
2486        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2487        let project = Project::test(fs.clone(), [], cx).await;
2488
2489        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2490        let agent =
2491            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2492        let connection = NativeAgentConnection(agent.clone());
2493
2494        let acp_thread = cx
2495            .update(|cx| {
2496                Rc::new(connection.clone()).new_session(
2497                    project.clone(),
2498                    PathList::new(&[Path::new("/a")]),
2499                    cx,
2500                )
2501            })
2502            .await
2503            .unwrap();
2504        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2505
2506        // Register a second provider with a thinking model.
2507        cx.update(|cx| {
2508            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2509                "fake-corp",
2510                "fake-thinking",
2511                "Fake Thinking",
2512                true,
2513            ));
2514            let thinking_provider = Arc::new(
2515                FakeLanguageModelProvider::new(
2516                    LanguageModelProviderId::from("fake-corp".to_string()),
2517                    LanguageModelProviderName::from("Fake Corp".to_string()),
2518                )
2519                .with_models(vec![thinking_model]),
2520            );
2521            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2522                registry.register_provider(thinking_provider, cx);
2523            });
2524        });
2525        // Refresh the agent's model list so it picks up the new provider.
2526        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2527
2528        // Thread starts with thinking_enabled = false (the default).
2529        agent.read_with(cx, |agent, _| {
2530            let session = agent.sessions.get(&session_id).unwrap();
2531            session.thread.read_with(cx, |thread, _| {
2532                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2533            });
2534        });
2535
2536        // Select the thinking model via select_model.
2537        let selector = connection.model_selector(&session_id).unwrap();
2538        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2539            .await
2540            .unwrap();
2541
2542        // select_model should have enabled thinking based on the model's supports_thinking().
2543        agent.read_with(cx, |agent, _| {
2544            let session = agent.sessions.get(&session_id).unwrap();
2545            session.thread.read_with(cx, |thread, _| {
2546                assert!(
2547                    thread.thinking_enabled(),
2548                    "select_model should enable thinking when model supports it"
2549                );
2550            });
2551        });
2552
2553        // Switch back to the non-thinking model.
2554        let selector = connection.model_selector(&session_id).unwrap();
2555        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2556            .await
2557            .unwrap();
2558
2559        // select_model should have disabled thinking.
2560        agent.read_with(cx, |agent, _| {
2561            let session = agent.sessions.get(&session_id).unwrap();
2562            session.thread.read_with(cx, |thread, _| {
2563                assert!(
2564                    !thread.thinking_enabled(),
2565                    "select_model should disable thinking when model does not support it"
2566                );
2567            });
2568        });
2569    }
2570
2571    #[gpui::test]
2572    async fn test_summarization_model_survives_transient_registry_clearing(
2573        cx: &mut TestAppContext,
2574    ) {
2575        init_test(cx);
2576        let fs = FakeFs::new(cx.executor());
2577        fs.insert_tree("/", json!({ "a": {} })).await;
2578        let project = Project::test(fs.clone(), [], cx).await;
2579
2580        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2581        let agent =
2582            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2583        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2584
2585        let acp_thread = cx
2586            .update(|cx| {
2587                connection.clone().new_session(
2588                    project.clone(),
2589                    PathList::new(&[Path::new("/a")]),
2590                    cx,
2591                )
2592            })
2593            .await
2594            .unwrap();
2595        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2596
2597        let thread = agent.read_with(cx, |agent, _| {
2598            agent.sessions.get(&session_id).unwrap().thread.clone()
2599        });
2600
2601        thread.read_with(cx, |thread, _| {
2602            assert!(
2603                thread.summarization_model().is_some(),
2604                "session should have a summarization model from the test registry"
2605            );
2606        });
2607
2608        // Simulate what happens during a provider blip:
2609        // update_active_language_model_from_settings calls set_default_model(None)
2610        // when it can't resolve the model, clearing all fallbacks.
2611        cx.update(|cx| {
2612            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2613                registry.set_default_model(None, cx);
2614            });
2615        });
2616        cx.run_until_parked();
2617
2618        thread.read_with(cx, |thread, _| {
2619            assert!(
2620                thread.summarization_model().is_some(),
2621                "summarization model should survive a transient default model clearing"
2622            );
2623        });
2624    }
2625
2626    #[gpui::test]
2627    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2628        init_test(cx);
2629        let fs = FakeFs::new(cx.executor());
2630        fs.insert_tree("/", json!({ "a": {} })).await;
2631        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2632        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2633        let agent = cx.update(|cx| {
2634            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2635        });
2636        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2637
2638        // Register a thinking model.
2639        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2640            "fake-corp",
2641            "fake-thinking",
2642            "Fake Thinking",
2643            true,
2644        ));
2645        let thinking_provider = Arc::new(
2646            FakeLanguageModelProvider::new(
2647                LanguageModelProviderId::from("fake-corp".to_string()),
2648                LanguageModelProviderName::from("Fake Corp".to_string()),
2649            )
2650            .with_models(vec![thinking_model.clone()]),
2651        );
2652        cx.update(|cx| {
2653            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2654                registry.register_provider(thinking_provider, cx);
2655            });
2656        });
2657        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2658
2659        // Create a thread and select the thinking model.
2660        let acp_thread = cx
2661            .update(|cx| {
2662                connection.clone().new_session(
2663                    project.clone(),
2664                    PathList::new(&[Path::new("/a")]),
2665                    cx,
2666                )
2667            })
2668            .await
2669            .unwrap();
2670        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2671
2672        let selector = connection.model_selector(&session_id).unwrap();
2673        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2674            .await
2675            .unwrap();
2676
2677        // Verify thinking is enabled after selecting the thinking model.
2678        let thread = agent.read_with(cx, |agent, _| {
2679            agent.sessions.get(&session_id).unwrap().thread.clone()
2680        });
2681        thread.read_with(cx, |thread, _| {
2682            assert!(
2683                thread.thinking_enabled(),
2684                "thinking should be enabled after selecting thinking model"
2685            );
2686        });
2687
2688        // Send a message so the thread gets persisted.
2689        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2690        let send = cx.foreground_executor().spawn(send);
2691        cx.run_until_parked();
2692
2693        thinking_model.send_last_completion_stream_text_chunk("Response.");
2694        thinking_model.end_last_completion_stream();
2695
2696        send.await.unwrap();
2697        cx.run_until_parked();
2698
2699        // Close the session so it can be reloaded from disk.
2700        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2701            .await
2702            .unwrap();
2703        drop(thread);
2704        drop(acp_thread);
2705        agent.read_with(cx, |agent, _| {
2706            assert!(agent.sessions.is_empty());
2707        });
2708
2709        // Reload the thread and verify thinking_enabled is still true.
2710        let reloaded_acp_thread = agent
2711            .update(cx, |agent, cx| {
2712                agent.open_thread(session_id.clone(), project.clone(), cx)
2713            })
2714            .await
2715            .unwrap();
2716        let reloaded_thread = agent.read_with(cx, |agent, _| {
2717            agent.sessions.get(&session_id).unwrap().thread.clone()
2718        });
2719        reloaded_thread.read_with(cx, |thread, _| {
2720            assert!(
2721                thread.thinking_enabled(),
2722                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2723            );
2724        });
2725
2726        drop(reloaded_acp_thread);
2727    }
2728
2729    #[gpui::test]
2730    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2731        init_test(cx);
2732        let fs = FakeFs::new(cx.executor());
2733        fs.insert_tree("/", json!({ "a": {} })).await;
2734        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2735        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2736        let agent = cx.update(|cx| {
2737            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2738        });
2739        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2740
2741        // Register a model where id() != name(), like real Anthropic models
2742        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2743        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2744            "fake-corp",
2745            "custom-model-id",
2746            "Custom Model Display Name",
2747            false,
2748        ));
2749        let provider = Arc::new(
2750            FakeLanguageModelProvider::new(
2751                LanguageModelProviderId::from("fake-corp".to_string()),
2752                LanguageModelProviderName::from("Fake Corp".to_string()),
2753            )
2754            .with_models(vec![model.clone()]),
2755        );
2756        cx.update(|cx| {
2757            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2758                registry.register_provider(provider, cx);
2759            });
2760        });
2761        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2762
2763        // Create a thread and select the model.
2764        let acp_thread = cx
2765            .update(|cx| {
2766                connection.clone().new_session(
2767                    project.clone(),
2768                    PathList::new(&[Path::new("/a")]),
2769                    cx,
2770                )
2771            })
2772            .await
2773            .unwrap();
2774        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2775
2776        let selector = connection.model_selector(&session_id).unwrap();
2777        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2778            .await
2779            .unwrap();
2780
2781        let thread = agent.read_with(cx, |agent, _| {
2782            agent.sessions.get(&session_id).unwrap().thread.clone()
2783        });
2784        thread.read_with(cx, |thread, _| {
2785            assert_eq!(
2786                thread.model().unwrap().id().0.as_ref(),
2787                "custom-model-id",
2788                "model should be set before persisting"
2789            );
2790        });
2791
2792        // Send a message so the thread gets persisted.
2793        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2794        let send = cx.foreground_executor().spawn(send);
2795        cx.run_until_parked();
2796
2797        model.send_last_completion_stream_text_chunk("Response.");
2798        model.end_last_completion_stream();
2799
2800        send.await.unwrap();
2801        cx.run_until_parked();
2802
2803        // Close the session so it can be reloaded from disk.
2804        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2805            .await
2806            .unwrap();
2807        drop(thread);
2808        drop(acp_thread);
2809        agent.read_with(cx, |agent, _| {
2810            assert!(agent.sessions.is_empty());
2811        });
2812
2813        // Reload the thread and verify the model was preserved.
2814        let reloaded_acp_thread = agent
2815            .update(cx, |agent, cx| {
2816                agent.open_thread(session_id.clone(), project.clone(), cx)
2817            })
2818            .await
2819            .unwrap();
2820        let reloaded_thread = agent.read_with(cx, |agent, _| {
2821            agent.sessions.get(&session_id).unwrap().thread.clone()
2822        });
2823        reloaded_thread.read_with(cx, |thread, _| {
2824            let reloaded_model = thread
2825                .model()
2826                .expect("model should be present after reload");
2827            assert_eq!(
2828                reloaded_model.id().0.as_ref(),
2829                "custom-model-id",
2830                "reloaded thread should have the same model, not fall back to the default"
2831            );
2832        });
2833
2834        drop(reloaded_acp_thread);
2835    }
2836
2837    #[gpui::test]
2838    async fn test_save_load_thread(cx: &mut TestAppContext) {
2839        init_test(cx);
2840        let fs = FakeFs::new(cx.executor());
2841        fs.insert_tree(
2842            "/",
2843            json!({
2844                "a": {
2845                    "b.md": "Lorem"
2846                }
2847            }),
2848        )
2849        .await;
2850        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2851        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2852        let agent = cx.update(|cx| {
2853            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2854        });
2855        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2856
2857        let acp_thread = cx
2858            .update(|cx| {
2859                connection
2860                    .clone()
2861                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2862            })
2863            .await
2864            .unwrap();
2865        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2866        let thread = agent.read_with(cx, |agent, _| {
2867            agent.sessions.get(&session_id).unwrap().thread.clone()
2868        });
2869
2870        // Ensure empty threads are not saved, even if they get mutated.
2871        let model = Arc::new(FakeLanguageModel::default());
2872        let summary_model = Arc::new(FakeLanguageModel::default());
2873        thread.update(cx, |thread, cx| {
2874            thread.set_model(model.clone(), cx);
2875            thread.set_summarization_model(Some(summary_model.clone()), cx);
2876        });
2877        cx.run_until_parked();
2878        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2879
2880        let send = acp_thread.update(cx, |thread, cx| {
2881            thread.send(
2882                vec![
2883                    "What does ".into(),
2884                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2885                        "b.md",
2886                        MentionUri::File {
2887                            abs_path: path!("/a/b.md").into(),
2888                        }
2889                        .to_uri()
2890                        .to_string(),
2891                    )),
2892                    " mean?".into(),
2893                ],
2894                cx,
2895            )
2896        });
2897        let send = cx.foreground_executor().spawn(send);
2898        cx.run_until_parked();
2899
2900        model.send_last_completion_stream_text_chunk("Lorem.");
2901        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2902            language_model::TokenUsage {
2903                input_tokens: 150,
2904                output_tokens: 75,
2905                ..Default::default()
2906            },
2907        ));
2908        model.end_last_completion_stream();
2909        cx.run_until_parked();
2910        summary_model
2911            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2912        summary_model.end_last_completion_stream();
2913
2914        send.await.unwrap();
2915        let uri = MentionUri::File {
2916            abs_path: path!("/a/b.md").into(),
2917        }
2918        .to_uri();
2919        acp_thread.read_with(cx, |thread, cx| {
2920            assert_eq!(
2921                thread.to_markdown(cx),
2922                formatdoc! {"
2923                    ## User
2924
2925                    What does [@b.md]({uri}) mean?
2926
2927                    ## Assistant
2928
2929                    Lorem.
2930
2931                "}
2932            )
2933        });
2934
2935        cx.run_until_parked();
2936
2937        // Set a draft prompt with rich content blocks and scroll position
2938        // AFTER run_until_parked, so the only save that captures these
2939        // changes is the one performed by close_session itself.
2940        let draft_blocks = vec![
2941            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2942            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2943            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2944        ];
2945        acp_thread.update(cx, |thread, cx| {
2946            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2947        });
2948        thread.update(cx, |thread, _cx| {
2949            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2950                item_ix: 5,
2951                offset_in_item: gpui::px(12.5),
2952            }));
2953        });
2954
2955        // Close the session so it can be reloaded from disk.
2956        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2957            .await
2958            .unwrap();
2959        drop(thread);
2960        drop(acp_thread);
2961        agent.read_with(cx, |agent, _| {
2962            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2963        });
2964
2965        // Ensure the thread can be reloaded from disk.
2966        assert_eq!(
2967            thread_entries(&thread_store, cx),
2968            vec![(
2969                session_id.clone(),
2970                format!("Explaining {}", path!("/a/b.md"))
2971            )]
2972        );
2973        let acp_thread = agent
2974            .update(cx, |agent, cx| {
2975                agent.open_thread(session_id.clone(), project.clone(), cx)
2976            })
2977            .await
2978            .unwrap();
2979        acp_thread.read_with(cx, |thread, cx| {
2980            assert_eq!(
2981                thread.to_markdown(cx),
2982                formatdoc! {"
2983                    ## User
2984
2985                    What does [@b.md]({uri}) mean?
2986
2987                    ## Assistant
2988
2989                    Lorem.
2990
2991                "}
2992            )
2993        });
2994
2995        // Ensure the draft prompt with rich content blocks survived the round-trip.
2996        acp_thread.read_with(cx, |thread, _| {
2997            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2998        });
2999
3000        // Ensure token usage survived the round-trip.
3001        acp_thread.read_with(cx, |thread, _| {
3002            let usage = thread
3003                .token_usage()
3004                .expect("token usage should be restored after reload");
3005            assert_eq!(usage.input_tokens, 150);
3006            assert_eq!(usage.output_tokens, 75);
3007        });
3008
3009        // Ensure scroll position survived the round-trip.
3010        acp_thread.read_with(cx, |thread, _| {
3011            let scroll = thread
3012                .ui_scroll_position()
3013                .expect("scroll position should be restored after reload");
3014            assert_eq!(scroll.item_ix, 5);
3015            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3016        });
3017    }
3018
3019    #[gpui::test]
3020    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3021        init_test(cx);
3022        let fs = FakeFs::new(cx.executor());
3023        fs.insert_tree(
3024            "/",
3025            json!({
3026                "a": {
3027                    "file.txt": "hello"
3028                }
3029            }),
3030        )
3031        .await;
3032        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3033        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3034        let agent = cx.update(|cx| {
3035            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3036        });
3037        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3038
3039        let acp_thread = cx
3040            .update(|cx| {
3041                connection
3042                    .clone()
3043                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3044            })
3045            .await
3046            .unwrap();
3047        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3048        let thread = agent.read_with(cx, |agent, _| {
3049            agent.sessions.get(&session_id).unwrap().thread.clone()
3050        });
3051
3052        let model = Arc::new(FakeLanguageModel::default());
3053        thread.update(cx, |thread, cx| {
3054            thread.set_model(model.clone(), cx);
3055        });
3056
3057        // Send a message so the thread is non-empty (empty threads aren't saved).
3058        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3059        let send = cx.foreground_executor().spawn(send);
3060        cx.run_until_parked();
3061
3062        model.send_last_completion_stream_text_chunk("world");
3063        model.end_last_completion_stream();
3064        send.await.unwrap();
3065        cx.run_until_parked();
3066
3067        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3068        // This means no observe-triggered save has run for this change.
3069        // The only way this data gets persisted is if close_session
3070        // itself performs the save.
3071        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3072            "unsaved draft",
3073        ))];
3074        acp_thread.update(cx, |thread, cx| {
3075            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3076        });
3077
3078        // Close the session immediately — no run_until_parked in between.
3079        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3080            .await
3081            .unwrap();
3082        cx.run_until_parked();
3083
3084        // Reopen and verify the draft prompt was saved.
3085        let reloaded = agent
3086            .update(cx, |agent, cx| {
3087                agent.open_thread(session_id.clone(), project.clone(), cx)
3088            })
3089            .await
3090            .unwrap();
3091        reloaded.read_with(cx, |thread, _| {
3092            assert_eq!(
3093                thread.draft_prompt(),
3094                Some(draft_blocks.as_slice()),
3095                "close_session must save the thread; draft prompt was lost"
3096            );
3097        });
3098    }
3099
3100    #[gpui::test]
3101    async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3102        init_test(cx);
3103        let fs = FakeFs::new(cx.executor());
3104        fs.insert_tree(
3105            "/",
3106            json!({
3107                "a": {
3108                    "file.txt": "hello"
3109                }
3110            }),
3111        )
3112        .await;
3113        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3114        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3115        let agent = cx.update(|cx| {
3116            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3117        });
3118        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3119
3120        let acp_thread = cx
3121            .update(|cx| {
3122                connection
3123                    .clone()
3124                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3125            })
3126            .await
3127            .unwrap();
3128        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3129        let thread = agent.read_with(cx, |agent, _| {
3130            agent.sessions.get(&session_id).unwrap().thread.clone()
3131        });
3132
3133        let model = Arc::new(FakeLanguageModel::default());
3134        let summary_model = Arc::new(FakeLanguageModel::default());
3135        thread.update(cx, |thread, cx| {
3136            thread.set_model(model.clone(), cx);
3137            thread.set_summarization_model(Some(summary_model.clone()), cx);
3138        });
3139
3140        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3141        let send = cx.foreground_executor().spawn(send);
3142        cx.run_until_parked();
3143
3144        model.send_last_completion_stream_text_chunk("world");
3145        model.end_last_completion_stream();
3146        send.await.unwrap();
3147        cx.run_until_parked();
3148
3149        let summary = agent.update(cx, |agent, cx| {
3150            agent.thread_summary(session_id.clone(), project.clone(), cx)
3151        });
3152        cx.run_until_parked();
3153
3154        summary_model.send_last_completion_stream_text_chunk("summary");
3155        summary_model.end_last_completion_stream();
3156
3157        assert_eq!(summary.await.unwrap(), "summary");
3158        cx.run_until_parked();
3159
3160        agent.read_with(cx, |agent, _| {
3161            let session = agent
3162                .sessions
3163                .get(&session_id)
3164                .expect("thread_summary should not close the active session");
3165            assert_eq!(
3166                session.ref_count, 1,
3167                "thread_summary should release its temporary session reference"
3168            );
3169        });
3170
3171        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3172            .await
3173            .unwrap();
3174        cx.run_until_parked();
3175
3176        agent.read_with(cx, |agent, _| {
3177            assert!(
3178                agent.sessions.is_empty(),
3179                "closing the active session after thread_summary should unload it"
3180            );
3181        });
3182    }
3183
3184    #[gpui::test]
3185    async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3186        init_test(cx);
3187        let fs = FakeFs::new(cx.executor());
3188        fs.insert_tree(
3189            "/",
3190            json!({
3191                "a": {
3192                    "file.txt": "hello"
3193                }
3194            }),
3195        )
3196        .await;
3197        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3198        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3199        let agent = cx.update(|cx| {
3200            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3201        });
3202        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3203
3204        let acp_thread = cx
3205            .update(|cx| {
3206                connection
3207                    .clone()
3208                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3209            })
3210            .await
3211            .unwrap();
3212        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3213        let thread = agent.read_with(cx, |agent, _| {
3214            agent.sessions.get(&session_id).unwrap().thread.clone()
3215        });
3216
3217        let model = cx.update(|cx| {
3218            LanguageModelRegistry::read_global(cx)
3219                .default_model()
3220                .map(|default_model| default_model.model)
3221                .expect("default test model should be available")
3222        });
3223        let fake_model = model.as_fake();
3224        thread.update(cx, |thread, cx| {
3225            thread.set_model(model.clone(), cx);
3226        });
3227
3228        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3229        let send = cx.foreground_executor().spawn(send);
3230        cx.run_until_parked();
3231
3232        fake_model.send_last_completion_stream_text_chunk("world");
3233        fake_model.end_last_completion_stream();
3234        send.await.unwrap();
3235        cx.run_until_parked();
3236
3237        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3238            .await
3239            .unwrap();
3240        drop(thread);
3241        drop(acp_thread);
3242        agent.read_with(cx, |agent, _| {
3243            assert!(agent.sessions.is_empty());
3244        });
3245
3246        let first_loaded_thread = cx.update(|cx| {
3247            connection.clone().load_session(
3248                session_id.clone(),
3249                project.clone(),
3250                PathList::new(&[Path::new("")]),
3251                None,
3252                cx,
3253            )
3254        });
3255        let second_loaded_thread = cx.update(|cx| {
3256            connection.clone().load_session(
3257                session_id.clone(),
3258                project.clone(),
3259                PathList::new(&[Path::new("")]),
3260                None,
3261                cx,
3262            )
3263        });
3264
3265        let first_loaded_thread = first_loaded_thread.await.unwrap();
3266        let second_loaded_thread = second_loaded_thread.await.unwrap();
3267
3268        cx.run_until_parked();
3269
3270        assert_eq!(
3271            first_loaded_thread.entity_id(),
3272            second_loaded_thread.entity_id(),
3273            "concurrent loads for the same session should share one AcpThread"
3274        );
3275
3276        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3277            .await
3278            .unwrap();
3279
3280        agent.read_with(cx, |agent, _| {
3281            assert!(
3282                agent.sessions.contains_key(&session_id),
3283                "closing one loaded session should not drop shared session state"
3284            );
3285        });
3286
3287        let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3288            thread.send(vec!["still there?".into()], cx)
3289        });
3290        let follow_up = cx.foreground_executor().spawn(follow_up);
3291        cx.run_until_parked();
3292
3293        fake_model.send_last_completion_stream_text_chunk("yes");
3294        fake_model.end_last_completion_stream();
3295        follow_up.await.unwrap();
3296        cx.run_until_parked();
3297
3298        second_loaded_thread.read_with(cx, |thread, cx| {
3299            assert_eq!(
3300                thread.to_markdown(cx),
3301                formatdoc! {"
3302                    ## User
3303
3304                    hello
3305
3306                    ## Assistant
3307
3308                    world
3309
3310                    ## User
3311
3312                    still there?
3313
3314                    ## Assistant
3315
3316                    yes
3317
3318                "}
3319            );
3320        });
3321
3322        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3323            .await
3324            .unwrap();
3325
3326        cx.run_until_parked();
3327
3328        drop(first_loaded_thread);
3329        drop(second_loaded_thread);
3330        agent.read_with(cx, |agent, _| {
3331            assert!(agent.sessions.is_empty());
3332        });
3333    }
3334
3335    #[gpui::test]
3336    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3337        // Regression test: rapid title changes must not cause a propagation loop
3338        // between Thread and AcpThread via handle_thread_title_updated.
3339        init_test(cx);
3340        let fs = FakeFs::new(cx.executor());
3341        fs.insert_tree("/", json!({ "a": {} })).await;
3342        let project = Project::test(fs.clone(), [], cx).await;
3343        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3344        let agent = cx.update(|cx| {
3345            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3346        });
3347        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3348
3349        let acp_thread = cx
3350            .update(|cx| {
3351                connection
3352                    .clone()
3353                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3354            })
3355            .await
3356            .unwrap();
3357
3358        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3359        let thread = agent.read_with(cx, |agent, _| {
3360            agent.sessions.get(&session_id).unwrap().thread.clone()
3361        });
3362
3363        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3364        cx.update(|cx| {
3365            let count = title_updated_count.clone();
3366            cx.subscribe(
3367                &thread,
3368                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3369                    let new_count = {
3370                        let mut count = count.borrow_mut();
3371                        *count += 1;
3372                        *count
3373                    };
3374                    assert!(
3375                        new_count <= 2,
3376                        "TitleUpdated fired {new_count} times; \
3377                         title updates are looping"
3378                    );
3379                },
3380            )
3381            .detach();
3382        });
3383
3384        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3385        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3386
3387        cx.run_until_parked();
3388
3389        thread.read_with(cx, |thread, _| {
3390            assert_eq!(thread.title(), Some("second".into()));
3391        });
3392        acp_thread.read_with(cx, |acp_thread, _| {
3393            assert_eq!(acp_thread.title(), Some("second".into()));
3394        });
3395
3396        assert_eq!(*title_updated_count.borrow(), 2);
3397    }
3398
3399    fn thread_entries(
3400        thread_store: &Entity<ThreadStore>,
3401        cx: &mut TestAppContext,
3402    ) -> Vec<(acp::SessionId, String)> {
3403        thread_store.read_with(cx, |store, _| {
3404            store
3405                .entries()
3406                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3407                .collect::<Vec<_>>()
3408        })
3409    }
3410
3411    fn init_test(cx: &mut TestAppContext) {
3412        env_logger::try_init().ok();
3413        cx.update(|cx| {
3414            let settings_store = SettingsStore::test(cx);
3415            cx.set_global(settings_store);
3416
3417            LanguageModelRegistry::test(cx);
3418        });
3419    }
3420}
3421
3422fn mcp_message_content_to_acp_content_block(
3423    content: context_server::types::MessageContent,
3424) -> acp::ContentBlock {
3425    match content {
3426        context_server::types::MessageContent::Text {
3427            text,
3428            annotations: _,
3429        } => text.into(),
3430        context_server::types::MessageContent::Image {
3431            data,
3432            mime_type,
3433            annotations: _,
3434        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3435        context_server::types::MessageContent::Audio {
3436            data,
3437            mime_type,
3438            annotations: _,
3439        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3440        context_server::types::MessageContent::Resource {
3441            resource,
3442            annotations: _,
3443        } => {
3444            let mut link =
3445                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3446            if let Some(mime_type) = resource.mime_type {
3447                link = link.mime_type(mime_type);
3448            }
3449            acp::ContentBlock::ResourceLink(link)
3450        }
3451    }
3452}