agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol::schema as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, Settings as _, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<Result<()>>,
  86    _subscriptions: Vec<Subscription>,
  87    ref_count: usize,
  88}
  89
  90struct PendingSession {
  91    task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
  92    ref_count: usize,
  93}
  94
  95pub struct LanguageModels {
  96    /// Access language model by ID
  97    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  98    /// Cached list for returning language model information
  99    model_list: acp_thread::AgentModelList,
 100    refresh_models_rx: watch::Receiver<()>,
 101    refresh_models_tx: watch::Sender<()>,
 102    _authenticate_all_providers_task: Task<()>,
 103}
 104
 105impl LanguageModels {
 106    fn new(cx: &mut App) -> Self {
 107        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 108
 109        let mut this = Self {
 110            models: HashMap::default(),
 111            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 112            refresh_models_rx,
 113            refresh_models_tx,
 114            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 115        };
 116        this.refresh_list(cx);
 117        this
 118    }
 119
 120    fn refresh_list(&mut self, cx: &App) {
 121        let providers = LanguageModelRegistry::global(cx)
 122            .read(cx)
 123            .visible_providers()
 124            .into_iter()
 125            .filter(|provider| provider.is_authenticated(cx))
 126            .collect::<Vec<_>>();
 127
 128        let mut language_model_list = IndexMap::default();
 129        let mut recommended_models = HashSet::default();
 130
 131        let mut recommended = Vec::new();
 132        for provider in &providers {
 133            for model in provider.recommended_models(cx) {
 134                recommended_models.insert((model.provider_id(), model.id()));
 135                recommended.push(Self::map_language_model_to_info(&model, provider));
 136            }
 137        }
 138        if !recommended.is_empty() {
 139            language_model_list.insert(
 140                acp_thread::AgentModelGroupName("Recommended".into()),
 141                recommended,
 142            );
 143        }
 144
 145        let mut models = HashMap::default();
 146        for provider in providers {
 147            let mut provider_models = Vec::new();
 148            for model in provider.provided_models(cx) {
 149                let model_info = Self::map_language_model_to_info(&model, &provider);
 150                let model_id = model_info.id.clone();
 151                provider_models.push(model_info);
 152                models.insert(model_id, model);
 153            }
 154            if !provider_models.is_empty() {
 155                language_model_list.insert(
 156                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 157                    provider_models,
 158                );
 159            }
 160        }
 161
 162        self.models = models;
 163        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 164        self.refresh_models_tx.send(()).ok();
 165    }
 166
 167    fn watch(&self) -> watch::Receiver<()> {
 168        self.refresh_models_rx.clone()
 169    }
 170
 171    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 172        self.models.get(model_id).cloned()
 173    }
 174
 175    fn map_language_model_to_info(
 176        model: &Arc<dyn LanguageModel>,
 177        provider: &Arc<dyn LanguageModelProvider>,
 178    ) -> acp_thread::AgentModelInfo {
 179        acp_thread::AgentModelInfo {
 180            id: Self::model_id(model),
 181            name: model.name().0,
 182            description: None,
 183            icon: Some(match provider.icon() {
 184                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 185                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 186            }),
 187            is_latest: model.is_latest(),
 188            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 189        }
 190    }
 191
 192    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 193        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 194    }
 195
 196    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 197        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 198            .read(cx)
 199            .visible_providers()
 200            .iter()
 201            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 202            .collect::<Vec<_>>();
 203
 204        cx.spawn(async move |cx| {
 205            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 206                if let Err(err) = authenticate_task.await {
 207                    match err {
 208                        language_model::AuthenticateError::CredentialsNotFound => {
 209                            // Since we're authenticating these providers in the
 210                            // background for the purposes of populating the
 211                            // language selector, we don't care about providers
 212                            // where the credentials are not found.
 213                        }
 214                        language_model::AuthenticateError::ConnectionRefused => {
 215                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 216                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 217                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 218                        }
 219                        _ => {
 220                            // Some providers have noisy failure states that we
 221                            // don't want to spam the logs with every time the
 222                            // language model selector is initialized.
 223                            //
 224                            // Ideally these should have more clear failure modes
 225                            // that we know are safe to ignore here, like what we do
 226                            // with `CredentialsNotFound` above.
 227                            match provider_id.0.as_ref() {
 228                                "lmstudio" | "ollama" => {
 229                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 230                                    //
 231                                    // These fail noisily, so we don't log them.
 232                                }
 233                                "copilot_chat" => {
 234                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 235                                }
 236                                _ => {
 237                                    log::error!(
 238                                        "Failed to authenticate provider: {}: {err:#}",
 239                                        provider_name.0
 240                                    );
 241                                }
 242                            }
 243                        }
 244                    }
 245                }
 246            }
 247
 248            cx.update(language_models::update_environment_fallback_model);
 249        })
 250    }
 251}
 252
 253pub struct NativeAgent {
 254    /// Session ID -> Session mapping
 255    sessions: HashMap<acp::SessionId, Session>,
 256    pending_sessions: HashMap<acp::SessionId, PendingSession>,
 257    thread_store: Entity<ThreadStore>,
 258    /// Project-specific state keyed by project EntityId
 259    projects: HashMap<EntityId, ProjectState>,
 260    /// Shared templates for all threads
 261    templates: Arc<Templates>,
 262    /// Cached model information
 263    models: LanguageModels,
 264    prompt_store: Option<Entity<PromptStore>>,
 265    fs: Arc<dyn Fs>,
 266    _subscriptions: Vec<Subscription>,
 267}
 268
 269impl NativeAgent {
 270    pub fn new(
 271        thread_store: Entity<ThreadStore>,
 272        templates: Arc<Templates>,
 273        prompt_store: Option<Entity<PromptStore>>,
 274        fs: Arc<dyn Fs>,
 275        cx: &mut App,
 276    ) -> Entity<NativeAgent> {
 277        log::debug!("Creating new NativeAgent");
 278
 279        cx.new(|cx| {
 280            let mut subscriptions = vec![cx.subscribe(
 281                &LanguageModelRegistry::global(cx),
 282                Self::handle_models_updated_event,
 283            )];
 284            if let Some(prompt_store) = prompt_store.as_ref() {
 285                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 286            }
 287
 288            Self {
 289                sessions: HashMap::default(),
 290                pending_sessions: HashMap::default(),
 291                thread_store,
 292                projects: HashMap::default(),
 293                templates,
 294                models: LanguageModels::new(cx),
 295                prompt_store,
 296                fs,
 297                _subscriptions: subscriptions,
 298            }
 299        })
 300    }
 301
 302    fn new_session(
 303        &mut self,
 304        project: Entity<Project>,
 305        cx: &mut Context<Self>,
 306    ) -> Entity<AcpThread> {
 307        let project_id = self.get_or_create_project_state(&project, cx);
 308        let project_state = &self.projects[&project_id];
 309
 310        let registry = LanguageModelRegistry::read_global(cx);
 311        let available_count = registry.available_models(cx).count();
 312        log::debug!("Total available models: {}", available_count);
 313
 314        let default_model = registry.default_model().and_then(|default_model| {
 315            self.models
 316                .model_from_id(&LanguageModels::model_id(&default_model.model))
 317        });
 318        let thread = cx.new(|cx| {
 319            Thread::new(
 320                project,
 321                project_state.project_context.clone(),
 322                project_state.context_server_registry.clone(),
 323                self.templates.clone(),
 324                default_model,
 325                cx,
 326            )
 327        });
 328
 329        self.register_session(thread, project_id, 1, cx)
 330    }
 331
 332    fn register_session(
 333        &mut self,
 334        thread_handle: Entity<Thread>,
 335        project_id: EntityId,
 336        ref_count: usize,
 337        cx: &mut Context<Self>,
 338    ) -> Entity<AcpThread> {
 339        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 340
 341        let thread = thread_handle.read(cx);
 342        let session_id = thread.id().clone();
 343        let parent_session_id = thread.parent_thread_id();
 344        let title = thread.title();
 345        let draft_prompt = thread.draft_prompt().map(Vec::from);
 346        let scroll_position = thread.ui_scroll_position();
 347        let token_usage = thread.latest_token_usage();
 348        let project = thread.project.clone();
 349        let action_log = thread.action_log.clone();
 350        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 351        let acp_thread = cx.new(|cx| {
 352            let mut acp_thread = acp_thread::AcpThread::new(
 353                parent_session_id,
 354                title,
 355                None,
 356                connection,
 357                project.clone(),
 358                action_log.clone(),
 359                session_id.clone(),
 360                prompt_capabilities_rx,
 361                cx,
 362            );
 363            acp_thread.set_draft_prompt(draft_prompt, cx);
 364            acp_thread.set_ui_scroll_position(scroll_position);
 365            acp_thread.update_token_usage(token_usage, cx);
 366            acp_thread
 367        });
 368
 369        let registry = LanguageModelRegistry::read_global(cx);
 370        let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
 371
 372        let weak = cx.weak_entity();
 373        let weak_thread = thread_handle.downgrade();
 374        thread_handle.update(cx, |thread, cx| {
 375            thread.set_summarization_model(summarization_model, cx);
 376            thread.add_default_tools(
 377                Rc::new(NativeThreadEnvironment {
 378                    acp_thread: acp_thread.downgrade(),
 379                    thread: weak_thread,
 380                    agent: weak,
 381                }) as _,
 382                cx,
 383            )
 384        });
 385
 386        let subscriptions = vec![
 387            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 388            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 389            cx.observe(&thread_handle, move |this, thread, cx| {
 390                this.save_thread(thread, cx)
 391            }),
 392        ];
 393
 394        self.sessions.insert(
 395            session_id,
 396            Session {
 397                thread: thread_handle,
 398                acp_thread: acp_thread.clone(),
 399                project_id,
 400                _subscriptions: subscriptions,
 401                pending_save: Task::ready(Ok(())),
 402                ref_count,
 403            },
 404        );
 405
 406        self.update_available_commands_for_project(project_id, cx);
 407
 408        acp_thread
 409    }
 410
 411    pub fn models(&self) -> &LanguageModels {
 412        &self.models
 413    }
 414
 415    fn get_or_create_project_state(
 416        &mut self,
 417        project: &Entity<Project>,
 418        cx: &mut Context<Self>,
 419    ) -> EntityId {
 420        let project_id = project.entity_id();
 421        if self.projects.contains_key(&project_id) {
 422            return project_id;
 423        }
 424
 425        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 426        self.register_project_with_initial_context(project.clone(), project_context, cx);
 427        if let Some(state) = self.projects.get_mut(&project_id) {
 428            state.project_context_needs_refresh.send(()).ok();
 429        }
 430        project_id
 431    }
 432
 433    fn register_project_with_initial_context(
 434        &mut self,
 435        project: Entity<Project>,
 436        project_context: Entity<ProjectContext>,
 437        cx: &mut Context<Self>,
 438    ) {
 439        let project_id = project.entity_id();
 440
 441        let context_server_store = project.read(cx).context_server_store();
 442        let context_server_registry =
 443            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 444
 445        let subscriptions = vec![
 446            cx.subscribe(&project, Self::handle_project_event),
 447            cx.subscribe(
 448                &context_server_store,
 449                Self::handle_context_server_store_updated,
 450            ),
 451            cx.subscribe(
 452                &context_server_registry,
 453                Self::handle_context_server_registry_event,
 454            ),
 455        ];
 456
 457        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 458            watch::channel(());
 459
 460        self.projects.insert(
 461            project_id,
 462            ProjectState {
 463                project,
 464                project_context,
 465                project_context_needs_refresh: project_context_needs_refresh_tx,
 466                _maintain_project_context: cx.spawn(async move |this, cx| {
 467                    Self::maintain_project_context(
 468                        this,
 469                        project_id,
 470                        project_context_needs_refresh_rx,
 471                        cx,
 472                    )
 473                    .await
 474                }),
 475                context_server_registry,
 476                _subscriptions: subscriptions,
 477            },
 478        );
 479    }
 480
 481    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 482        self.sessions
 483            .get(session_id)
 484            .and_then(|session| self.projects.get(&session.project_id))
 485    }
 486
 487    async fn maintain_project_context(
 488        this: WeakEntity<Self>,
 489        project_id: EntityId,
 490        mut needs_refresh: watch::Receiver<()>,
 491        cx: &mut AsyncApp,
 492    ) -> Result<()> {
 493        while needs_refresh.changed().await.is_ok() {
 494            let project_context = this
 495                .update(cx, |this, cx| {
 496                    let state = this
 497                        .projects
 498                        .get(&project_id)
 499                        .context("project state not found")?;
 500                    anyhow::Ok(Self::build_project_context(
 501                        &state.project,
 502                        this.prompt_store.as_ref(),
 503                        cx,
 504                    ))
 505                })??
 506                .await;
 507            this.update(cx, |this, cx| {
 508                if let Some(state) = this.projects.get(&project_id) {
 509                    state
 510                        .project_context
 511                        .update(cx, |current_project_context, _cx| {
 512                            *current_project_context = project_context;
 513                        });
 514                }
 515            })?;
 516        }
 517
 518        Ok(())
 519    }
 520
 521    fn build_project_context(
 522        project: &Entity<Project>,
 523        prompt_store: Option<&Entity<PromptStore>>,
 524        cx: &mut App,
 525    ) -> Task<ProjectContext> {
 526        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 527        let worktree_tasks = worktrees
 528            .into_iter()
 529            .map(|worktree| {
 530                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 531            })
 532            .collect::<Vec<_>>();
 533        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 534            prompt_store.read_with(cx, |prompt_store, cx| {
 535                let prompts = prompt_store.default_prompt_metadata();
 536                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 537                    let contents = prompt_store.load(prompt_metadata.id, cx);
 538                    async move { (contents.await, prompt_metadata) }
 539                });
 540                cx.background_spawn(future::join_all(load_tasks))
 541            })
 542        } else {
 543            Task::ready(vec![])
 544        };
 545
 546        cx.spawn(async move |_cx| {
 547            let (worktrees, default_user_rules) =
 548                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 549
 550            let worktrees = worktrees
 551                .into_iter()
 552                .map(|(worktree, _rules_error)| {
 553                    // TODO: show error message
 554                    // if let Some(rules_error) = rules_error {
 555                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 556                    // }
 557                    worktree
 558                })
 559                .collect::<Vec<_>>();
 560
 561            let default_user_rules = default_user_rules
 562                .into_iter()
 563                .flat_map(|(contents, prompt_metadata)| match contents {
 564                    Ok(contents) => Some(UserRulesContext {
 565                        uuid: prompt_metadata.id.as_user()?,
 566                        title: prompt_metadata.title.map(|title| title.to_string()),
 567                        contents,
 568                    }),
 569                    Err(_err) => {
 570                        // TODO: show error message
 571                        // this.update(cx, |_, cx| {
 572                        //     cx.emit(RulesLoadingError {
 573                        //         message: format!("{err:?}").into(),
 574                        //     });
 575                        // })
 576                        // .ok();
 577                        None
 578                    }
 579                })
 580                .collect::<Vec<_>>();
 581
 582            ProjectContext::new(worktrees, default_user_rules)
 583        })
 584    }
 585
 586    fn load_worktree_info_for_system_prompt(
 587        worktree: Entity<Worktree>,
 588        project: Entity<Project>,
 589        cx: &mut App,
 590    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 591        let tree = worktree.read(cx);
 592        let root_name = tree.root_name_str().into();
 593        let abs_path = tree.abs_path();
 594        let scan_complete = tree.as_local().map(|local| local.scan_complete());
 595
 596        let mut context = WorktreeContext {
 597            root_name,
 598            abs_path,
 599            rules_file: None,
 600        };
 601
 602        cx.spawn(async move |cx| {
 603            if let Some(scan_complete) = scan_complete {
 604                scan_complete.await;
 605            }
 606
 607            let rules_task = cx.update(|cx| Self::load_worktree_rules_file(worktree, project, cx));
 608
 609            let (rules_file, rules_file_error) = match rules_task {
 610                Some(rules_task) => match rules_task.await {
 611                    Ok(rules_file) => (Some(rules_file), None),
 612                    Err(err) => (
 613                        None,
 614                        Some(RulesLoadingError {
 615                            message: format!("{err}").into(),
 616                        }),
 617                    ),
 618                },
 619                None => (None, None),
 620            };
 621            context.rules_file = rules_file;
 622            (context, rules_file_error)
 623        })
 624    }
 625
 626    fn load_worktree_rules_file(
 627        worktree: Entity<Worktree>,
 628        project: Entity<Project>,
 629        cx: &mut App,
 630    ) -> Option<Task<Result<RulesFileContext>>> {
 631        let worktree = worktree.read(cx);
 632        let worktree_id = worktree.id();
 633        let selected_rules_file = RULES_FILE_NAMES
 634            .into_iter()
 635            .filter_map(|name| {
 636                worktree
 637                    .entry_for_path(RelPath::unix(name).unwrap())
 638                    .filter(|entry| entry.is_file())
 639                    .map(|entry| entry.path.clone())
 640            })
 641            .next();
 642
 643        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 644        // supported. This doesn't seem to occur often in GitHub repositories.
 645        selected_rules_file.map(|path_in_worktree| {
 646            let project_path = ProjectPath {
 647                worktree_id,
 648                path: path_in_worktree.clone(),
 649            };
 650            let buffer_task =
 651                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 652            let rope_task = cx.spawn(async move |cx| {
 653                let buffer = buffer_task.await?;
 654                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 655                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 656                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 657                })?;
 658                anyhow::Ok((project_entry_id, rope))
 659            });
 660            // Build a string from the rope on a background thread.
 661            cx.background_spawn(async move {
 662                let (project_entry_id, rope) = rope_task.await?;
 663                anyhow::Ok(RulesFileContext {
 664                    path_in_worktree,
 665                    text: rope.to_string().trim().to_string(),
 666                    project_entry_id: project_entry_id.to_usize(),
 667                })
 668            })
 669        })
 670    }
 671
 672    fn handle_thread_title_updated(
 673        &mut self,
 674        thread: Entity<Thread>,
 675        _: &TitleUpdated,
 676        cx: &mut Context<Self>,
 677    ) {
 678        let session_id = thread.read(cx).id();
 679        let Some(session) = self.sessions.get(session_id) else {
 680            return;
 681        };
 682
 683        let thread = thread.downgrade();
 684        let acp_thread = session.acp_thread.downgrade();
 685        cx.spawn(async move |_, cx| {
 686            let title = thread.read_with(cx, |thread, _| thread.title())?;
 687            if let Some(title) = title {
 688                let task =
 689                    acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 690                task.await?;
 691            }
 692            anyhow::Ok(())
 693        })
 694        .detach_and_log_err(cx);
 695    }
 696
 697    fn handle_thread_token_usage_updated(
 698        &mut self,
 699        thread: Entity<Thread>,
 700        usage: &TokenUsageUpdated,
 701        cx: &mut Context<Self>,
 702    ) {
 703        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 704            return;
 705        };
 706        session.acp_thread.update(cx, |acp_thread, cx| {
 707            acp_thread.update_token_usage(usage.0.clone(), cx);
 708        });
 709    }
 710
 711    fn handle_project_event(
 712        &mut self,
 713        project: Entity<Project>,
 714        event: &project::Event,
 715        _cx: &mut Context<Self>,
 716    ) {
 717        let project_id = project.entity_id();
 718        let Some(state) = self.projects.get_mut(&project_id) else {
 719            return;
 720        };
 721        match event {
 722            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 723                state.project_context_needs_refresh.send(()).ok();
 724            }
 725            project::Event::WorktreeUpdatedEntries(_, items) => {
 726                if items.iter().any(|(path, _, _)| {
 727                    RULES_FILE_NAMES
 728                        .iter()
 729                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 730                }) {
 731                    state.project_context_needs_refresh.send(()).ok();
 732                }
 733            }
 734            _ => {}
 735        }
 736    }
 737
 738    fn handle_prompts_updated_event(
 739        &mut self,
 740        _prompt_store: Entity<PromptStore>,
 741        _event: &prompt_store::PromptsUpdatedEvent,
 742        _cx: &mut Context<Self>,
 743    ) {
 744        for state in self.projects.values_mut() {
 745            state.project_context_needs_refresh.send(()).ok();
 746        }
 747    }
 748
 749    fn handle_models_updated_event(
 750        &mut self,
 751        _registry: Entity<LanguageModelRegistry>,
 752        event: &language_model::Event,
 753        cx: &mut Context<Self>,
 754    ) {
 755        self.models.refresh_list(cx);
 756
 757        let registry = LanguageModelRegistry::read_global(cx);
 758        let default_model = registry.default_model().map(|m| m.model);
 759        let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
 760
 761        for session in self.sessions.values_mut() {
 762            session.thread.update(cx, |thread, cx| {
 763                if thread.model().is_none()
 764                    && let Some(model) = default_model.clone()
 765                {
 766                    thread.set_model(model, cx);
 767                    cx.notify();
 768                }
 769                if let Some(model) = summarization_model.clone() {
 770                    if thread.summarization_model().is_none()
 771                        || matches!(event, language_model::Event::ThreadSummaryModelChanged)
 772                    {
 773                        thread.set_summarization_model(Some(model), cx);
 774                    }
 775                }
 776            });
 777        }
 778    }
 779
 780    fn handle_context_server_store_updated(
 781        &mut self,
 782        store: Entity<project::context_server_store::ContextServerStore>,
 783        _event: &project::context_server_store::ServerStatusChangedEvent,
 784        cx: &mut Context<Self>,
 785    ) {
 786        let project_id = self.projects.iter().find_map(|(id, state)| {
 787            if *state.context_server_registry.read(cx).server_store() == store {
 788                Some(*id)
 789            } else {
 790                None
 791            }
 792        });
 793        if let Some(project_id) = project_id {
 794            self.update_available_commands_for_project(project_id, cx);
 795        }
 796    }
 797
 798    fn handle_context_server_registry_event(
 799        &mut self,
 800        registry: Entity<ContextServerRegistry>,
 801        event: &ContextServerRegistryEvent,
 802        cx: &mut Context<Self>,
 803    ) {
 804        match event {
 805            ContextServerRegistryEvent::ToolsChanged => {}
 806            ContextServerRegistryEvent::PromptsChanged => {
 807                let project_id = self.projects.iter().find_map(|(id, state)| {
 808                    if state.context_server_registry == registry {
 809                        Some(*id)
 810                    } else {
 811                        None
 812                    }
 813                });
 814                if let Some(project_id) = project_id {
 815                    self.update_available_commands_for_project(project_id, cx);
 816                }
 817            }
 818        }
 819    }
 820
 821    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 822        let available_commands =
 823            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 824        for session in self.sessions.values() {
 825            if session.project_id != project_id {
 826                continue;
 827            }
 828            session.acp_thread.update(cx, |thread, cx| {
 829                thread
 830                    .handle_session_update(
 831                        acp::SessionUpdate::AvailableCommandsUpdate(
 832                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 833                        ),
 834                        cx,
 835                    )
 836                    .log_err();
 837            });
 838        }
 839    }
 840
 841    fn build_available_commands_for_project(
 842        project_state: Option<&ProjectState>,
 843        cx: &App,
 844    ) -> Vec<acp::AvailableCommand> {
 845        let Some(state) = project_state else {
 846            return vec![];
 847        };
 848        let registry = state.context_server_registry.read(cx);
 849
 850        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 851        for context_server_prompt in registry.prompts() {
 852            *prompt_name_counts
 853                .entry(context_server_prompt.prompt.name.as_str())
 854                .or_insert(0) += 1;
 855        }
 856
 857        registry
 858            .prompts()
 859            .flat_map(|context_server_prompt| {
 860                let prompt = &context_server_prompt.prompt;
 861
 862                let should_prefix = prompt_name_counts
 863                    .get(prompt.name.as_str())
 864                    .copied()
 865                    .unwrap_or(0)
 866                    > 1;
 867
 868                let name = if should_prefix {
 869                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 870                } else {
 871                    prompt.name.clone()
 872                };
 873
 874                let mut command = acp::AvailableCommand::new(
 875                    name,
 876                    prompt.description.clone().unwrap_or_default(),
 877                );
 878
 879                match prompt.arguments.as_deref() {
 880                    Some([arg]) => {
 881                        let hint = format!("<{}>", arg.name);
 882
 883                        command = command.input(acp::AvailableCommandInput::Unstructured(
 884                            acp::UnstructuredCommandInput::new(hint),
 885                        ));
 886                    }
 887                    Some([]) | None => {}
 888                    Some(_) => {
 889                        // skip >1 argument commands since we don't support them yet
 890                        return None;
 891                    }
 892                }
 893
 894                Some(command)
 895            })
 896            .collect()
 897    }
 898
 899    pub fn load_thread(
 900        &mut self,
 901        id: acp::SessionId,
 902        project: Entity<Project>,
 903        cx: &mut Context<Self>,
 904    ) -> Task<Result<Entity<Thread>>> {
 905        let database_future = ThreadsDatabase::connect(cx);
 906        cx.spawn(async move |this, cx| {
 907            let database = database_future.await.map_err(|err| anyhow!(err))?;
 908            let db_thread = database
 909                .load_thread(id.clone())
 910                .await?
 911                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 912
 913            this.update(cx, |this, cx| {
 914                let project_id = this.get_or_create_project_state(&project, cx);
 915                let project_state = this
 916                    .projects
 917                    .get(&project_id)
 918                    .context("project state not found")?;
 919                let summarization_model = LanguageModelRegistry::read_global(cx)
 920                    .thread_summary_model(cx)
 921                    .map(|c| c.model);
 922
 923                Ok(cx.new(|cx| {
 924                    let mut thread = Thread::from_db(
 925                        id.clone(),
 926                        db_thread,
 927                        project_state.project.clone(),
 928                        project_state.project_context.clone(),
 929                        project_state.context_server_registry.clone(),
 930                        this.templates.clone(),
 931                        cx,
 932                    );
 933                    thread.set_summarization_model(summarization_model, cx);
 934                    thread
 935                }))
 936            })?
 937        })
 938    }
 939
 940    pub fn open_thread(
 941        &mut self,
 942        id: acp::SessionId,
 943        project: Entity<Project>,
 944        cx: &mut Context<Self>,
 945    ) -> Task<Result<Entity<AcpThread>>> {
 946        if let Some(session) = self.sessions.get_mut(&id) {
 947            session.ref_count += 1;
 948            return Task::ready(Ok(session.acp_thread.clone()));
 949        }
 950
 951        if let Some(pending) = self.pending_sessions.get_mut(&id) {
 952            pending.ref_count += 1;
 953            let task = pending.task.clone();
 954            return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
 955        }
 956
 957        let task = self.load_thread(id.clone(), project.clone(), cx);
 958        let shared_task = cx
 959            .spawn({
 960                let id = id.clone();
 961                async move |this, cx| {
 962                    let thread = match task.await {
 963                        Ok(thread) => thread,
 964                        Err(err) => {
 965                            this.update(cx, |this, _cx| {
 966                                this.pending_sessions.remove(&id);
 967                            })
 968                            .ok();
 969                            return Err(Arc::new(err));
 970                        }
 971                    };
 972                    let acp_thread = this
 973                        .update(cx, |this, cx| {
 974                            let project_id = this.get_or_create_project_state(&project, cx);
 975                            let ref_count = this
 976                                .pending_sessions
 977                                .remove(&id)
 978                                .map_or(1, |pending| pending.ref_count);
 979                            this.register_session(thread.clone(), project_id, ref_count, cx)
 980                        })
 981                        .map_err(Arc::new)?;
 982                    let events = thread.update(cx, |thread, cx| thread.replay(cx));
 983                    cx.update(|cx| {
 984                        NativeAgentConnection::handle_thread_events(
 985                            events,
 986                            acp_thread.downgrade(),
 987                            cx,
 988                        )
 989                    })
 990                    .await
 991                    .map_err(Arc::new)?;
 992                    acp_thread.update(cx, |thread, cx| {
 993                        thread.snapshot_completed_plan(cx);
 994                    });
 995                    Ok(acp_thread)
 996                }
 997            })
 998            .shared();
 999        self.pending_sessions.insert(
1000            id,
1001            PendingSession {
1002                task: shared_task.clone(),
1003                ref_count: 1,
1004            },
1005        );
1006
1007        cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
1008    }
1009
1010    pub fn thread_summary(
1011        &mut self,
1012        id: acp::SessionId,
1013        project: Entity<Project>,
1014        cx: &mut Context<Self>,
1015    ) -> Task<Result<SharedString>> {
1016        let thread = self.open_thread(id.clone(), project, cx);
1017        cx.spawn(async move |this, cx| {
1018            let acp_thread = thread.await?;
1019            let result = this
1020                .update(cx, |this, cx| {
1021                    this.sessions
1022                        .get(&id)
1023                        .unwrap()
1024                        .thread
1025                        .update(cx, |thread, cx| thread.summary(cx))
1026                })?
1027                .await
1028                .context("Failed to generate summary")?;
1029
1030            this.update(cx, |this, cx| this.close_session(&id, cx))?
1031                .await?;
1032            drop(acp_thread);
1033            Ok(result)
1034        })
1035    }
1036
1037    fn close_session(
1038        &mut self,
1039        session_id: &acp::SessionId,
1040        cx: &mut Context<Self>,
1041    ) -> Task<Result<()>> {
1042        let Some(session) = self.sessions.get_mut(session_id) else {
1043            return Task::ready(Ok(()));
1044        };
1045
1046        session.ref_count -= 1;
1047        if session.ref_count > 0 {
1048            return Task::ready(Ok(()));
1049        }
1050
1051        let thread = session.thread.clone();
1052        self.save_thread(thread, cx);
1053        let Some(session) = self.sessions.remove(session_id) else {
1054            return Task::ready(Ok(()));
1055        };
1056        let project_id = session.project_id;
1057
1058        let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
1059        if !has_remaining {
1060            self.projects.remove(&project_id);
1061        }
1062
1063        session.pending_save
1064    }
1065
1066    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
1067        if thread.read(cx).is_empty() {
1068            return;
1069        }
1070
1071        let id = thread.read(cx).id().clone();
1072        let Some(session) = self.sessions.get_mut(&id) else {
1073            return;
1074        };
1075
1076        let project_id = session.project_id;
1077        let Some(state) = self.projects.get(&project_id) else {
1078            return;
1079        };
1080
1081        let folder_paths = PathList::new(
1082            &state
1083                .project
1084                .read(cx)
1085                .visible_worktrees(cx)
1086                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
1087                .collect::<Vec<_>>(),
1088        );
1089
1090        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
1091        let database_future = ThreadsDatabase::connect(cx);
1092        let db_thread = thread.update(cx, |thread, cx| {
1093            thread.set_draft_prompt(draft_prompt);
1094            thread.to_db(cx)
1095        });
1096        let thread_store = self.thread_store.clone();
1097        session.pending_save = cx.spawn(async move |_, cx| {
1098            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1099                return Ok(());
1100            };
1101            let db_thread = db_thread.await;
1102            database
1103                .save_thread(id, db_thread, folder_paths)
1104                .await
1105                .log_err();
1106            thread_store.update(cx, |store, cx| store.reload(cx));
1107            Ok(())
1108        });
1109    }
1110
1111    fn send_mcp_prompt(
1112        &self,
1113        message_id: UserMessageId,
1114        session_id: acp::SessionId,
1115        prompt_name: String,
1116        server_id: ContextServerId,
1117        arguments: HashMap<String, String>,
1118        original_content: Vec<acp::ContentBlock>,
1119        cx: &mut Context<Self>,
1120    ) -> Task<Result<acp::PromptResponse>> {
1121        let Some(state) = self.session_project_state(&session_id) else {
1122            return Task::ready(Err(anyhow!("Project state not found for session")));
1123        };
1124        let server_store = state
1125            .context_server_registry
1126            .read(cx)
1127            .server_store()
1128            .clone();
1129        let path_style = state.project.read(cx).path_style(cx);
1130
1131        cx.spawn(async move |this, cx| {
1132            let prompt =
1133                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1134
1135            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1136                let session = this
1137                    .sessions
1138                    .get(&session_id)
1139                    .context("Failed to get session")?;
1140                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1141            })??;
1142
1143            let mut last_is_user = true;
1144
1145            thread.update(cx, |thread, cx| {
1146                thread.push_acp_user_block(
1147                    message_id,
1148                    original_content.into_iter().skip(1),
1149                    path_style,
1150                    cx,
1151                );
1152            });
1153
1154            for message in prompt.messages {
1155                let context_server::types::PromptMessage { role, content } = message;
1156                let block = mcp_message_content_to_acp_content_block(content);
1157
1158                match role {
1159                    context_server::types::Role::User => {
1160                        let id = acp_thread::UserMessageId::new();
1161
1162                        acp_thread.update(cx, |acp_thread, cx| {
1163                            acp_thread.push_user_content_block_with_indent(
1164                                Some(id.clone()),
1165                                block.clone(),
1166                                true,
1167                                cx,
1168                            );
1169                        });
1170
1171                        thread.update(cx, |thread, cx| {
1172                            thread.push_acp_user_block(id, [block], path_style, cx);
1173                        });
1174                    }
1175                    context_server::types::Role::Assistant => {
1176                        acp_thread.update(cx, |acp_thread, cx| {
1177                            acp_thread.push_assistant_content_block_with_indent(
1178                                block.clone(),
1179                                false,
1180                                true,
1181                                cx,
1182                            );
1183                        });
1184
1185                        thread.update(cx, |thread, cx| {
1186                            thread.push_acp_agent_block(block, cx);
1187                        });
1188                    }
1189                }
1190
1191                last_is_user = role == context_server::types::Role::User;
1192            }
1193
1194            let response_stream = thread.update(cx, |thread, cx| {
1195                if last_is_user {
1196                    thread.send_existing(cx)
1197                } else {
1198                    // Resume if MCP prompt did not end with a user message
1199                    thread.resume(cx)
1200                }
1201            })?;
1202
1203            cx.update(|cx| {
1204                NativeAgentConnection::handle_thread_events(
1205                    response_stream,
1206                    acp_thread.downgrade(),
1207                    cx,
1208                )
1209            })
1210            .await
1211        })
1212    }
1213}
1214
1215/// Wrapper struct that implements the AgentConnection trait
1216#[derive(Clone)]
1217pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1218
1219impl NativeAgentConnection {
1220    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1221        self.0
1222            .read(cx)
1223            .sessions
1224            .get(session_id)
1225            .map(|session| session.thread.clone())
1226    }
1227
1228    pub fn load_thread(
1229        &self,
1230        id: acp::SessionId,
1231        project: Entity<Project>,
1232        cx: &mut App,
1233    ) -> Task<Result<Entity<Thread>>> {
1234        self.0
1235            .update(cx, |this, cx| this.load_thread(id, project, cx))
1236    }
1237
1238    fn run_turn(
1239        &self,
1240        session_id: acp::SessionId,
1241        cx: &mut App,
1242        f: impl 'static
1243        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1244    ) -> Task<Result<acp::PromptResponse>> {
1245        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1246            agent
1247                .sessions
1248                .get_mut(&session_id)
1249                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1250        }) else {
1251            log::error!("Session not found in run_turn: {}", session_id);
1252            return Task::ready(Err(anyhow!("Session not found")));
1253        };
1254        log::debug!("Found session for: {}", session_id);
1255
1256        let response_stream = match f(thread, cx) {
1257            Ok(stream) => stream,
1258            Err(err) => return Task::ready(Err(err)),
1259        };
1260        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1261    }
1262
1263    fn handle_thread_events(
1264        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1265        acp_thread: WeakEntity<AcpThread>,
1266        cx: &App,
1267    ) -> Task<Result<acp::PromptResponse>> {
1268        cx.spawn(async move |cx| {
1269            // Handle response stream and forward to session.acp_thread
1270            while let Some(result) = events.next().await {
1271                match result {
1272                    Ok(event) => {
1273                        log::trace!("Received completion event: {:?}", event);
1274
1275                        match event {
1276                            ThreadEvent::UserMessage(message) => {
1277                                acp_thread.update(cx, |thread, cx| {
1278                                    for content in message.content {
1279                                        thread.push_user_content_block(
1280                                            Some(message.id.clone()),
1281                                            content.into(),
1282                                            cx,
1283                                        );
1284                                    }
1285                                })?;
1286                            }
1287                            ThreadEvent::AgentText(text) => {
1288                                acp_thread.update(cx, |thread, cx| {
1289                                    thread.push_assistant_content_block(text.into(), false, cx)
1290                                })?;
1291                            }
1292                            ThreadEvent::AgentThinking(text) => {
1293                                acp_thread.update(cx, |thread, cx| {
1294                                    thread.push_assistant_content_block(text.into(), true, cx)
1295                                })?;
1296                            }
1297                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1298                                tool_call,
1299                                options,
1300                                response,
1301                                context: _,
1302                            }) => {
1303                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1304                                    thread.request_tool_call_authorization(tool_call, options, cx)
1305                                })??;
1306                                cx.background_spawn(async move {
1307                                    if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
1308                                        outcome_task.await
1309                                    {
1310                                        response
1311                                            .send(outcome)
1312                                            .map_err(|_| {
1313                                                anyhow!("authorization receiver was dropped")
1314                                            })
1315                                            .log_err();
1316                                    }
1317                                })
1318                                .detach();
1319                            }
1320                            ThreadEvent::ToolCall(tool_call) => {
1321                                acp_thread.update(cx, |thread, cx| {
1322                                    thread.upsert_tool_call(tool_call, cx)
1323                                })??;
1324                            }
1325                            ThreadEvent::ToolCallUpdate(update) => {
1326                                acp_thread.update(cx, |thread, cx| {
1327                                    thread.update_tool_call(update, cx)
1328                                })??;
1329                            }
1330                            ThreadEvent::Plan(plan) => {
1331                                acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
1332                            }
1333                            ThreadEvent::SubagentSpawned(session_id) => {
1334                                acp_thread.update(cx, |thread, cx| {
1335                                    thread.subagent_spawned(session_id, cx);
1336                                })?;
1337                            }
1338                            ThreadEvent::Retry(status) => {
1339                                acp_thread.update(cx, |thread, cx| {
1340                                    thread.update_retry_status(status, cx)
1341                                })?;
1342                            }
1343                            ThreadEvent::Stop(stop_reason) => {
1344                                log::debug!("Assistant message complete: {:?}", stop_reason);
1345                                return Ok(acp::PromptResponse::new(stop_reason));
1346                            }
1347                        }
1348                    }
1349                    Err(e) => {
1350                        log::error!("Error in model response stream: {:?}", e);
1351                        return Err(e);
1352                    }
1353                }
1354            }
1355
1356            log::debug!("Response stream completed");
1357            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1358        })
1359    }
1360}
1361
1362struct Command<'a> {
1363    prompt_name: &'a str,
1364    arg_value: &'a str,
1365    explicit_server_id: Option<&'a str>,
1366}
1367
1368impl<'a> Command<'a> {
1369    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1370        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1371            return None;
1372        };
1373        let text = text_content.text.trim();
1374        let command = text.strip_prefix('/')?;
1375        let (command, arg_value) = command
1376            .split_once(char::is_whitespace)
1377            .unwrap_or((command, ""));
1378
1379        if let Some((server_id, prompt_name)) = command.split_once('.') {
1380            Some(Self {
1381                prompt_name,
1382                arg_value,
1383                explicit_server_id: Some(server_id),
1384            })
1385        } else {
1386            Some(Self {
1387                prompt_name: command,
1388                arg_value,
1389                explicit_server_id: None,
1390            })
1391        }
1392    }
1393}
1394
1395struct NativeAgentModelSelector {
1396    session_id: acp::SessionId,
1397    connection: NativeAgentConnection,
1398}
1399
1400impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1401    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1402        log::debug!("NativeAgentConnection::list_models called");
1403        let list = self.connection.0.read(cx).models.model_list.clone();
1404        Task::ready(if list.is_empty() {
1405            Err(anyhow::anyhow!("No models available"))
1406        } else {
1407            Ok(list)
1408        })
1409    }
1410
1411    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1412        log::debug!(
1413            "Setting model for session {}: {}",
1414            self.session_id,
1415            model_id
1416        );
1417        let Some(thread) = self
1418            .connection
1419            .0
1420            .read(cx)
1421            .sessions
1422            .get(&self.session_id)
1423            .map(|session| session.thread.clone())
1424        else {
1425            return Task::ready(Err(anyhow!("Session not found")));
1426        };
1427
1428        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1429            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1430        };
1431
1432        let favorite = agent_settings::AgentSettings::get_global(cx)
1433            .favorite_models
1434            .iter()
1435            .find(|favorite| {
1436                favorite.provider.0 == model.provider_id().0.as_ref()
1437                    && favorite.model == model.id().0.as_ref()
1438            })
1439            .cloned();
1440
1441        let LanguageModelSelection {
1442            enable_thinking,
1443            effort,
1444            speed,
1445            ..
1446        } = agent_settings::language_model_to_selection(&model, favorite.as_ref());
1447
1448        thread.update(cx, |thread, cx| {
1449            thread.set_model(model.clone(), cx);
1450            thread.set_thinking_effort(effort.clone(), cx);
1451            thread.set_thinking_enabled(enable_thinking, cx);
1452            if let Some(speed) = speed {
1453                thread.set_speed(speed, cx);
1454            }
1455        });
1456
1457        update_settings_file(
1458            self.connection.0.read(cx).fs.clone(),
1459            cx,
1460            move |settings, cx| {
1461                let provider = model.provider_id().0.to_string();
1462                let model = model.id().0.to_string();
1463                let enable_thinking = thread.read(cx).thinking_enabled();
1464                let speed = thread.read(cx).speed();
1465                settings
1466                    .agent
1467                    .get_or_insert_default()
1468                    .set_model(LanguageModelSelection {
1469                        provider: provider.into(),
1470                        model,
1471                        enable_thinking,
1472                        effort,
1473                        speed,
1474                    });
1475            },
1476        );
1477
1478        Task::ready(Ok(()))
1479    }
1480
1481    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1482        let Some(thread) = self
1483            .connection
1484            .0
1485            .read(cx)
1486            .sessions
1487            .get(&self.session_id)
1488            .map(|session| session.thread.clone())
1489        else {
1490            return Task::ready(Err(anyhow!("Session not found")));
1491        };
1492        let Some(model) = thread.read(cx).model() else {
1493            return Task::ready(Err(anyhow!("Model not found")));
1494        };
1495        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1496        else {
1497            return Task::ready(Err(anyhow!("Provider not found")));
1498        };
1499        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1500            model, &provider,
1501        )))
1502    }
1503
1504    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1505        Some(self.connection.0.read(cx).models.watch())
1506    }
1507
1508    fn should_render_footer(&self) -> bool {
1509        true
1510    }
1511}
1512
1513pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1514
1515impl acp_thread::AgentConnection for NativeAgentConnection {
1516    fn agent_id(&self) -> AgentId {
1517        ZED_AGENT_ID.clone()
1518    }
1519
1520    fn telemetry_id(&self) -> SharedString {
1521        "zed".into()
1522    }
1523
1524    fn new_session(
1525        self: Rc<Self>,
1526        project: Entity<Project>,
1527        work_dirs: PathList,
1528        cx: &mut App,
1529    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1530        log::debug!("Creating new thread for project at: {work_dirs:?}");
1531        Task::ready(Ok(self
1532            .0
1533            .update(cx, |agent, cx| agent.new_session(project, cx))))
1534    }
1535
1536    fn supports_load_session(&self) -> bool {
1537        true
1538    }
1539
1540    fn load_session(
1541        self: Rc<Self>,
1542        session_id: acp::SessionId,
1543        project: Entity<Project>,
1544        _work_dirs: PathList,
1545        _title: Option<SharedString>,
1546        cx: &mut App,
1547    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1548        self.0
1549            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1550    }
1551
1552    fn supports_close_session(&self) -> bool {
1553        true
1554    }
1555
1556    fn close_session(
1557        self: Rc<Self>,
1558        session_id: &acp::SessionId,
1559        cx: &mut App,
1560    ) -> Task<Result<()>> {
1561        self.0
1562            .update(cx, |agent, cx| agent.close_session(session_id, cx))
1563    }
1564
1565    fn auth_methods(&self) -> &[acp::AuthMethod] {
1566        &[] // No auth for in-process
1567    }
1568
1569    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1570        Task::ready(Ok(()))
1571    }
1572
1573    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1574        Some(Rc::new(NativeAgentModelSelector {
1575            session_id: session_id.clone(),
1576            connection: self.clone(),
1577        }) as Rc<dyn AgentModelSelector>)
1578    }
1579
1580    fn prompt(
1581        &self,
1582        id: acp_thread::UserMessageId,
1583        params: acp::PromptRequest,
1584        cx: &mut App,
1585    ) -> Task<Result<acp::PromptResponse>> {
1586        let session_id = params.session_id.clone();
1587        log::info!("Received prompt request for session: {}", session_id);
1588        log::debug!("Prompt blocks count: {}", params.prompt.len());
1589
1590        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1591            log::error!("Session not found in prompt: {}", session_id);
1592            if self.0.read(cx).sessions.contains_key(&session_id) {
1593                log::error!(
1594                    "Session found in sessions map, but not in project state: {}",
1595                    session_id
1596                );
1597            }
1598            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1599        };
1600
1601        if let Some(parsed_command) = Command::parse(&params.prompt) {
1602            let registry = project_state.context_server_registry.read(cx);
1603
1604            let explicit_server_id = parsed_command
1605                .explicit_server_id
1606                .map(|server_id| ContextServerId(server_id.into()));
1607
1608            if let Some(prompt) =
1609                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1610            {
1611                let arguments = if !parsed_command.arg_value.is_empty()
1612                    && let Some(arg_name) = prompt
1613                        .prompt
1614                        .arguments
1615                        .as_ref()
1616                        .and_then(|args| args.first())
1617                        .map(|arg| arg.name.clone())
1618                {
1619                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1620                } else {
1621                    Default::default()
1622                };
1623
1624                let prompt_name = prompt.prompt.name.clone();
1625                let server_id = prompt.server_id.clone();
1626
1627                return self.0.update(cx, |agent, cx| {
1628                    agent.send_mcp_prompt(
1629                        id,
1630                        session_id.clone(),
1631                        prompt_name,
1632                        server_id,
1633                        arguments,
1634                        params.prompt,
1635                        cx,
1636                    )
1637                });
1638            }
1639        };
1640
1641        let path_style = project_state.project.read(cx).path_style(cx);
1642
1643        self.run_turn(session_id, cx, move |thread, cx| {
1644            let content: Vec<UserMessageContent> = params
1645                .prompt
1646                .into_iter()
1647                .map(|block| UserMessageContent::from_content_block(block, path_style))
1648                .collect::<Vec<_>>();
1649            log::debug!("Converted prompt to message: {} chars", content.len());
1650            log::debug!("Message id: {:?}", id);
1651            log::debug!("Message content: {:?}", content);
1652
1653            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1654        })
1655    }
1656
1657    fn retry(
1658        &self,
1659        session_id: &acp::SessionId,
1660        _cx: &App,
1661    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1662        Some(Rc::new(NativeAgentSessionRetry {
1663            connection: self.clone(),
1664            session_id: session_id.clone(),
1665        }) as _)
1666    }
1667
1668    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1669        log::info!("Cancelling on session: {}", session_id);
1670        self.0.update(cx, |agent, cx| {
1671            if let Some(session) = agent.sessions.get(session_id) {
1672                session
1673                    .thread
1674                    .update(cx, |thread, cx| thread.cancel(cx))
1675                    .detach();
1676            }
1677        });
1678    }
1679
1680    fn truncate(
1681        &self,
1682        session_id: &acp::SessionId,
1683        cx: &App,
1684    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1685        self.0.read_with(cx, |agent, _cx| {
1686            agent.sessions.get(session_id).map(|session| {
1687                Rc::new(NativeAgentSessionTruncate {
1688                    thread: session.thread.clone(),
1689                    acp_thread: session.acp_thread.downgrade(),
1690                }) as _
1691            })
1692        })
1693    }
1694
1695    fn set_title(
1696        &self,
1697        session_id: &acp::SessionId,
1698        cx: &App,
1699    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1700        self.0.read_with(cx, |agent, _cx| {
1701            agent
1702                .sessions
1703                .get(session_id)
1704                .filter(|s| !s.thread.read(cx).is_subagent())
1705                .map(|session| {
1706                    Rc::new(NativeAgentSessionSetTitle {
1707                        thread: session.thread.clone(),
1708                    }) as _
1709                })
1710        })
1711    }
1712
1713    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1714        let thread_store = self.0.read(cx).thread_store.clone();
1715        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1716    }
1717
1718    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1719        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1720    }
1721
1722    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1723        self
1724    }
1725}
1726
1727impl acp_thread::AgentTelemetry for NativeAgentConnection {
1728    fn thread_data(
1729        &self,
1730        session_id: &acp::SessionId,
1731        cx: &mut App,
1732    ) -> Task<Result<serde_json::Value>> {
1733        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1734            return Task::ready(Err(anyhow!("Session not found")));
1735        };
1736
1737        let task = session.thread.read(cx).to_db(cx);
1738        cx.background_spawn(async move {
1739            serde_json::to_value(task.await).context("Failed to serialize thread")
1740        })
1741    }
1742}
1743
1744pub struct NativeAgentSessionList {
1745    thread_store: Entity<ThreadStore>,
1746    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1747    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1748    _subscription: Subscription,
1749}
1750
1751impl NativeAgentSessionList {
1752    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1753        let (tx, rx) = smol::channel::unbounded();
1754        let this_tx = tx.clone();
1755        let subscription = cx.observe(&thread_store, move |_, _| {
1756            this_tx
1757                .try_send(acp_thread::SessionListUpdate::Refresh)
1758                .ok();
1759        });
1760        Self {
1761            thread_store,
1762            updates_tx: tx,
1763            updates_rx: rx,
1764            _subscription: subscription,
1765        }
1766    }
1767
1768    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1769        &self.thread_store
1770    }
1771}
1772
1773impl AgentSessionList for NativeAgentSessionList {
1774    fn list_sessions(
1775        &self,
1776        _request: AgentSessionListRequest,
1777        cx: &mut App,
1778    ) -> Task<Result<AgentSessionListResponse>> {
1779        let sessions = self
1780            .thread_store
1781            .read(cx)
1782            .entries()
1783            .map(|entry| AgentSessionInfo::from(&entry))
1784            .collect();
1785        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1786    }
1787
1788    fn supports_delete(&self) -> bool {
1789        true
1790    }
1791
1792    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1793        self.thread_store
1794            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1795    }
1796
1797    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1798        self.thread_store
1799            .update(cx, |store, cx| store.delete_threads(cx))
1800    }
1801
1802    fn watch(
1803        &self,
1804        _cx: &mut App,
1805    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1806        Some(self.updates_rx.clone())
1807    }
1808
1809    fn notify_refresh(&self) {
1810        self.updates_tx
1811            .try_send(acp_thread::SessionListUpdate::Refresh)
1812            .ok();
1813    }
1814
1815    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1816        self
1817    }
1818}
1819
1820struct NativeAgentSessionTruncate {
1821    thread: Entity<Thread>,
1822    acp_thread: WeakEntity<AcpThread>,
1823}
1824
1825impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1826    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1827        match self.thread.update(cx, |thread, cx| {
1828            thread.truncate(message_id.clone(), cx)?;
1829            Ok(thread.latest_token_usage())
1830        }) {
1831            Ok(usage) => {
1832                self.acp_thread
1833                    .update(cx, |thread, cx| {
1834                        thread.update_token_usage(usage, cx);
1835                    })
1836                    .ok();
1837                Task::ready(Ok(()))
1838            }
1839            Err(error) => Task::ready(Err(error)),
1840        }
1841    }
1842}
1843
1844struct NativeAgentSessionRetry {
1845    connection: NativeAgentConnection,
1846    session_id: acp::SessionId,
1847}
1848
1849impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1850    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1851        self.connection
1852            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1853                thread.update(cx, |thread, cx| thread.resume(cx))
1854            })
1855    }
1856}
1857
1858struct NativeAgentSessionSetTitle {
1859    thread: Entity<Thread>,
1860}
1861
1862impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1863    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1864        self.thread
1865            .update(cx, |thread, cx| thread.set_title(title, cx));
1866        Task::ready(Ok(()))
1867    }
1868}
1869
1870pub struct NativeThreadEnvironment {
1871    agent: WeakEntity<NativeAgent>,
1872    thread: WeakEntity<Thread>,
1873    acp_thread: WeakEntity<AcpThread>,
1874}
1875
1876impl NativeThreadEnvironment {
1877    pub(crate) fn create_subagent_thread(
1878        &self,
1879        label: String,
1880        cx: &mut App,
1881    ) -> Result<Rc<dyn SubagentHandle>> {
1882        let Some(parent_thread_entity) = self.thread.upgrade() else {
1883            anyhow::bail!("Parent thread no longer exists".to_string());
1884        };
1885        let parent_thread = parent_thread_entity.read(cx);
1886        let current_depth = parent_thread.depth();
1887        let parent_session_id = parent_thread.id().clone();
1888
1889        if current_depth >= MAX_SUBAGENT_DEPTH {
1890            return Err(anyhow!(
1891                "Maximum subagent depth ({}) reached",
1892                MAX_SUBAGENT_DEPTH
1893            ));
1894        }
1895
1896        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1897            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1898            thread.set_title(label.into(), cx);
1899            thread
1900        });
1901
1902        let session_id = subagent_thread.read(cx).id().clone();
1903
1904        let acp_thread = self
1905            .agent
1906            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1907                let project_id = agent
1908                    .sessions
1909                    .get(&parent_session_id)
1910                    .map(|s| s.project_id)
1911                    .context("parent session not found")?;
1912                Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
1913            })??;
1914
1915        let depth = current_depth + 1;
1916
1917        telemetry::event!(
1918            "Subagent Started",
1919            session = parent_thread_entity.read(cx).id().to_string(),
1920            subagent_session = session_id.to_string(),
1921            depth,
1922            is_resumed = false,
1923        );
1924
1925        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1926    }
1927
1928    pub(crate) fn resume_subagent_thread(
1929        &self,
1930        session_id: acp::SessionId,
1931        cx: &mut App,
1932    ) -> Result<Rc<dyn SubagentHandle>> {
1933        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1934            let session = agent
1935                .sessions
1936                .get(&session_id)
1937                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1938            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1939        })??;
1940
1941        let depth = subagent_thread.read(cx).depth();
1942
1943        if let Some(parent_thread_entity) = self.thread.upgrade() {
1944            telemetry::event!(
1945                "Subagent Started",
1946                session = parent_thread_entity.read(cx).id().to_string(),
1947                subagent_session = session_id.to_string(),
1948                depth,
1949                is_resumed = true,
1950            );
1951        }
1952
1953        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1954    }
1955
1956    fn prompt_subagent(
1957        &self,
1958        session_id: acp::SessionId,
1959        subagent_thread: Entity<Thread>,
1960        acp_thread: Entity<acp_thread::AcpThread>,
1961    ) -> Result<Rc<dyn SubagentHandle>> {
1962        let Some(parent_thread_entity) = self.thread.upgrade() else {
1963            anyhow::bail!("Parent thread no longer exists".to_string());
1964        };
1965        Ok(Rc::new(NativeSubagentHandle::new(
1966            session_id,
1967            subagent_thread,
1968            acp_thread,
1969            parent_thread_entity,
1970        )) as _)
1971    }
1972}
1973
1974impl ThreadEnvironment for NativeThreadEnvironment {
1975    fn create_terminal(
1976        &self,
1977        command: String,
1978        cwd: Option<PathBuf>,
1979        output_byte_limit: Option<u64>,
1980        cx: &mut AsyncApp,
1981    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1982        let task = self.acp_thread.update(cx, |thread, cx| {
1983            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1984        });
1985
1986        let acp_thread = self.acp_thread.clone();
1987        cx.spawn(async move |cx| {
1988            let terminal = task?.await?;
1989
1990            let (drop_tx, drop_rx) = oneshot::channel();
1991            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1992
1993            cx.spawn(async move |cx| {
1994                drop_rx.await.ok();
1995                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1996            })
1997            .detach();
1998
1999            let handle = AcpTerminalHandle {
2000                terminal,
2001                _drop_tx: Some(drop_tx),
2002            };
2003
2004            Ok(Rc::new(handle) as _)
2005        })
2006    }
2007
2008    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
2009        self.create_subagent_thread(label, cx)
2010    }
2011
2012    fn resume_subagent(
2013        &self,
2014        session_id: acp::SessionId,
2015        cx: &mut App,
2016    ) -> Result<Rc<dyn SubagentHandle>> {
2017        self.resume_subagent_thread(session_id, cx)
2018    }
2019}
2020
2021#[derive(Debug, Clone)]
2022enum SubagentPromptResult {
2023    Completed,
2024    Cancelled,
2025    ContextWindowWarning,
2026    Error(String),
2027}
2028
2029pub struct NativeSubagentHandle {
2030    session_id: acp::SessionId,
2031    parent_thread: WeakEntity<Thread>,
2032    subagent_thread: Entity<Thread>,
2033    acp_thread: Entity<acp_thread::AcpThread>,
2034}
2035
2036impl NativeSubagentHandle {
2037    fn new(
2038        session_id: acp::SessionId,
2039        subagent_thread: Entity<Thread>,
2040        acp_thread: Entity<acp_thread::AcpThread>,
2041        parent_thread_entity: Entity<Thread>,
2042    ) -> Self {
2043        NativeSubagentHandle {
2044            session_id,
2045            subagent_thread,
2046            parent_thread: parent_thread_entity.downgrade(),
2047            acp_thread,
2048        }
2049    }
2050}
2051
2052impl SubagentHandle for NativeSubagentHandle {
2053    fn id(&self) -> acp::SessionId {
2054        self.session_id.clone()
2055    }
2056
2057    fn num_entries(&self, cx: &App) -> usize {
2058        self.acp_thread.read(cx).entries().len()
2059    }
2060
2061    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
2062        let thread = self.subagent_thread.clone();
2063        let acp_thread = self.acp_thread.clone();
2064        let subagent_session_id = self.session_id.clone();
2065        let parent_thread = self.parent_thread.clone();
2066
2067        cx.spawn(async move |cx| {
2068            let (task, _subscription) = cx.update(|cx| {
2069                let ratio_before_prompt = thread
2070                    .read(cx)
2071                    .latest_token_usage()
2072                    .map(|usage| usage.ratio());
2073
2074                parent_thread
2075                    .update(cx, |parent_thread, _cx| {
2076                        parent_thread.register_running_subagent(thread.downgrade())
2077                    })
2078                    .ok();
2079
2080                let task = acp_thread.update(cx, |acp_thread, cx| {
2081                    acp_thread.send(vec![message.into()], cx)
2082                });
2083
2084                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
2085                let mut token_limit_tx = Some(token_limit_tx);
2086
2087                let subscription = cx.subscribe(
2088                    &thread,
2089                    move |_thread, event: &TokenUsageUpdated, _cx| {
2090                        if let Some(usage) = &event.0 {
2091                            let old_ratio = ratio_before_prompt
2092                                .clone()
2093                                .unwrap_or(TokenUsageRatio::Normal);
2094                            let new_ratio = usage.ratio();
2095                            if old_ratio == TokenUsageRatio::Normal
2096                                && new_ratio == TokenUsageRatio::Warning
2097                            {
2098                                if let Some(tx) = token_limit_tx.take() {
2099                                    tx.send(()).ok();
2100                                }
2101                            }
2102                        }
2103                    },
2104                );
2105
2106                let wait_for_prompt = cx
2107                    .background_spawn(async move {
2108                        futures::select! {
2109                            response = task.fuse() => match response {
2110                                Ok(Some(response)) => {
2111                                    match response.stop_reason {
2112                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
2113                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
2114                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
2115                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
2116                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
2117                                    }
2118                                }
2119                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2120                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2121                            },
2122                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2123                        }
2124                    });
2125
2126                (wait_for_prompt, subscription)
2127            });
2128
2129            let result = match task.await {
2130                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2131                    thread
2132                        .last_message()
2133                        .and_then(|message| {
2134                            let content = message.as_agent_message()?
2135                                .content
2136                                .iter()
2137                                .filter_map(|c| match c {
2138                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2139                                    _ => None,
2140                                })
2141                                .join("\n\n");
2142                            if content.is_empty() {
2143                                None
2144                            } else {
2145                                Some( content)
2146                            }
2147                        })
2148                        .context("No response from subagent")
2149                }),
2150                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2151                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2152                SubagentPromptResult::ContextWindowWarning => {
2153                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2154                    Err(anyhow!(
2155                        "The agent is nearing the end of its context window and has been \
2156                         stopped. You can prompt the thread again to have the agent wrap up \
2157                         or hand off its work."
2158                    ))
2159                }
2160            };
2161
2162            parent_thread
2163                .update(cx, |parent_thread, cx| {
2164                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2165                })
2166                .ok();
2167
2168            result
2169        })
2170    }
2171}
2172
2173pub struct AcpTerminalHandle {
2174    terminal: Entity<acp_thread::Terminal>,
2175    _drop_tx: Option<oneshot::Sender<()>>,
2176}
2177
2178impl TerminalHandle for AcpTerminalHandle {
2179    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2180        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2181    }
2182
2183    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2184        Ok(self
2185            .terminal
2186            .read_with(cx, |term, _cx| term.wait_for_exit()))
2187    }
2188
2189    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2190        Ok(self
2191            .terminal
2192            .read_with(cx, |term, cx| term.current_output(cx)))
2193    }
2194
2195    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2196        cx.update(|cx| {
2197            self.terminal.update(cx, |terminal, cx| {
2198                terminal.kill(cx);
2199            });
2200        });
2201        Ok(())
2202    }
2203
2204    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2205        Ok(self
2206            .terminal
2207            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2208    }
2209}
2210
2211#[cfg(test)]
2212mod internal_tests {
2213    use std::path::Path;
2214
2215    use super::*;
2216    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2217    use fs::FakeFs;
2218    use gpui::TestAppContext;
2219    use indoc::formatdoc;
2220    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2221    use language_model::{
2222        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2223    };
2224    use serde_json::json;
2225    use settings::SettingsStore;
2226    use util::{path, rel_path::rel_path};
2227
2228    #[gpui::test]
2229    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2230        init_test(cx);
2231        let fs = FakeFs::new(cx.executor());
2232        fs.insert_tree(
2233            "/",
2234            json!({
2235                "a": {}
2236            }),
2237        )
2238        .await;
2239        let project = Project::test(fs.clone(), [], cx).await;
2240        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2241        let agent =
2242            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2243
2244        // Creating a session registers the project and triggers context building.
2245        let connection = NativeAgentConnection(agent.clone());
2246        let _acp_thread = cx
2247            .update(|cx| {
2248                Rc::new(connection).new_session(
2249                    project.clone(),
2250                    PathList::new(&[Path::new("/")]),
2251                    cx,
2252                )
2253            })
2254            .await
2255            .unwrap();
2256        cx.run_until_parked();
2257
2258        let thread = agent.read_with(cx, |agent, _cx| {
2259            agent.sessions.values().next().unwrap().thread.clone()
2260        });
2261
2262        agent.read_with(cx, |agent, cx| {
2263            let project_id = project.entity_id();
2264            let state = agent.projects.get(&project_id).unwrap();
2265            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2266            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2267        });
2268
2269        let worktree = project
2270            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2271            .await
2272            .unwrap();
2273        cx.run_until_parked();
2274        agent.read_with(cx, |agent, cx| {
2275            let project_id = project.entity_id();
2276            let state = agent.projects.get(&project_id).unwrap();
2277            let expected_worktrees = vec![WorktreeContext {
2278                root_name: "a".into(),
2279                abs_path: Path::new("/a").into(),
2280                rules_file: None,
2281            }];
2282            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2283            assert_eq!(
2284                thread.read(cx).project_context().read(cx).worktrees,
2285                expected_worktrees
2286            );
2287        });
2288
2289        // Creating `/a/.rules` updates the project context.
2290        fs.insert_file("/a/.rules", Vec::new()).await;
2291        cx.run_until_parked();
2292        agent.read_with(cx, |agent, cx| {
2293            let project_id = project.entity_id();
2294            let state = agent.projects.get(&project_id).unwrap();
2295            let rules_entry = worktree
2296                .read(cx)
2297                .entry_for_path(rel_path(".rules"))
2298                .unwrap();
2299            let expected_worktrees = vec![WorktreeContext {
2300                root_name: "a".into(),
2301                abs_path: Path::new("/a").into(),
2302                rules_file: Some(RulesFileContext {
2303                    path_in_worktree: rel_path(".rules").into(),
2304                    text: "".into(),
2305                    project_entry_id: rules_entry.id.to_usize(),
2306                }),
2307            }];
2308            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2309            assert_eq!(
2310                thread.read(cx).project_context().read(cx).worktrees,
2311                expected_worktrees
2312            );
2313        });
2314    }
2315
2316    #[gpui::test]
2317    async fn test_listing_models(cx: &mut TestAppContext) {
2318        init_test(cx);
2319        let fs = FakeFs::new(cx.executor());
2320        fs.insert_tree("/", json!({ "a": {}  })).await;
2321        let project = Project::test(fs.clone(), [], cx).await;
2322        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2323        let connection =
2324            NativeAgentConnection(cx.update(|cx| {
2325                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2326            }));
2327
2328        // Create a thread/session
2329        let acp_thread = cx
2330            .update(|cx| {
2331                Rc::new(connection.clone()).new_session(
2332                    project.clone(),
2333                    PathList::new(&[Path::new("/a")]),
2334                    cx,
2335                )
2336            })
2337            .await
2338            .unwrap();
2339
2340        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2341
2342        let models = cx
2343            .update(|cx| {
2344                connection
2345                    .model_selector(&session_id)
2346                    .unwrap()
2347                    .list_models(cx)
2348            })
2349            .await
2350            .unwrap();
2351
2352        let acp_thread::AgentModelList::Grouped(models) = models else {
2353            panic!("Unexpected model group");
2354        };
2355        assert_eq!(
2356            models,
2357            IndexMap::from_iter([(
2358                AgentModelGroupName("Fake".into()),
2359                vec![AgentModelInfo {
2360                    id: acp::ModelId::new("fake/fake"),
2361                    name: "Fake".into(),
2362                    description: None,
2363                    icon: Some(acp_thread::AgentModelIcon::Named(
2364                        ui::IconName::ZedAssistant
2365                    )),
2366                    is_latest: false,
2367                    cost: None,
2368                }]
2369            )])
2370        );
2371    }
2372
2373    #[gpui::test]
2374    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2375        init_test(cx);
2376        let fs = FakeFs::new(cx.executor());
2377        fs.create_dir(paths::settings_file().parent().unwrap())
2378            .await
2379            .unwrap();
2380        fs.insert_file(
2381            paths::settings_file(),
2382            json!({
2383                "agent": {
2384                    "default_model": {
2385                        "provider": "foo",
2386                        "model": "bar"
2387                    }
2388                }
2389            })
2390            .to_string()
2391            .into_bytes(),
2392        )
2393        .await;
2394        let project = Project::test(fs.clone(), [], cx).await;
2395
2396        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2397
2398        // Create the agent and connection
2399        let agent =
2400            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2401        let connection = NativeAgentConnection(agent.clone());
2402
2403        // Create a thread/session
2404        let acp_thread = cx
2405            .update(|cx| {
2406                Rc::new(connection.clone()).new_session(
2407                    project.clone(),
2408                    PathList::new(&[Path::new("/a")]),
2409                    cx,
2410                )
2411            })
2412            .await
2413            .unwrap();
2414
2415        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2416
2417        // Select a model
2418        let selector = connection.model_selector(&session_id).unwrap();
2419        let model_id = acp::ModelId::new("fake/fake");
2420        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2421            .await
2422            .unwrap();
2423
2424        // Verify the thread has the selected model
2425        agent.read_with(cx, |agent, _| {
2426            let session = agent.sessions.get(&session_id).unwrap();
2427            session.thread.read_with(cx, |thread, _| {
2428                assert_eq!(thread.model().unwrap().id().0, "fake");
2429            });
2430        });
2431
2432        cx.run_until_parked();
2433
2434        // Verify settings file was updated
2435        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2436        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2437
2438        // Check that the agent settings contain the selected model
2439        assert_eq!(
2440            settings_json["agent"]["default_model"]["model"],
2441            json!("fake")
2442        );
2443        assert_eq!(
2444            settings_json["agent"]["default_model"]["provider"],
2445            json!("fake")
2446        );
2447
2448        // Register a thinking model and select it.
2449        cx.update(|cx| {
2450            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2451                "fake-corp",
2452                "fake-thinking",
2453                "Fake Thinking",
2454                true,
2455            ));
2456            let thinking_provider = Arc::new(
2457                FakeLanguageModelProvider::new(
2458                    LanguageModelProviderId::from("fake-corp".to_string()),
2459                    LanguageModelProviderName::from("Fake Corp".to_string()),
2460                )
2461                .with_models(vec![thinking_model]),
2462            );
2463            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2464                registry.register_provider(thinking_provider, cx);
2465            });
2466        });
2467        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2468
2469        let selector = connection.model_selector(&session_id).unwrap();
2470        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2471            .await
2472            .unwrap();
2473        cx.run_until_parked();
2474
2475        // Verify enable_thinking was written to settings as true.
2476        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2477        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2478        assert_eq!(
2479            settings_json["agent"]["default_model"]["enable_thinking"],
2480            json!(true),
2481            "selecting a thinking model should persist enable_thinking: true to settings"
2482        );
2483    }
2484
2485    #[gpui::test]
2486    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2487        init_test(cx);
2488        let fs = FakeFs::new(cx.executor());
2489        fs.create_dir(paths::settings_file().parent().unwrap())
2490            .await
2491            .unwrap();
2492        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2493        let project = Project::test(fs.clone(), [], cx).await;
2494
2495        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2496        let agent =
2497            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2498        let connection = NativeAgentConnection(agent.clone());
2499
2500        let acp_thread = cx
2501            .update(|cx| {
2502                Rc::new(connection.clone()).new_session(
2503                    project.clone(),
2504                    PathList::new(&[Path::new("/a")]),
2505                    cx,
2506                )
2507            })
2508            .await
2509            .unwrap();
2510        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2511
2512        // Register a second provider with a thinking model.
2513        cx.update(|cx| {
2514            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2515                "fake-corp",
2516                "fake-thinking",
2517                "Fake Thinking",
2518                true,
2519            ));
2520            let thinking_provider = Arc::new(
2521                FakeLanguageModelProvider::new(
2522                    LanguageModelProviderId::from("fake-corp".to_string()),
2523                    LanguageModelProviderName::from("Fake Corp".to_string()),
2524                )
2525                .with_models(vec![thinking_model]),
2526            );
2527            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2528                registry.register_provider(thinking_provider, cx);
2529            });
2530        });
2531        // Refresh the agent's model list so it picks up the new provider.
2532        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2533
2534        // Thread starts with thinking_enabled = false (the default).
2535        agent.read_with(cx, |agent, _| {
2536            let session = agent.sessions.get(&session_id).unwrap();
2537            session.thread.read_with(cx, |thread, _| {
2538                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2539            });
2540        });
2541
2542        // Select the thinking model via select_model.
2543        let selector = connection.model_selector(&session_id).unwrap();
2544        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2545            .await
2546            .unwrap();
2547
2548        // select_model should have enabled thinking based on the model's supports_thinking().
2549        agent.read_with(cx, |agent, _| {
2550            let session = agent.sessions.get(&session_id).unwrap();
2551            session.thread.read_with(cx, |thread, _| {
2552                assert!(
2553                    thread.thinking_enabled(),
2554                    "select_model should enable thinking when model supports it"
2555                );
2556            });
2557        });
2558
2559        // Switch back to the non-thinking model.
2560        let selector = connection.model_selector(&session_id).unwrap();
2561        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2562            .await
2563            .unwrap();
2564
2565        // select_model should have disabled thinking.
2566        agent.read_with(cx, |agent, _| {
2567            let session = agent.sessions.get(&session_id).unwrap();
2568            session.thread.read_with(cx, |thread, _| {
2569                assert!(
2570                    !thread.thinking_enabled(),
2571                    "select_model should disable thinking when model does not support it"
2572                );
2573            });
2574        });
2575    }
2576
2577    #[gpui::test]
2578    async fn test_summarization_model_survives_transient_registry_clearing(
2579        cx: &mut TestAppContext,
2580    ) {
2581        init_test(cx);
2582        let fs = FakeFs::new(cx.executor());
2583        fs.insert_tree("/", json!({ "a": {} })).await;
2584        let project = Project::test(fs.clone(), [], cx).await;
2585
2586        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2587        let agent =
2588            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2589        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2590
2591        let acp_thread = cx
2592            .update(|cx| {
2593                connection.clone().new_session(
2594                    project.clone(),
2595                    PathList::new(&[Path::new("/a")]),
2596                    cx,
2597                )
2598            })
2599            .await
2600            .unwrap();
2601        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2602
2603        let thread = agent.read_with(cx, |agent, _| {
2604            agent.sessions.get(&session_id).unwrap().thread.clone()
2605        });
2606
2607        thread.read_with(cx, |thread, _| {
2608            assert!(
2609                thread.summarization_model().is_some(),
2610                "session should have a summarization model from the test registry"
2611            );
2612        });
2613
2614        // Simulate what happens during a provider blip:
2615        // update_active_language_model_from_settings calls set_default_model(None)
2616        // when it can't resolve the model, clearing all fallbacks.
2617        cx.update(|cx| {
2618            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2619                registry.set_default_model(None, cx);
2620            });
2621        });
2622        cx.run_until_parked();
2623
2624        thread.read_with(cx, |thread, _| {
2625            assert!(
2626                thread.summarization_model().is_some(),
2627                "summarization model should survive a transient default model clearing"
2628            );
2629        });
2630    }
2631
2632    #[gpui::test]
2633    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2634        init_test(cx);
2635        let fs = FakeFs::new(cx.executor());
2636        fs.insert_tree("/", json!({ "a": {} })).await;
2637        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2638        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2639        let agent = cx.update(|cx| {
2640            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2641        });
2642        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2643
2644        // Register a thinking model.
2645        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2646            "fake-corp",
2647            "fake-thinking",
2648            "Fake Thinking",
2649            true,
2650        ));
2651        let thinking_provider = Arc::new(
2652            FakeLanguageModelProvider::new(
2653                LanguageModelProviderId::from("fake-corp".to_string()),
2654                LanguageModelProviderName::from("Fake Corp".to_string()),
2655            )
2656            .with_models(vec![thinking_model.clone()]),
2657        );
2658        cx.update(|cx| {
2659            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2660                registry.register_provider(thinking_provider, cx);
2661            });
2662        });
2663        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2664
2665        // Create a thread and select the thinking model.
2666        let acp_thread = cx
2667            .update(|cx| {
2668                connection.clone().new_session(
2669                    project.clone(),
2670                    PathList::new(&[Path::new("/a")]),
2671                    cx,
2672                )
2673            })
2674            .await
2675            .unwrap();
2676        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2677
2678        let selector = connection.model_selector(&session_id).unwrap();
2679        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2680            .await
2681            .unwrap();
2682
2683        // Verify thinking is enabled after selecting the thinking model.
2684        let thread = agent.read_with(cx, |agent, _| {
2685            agent.sessions.get(&session_id).unwrap().thread.clone()
2686        });
2687        thread.read_with(cx, |thread, _| {
2688            assert!(
2689                thread.thinking_enabled(),
2690                "thinking should be enabled after selecting thinking model"
2691            );
2692        });
2693
2694        // Send a message so the thread gets persisted.
2695        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2696        let send = cx.foreground_executor().spawn(send);
2697        cx.run_until_parked();
2698
2699        thinking_model.send_last_completion_stream_text_chunk("Response.");
2700        thinking_model.end_last_completion_stream();
2701
2702        send.await.unwrap();
2703        cx.run_until_parked();
2704
2705        // Close the session so it can be reloaded from disk.
2706        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2707            .await
2708            .unwrap();
2709        drop(thread);
2710        drop(acp_thread);
2711        agent.read_with(cx, |agent, _| {
2712            assert!(agent.sessions.is_empty());
2713        });
2714
2715        // Reload the thread and verify thinking_enabled is still true.
2716        let reloaded_acp_thread = agent
2717            .update(cx, |agent, cx| {
2718                agent.open_thread(session_id.clone(), project.clone(), cx)
2719            })
2720            .await
2721            .unwrap();
2722        let reloaded_thread = agent.read_with(cx, |agent, _| {
2723            agent.sessions.get(&session_id).unwrap().thread.clone()
2724        });
2725        reloaded_thread.read_with(cx, |thread, _| {
2726            assert!(
2727                thread.thinking_enabled(),
2728                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2729            );
2730        });
2731
2732        drop(reloaded_acp_thread);
2733    }
2734
2735    #[gpui::test]
2736    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2737        init_test(cx);
2738        let fs = FakeFs::new(cx.executor());
2739        fs.insert_tree("/", json!({ "a": {} })).await;
2740        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2741        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2742        let agent = cx.update(|cx| {
2743            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2744        });
2745        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2746
2747        // Register a model where id() != name(), like real Anthropic models
2748        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2749        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2750            "fake-corp",
2751            "custom-model-id",
2752            "Custom Model Display Name",
2753            false,
2754        ));
2755        let provider = Arc::new(
2756            FakeLanguageModelProvider::new(
2757                LanguageModelProviderId::from("fake-corp".to_string()),
2758                LanguageModelProviderName::from("Fake Corp".to_string()),
2759            )
2760            .with_models(vec![model.clone()]),
2761        );
2762        cx.update(|cx| {
2763            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2764                registry.register_provider(provider, cx);
2765            });
2766        });
2767        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2768
2769        // Create a thread and select the model.
2770        let acp_thread = cx
2771            .update(|cx| {
2772                connection.clone().new_session(
2773                    project.clone(),
2774                    PathList::new(&[Path::new("/a")]),
2775                    cx,
2776                )
2777            })
2778            .await
2779            .unwrap();
2780        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2781
2782        let selector = connection.model_selector(&session_id).unwrap();
2783        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2784            .await
2785            .unwrap();
2786
2787        let thread = agent.read_with(cx, |agent, _| {
2788            agent.sessions.get(&session_id).unwrap().thread.clone()
2789        });
2790        thread.read_with(cx, |thread, _| {
2791            assert_eq!(
2792                thread.model().unwrap().id().0.as_ref(),
2793                "custom-model-id",
2794                "model should be set before persisting"
2795            );
2796        });
2797
2798        // Send a message so the thread gets persisted.
2799        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2800        let send = cx.foreground_executor().spawn(send);
2801        cx.run_until_parked();
2802
2803        model.send_last_completion_stream_text_chunk("Response.");
2804        model.end_last_completion_stream();
2805
2806        send.await.unwrap();
2807        cx.run_until_parked();
2808
2809        // Close the session so it can be reloaded from disk.
2810        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2811            .await
2812            .unwrap();
2813        drop(thread);
2814        drop(acp_thread);
2815        agent.read_with(cx, |agent, _| {
2816            assert!(agent.sessions.is_empty());
2817        });
2818
2819        // Reload the thread and verify the model was preserved.
2820        let reloaded_acp_thread = agent
2821            .update(cx, |agent, cx| {
2822                agent.open_thread(session_id.clone(), project.clone(), cx)
2823            })
2824            .await
2825            .unwrap();
2826        let reloaded_thread = agent.read_with(cx, |agent, _| {
2827            agent.sessions.get(&session_id).unwrap().thread.clone()
2828        });
2829        reloaded_thread.read_with(cx, |thread, _| {
2830            let reloaded_model = thread
2831                .model()
2832                .expect("model should be present after reload");
2833            assert_eq!(
2834                reloaded_model.id().0.as_ref(),
2835                "custom-model-id",
2836                "reloaded thread should have the same model, not fall back to the default"
2837            );
2838        });
2839
2840        drop(reloaded_acp_thread);
2841    }
2842
2843    #[gpui::test]
2844    async fn test_save_load_thread(cx: &mut TestAppContext) {
2845        init_test(cx);
2846        let fs = FakeFs::new(cx.executor());
2847        fs.insert_tree(
2848            "/",
2849            json!({
2850                "a": {
2851                    "b.md": "Lorem"
2852                }
2853            }),
2854        )
2855        .await;
2856        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2857        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2858        let agent = cx.update(|cx| {
2859            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2860        });
2861        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2862
2863        let acp_thread = cx
2864            .update(|cx| {
2865                connection
2866                    .clone()
2867                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2868            })
2869            .await
2870            .unwrap();
2871        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2872        let thread = agent.read_with(cx, |agent, _| {
2873            agent.sessions.get(&session_id).unwrap().thread.clone()
2874        });
2875
2876        // Ensure empty threads are not saved, even if they get mutated.
2877        let model = Arc::new(FakeLanguageModel::default());
2878        let summary_model = Arc::new(FakeLanguageModel::default());
2879        thread.update(cx, |thread, cx| {
2880            thread.set_model(model.clone(), cx);
2881            thread.set_summarization_model(Some(summary_model.clone()), cx);
2882        });
2883        cx.run_until_parked();
2884        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2885
2886        let send = acp_thread.update(cx, |thread, cx| {
2887            thread.send(
2888                vec![
2889                    "What does ".into(),
2890                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2891                        "b.md",
2892                        MentionUri::File {
2893                            abs_path: path!("/a/b.md").into(),
2894                        }
2895                        .to_uri()
2896                        .to_string(),
2897                    )),
2898                    " mean?".into(),
2899                ],
2900                cx,
2901            )
2902        });
2903        let send = cx.foreground_executor().spawn(send);
2904        cx.run_until_parked();
2905
2906        model.send_last_completion_stream_text_chunk("Lorem.");
2907        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2908            language_model::TokenUsage {
2909                input_tokens: 150,
2910                output_tokens: 75,
2911                ..Default::default()
2912            },
2913        ));
2914        model.end_last_completion_stream();
2915        cx.run_until_parked();
2916        summary_model
2917            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2918        summary_model.end_last_completion_stream();
2919
2920        send.await.unwrap();
2921        let uri = MentionUri::File {
2922            abs_path: path!("/a/b.md").into(),
2923        }
2924        .to_uri();
2925        acp_thread.read_with(cx, |thread, cx| {
2926            assert_eq!(
2927                thread.to_markdown(cx),
2928                formatdoc! {"
2929                    ## User
2930
2931                    What does [@b.md]({uri}) mean?
2932
2933                    ## Assistant
2934
2935                    Lorem.
2936
2937                "}
2938            )
2939        });
2940
2941        cx.run_until_parked();
2942
2943        // Set a draft prompt with rich content blocks and scroll position
2944        // AFTER run_until_parked, so the only save that captures these
2945        // changes is the one performed by close_session itself.
2946        let draft_blocks = vec![
2947            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2948            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2949            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2950        ];
2951        acp_thread.update(cx, |thread, cx| {
2952            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
2953        });
2954        thread.update(cx, |thread, _cx| {
2955            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2956                item_ix: 5,
2957                offset_in_item: gpui::px(12.5),
2958            }));
2959        });
2960
2961        // Close the session so it can be reloaded from disk.
2962        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2963            .await
2964            .unwrap();
2965        drop(thread);
2966        drop(acp_thread);
2967        agent.read_with(cx, |agent, _| {
2968            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2969        });
2970
2971        // Ensure the thread can be reloaded from disk.
2972        assert_eq!(
2973            thread_entries(&thread_store, cx),
2974            vec![(
2975                session_id.clone(),
2976                format!("Explaining {}", path!("/a/b.md"))
2977            )]
2978        );
2979        let acp_thread = agent
2980            .update(cx, |agent, cx| {
2981                agent.open_thread(session_id.clone(), project.clone(), cx)
2982            })
2983            .await
2984            .unwrap();
2985        acp_thread.read_with(cx, |thread, cx| {
2986            assert_eq!(
2987                thread.to_markdown(cx),
2988                formatdoc! {"
2989                    ## User
2990
2991                    What does [@b.md]({uri}) mean?
2992
2993                    ## Assistant
2994
2995                    Lorem.
2996
2997                "}
2998            )
2999        });
3000
3001        // Ensure the draft prompt with rich content blocks survived the round-trip.
3002        acp_thread.read_with(cx, |thread, _| {
3003            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
3004        });
3005
3006        // Ensure token usage survived the round-trip.
3007        acp_thread.read_with(cx, |thread, _| {
3008            let usage = thread
3009                .token_usage()
3010                .expect("token usage should be restored after reload");
3011            assert_eq!(usage.input_tokens, 150);
3012            assert_eq!(usage.output_tokens, 75);
3013        });
3014
3015        // Ensure scroll position survived the round-trip.
3016        acp_thread.read_with(cx, |thread, _| {
3017            let scroll = thread
3018                .ui_scroll_position()
3019                .expect("scroll position should be restored after reload");
3020            assert_eq!(scroll.item_ix, 5);
3021            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
3022        });
3023    }
3024
3025    #[gpui::test]
3026    async fn test_close_session_saves_thread(cx: &mut TestAppContext) {
3027        init_test(cx);
3028        let fs = FakeFs::new(cx.executor());
3029        fs.insert_tree(
3030            "/",
3031            json!({
3032                "a": {
3033                    "file.txt": "hello"
3034                }
3035            }),
3036        )
3037        .await;
3038        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3039        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3040        let agent = cx.update(|cx| {
3041            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3042        });
3043        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3044
3045        let acp_thread = cx
3046            .update(|cx| {
3047                connection
3048                    .clone()
3049                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3050            })
3051            .await
3052            .unwrap();
3053        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3054        let thread = agent.read_with(cx, |agent, _| {
3055            agent.sessions.get(&session_id).unwrap().thread.clone()
3056        });
3057
3058        let model = Arc::new(FakeLanguageModel::default());
3059        thread.update(cx, |thread, cx| {
3060            thread.set_model(model.clone(), cx);
3061        });
3062
3063        // Send a message so the thread is non-empty (empty threads aren't saved).
3064        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3065        let send = cx.foreground_executor().spawn(send);
3066        cx.run_until_parked();
3067
3068        model.send_last_completion_stream_text_chunk("world");
3069        model.end_last_completion_stream();
3070        send.await.unwrap();
3071        cx.run_until_parked();
3072
3073        // Set a draft prompt WITHOUT calling run_until_parked afterwards.
3074        // This means no observe-triggered save has run for this change.
3075        // The only way this data gets persisted is if close_session
3076        // itself performs the save.
3077        let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new(
3078            "unsaved draft",
3079        ))];
3080        acp_thread.update(cx, |thread, cx| {
3081            thread.set_draft_prompt(Some(draft_blocks.clone()), cx);
3082        });
3083
3084        // Close the session immediately — no run_until_parked in between.
3085        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3086            .await
3087            .unwrap();
3088        cx.run_until_parked();
3089
3090        // Reopen and verify the draft prompt was saved.
3091        let reloaded = agent
3092            .update(cx, |agent, cx| {
3093                agent.open_thread(session_id.clone(), project.clone(), cx)
3094            })
3095            .await
3096            .unwrap();
3097        reloaded.read_with(cx, |thread, _| {
3098            assert_eq!(
3099                thread.draft_prompt(),
3100                Some(draft_blocks.as_slice()),
3101                "close_session must save the thread; draft prompt was lost"
3102            );
3103        });
3104    }
3105
3106    #[gpui::test]
3107    async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
3108        init_test(cx);
3109        let fs = FakeFs::new(cx.executor());
3110        fs.insert_tree(
3111            "/",
3112            json!({
3113                "a": {
3114                    "file.txt": "hello"
3115                }
3116            }),
3117        )
3118        .await;
3119        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3120        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3121        let agent = cx.update(|cx| {
3122            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3123        });
3124        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3125
3126        let acp_thread = cx
3127            .update(|cx| {
3128                connection
3129                    .clone()
3130                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3131            })
3132            .await
3133            .unwrap();
3134        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3135        let thread = agent.read_with(cx, |agent, _| {
3136            agent.sessions.get(&session_id).unwrap().thread.clone()
3137        });
3138
3139        let model = Arc::new(FakeLanguageModel::default());
3140        let summary_model = Arc::new(FakeLanguageModel::default());
3141        thread.update(cx, |thread, cx| {
3142            thread.set_model(model.clone(), cx);
3143            thread.set_summarization_model(Some(summary_model.clone()), cx);
3144        });
3145
3146        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3147        let send = cx.foreground_executor().spawn(send);
3148        cx.run_until_parked();
3149
3150        model.send_last_completion_stream_text_chunk("world");
3151        model.end_last_completion_stream();
3152        send.await.unwrap();
3153        cx.run_until_parked();
3154
3155        let summary = agent.update(cx, |agent, cx| {
3156            agent.thread_summary(session_id.clone(), project.clone(), cx)
3157        });
3158        cx.run_until_parked();
3159
3160        summary_model.send_last_completion_stream_text_chunk("summary");
3161        summary_model.end_last_completion_stream();
3162
3163        assert_eq!(summary.await.unwrap(), "summary");
3164        cx.run_until_parked();
3165
3166        agent.read_with(cx, |agent, _| {
3167            let session = agent
3168                .sessions
3169                .get(&session_id)
3170                .expect("thread_summary should not close the active session");
3171            assert_eq!(
3172                session.ref_count, 1,
3173                "thread_summary should release its temporary session reference"
3174            );
3175        });
3176
3177        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3178            .await
3179            .unwrap();
3180        cx.run_until_parked();
3181
3182        agent.read_with(cx, |agent, _| {
3183            assert!(
3184                agent.sessions.is_empty(),
3185                "closing the active session after thread_summary should unload it"
3186            );
3187        });
3188    }
3189
3190    #[gpui::test]
3191    async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
3192        init_test(cx);
3193        let fs = FakeFs::new(cx.executor());
3194        fs.insert_tree(
3195            "/",
3196            json!({
3197                "a": {
3198                    "file.txt": "hello"
3199                }
3200            }),
3201        )
3202        .await;
3203        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
3204        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3205        let agent = cx.update(|cx| {
3206            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3207        });
3208        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3209
3210        let acp_thread = cx
3211            .update(|cx| {
3212                connection
3213                    .clone()
3214                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3215            })
3216            .await
3217            .unwrap();
3218        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3219        let thread = agent.read_with(cx, |agent, _| {
3220            agent.sessions.get(&session_id).unwrap().thread.clone()
3221        });
3222
3223        let model = cx.update(|cx| {
3224            LanguageModelRegistry::read_global(cx)
3225                .default_model()
3226                .map(|default_model| default_model.model)
3227                .expect("default test model should be available")
3228        });
3229        let fake_model = model.as_fake();
3230        thread.update(cx, |thread, cx| {
3231            thread.set_model(model.clone(), cx);
3232        });
3233
3234        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
3235        let send = cx.foreground_executor().spawn(send);
3236        cx.run_until_parked();
3237
3238        fake_model.send_last_completion_stream_text_chunk("world");
3239        fake_model.end_last_completion_stream();
3240        send.await.unwrap();
3241        cx.run_until_parked();
3242
3243        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3244            .await
3245            .unwrap();
3246        drop(thread);
3247        drop(acp_thread);
3248        agent.read_with(cx, |agent, _| {
3249            assert!(agent.sessions.is_empty());
3250        });
3251
3252        let first_loaded_thread = cx.update(|cx| {
3253            connection.clone().load_session(
3254                session_id.clone(),
3255                project.clone(),
3256                PathList::new(&[Path::new("")]),
3257                None,
3258                cx,
3259            )
3260        });
3261        let second_loaded_thread = cx.update(|cx| {
3262            connection.clone().load_session(
3263                session_id.clone(),
3264                project.clone(),
3265                PathList::new(&[Path::new("")]),
3266                None,
3267                cx,
3268            )
3269        });
3270
3271        let first_loaded_thread = first_loaded_thread.await.unwrap();
3272        let second_loaded_thread = second_loaded_thread.await.unwrap();
3273
3274        cx.run_until_parked();
3275
3276        assert_eq!(
3277            first_loaded_thread.entity_id(),
3278            second_loaded_thread.entity_id(),
3279            "concurrent loads for the same session should share one AcpThread"
3280        );
3281
3282        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3283            .await
3284            .unwrap();
3285
3286        agent.read_with(cx, |agent, _| {
3287            assert!(
3288                agent.sessions.contains_key(&session_id),
3289                "closing one loaded session should not drop shared session state"
3290            );
3291        });
3292
3293        let follow_up = second_loaded_thread.update(cx, |thread, cx| {
3294            thread.send(vec!["still there?".into()], cx)
3295        });
3296        let follow_up = cx.foreground_executor().spawn(follow_up);
3297        cx.run_until_parked();
3298
3299        fake_model.send_last_completion_stream_text_chunk("yes");
3300        fake_model.end_last_completion_stream();
3301        follow_up.await.unwrap();
3302        cx.run_until_parked();
3303
3304        second_loaded_thread.read_with(cx, |thread, cx| {
3305            assert_eq!(
3306                thread.to_markdown(cx),
3307                formatdoc! {"
3308                    ## User
3309
3310                    hello
3311
3312                    ## Assistant
3313
3314                    world
3315
3316                    ## User
3317
3318                    still there?
3319
3320                    ## Assistant
3321
3322                    yes
3323
3324                "}
3325            );
3326        });
3327
3328        cx.update(|cx| connection.clone().close_session(&session_id, cx))
3329            .await
3330            .unwrap();
3331
3332        cx.run_until_parked();
3333
3334        drop(first_loaded_thread);
3335        drop(second_loaded_thread);
3336        agent.read_with(cx, |agent, _| {
3337            assert!(agent.sessions.is_empty());
3338        });
3339    }
3340
3341    #[gpui::test]
3342    async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
3343        // Regression test: rapid title changes must not cause a propagation loop
3344        // between Thread and AcpThread via handle_thread_title_updated.
3345        init_test(cx);
3346        let fs = FakeFs::new(cx.executor());
3347        fs.insert_tree("/", json!({ "a": {} })).await;
3348        let project = Project::test(fs.clone(), [], cx).await;
3349        let thread_store = cx.new(|cx| ThreadStore::new(cx));
3350        let agent = cx.update(|cx| {
3351            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
3352        });
3353        let connection = Rc::new(NativeAgentConnection(agent.clone()));
3354
3355        let acp_thread = cx
3356            .update(|cx| {
3357                connection
3358                    .clone()
3359                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
3360            })
3361            .await
3362            .unwrap();
3363
3364        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3365        let thread = agent.read_with(cx, |agent, _| {
3366            agent.sessions.get(&session_id).unwrap().thread.clone()
3367        });
3368
3369        let title_updated_count = Rc::new(std::cell::RefCell::new(0usize));
3370        cx.update(|cx| {
3371            let count = title_updated_count.clone();
3372            cx.subscribe(
3373                &thread,
3374                move |_entity: Entity<Thread>, _event: &TitleUpdated, _cx: &mut App| {
3375                    let new_count = {
3376                        let mut count = count.borrow_mut();
3377                        *count += 1;
3378                        *count
3379                    };
3380                    assert!(
3381                        new_count <= 2,
3382                        "TitleUpdated fired {new_count} times; \
3383                         title updates are looping"
3384                    );
3385                },
3386            )
3387            .detach();
3388        });
3389
3390        thread.update(cx, |thread, cx| thread.set_title("first".into(), cx));
3391        thread.update(cx, |thread, cx| thread.set_title("second".into(), cx));
3392
3393        cx.run_until_parked();
3394
3395        thread.read_with(cx, |thread, _| {
3396            assert_eq!(thread.title(), Some("second".into()));
3397        });
3398        acp_thread.read_with(cx, |acp_thread, _| {
3399            assert_eq!(acp_thread.title(), Some("second".into()));
3400        });
3401
3402        assert_eq!(*title_updated_count.borrow(), 2);
3403    }
3404
3405    fn thread_entries(
3406        thread_store: &Entity<ThreadStore>,
3407        cx: &mut TestAppContext,
3408    ) -> Vec<(acp::SessionId, String)> {
3409        thread_store.read_with(cx, |store, _| {
3410            store
3411                .entries()
3412                .map(|entry| (entry.id.clone(), entry.title.to_string()))
3413                .collect::<Vec<_>>()
3414        })
3415    }
3416
3417    fn init_test(cx: &mut TestAppContext) {
3418        env_logger::try_init().ok();
3419        cx.update(|cx| {
3420            let settings_store = SettingsStore::test(cx);
3421            cx.set_global(settings_store);
3422
3423            LanguageModelRegistry::test(cx);
3424        });
3425    }
3426}
3427
3428fn mcp_message_content_to_acp_content_block(
3429    content: context_server::types::MessageContent,
3430) -> acp::ContentBlock {
3431    match content {
3432        context_server::types::MessageContent::Text {
3433            text,
3434            annotations: _,
3435        } => text.into(),
3436        context_server::types::MessageContent::Image {
3437            data,
3438            mime_type,
3439            annotations: _,
3440        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
3441        context_server::types::MessageContent::Audio {
3442            data,
3443            mime_type,
3444            annotations: _,
3445        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
3446        context_server::types::MessageContent::Resource {
3447            resource,
3448            annotations: _,
3449        } => {
3450            let mut link =
3451                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
3452            if let Some(mime_type) = resource.mime_type {
3453                link = link.mime_type(mime_type);
3454            }
3455            acp::ContentBlock::ResourceLink(link)
3456        }
3457    }
3458}