agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::{Path, PathBuf};
  53use std::rc::Rc;
  54use std::sync::Arc;
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<()>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let session_id = thread.id().clone();
 332        let parent_session_id = thread.parent_thread_id();
 333        let title = thread.title();
 334        let draft_prompt = thread.draft_prompt().map(Vec::from);
 335        let provisional_title = thread.provisional_title().cloned();
 336        let scroll_position = thread.ui_scroll_position();
 337        let token_usage = thread.latest_token_usage();
 338        let project = thread.project.clone();
 339        let action_log = thread.action_log.clone();
 340        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 341        let acp_thread = cx.new(|cx| {
 342            let mut acp_thread = acp_thread::AcpThread::new(
 343                parent_session_id,
 344                title,
 345                None,
 346                connection,
 347                project.clone(),
 348                action_log.clone(),
 349                session_id.clone(),
 350                prompt_capabilities_rx,
 351                cx,
 352            );
 353            acp_thread.set_draft_prompt(draft_prompt);
 354            acp_thread.set_ui_scroll_position(scroll_position);
 355            acp_thread.update_token_usage(token_usage, cx);
 356            acp_thread
 357        });
 358
 359        if let Some(provisional_title) = provisional_title {
 360            acp_thread.update(cx, |acp_thread, cx| {
 361                acp_thread.set_provisional_title(provisional_title, cx);
 362            });
 363        }
 364
 365        let registry = LanguageModelRegistry::read_global(cx);
 366        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 367
 368        let weak = cx.weak_entity();
 369        let weak_thread = thread_handle.downgrade();
 370        thread_handle.update(cx, |thread, cx| {
 371            thread.set_summarization_model(summarization_model, cx);
 372            thread.add_default_tools(
 373                Rc::new(NativeThreadEnvironment {
 374                    acp_thread: acp_thread.downgrade(),
 375                    thread: weak_thread,
 376                    agent: weak,
 377                }) as _,
 378                cx,
 379            )
 380        });
 381
 382        let subscriptions = vec![
 383            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 384            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 385            cx.observe(&thread_handle, move |this, thread, cx| {
 386                this.save_thread(thread, cx)
 387            }),
 388        ];
 389
 390        self.sessions.insert(
 391            session_id,
 392            Session {
 393                thread: thread_handle,
 394                acp_thread: acp_thread.clone(),
 395                project_id,
 396                _subscriptions: subscriptions,
 397                pending_save: Task::ready(()),
 398            },
 399        );
 400
 401        self.update_available_commands_for_project(project_id, cx);
 402
 403        acp_thread
 404    }
 405
 406    pub fn models(&self) -> &LanguageModels {
 407        &self.models
 408    }
 409
 410    fn get_or_create_project_state(
 411        &mut self,
 412        project: &Entity<Project>,
 413        cx: &mut Context<Self>,
 414    ) -> EntityId {
 415        let project_id = project.entity_id();
 416        if self.projects.contains_key(&project_id) {
 417            return project_id;
 418        }
 419
 420        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 421        self.register_project_with_initial_context(project.clone(), project_context, cx);
 422        if let Some(state) = self.projects.get_mut(&project_id) {
 423            state.project_context_needs_refresh.send(()).ok();
 424        }
 425        project_id
 426    }
 427
 428    fn register_project_with_initial_context(
 429        &mut self,
 430        project: Entity<Project>,
 431        project_context: Entity<ProjectContext>,
 432        cx: &mut Context<Self>,
 433    ) {
 434        let project_id = project.entity_id();
 435
 436        let context_server_store = project.read(cx).context_server_store();
 437        let context_server_registry =
 438            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 439
 440        let subscriptions = vec![
 441            cx.subscribe(&project, Self::handle_project_event),
 442            cx.subscribe(
 443                &context_server_store,
 444                Self::handle_context_server_store_updated,
 445            ),
 446            cx.subscribe(
 447                &context_server_registry,
 448                Self::handle_context_server_registry_event,
 449            ),
 450        ];
 451
 452        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 453            watch::channel(());
 454
 455        self.projects.insert(
 456            project_id,
 457            ProjectState {
 458                project,
 459                project_context,
 460                project_context_needs_refresh: project_context_needs_refresh_tx,
 461                _maintain_project_context: cx.spawn(async move |this, cx| {
 462                    Self::maintain_project_context(
 463                        this,
 464                        project_id,
 465                        project_context_needs_refresh_rx,
 466                        cx,
 467                    )
 468                    .await
 469                }),
 470                context_server_registry,
 471                _subscriptions: subscriptions,
 472            },
 473        );
 474    }
 475
 476    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 477        self.sessions
 478            .get(session_id)
 479            .and_then(|session| self.projects.get(&session.project_id))
 480    }
 481
 482    async fn maintain_project_context(
 483        this: WeakEntity<Self>,
 484        project_id: EntityId,
 485        mut needs_refresh: watch::Receiver<()>,
 486        cx: &mut AsyncApp,
 487    ) -> Result<()> {
 488        while needs_refresh.changed().await.is_ok() {
 489            let project_context = this
 490                .update(cx, |this, cx| {
 491                    let state = this
 492                        .projects
 493                        .get(&project_id)
 494                        .context("project state not found")?;
 495                    anyhow::Ok(Self::build_project_context(
 496                        &state.project,
 497                        this.prompt_store.as_ref(),
 498                        cx,
 499                    ))
 500                })??
 501                .await;
 502            this.update(cx, |this, cx| {
 503                if let Some(state) = this.projects.get_mut(&project_id) {
 504                    state.project_context = cx.new(|_| project_context);
 505                }
 506            })?;
 507        }
 508
 509        Ok(())
 510    }
 511
 512    fn build_project_context(
 513        project: &Entity<Project>,
 514        prompt_store: Option<&Entity<PromptStore>>,
 515        cx: &mut App,
 516    ) -> Task<ProjectContext> {
 517        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 518        let worktree_tasks = worktrees
 519            .into_iter()
 520            .map(|worktree| {
 521                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 522            })
 523            .collect::<Vec<_>>();
 524        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 525            prompt_store.read_with(cx, |prompt_store, cx| {
 526                let prompts = prompt_store.default_prompt_metadata();
 527                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 528                    let contents = prompt_store.load(prompt_metadata.id, cx);
 529                    async move { (contents.await, prompt_metadata) }
 530                });
 531                cx.background_spawn(future::join_all(load_tasks))
 532            })
 533        } else {
 534            Task::ready(vec![])
 535        };
 536
 537        cx.spawn(async move |_cx| {
 538            let (worktrees, default_user_rules) =
 539                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 540
 541            let worktrees = worktrees
 542                .into_iter()
 543                .map(|(worktree, _rules_error)| {
 544                    // TODO: show error message
 545                    // if let Some(rules_error) = rules_error {
 546                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 547                    // }
 548                    worktree
 549                })
 550                .collect::<Vec<_>>();
 551
 552            let default_user_rules = default_user_rules
 553                .into_iter()
 554                .flat_map(|(contents, prompt_metadata)| match contents {
 555                    Ok(contents) => Some(UserRulesContext {
 556                        uuid: prompt_metadata.id.as_user()?,
 557                        title: prompt_metadata.title.map(|title| title.to_string()),
 558                        contents,
 559                    }),
 560                    Err(_err) => {
 561                        // TODO: show error message
 562                        // this.update(cx, |_, cx| {
 563                        //     cx.emit(RulesLoadingError {
 564                        //         message: format!("{err:?}").into(),
 565                        //     });
 566                        // })
 567                        // .ok();
 568                        None
 569                    }
 570                })
 571                .collect::<Vec<_>>();
 572
 573            ProjectContext::new(worktrees, default_user_rules)
 574        })
 575    }
 576
 577    fn load_worktree_info_for_system_prompt(
 578        worktree: Entity<Worktree>,
 579        project: Entity<Project>,
 580        cx: &mut App,
 581    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 582        let tree = worktree.read(cx);
 583        let root_name = tree.root_name_str().into();
 584        let abs_path = tree.abs_path();
 585
 586        let mut context = WorktreeContext {
 587            root_name,
 588            abs_path,
 589            rules_file: None,
 590        };
 591
 592        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 593        let Some(rules_task) = rules_task else {
 594            return Task::ready((context, None));
 595        };
 596
 597        cx.spawn(async move |_| {
 598            let (rules_file, rules_file_error) = match rules_task.await {
 599                Ok(rules_file) => (Some(rules_file), None),
 600                Err(err) => (
 601                    None,
 602                    Some(RulesLoadingError {
 603                        message: format!("{err}").into(),
 604                    }),
 605                ),
 606            };
 607            context.rules_file = rules_file;
 608            (context, rules_file_error)
 609        })
 610    }
 611
 612    fn load_worktree_rules_file(
 613        worktree: Entity<Worktree>,
 614        project: Entity<Project>,
 615        cx: &mut App,
 616    ) -> Option<Task<Result<RulesFileContext>>> {
 617        let worktree = worktree.read(cx);
 618        let worktree_id = worktree.id();
 619        let selected_rules_file = RULES_FILE_NAMES
 620            .into_iter()
 621            .filter_map(|name| {
 622                worktree
 623                    .entry_for_path(RelPath::unix(name).unwrap())
 624                    .filter(|entry| entry.is_file())
 625                    .map(|entry| entry.path.clone())
 626            })
 627            .next();
 628
 629        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 630        // supported. This doesn't seem to occur often in GitHub repositories.
 631        selected_rules_file.map(|path_in_worktree| {
 632            let project_path = ProjectPath {
 633                worktree_id,
 634                path: path_in_worktree.clone(),
 635            };
 636            let buffer_task =
 637                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 638            let rope_task = cx.spawn(async move |cx| {
 639                let buffer = buffer_task.await?;
 640                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 641                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 642                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 643                })?;
 644                anyhow::Ok((project_entry_id, rope))
 645            });
 646            // Build a string from the rope on a background thread.
 647            cx.background_spawn(async move {
 648                let (project_entry_id, rope) = rope_task.await?;
 649                anyhow::Ok(RulesFileContext {
 650                    path_in_worktree,
 651                    text: rope.to_string().trim().to_string(),
 652                    project_entry_id: project_entry_id.to_usize(),
 653                })
 654            })
 655        })
 656    }
 657
 658    fn handle_thread_title_updated(
 659        &mut self,
 660        thread: Entity<Thread>,
 661        _: &TitleUpdated,
 662        cx: &mut Context<Self>,
 663    ) {
 664        let session_id = thread.read(cx).id();
 665        let Some(session) = self.sessions.get(session_id) else {
 666            return;
 667        };
 668        let thread = thread.downgrade();
 669        let acp_thread = session.acp_thread.downgrade();
 670        cx.spawn(async move |_, cx| {
 671            let title = thread.read_with(cx, |thread, _| thread.title())?;
 672            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 673            task.await
 674        })
 675        .detach_and_log_err(cx);
 676    }
 677
 678    fn handle_thread_token_usage_updated(
 679        &mut self,
 680        thread: Entity<Thread>,
 681        usage: &TokenUsageUpdated,
 682        cx: &mut Context<Self>,
 683    ) {
 684        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 685            return;
 686        };
 687        session.acp_thread.update(cx, |acp_thread, cx| {
 688            acp_thread.update_token_usage(usage.0.clone(), cx);
 689        });
 690    }
 691
 692    fn handle_project_event(
 693        &mut self,
 694        project: Entity<Project>,
 695        event: &project::Event,
 696        _cx: &mut Context<Self>,
 697    ) {
 698        let project_id = project.entity_id();
 699        let Some(state) = self.projects.get_mut(&project_id) else {
 700            return;
 701        };
 702        match event {
 703            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 704                state.project_context_needs_refresh.send(()).ok();
 705            }
 706            project::Event::WorktreeUpdatedEntries(_, items) => {
 707                if items.iter().any(|(path, _, _)| {
 708                    RULES_FILE_NAMES
 709                        .iter()
 710                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 711                }) {
 712                    state.project_context_needs_refresh.send(()).ok();
 713                }
 714            }
 715            _ => {}
 716        }
 717    }
 718
 719    fn handle_prompts_updated_event(
 720        &mut self,
 721        _prompt_store: Entity<PromptStore>,
 722        _event: &prompt_store::PromptsUpdatedEvent,
 723        _cx: &mut Context<Self>,
 724    ) {
 725        for state in self.projects.values_mut() {
 726            state.project_context_needs_refresh.send(()).ok();
 727        }
 728    }
 729
 730    fn handle_models_updated_event(
 731        &mut self,
 732        _registry: Entity<LanguageModelRegistry>,
 733        _event: &language_model::Event,
 734        cx: &mut Context<Self>,
 735    ) {
 736        self.models.refresh_list(cx);
 737
 738        let registry = LanguageModelRegistry::read_global(cx);
 739        let default_model = registry.default_model().map(|m| m.model);
 740        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 741
 742        for session in self.sessions.values_mut() {
 743            session.thread.update(cx, |thread, cx| {
 744                if thread.model().is_none()
 745                    && let Some(model) = default_model.clone()
 746                {
 747                    thread.set_model(model, cx);
 748                    cx.notify();
 749                }
 750                thread.set_summarization_model(summarization_model.clone(), cx);
 751            });
 752        }
 753    }
 754
 755    fn handle_context_server_store_updated(
 756        &mut self,
 757        store: Entity<project::context_server_store::ContextServerStore>,
 758        _event: &project::context_server_store::ServerStatusChangedEvent,
 759        cx: &mut Context<Self>,
 760    ) {
 761        let project_id = self.projects.iter().find_map(|(id, state)| {
 762            if *state.context_server_registry.read(cx).server_store() == store {
 763                Some(*id)
 764            } else {
 765                None
 766            }
 767        });
 768        if let Some(project_id) = project_id {
 769            self.update_available_commands_for_project(project_id, cx);
 770        }
 771    }
 772
 773    fn handle_context_server_registry_event(
 774        &mut self,
 775        registry: Entity<ContextServerRegistry>,
 776        event: &ContextServerRegistryEvent,
 777        cx: &mut Context<Self>,
 778    ) {
 779        match event {
 780            ContextServerRegistryEvent::ToolsChanged => {}
 781            ContextServerRegistryEvent::PromptsChanged => {
 782                let project_id = self.projects.iter().find_map(|(id, state)| {
 783                    if state.context_server_registry == registry {
 784                        Some(*id)
 785                    } else {
 786                        None
 787                    }
 788                });
 789                if let Some(project_id) = project_id {
 790                    self.update_available_commands_for_project(project_id, cx);
 791                }
 792            }
 793        }
 794    }
 795
 796    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 797        let available_commands =
 798            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 799        for session in self.sessions.values() {
 800            if session.project_id != project_id {
 801                continue;
 802            }
 803            session.acp_thread.update(cx, |thread, cx| {
 804                thread
 805                    .handle_session_update(
 806                        acp::SessionUpdate::AvailableCommandsUpdate(
 807                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 808                        ),
 809                        cx,
 810                    )
 811                    .log_err();
 812            });
 813        }
 814    }
 815
 816    fn build_available_commands_for_project(
 817        project_state: Option<&ProjectState>,
 818        cx: &App,
 819    ) -> Vec<acp::AvailableCommand> {
 820        let Some(state) = project_state else {
 821            return vec![];
 822        };
 823        let registry = state.context_server_registry.read(cx);
 824
 825        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 826        for context_server_prompt in registry.prompts() {
 827            *prompt_name_counts
 828                .entry(context_server_prompt.prompt.name.as_str())
 829                .or_insert(0) += 1;
 830        }
 831
 832        registry
 833            .prompts()
 834            .flat_map(|context_server_prompt| {
 835                let prompt = &context_server_prompt.prompt;
 836
 837                let should_prefix = prompt_name_counts
 838                    .get(prompt.name.as_str())
 839                    .copied()
 840                    .unwrap_or(0)
 841                    > 1;
 842
 843                let name = if should_prefix {
 844                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 845                } else {
 846                    prompt.name.clone()
 847                };
 848
 849                let mut command = acp::AvailableCommand::new(
 850                    name,
 851                    prompt.description.clone().unwrap_or_default(),
 852                );
 853
 854                match prompt.arguments.as_deref() {
 855                    Some([arg]) => {
 856                        let hint = format!("<{}>", arg.name);
 857
 858                        command = command.input(acp::AvailableCommandInput::Unstructured(
 859                            acp::UnstructuredCommandInput::new(hint),
 860                        ));
 861                    }
 862                    Some([]) | None => {}
 863                    Some(_) => {
 864                        // skip >1 argument commands since we don't support them yet
 865                        return None;
 866                    }
 867                }
 868
 869                Some(command)
 870            })
 871            .collect()
 872    }
 873
 874    pub fn load_thread(
 875        &mut self,
 876        id: acp::SessionId,
 877        project: Entity<Project>,
 878        cx: &mut Context<Self>,
 879    ) -> Task<Result<Entity<Thread>>> {
 880        let project_id = self.get_or_create_project_state(&project, cx);
 881        let database_future = ThreadsDatabase::connect(cx);
 882        cx.spawn(async move |this, cx| {
 883            let database = database_future.await.map_err(|err| anyhow!(err))?;
 884            let db_thread = database
 885                .load_thread(id.clone())
 886                .await?
 887                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 888
 889            this.update(cx, |this, cx| {
 890                let project_state = this
 891                    .projects
 892                    .get(&project_id)
 893                    .context("project state not found")?;
 894                let summarization_model = LanguageModelRegistry::read_global(cx)
 895                    .thread_summary_model()
 896                    .map(|c| c.model);
 897
 898                Ok(cx.new(|cx| {
 899                    let mut thread = Thread::from_db(
 900                        id.clone(),
 901                        db_thread,
 902                        project_state.project.clone(),
 903                        project_state.project_context.clone(),
 904                        project_state.context_server_registry.clone(),
 905                        this.templates.clone(),
 906                        cx,
 907                    );
 908                    thread.set_summarization_model(summarization_model, cx);
 909                    thread
 910                }))
 911            })?
 912        })
 913    }
 914
 915    pub fn open_thread(
 916        &mut self,
 917        id: acp::SessionId,
 918        project: Entity<Project>,
 919        cx: &mut Context<Self>,
 920    ) -> Task<Result<Entity<AcpThread>>> {
 921        if let Some(session) = self.sessions.get(&id) {
 922            return Task::ready(Ok(session.acp_thread.clone()));
 923        }
 924
 925        let project_id = self.get_or_create_project_state(&project, cx);
 926        let task = self.load_thread(id, project, cx);
 927        cx.spawn(async move |this, cx| {
 928            let thread = task.await?;
 929            let acp_thread = this.update(cx, |this, cx| {
 930                this.register_session(thread.clone(), project_id, cx)
 931            })?;
 932            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 933            cx.update(|cx| {
 934                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 935            })
 936            .await?;
 937            Ok(acp_thread)
 938        })
 939    }
 940
 941    pub fn thread_summary(
 942        &mut self,
 943        id: acp::SessionId,
 944        project: Entity<Project>,
 945        cx: &mut Context<Self>,
 946    ) -> Task<Result<SharedString>> {
 947        let thread = self.open_thread(id.clone(), project, cx);
 948        cx.spawn(async move |this, cx| {
 949            let acp_thread = thread.await?;
 950            let result = this
 951                .update(cx, |this, cx| {
 952                    this.sessions
 953                        .get(&id)
 954                        .unwrap()
 955                        .thread
 956                        .update(cx, |thread, cx| thread.summary(cx))
 957                })?
 958                .await
 959                .context("Failed to generate summary")?;
 960            drop(acp_thread);
 961            Ok(result)
 962        })
 963    }
 964
 965    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 966        if thread.read(cx).is_empty() {
 967            return;
 968        }
 969
 970        let id = thread.read(cx).id().clone();
 971        let Some(session) = self.sessions.get_mut(&id) else {
 972            return;
 973        };
 974
 975        let project_id = session.project_id;
 976        let Some(state) = self.projects.get(&project_id) else {
 977            return;
 978        };
 979
 980        let folder_paths = PathList::new(
 981            &state
 982                .project
 983                .read(cx)
 984                .visible_worktrees(cx)
 985                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 986                .collect::<Vec<_>>(),
 987        );
 988
 989        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 990        let provisional_title = session.acp_thread.read(cx).provisional_title().cloned();
 991        let database_future = ThreadsDatabase::connect(cx);
 992        let db_thread = thread.update(cx, |thread, cx| {
 993            thread.set_draft_prompt(draft_prompt);
 994            thread.set_provisional_title(provisional_title);
 995            thread.to_db(cx)
 996        });
 997        let thread_store = self.thread_store.clone();
 998        session.pending_save = cx.spawn(async move |_, cx| {
 999            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
1000                return;
1001            };
1002            let db_thread = db_thread.await;
1003            database
1004                .save_thread(id, db_thread, folder_paths)
1005                .await
1006                .log_err();
1007            thread_store.update(cx, |store, cx| store.reload(cx));
1008        });
1009    }
1010
1011    fn send_mcp_prompt(
1012        &self,
1013        message_id: UserMessageId,
1014        session_id: acp::SessionId,
1015        prompt_name: String,
1016        server_id: ContextServerId,
1017        arguments: HashMap<String, String>,
1018        original_content: Vec<acp::ContentBlock>,
1019        cx: &mut Context<Self>,
1020    ) -> Task<Result<acp::PromptResponse>> {
1021        let Some(state) = self.session_project_state(&session_id) else {
1022            return Task::ready(Err(anyhow!("Project state not found for session")));
1023        };
1024        let server_store = state
1025            .context_server_registry
1026            .read(cx)
1027            .server_store()
1028            .clone();
1029        let path_style = state.project.read(cx).path_style(cx);
1030
1031        cx.spawn(async move |this, cx| {
1032            let prompt =
1033                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1034
1035            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1036                let session = this
1037                    .sessions
1038                    .get(&session_id)
1039                    .context("Failed to get session")?;
1040                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1041            })??;
1042
1043            let mut last_is_user = true;
1044
1045            thread.update(cx, |thread, cx| {
1046                thread.push_acp_user_block(
1047                    message_id,
1048                    original_content.into_iter().skip(1),
1049                    path_style,
1050                    cx,
1051                );
1052            });
1053
1054            for message in prompt.messages {
1055                let context_server::types::PromptMessage { role, content } = message;
1056                let block = mcp_message_content_to_acp_content_block(content);
1057
1058                match role {
1059                    context_server::types::Role::User => {
1060                        let id = acp_thread::UserMessageId::new();
1061
1062                        acp_thread.update(cx, |acp_thread, cx| {
1063                            acp_thread.push_user_content_block_with_indent(
1064                                Some(id.clone()),
1065                                block.clone(),
1066                                true,
1067                                cx,
1068                            );
1069                        });
1070
1071                        thread.update(cx, |thread, cx| {
1072                            thread.push_acp_user_block(id, [block], path_style, cx);
1073                        });
1074                    }
1075                    context_server::types::Role::Assistant => {
1076                        acp_thread.update(cx, |acp_thread, cx| {
1077                            acp_thread.push_assistant_content_block_with_indent(
1078                                block.clone(),
1079                                false,
1080                                true,
1081                                cx,
1082                            );
1083                        });
1084
1085                        thread.update(cx, |thread, cx| {
1086                            thread.push_acp_agent_block(block, cx);
1087                        });
1088                    }
1089                }
1090
1091                last_is_user = role == context_server::types::Role::User;
1092            }
1093
1094            let response_stream = thread.update(cx, |thread, cx| {
1095                if last_is_user {
1096                    thread.send_existing(cx)
1097                } else {
1098                    // Resume if MCP prompt did not end with a user message
1099                    thread.resume(cx)
1100                }
1101            })?;
1102
1103            cx.update(|cx| {
1104                NativeAgentConnection::handle_thread_events(
1105                    response_stream,
1106                    acp_thread.downgrade(),
1107                    cx,
1108                )
1109            })
1110            .await
1111        })
1112    }
1113}
1114
1115/// Wrapper struct that implements the AgentConnection trait
1116#[derive(Clone)]
1117pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1118
1119impl NativeAgentConnection {
1120    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1121        self.0
1122            .read(cx)
1123            .sessions
1124            .get(session_id)
1125            .map(|session| session.thread.clone())
1126    }
1127
1128    pub fn load_thread(
1129        &self,
1130        id: acp::SessionId,
1131        project: Entity<Project>,
1132        cx: &mut App,
1133    ) -> Task<Result<Entity<Thread>>> {
1134        self.0
1135            .update(cx, |this, cx| this.load_thread(id, project, cx))
1136    }
1137
1138    fn run_turn(
1139        &self,
1140        session_id: acp::SessionId,
1141        cx: &mut App,
1142        f: impl 'static
1143        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1144    ) -> Task<Result<acp::PromptResponse>> {
1145        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1146            agent
1147                .sessions
1148                .get_mut(&session_id)
1149                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1150        }) else {
1151            return Task::ready(Err(anyhow!("Session not found")));
1152        };
1153        log::debug!("Found session for: {}", session_id);
1154
1155        let response_stream = match f(thread, cx) {
1156            Ok(stream) => stream,
1157            Err(err) => return Task::ready(Err(err)),
1158        };
1159        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1160    }
1161
1162    fn handle_thread_events(
1163        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1164        acp_thread: WeakEntity<AcpThread>,
1165        cx: &App,
1166    ) -> Task<Result<acp::PromptResponse>> {
1167        cx.spawn(async move |cx| {
1168            // Handle response stream and forward to session.acp_thread
1169            while let Some(result) = events.next().await {
1170                match result {
1171                    Ok(event) => {
1172                        log::trace!("Received completion event: {:?}", event);
1173
1174                        match event {
1175                            ThreadEvent::UserMessage(message) => {
1176                                acp_thread.update(cx, |thread, cx| {
1177                                    for content in message.content {
1178                                        thread.push_user_content_block(
1179                                            Some(message.id.clone()),
1180                                            content.into(),
1181                                            cx,
1182                                        );
1183                                    }
1184                                })?;
1185                            }
1186                            ThreadEvent::AgentText(text) => {
1187                                acp_thread.update(cx, |thread, cx| {
1188                                    thread.push_assistant_content_block(text.into(), false, cx)
1189                                })?;
1190                            }
1191                            ThreadEvent::AgentThinking(text) => {
1192                                acp_thread.update(cx, |thread, cx| {
1193                                    thread.push_assistant_content_block(text.into(), true, cx)
1194                                })?;
1195                            }
1196                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1197                                tool_call,
1198                                options,
1199                                response,
1200                                context: _,
1201                            }) => {
1202                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1203                                    thread.request_tool_call_authorization(tool_call, options, cx)
1204                                })??;
1205                                cx.background_spawn(async move {
1206                                    if let acp::RequestPermissionOutcome::Selected(
1207                                        acp::SelectedPermissionOutcome { option_id, .. },
1208                                    ) = outcome_task.await
1209                                    {
1210                                        response
1211                                            .send(option_id)
1212                                            .map(|_| anyhow!("authorization receiver was dropped"))
1213                                            .log_err();
1214                                    }
1215                                })
1216                                .detach();
1217                            }
1218                            ThreadEvent::ToolCall(tool_call) => {
1219                                acp_thread.update(cx, |thread, cx| {
1220                                    thread.upsert_tool_call(tool_call, cx)
1221                                })??;
1222                            }
1223                            ThreadEvent::ToolCallUpdate(update) => {
1224                                acp_thread.update(cx, |thread, cx| {
1225                                    thread.update_tool_call(update, cx)
1226                                })??;
1227                            }
1228                            ThreadEvent::SubagentSpawned(session_id) => {
1229                                acp_thread.update(cx, |thread, cx| {
1230                                    thread.subagent_spawned(session_id, cx);
1231                                })?;
1232                            }
1233                            ThreadEvent::Retry(status) => {
1234                                acp_thread.update(cx, |thread, cx| {
1235                                    thread.update_retry_status(status, cx)
1236                                })?;
1237                            }
1238                            ThreadEvent::Stop(stop_reason) => {
1239                                log::debug!("Assistant message complete: {:?}", stop_reason);
1240                                return Ok(acp::PromptResponse::new(stop_reason));
1241                            }
1242                        }
1243                    }
1244                    Err(e) => {
1245                        log::error!("Error in model response stream: {:?}", e);
1246                        return Err(e);
1247                    }
1248                }
1249            }
1250
1251            log::debug!("Response stream completed");
1252            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1253        })
1254    }
1255}
1256
1257struct Command<'a> {
1258    prompt_name: &'a str,
1259    arg_value: &'a str,
1260    explicit_server_id: Option<&'a str>,
1261}
1262
1263impl<'a> Command<'a> {
1264    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1265        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1266            return None;
1267        };
1268        let text = text_content.text.trim();
1269        let command = text.strip_prefix('/')?;
1270        let (command, arg_value) = command
1271            .split_once(char::is_whitespace)
1272            .unwrap_or((command, ""));
1273
1274        if let Some((server_id, prompt_name)) = command.split_once('.') {
1275            Some(Self {
1276                prompt_name,
1277                arg_value,
1278                explicit_server_id: Some(server_id),
1279            })
1280        } else {
1281            Some(Self {
1282                prompt_name: command,
1283                arg_value,
1284                explicit_server_id: None,
1285            })
1286        }
1287    }
1288}
1289
1290struct NativeAgentModelSelector {
1291    session_id: acp::SessionId,
1292    connection: NativeAgentConnection,
1293}
1294
1295impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1296    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1297        log::debug!("NativeAgentConnection::list_models called");
1298        let list = self.connection.0.read(cx).models.model_list.clone();
1299        Task::ready(if list.is_empty() {
1300            Err(anyhow::anyhow!("No models available"))
1301        } else {
1302            Ok(list)
1303        })
1304    }
1305
1306    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1307        log::debug!(
1308            "Setting model for session {}: {}",
1309            self.session_id,
1310            model_id
1311        );
1312        let Some(thread) = self
1313            .connection
1314            .0
1315            .read(cx)
1316            .sessions
1317            .get(&self.session_id)
1318            .map(|session| session.thread.clone())
1319        else {
1320            return Task::ready(Err(anyhow!("Session not found")));
1321        };
1322
1323        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1324            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1325        };
1326
1327        // We want to reset the effort level when switching models, as the currently-selected effort level may
1328        // not be compatible.
1329        let effort = model
1330            .default_effort_level()
1331            .map(|effort_level| effort_level.value.to_string());
1332
1333        thread.update(cx, |thread, cx| {
1334            thread.set_model(model.clone(), cx);
1335            thread.set_thinking_effort(effort.clone(), cx);
1336            thread.set_thinking_enabled(model.supports_thinking(), cx);
1337        });
1338
1339        update_settings_file(
1340            self.connection.0.read(cx).fs.clone(),
1341            cx,
1342            move |settings, cx| {
1343                let provider = model.provider_id().0.to_string();
1344                let model = model.id().0.to_string();
1345                let enable_thinking = thread.read(cx).thinking_enabled();
1346                settings
1347                    .agent
1348                    .get_or_insert_default()
1349                    .set_model(LanguageModelSelection {
1350                        provider: provider.into(),
1351                        model,
1352                        enable_thinking,
1353                        effort,
1354                    });
1355            },
1356        );
1357
1358        Task::ready(Ok(()))
1359    }
1360
1361    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1362        let Some(thread) = self
1363            .connection
1364            .0
1365            .read(cx)
1366            .sessions
1367            .get(&self.session_id)
1368            .map(|session| session.thread.clone())
1369        else {
1370            return Task::ready(Err(anyhow!("Session not found")));
1371        };
1372        let Some(model) = thread.read(cx).model() else {
1373            return Task::ready(Err(anyhow!("Model not found")));
1374        };
1375        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1376        else {
1377            return Task::ready(Err(anyhow!("Provider not found")));
1378        };
1379        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1380            model, &provider,
1381        )))
1382    }
1383
1384    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1385        Some(self.connection.0.read(cx).models.watch())
1386    }
1387
1388    fn should_render_footer(&self) -> bool {
1389        true
1390    }
1391}
1392
1393impl acp_thread::AgentConnection for NativeAgentConnection {
1394    fn telemetry_id(&self) -> SharedString {
1395        "zed".into()
1396    }
1397
1398    fn new_session(
1399        self: Rc<Self>,
1400        project: Entity<Project>,
1401        cwd: &Path,
1402        cx: &mut App,
1403    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1404        log::debug!("Creating new thread for project at: {cwd:?}");
1405        Task::ready(Ok(self
1406            .0
1407            .update(cx, |agent, cx| agent.new_session(project, cx))))
1408    }
1409
1410    fn supports_load_session(&self) -> bool {
1411        true
1412    }
1413
1414    fn load_session(
1415        self: Rc<Self>,
1416        session_id: acp::SessionId,
1417        project: Entity<Project>,
1418        _cwd: &Path,
1419        _title: Option<SharedString>,
1420        cx: &mut App,
1421    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1422        self.0
1423            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1424    }
1425
1426    fn supports_close_session(&self) -> bool {
1427        true
1428    }
1429
1430    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1431        self.0.update(cx, |agent, _cx| {
1432            let project_id = agent.sessions.get(session_id).map(|s| s.project_id);
1433            agent.sessions.remove(session_id);
1434
1435            if let Some(project_id) = project_id {
1436                let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1437                if !has_remaining {
1438                    agent.projects.remove(&project_id);
1439                }
1440            }
1441        });
1442        Task::ready(Ok(()))
1443    }
1444
1445    fn auth_methods(&self) -> &[acp::AuthMethod] {
1446        &[] // No auth for in-process
1447    }
1448
1449    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1450        Task::ready(Ok(()))
1451    }
1452
1453    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1454        Some(Rc::new(NativeAgentModelSelector {
1455            session_id: session_id.clone(),
1456            connection: self.clone(),
1457        }) as Rc<dyn AgentModelSelector>)
1458    }
1459
1460    fn prompt(
1461        &self,
1462        id: Option<acp_thread::UserMessageId>,
1463        params: acp::PromptRequest,
1464        cx: &mut App,
1465    ) -> Task<Result<acp::PromptResponse>> {
1466        let id = id.expect("UserMessageId is required");
1467        let session_id = params.session_id.clone();
1468        log::info!("Received prompt request for session: {}", session_id);
1469        log::debug!("Prompt blocks count: {}", params.prompt.len());
1470
1471        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1472            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1473        };
1474
1475        if let Some(parsed_command) = Command::parse(&params.prompt) {
1476            let registry = project_state.context_server_registry.read(cx);
1477
1478            let explicit_server_id = parsed_command
1479                .explicit_server_id
1480                .map(|server_id| ContextServerId(server_id.into()));
1481
1482            if let Some(prompt) =
1483                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1484            {
1485                let arguments = if !parsed_command.arg_value.is_empty()
1486                    && let Some(arg_name) = prompt
1487                        .prompt
1488                        .arguments
1489                        .as_ref()
1490                        .and_then(|args| args.first())
1491                        .map(|arg| arg.name.clone())
1492                {
1493                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1494                } else {
1495                    Default::default()
1496                };
1497
1498                let prompt_name = prompt.prompt.name.clone();
1499                let server_id = prompt.server_id.clone();
1500
1501                return self.0.update(cx, |agent, cx| {
1502                    agent.send_mcp_prompt(
1503                        id,
1504                        session_id.clone(),
1505                        prompt_name,
1506                        server_id,
1507                        arguments,
1508                        params.prompt,
1509                        cx,
1510                    )
1511                });
1512            }
1513        };
1514
1515        let path_style = project_state.project.read(cx).path_style(cx);
1516
1517        self.run_turn(session_id, cx, move |thread, cx| {
1518            let content: Vec<UserMessageContent> = params
1519                .prompt
1520                .into_iter()
1521                .map(|block| UserMessageContent::from_content_block(block, path_style))
1522                .collect::<Vec<_>>();
1523            log::debug!("Converted prompt to message: {} chars", content.len());
1524            log::debug!("Message id: {:?}", id);
1525            log::debug!("Message content: {:?}", content);
1526
1527            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1528        })
1529    }
1530
1531    fn retry(
1532        &self,
1533        session_id: &acp::SessionId,
1534        _cx: &App,
1535    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1536        Some(Rc::new(NativeAgentSessionRetry {
1537            connection: self.clone(),
1538            session_id: session_id.clone(),
1539        }) as _)
1540    }
1541
1542    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1543        log::info!("Cancelling on session: {}", session_id);
1544        self.0.update(cx, |agent, cx| {
1545            if let Some(session) = agent.sessions.get(session_id) {
1546                session
1547                    .thread
1548                    .update(cx, |thread, cx| thread.cancel(cx))
1549                    .detach();
1550            }
1551        });
1552    }
1553
1554    fn truncate(
1555        &self,
1556        session_id: &acp::SessionId,
1557        cx: &App,
1558    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1559        self.0.read_with(cx, |agent, _cx| {
1560            agent.sessions.get(session_id).map(|session| {
1561                Rc::new(NativeAgentSessionTruncate {
1562                    thread: session.thread.clone(),
1563                    acp_thread: session.acp_thread.downgrade(),
1564                }) as _
1565            })
1566        })
1567    }
1568
1569    fn set_title(
1570        &self,
1571        session_id: &acp::SessionId,
1572        cx: &App,
1573    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1574        self.0.read_with(cx, |agent, _cx| {
1575            agent
1576                .sessions
1577                .get(session_id)
1578                .filter(|s| !s.thread.read(cx).is_subagent())
1579                .map(|session| {
1580                    Rc::new(NativeAgentSessionSetTitle {
1581                        thread: session.thread.clone(),
1582                    }) as _
1583                })
1584        })
1585    }
1586
1587    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1588        let thread_store = self.0.read(cx).thread_store.clone();
1589        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1590    }
1591
1592    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1593        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1594    }
1595
1596    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1597        self
1598    }
1599}
1600
1601impl acp_thread::AgentTelemetry for NativeAgentConnection {
1602    fn thread_data(
1603        &self,
1604        session_id: &acp::SessionId,
1605        cx: &mut App,
1606    ) -> Task<Result<serde_json::Value>> {
1607        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1608            return Task::ready(Err(anyhow!("Session not found")));
1609        };
1610
1611        let task = session.thread.read(cx).to_db(cx);
1612        cx.background_spawn(async move {
1613            serde_json::to_value(task.await).context("Failed to serialize thread")
1614        })
1615    }
1616}
1617
1618pub struct NativeAgentSessionList {
1619    thread_store: Entity<ThreadStore>,
1620    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1621    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1622    _subscription: Subscription,
1623}
1624
1625impl NativeAgentSessionList {
1626    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1627        let (tx, rx) = smol::channel::unbounded();
1628        let this_tx = tx.clone();
1629        let subscription = cx.observe(&thread_store, move |_, _| {
1630            this_tx
1631                .try_send(acp_thread::SessionListUpdate::Refresh)
1632                .ok();
1633        });
1634        Self {
1635            thread_store,
1636            updates_tx: tx,
1637            updates_rx: rx,
1638            _subscription: subscription,
1639        }
1640    }
1641
1642    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1643        &self.thread_store
1644    }
1645}
1646
1647impl AgentSessionList for NativeAgentSessionList {
1648    fn list_sessions(
1649        &self,
1650        _request: AgentSessionListRequest,
1651        cx: &mut App,
1652    ) -> Task<Result<AgentSessionListResponse>> {
1653        let sessions = self
1654            .thread_store
1655            .read(cx)
1656            .entries()
1657            .map(|entry| AgentSessionInfo::from(&entry))
1658            .collect();
1659        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1660    }
1661
1662    fn supports_delete(&self) -> bool {
1663        true
1664    }
1665
1666    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1667        self.thread_store
1668            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1669    }
1670
1671    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1672        self.thread_store
1673            .update(cx, |store, cx| store.delete_threads(cx))
1674    }
1675
1676    fn watch(
1677        &self,
1678        _cx: &mut App,
1679    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1680        Some(self.updates_rx.clone())
1681    }
1682
1683    fn notify_refresh(&self) {
1684        self.updates_tx
1685            .try_send(acp_thread::SessionListUpdate::Refresh)
1686            .ok();
1687    }
1688
1689    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1690        self
1691    }
1692}
1693
1694struct NativeAgentSessionTruncate {
1695    thread: Entity<Thread>,
1696    acp_thread: WeakEntity<AcpThread>,
1697}
1698
1699impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1700    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1701        match self.thread.update(cx, |thread, cx| {
1702            thread.truncate(message_id.clone(), cx)?;
1703            Ok(thread.latest_token_usage())
1704        }) {
1705            Ok(usage) => {
1706                self.acp_thread
1707                    .update(cx, |thread, cx| {
1708                        thread.update_token_usage(usage, cx);
1709                    })
1710                    .ok();
1711                Task::ready(Ok(()))
1712            }
1713            Err(error) => Task::ready(Err(error)),
1714        }
1715    }
1716}
1717
1718struct NativeAgentSessionRetry {
1719    connection: NativeAgentConnection,
1720    session_id: acp::SessionId,
1721}
1722
1723impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1724    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1725        self.connection
1726            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1727                thread.update(cx, |thread, cx| thread.resume(cx))
1728            })
1729    }
1730}
1731
1732struct NativeAgentSessionSetTitle {
1733    thread: Entity<Thread>,
1734}
1735
1736impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1737    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1738        self.thread
1739            .update(cx, |thread, cx| thread.set_title(title, cx));
1740        Task::ready(Ok(()))
1741    }
1742}
1743
1744pub struct NativeThreadEnvironment {
1745    agent: WeakEntity<NativeAgent>,
1746    thread: WeakEntity<Thread>,
1747    acp_thread: WeakEntity<AcpThread>,
1748}
1749
1750impl NativeThreadEnvironment {
1751    pub(crate) fn create_subagent_thread(
1752        &self,
1753        label: String,
1754        cx: &mut App,
1755    ) -> Result<Rc<dyn SubagentHandle>> {
1756        let Some(parent_thread_entity) = self.thread.upgrade() else {
1757            anyhow::bail!("Parent thread no longer exists".to_string());
1758        };
1759        let parent_thread = parent_thread_entity.read(cx);
1760        let current_depth = parent_thread.depth();
1761        let parent_session_id = parent_thread.id().clone();
1762
1763        if current_depth >= MAX_SUBAGENT_DEPTH {
1764            return Err(anyhow!(
1765                "Maximum subagent depth ({}) reached",
1766                MAX_SUBAGENT_DEPTH
1767            ));
1768        }
1769
1770        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1771            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1772            thread.set_title(label.into(), cx);
1773            thread
1774        });
1775
1776        let session_id = subagent_thread.read(cx).id().clone();
1777
1778        let acp_thread = self
1779            .agent
1780            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1781                let project_id = agent
1782                    .sessions
1783                    .get(&parent_session_id)
1784                    .map(|s| s.project_id)
1785                    .context("parent session not found")?;
1786                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1787            })??;
1788
1789        let depth = current_depth + 1;
1790
1791        telemetry::event!(
1792            "Subagent Started",
1793            session = parent_thread_entity.read(cx).id().to_string(),
1794            subagent_session = session_id.to_string(),
1795            depth,
1796            is_resumed = false,
1797        );
1798
1799        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1800    }
1801
1802    pub(crate) fn resume_subagent_thread(
1803        &self,
1804        session_id: acp::SessionId,
1805        cx: &mut App,
1806    ) -> Result<Rc<dyn SubagentHandle>> {
1807        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1808            let session = agent
1809                .sessions
1810                .get(&session_id)
1811                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1812            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1813        })??;
1814
1815        let depth = subagent_thread.read(cx).depth();
1816
1817        if let Some(parent_thread_entity) = self.thread.upgrade() {
1818            telemetry::event!(
1819                "Subagent Started",
1820                session = parent_thread_entity.read(cx).id().to_string(),
1821                subagent_session = session_id.to_string(),
1822                depth,
1823                is_resumed = true,
1824            );
1825        }
1826
1827        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1828    }
1829
1830    fn prompt_subagent(
1831        &self,
1832        session_id: acp::SessionId,
1833        subagent_thread: Entity<Thread>,
1834        acp_thread: Entity<acp_thread::AcpThread>,
1835    ) -> Result<Rc<dyn SubagentHandle>> {
1836        let Some(parent_thread_entity) = self.thread.upgrade() else {
1837            anyhow::bail!("Parent thread no longer exists".to_string());
1838        };
1839        Ok(Rc::new(NativeSubagentHandle::new(
1840            session_id,
1841            subagent_thread,
1842            acp_thread,
1843            parent_thread_entity,
1844        )) as _)
1845    }
1846}
1847
1848impl ThreadEnvironment for NativeThreadEnvironment {
1849    fn create_terminal(
1850        &self,
1851        command: String,
1852        cwd: Option<PathBuf>,
1853        output_byte_limit: Option<u64>,
1854        cx: &mut AsyncApp,
1855    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1856        let task = self.acp_thread.update(cx, |thread, cx| {
1857            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1858        });
1859
1860        let acp_thread = self.acp_thread.clone();
1861        cx.spawn(async move |cx| {
1862            let terminal = task?.await?;
1863
1864            let (drop_tx, drop_rx) = oneshot::channel();
1865            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1866
1867            cx.spawn(async move |cx| {
1868                drop_rx.await.ok();
1869                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1870            })
1871            .detach();
1872
1873            let handle = AcpTerminalHandle {
1874                terminal,
1875                _drop_tx: Some(drop_tx),
1876            };
1877
1878            Ok(Rc::new(handle) as _)
1879        })
1880    }
1881
1882    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1883        self.create_subagent_thread(label, cx)
1884    }
1885
1886    fn resume_subagent(
1887        &self,
1888        session_id: acp::SessionId,
1889        cx: &mut App,
1890    ) -> Result<Rc<dyn SubagentHandle>> {
1891        self.resume_subagent_thread(session_id, cx)
1892    }
1893}
1894
1895#[derive(Debug, Clone)]
1896enum SubagentPromptResult {
1897    Completed,
1898    Cancelled,
1899    ContextWindowWarning,
1900    Error(String),
1901}
1902
1903pub struct NativeSubagentHandle {
1904    session_id: acp::SessionId,
1905    parent_thread: WeakEntity<Thread>,
1906    subagent_thread: Entity<Thread>,
1907    acp_thread: Entity<acp_thread::AcpThread>,
1908}
1909
1910impl NativeSubagentHandle {
1911    fn new(
1912        session_id: acp::SessionId,
1913        subagent_thread: Entity<Thread>,
1914        acp_thread: Entity<acp_thread::AcpThread>,
1915        parent_thread_entity: Entity<Thread>,
1916    ) -> Self {
1917        NativeSubagentHandle {
1918            session_id,
1919            subagent_thread,
1920            parent_thread: parent_thread_entity.downgrade(),
1921            acp_thread,
1922        }
1923    }
1924}
1925
1926impl SubagentHandle for NativeSubagentHandle {
1927    fn id(&self) -> acp::SessionId {
1928        self.session_id.clone()
1929    }
1930
1931    fn num_entries(&self, cx: &App) -> usize {
1932        self.acp_thread.read(cx).entries().len()
1933    }
1934
1935    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1936        let thread = self.subagent_thread.clone();
1937        let acp_thread = self.acp_thread.clone();
1938        let subagent_session_id = self.session_id.clone();
1939        let parent_thread = self.parent_thread.clone();
1940
1941        cx.spawn(async move |cx| {
1942            let (task, _subscription) = cx.update(|cx| {
1943                let ratio_before_prompt = thread
1944                    .read(cx)
1945                    .latest_token_usage()
1946                    .map(|usage| usage.ratio());
1947
1948                parent_thread
1949                    .update(cx, |parent_thread, _cx| {
1950                        parent_thread.register_running_subagent(thread.downgrade())
1951                    })
1952                    .ok();
1953
1954                let task = acp_thread.update(cx, |acp_thread, cx| {
1955                    acp_thread.send(vec![message.into()], cx)
1956                });
1957
1958                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1959                let mut token_limit_tx = Some(token_limit_tx);
1960
1961                let subscription = cx.subscribe(
1962                    &thread,
1963                    move |_thread, event: &TokenUsageUpdated, _cx| {
1964                        if let Some(usage) = &event.0 {
1965                            let old_ratio = ratio_before_prompt
1966                                .clone()
1967                                .unwrap_or(TokenUsageRatio::Normal);
1968                            let new_ratio = usage.ratio();
1969                            if old_ratio == TokenUsageRatio::Normal
1970                                && new_ratio == TokenUsageRatio::Warning
1971                            {
1972                                if let Some(tx) = token_limit_tx.take() {
1973                                    tx.send(()).ok();
1974                                }
1975                            }
1976                        }
1977                    },
1978                );
1979
1980                let wait_for_prompt = cx
1981                    .background_spawn(async move {
1982                        futures::select! {
1983                            response = task.fuse() => match response {
1984                                Ok(Some(response)) => {
1985                                    match response.stop_reason {
1986                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1987                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1988                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1989                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1990                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1991                                    }
1992                                }
1993                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1994                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1995                            },
1996                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1997                        }
1998                    });
1999
2000                (wait_for_prompt, subscription)
2001            });
2002
2003            let result = match task.await {
2004                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2005                    thread
2006                        .last_message()
2007                        .and_then(|message| {
2008                            let content = message.as_agent_message()?
2009                                .content
2010                                .iter()
2011                                .filter_map(|c| match c {
2012                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2013                                    _ => None,
2014                                })
2015                                .join("\n\n");
2016                            if content.is_empty() {
2017                                None
2018                            } else {
2019                                Some( content)
2020                            }
2021                        })
2022                        .context("No response from subagent")
2023                }),
2024                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2025                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2026                SubagentPromptResult::ContextWindowWarning => {
2027                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2028                    Err(anyhow!(
2029                        "The agent is nearing the end of its context window and has been \
2030                         stopped. You can prompt the thread again to have the agent wrap up \
2031                         or hand off its work."
2032                    ))
2033                }
2034            };
2035
2036            parent_thread
2037                .update(cx, |parent_thread, cx| {
2038                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2039                })
2040                .ok();
2041
2042            result
2043        })
2044    }
2045}
2046
2047pub struct AcpTerminalHandle {
2048    terminal: Entity<acp_thread::Terminal>,
2049    _drop_tx: Option<oneshot::Sender<()>>,
2050}
2051
2052impl TerminalHandle for AcpTerminalHandle {
2053    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2054        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2055    }
2056
2057    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2058        Ok(self
2059            .terminal
2060            .read_with(cx, |term, _cx| term.wait_for_exit()))
2061    }
2062
2063    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2064        Ok(self
2065            .terminal
2066            .read_with(cx, |term, cx| term.current_output(cx)))
2067    }
2068
2069    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2070        cx.update(|cx| {
2071            self.terminal.update(cx, |terminal, cx| {
2072                terminal.kill(cx);
2073            });
2074        });
2075        Ok(())
2076    }
2077
2078    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2079        Ok(self
2080            .terminal
2081            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2082    }
2083}
2084
2085#[cfg(test)]
2086mod internal_tests {
2087    use super::*;
2088    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2089    use fs::FakeFs;
2090    use gpui::TestAppContext;
2091    use indoc::formatdoc;
2092    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2093    use language_model::{
2094        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2095    };
2096    use serde_json::json;
2097    use settings::SettingsStore;
2098    use util::{path, rel_path::rel_path};
2099
2100    #[gpui::test]
2101    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2102        init_test(cx);
2103        let fs = FakeFs::new(cx.executor());
2104        fs.insert_tree(
2105            "/",
2106            json!({
2107                "a": {}
2108            }),
2109        )
2110        .await;
2111        let project = Project::test(fs.clone(), [], cx).await;
2112        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2113        let agent =
2114            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2115
2116        // Creating a session registers the project and triggers context building.
2117        let connection = NativeAgentConnection(agent.clone());
2118        let _acp_thread = cx
2119            .update(|cx| Rc::new(connection).new_session(project.clone(), Path::new("/"), cx))
2120            .await
2121            .unwrap();
2122        cx.run_until_parked();
2123
2124        agent.read_with(cx, |agent, cx| {
2125            let project_id = project.entity_id();
2126            let state = agent.projects.get(&project_id).unwrap();
2127            assert_eq!(state.project_context.read(cx).worktrees, vec![])
2128        });
2129
2130        let worktree = project
2131            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2132            .await
2133            .unwrap();
2134        cx.run_until_parked();
2135        agent.read_with(cx, |agent, cx| {
2136            let project_id = project.entity_id();
2137            let state = agent.projects.get(&project_id).unwrap();
2138            assert_eq!(
2139                state.project_context.read(cx).worktrees,
2140                vec![WorktreeContext {
2141                    root_name: "a".into(),
2142                    abs_path: Path::new("/a").into(),
2143                    rules_file: None
2144                }]
2145            )
2146        });
2147
2148        // Creating `/a/.rules` updates the project context.
2149        fs.insert_file("/a/.rules", Vec::new()).await;
2150        cx.run_until_parked();
2151        agent.read_with(cx, |agent, cx| {
2152            let project_id = project.entity_id();
2153            let state = agent.projects.get(&project_id).unwrap();
2154            let rules_entry = worktree
2155                .read(cx)
2156                .entry_for_path(rel_path(".rules"))
2157                .unwrap();
2158            assert_eq!(
2159                state.project_context.read(cx).worktrees,
2160                vec![WorktreeContext {
2161                    root_name: "a".into(),
2162                    abs_path: Path::new("/a").into(),
2163                    rules_file: Some(RulesFileContext {
2164                        path_in_worktree: rel_path(".rules").into(),
2165                        text: "".into(),
2166                        project_entry_id: rules_entry.id.to_usize()
2167                    })
2168                }]
2169            )
2170        });
2171    }
2172
2173    #[gpui::test]
2174    async fn test_listing_models(cx: &mut TestAppContext) {
2175        init_test(cx);
2176        let fs = FakeFs::new(cx.executor());
2177        fs.insert_tree("/", json!({ "a": {}  })).await;
2178        let project = Project::test(fs.clone(), [], cx).await;
2179        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2180        let connection =
2181            NativeAgentConnection(cx.update(|cx| {
2182                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2183            }));
2184
2185        // Create a thread/session
2186        let acp_thread = cx
2187            .update(|cx| {
2188                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2189            })
2190            .await
2191            .unwrap();
2192
2193        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2194
2195        let models = cx
2196            .update(|cx| {
2197                connection
2198                    .model_selector(&session_id)
2199                    .unwrap()
2200                    .list_models(cx)
2201            })
2202            .await
2203            .unwrap();
2204
2205        let acp_thread::AgentModelList::Grouped(models) = models else {
2206            panic!("Unexpected model group");
2207        };
2208        assert_eq!(
2209            models,
2210            IndexMap::from_iter([(
2211                AgentModelGroupName("Fake".into()),
2212                vec![AgentModelInfo {
2213                    id: acp::ModelId::new("fake/fake"),
2214                    name: "Fake".into(),
2215                    description: None,
2216                    icon: Some(acp_thread::AgentModelIcon::Named(
2217                        ui::IconName::ZedAssistant
2218                    )),
2219                    is_latest: false,
2220                    cost: None,
2221                }]
2222            )])
2223        );
2224    }
2225
2226    #[gpui::test]
2227    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2228        init_test(cx);
2229        let fs = FakeFs::new(cx.executor());
2230        fs.create_dir(paths::settings_file().parent().unwrap())
2231            .await
2232            .unwrap();
2233        fs.insert_file(
2234            paths::settings_file(),
2235            json!({
2236                "agent": {
2237                    "default_model": {
2238                        "provider": "foo",
2239                        "model": "bar"
2240                    }
2241                }
2242            })
2243            .to_string()
2244            .into_bytes(),
2245        )
2246        .await;
2247        let project = Project::test(fs.clone(), [], cx).await;
2248
2249        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2250
2251        // Create the agent and connection
2252        let agent =
2253            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2254        let connection = NativeAgentConnection(agent.clone());
2255
2256        // Create a thread/session
2257        let acp_thread = cx
2258            .update(|cx| {
2259                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2260            })
2261            .await
2262            .unwrap();
2263
2264        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2265
2266        // Select a model
2267        let selector = connection.model_selector(&session_id).unwrap();
2268        let model_id = acp::ModelId::new("fake/fake");
2269        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2270            .await
2271            .unwrap();
2272
2273        // Verify the thread has the selected model
2274        agent.read_with(cx, |agent, _| {
2275            let session = agent.sessions.get(&session_id).unwrap();
2276            session.thread.read_with(cx, |thread, _| {
2277                assert_eq!(thread.model().unwrap().id().0, "fake");
2278            });
2279        });
2280
2281        cx.run_until_parked();
2282
2283        // Verify settings file was updated
2284        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2285        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2286
2287        // Check that the agent settings contain the selected model
2288        assert_eq!(
2289            settings_json["agent"]["default_model"]["model"],
2290            json!("fake")
2291        );
2292        assert_eq!(
2293            settings_json["agent"]["default_model"]["provider"],
2294            json!("fake")
2295        );
2296
2297        // Register a thinking model and select it.
2298        cx.update(|cx| {
2299            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2300                "fake-corp",
2301                "fake-thinking",
2302                "Fake Thinking",
2303                true,
2304            ));
2305            let thinking_provider = Arc::new(
2306                FakeLanguageModelProvider::new(
2307                    LanguageModelProviderId::from("fake-corp".to_string()),
2308                    LanguageModelProviderName::from("Fake Corp".to_string()),
2309                )
2310                .with_models(vec![thinking_model]),
2311            );
2312            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2313                registry.register_provider(thinking_provider, cx);
2314            });
2315        });
2316        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2317
2318        let selector = connection.model_selector(&session_id).unwrap();
2319        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2320            .await
2321            .unwrap();
2322        cx.run_until_parked();
2323
2324        // Verify enable_thinking was written to settings as true.
2325        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2326        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2327        assert_eq!(
2328            settings_json["agent"]["default_model"]["enable_thinking"],
2329            json!(true),
2330            "selecting a thinking model should persist enable_thinking: true to settings"
2331        );
2332    }
2333
2334    #[gpui::test]
2335    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2336        init_test(cx);
2337        let fs = FakeFs::new(cx.executor());
2338        fs.create_dir(paths::settings_file().parent().unwrap())
2339            .await
2340            .unwrap();
2341        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2342        let project = Project::test(fs.clone(), [], cx).await;
2343
2344        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2345        let agent =
2346            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2347        let connection = NativeAgentConnection(agent.clone());
2348
2349        let acp_thread = cx
2350            .update(|cx| {
2351                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2352            })
2353            .await
2354            .unwrap();
2355        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2356
2357        // Register a second provider with a thinking model.
2358        cx.update(|cx| {
2359            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2360                "fake-corp",
2361                "fake-thinking",
2362                "Fake Thinking",
2363                true,
2364            ));
2365            let thinking_provider = Arc::new(
2366                FakeLanguageModelProvider::new(
2367                    LanguageModelProviderId::from("fake-corp".to_string()),
2368                    LanguageModelProviderName::from("Fake Corp".to_string()),
2369                )
2370                .with_models(vec![thinking_model]),
2371            );
2372            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2373                registry.register_provider(thinking_provider, cx);
2374            });
2375        });
2376        // Refresh the agent's model list so it picks up the new provider.
2377        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2378
2379        // Thread starts with thinking_enabled = false (the default).
2380        agent.read_with(cx, |agent, _| {
2381            let session = agent.sessions.get(&session_id).unwrap();
2382            session.thread.read_with(cx, |thread, _| {
2383                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2384            });
2385        });
2386
2387        // Select the thinking model via select_model.
2388        let selector = connection.model_selector(&session_id).unwrap();
2389        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2390            .await
2391            .unwrap();
2392
2393        // select_model should have enabled thinking based on the model's supports_thinking().
2394        agent.read_with(cx, |agent, _| {
2395            let session = agent.sessions.get(&session_id).unwrap();
2396            session.thread.read_with(cx, |thread, _| {
2397                assert!(
2398                    thread.thinking_enabled(),
2399                    "select_model should enable thinking when model supports it"
2400                );
2401            });
2402        });
2403
2404        // Switch back to the non-thinking model.
2405        let selector = connection.model_selector(&session_id).unwrap();
2406        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2407            .await
2408            .unwrap();
2409
2410        // select_model should have disabled thinking.
2411        agent.read_with(cx, |agent, _| {
2412            let session = agent.sessions.get(&session_id).unwrap();
2413            session.thread.read_with(cx, |thread, _| {
2414                assert!(
2415                    !thread.thinking_enabled(),
2416                    "select_model should disable thinking when model does not support it"
2417                );
2418            });
2419        });
2420    }
2421
2422    #[gpui::test]
2423    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2424        init_test(cx);
2425        let fs = FakeFs::new(cx.executor());
2426        fs.insert_tree("/", json!({ "a": {} })).await;
2427        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2428        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2429        let agent = cx.update(|cx| {
2430            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2431        });
2432        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2433
2434        // Register a thinking model.
2435        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2436            "fake-corp",
2437            "fake-thinking",
2438            "Fake Thinking",
2439            true,
2440        ));
2441        let thinking_provider = Arc::new(
2442            FakeLanguageModelProvider::new(
2443                LanguageModelProviderId::from("fake-corp".to_string()),
2444                LanguageModelProviderName::from("Fake Corp".to_string()),
2445            )
2446            .with_models(vec![thinking_model.clone()]),
2447        );
2448        cx.update(|cx| {
2449            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2450                registry.register_provider(thinking_provider, cx);
2451            });
2452        });
2453        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2454
2455        // Create a thread and select the thinking model.
2456        let acp_thread = cx
2457            .update(|cx| {
2458                connection
2459                    .clone()
2460                    .new_session(project.clone(), Path::new("/a"), cx)
2461            })
2462            .await
2463            .unwrap();
2464        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2465
2466        let selector = connection.model_selector(&session_id).unwrap();
2467        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2468            .await
2469            .unwrap();
2470
2471        // Verify thinking is enabled after selecting the thinking model.
2472        let thread = agent.read_with(cx, |agent, _| {
2473            agent.sessions.get(&session_id).unwrap().thread.clone()
2474        });
2475        thread.read_with(cx, |thread, _| {
2476            assert!(
2477                thread.thinking_enabled(),
2478                "thinking should be enabled after selecting thinking model"
2479            );
2480        });
2481
2482        // Send a message so the thread gets persisted.
2483        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2484        let send = cx.foreground_executor().spawn(send);
2485        cx.run_until_parked();
2486
2487        thinking_model.send_last_completion_stream_text_chunk("Response.");
2488        thinking_model.end_last_completion_stream();
2489
2490        send.await.unwrap();
2491        cx.run_until_parked();
2492
2493        // Close the session so it can be reloaded from disk.
2494        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2495            .await
2496            .unwrap();
2497        drop(thread);
2498        drop(acp_thread);
2499        agent.read_with(cx, |agent, _| {
2500            assert!(agent.sessions.is_empty());
2501        });
2502
2503        // Reload the thread and verify thinking_enabled is still true.
2504        let reloaded_acp_thread = agent
2505            .update(cx, |agent, cx| {
2506                agent.open_thread(session_id.clone(), project.clone(), cx)
2507            })
2508            .await
2509            .unwrap();
2510        let reloaded_thread = agent.read_with(cx, |agent, _| {
2511            agent.sessions.get(&session_id).unwrap().thread.clone()
2512        });
2513        reloaded_thread.read_with(cx, |thread, _| {
2514            assert!(
2515                thread.thinking_enabled(),
2516                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2517            );
2518        });
2519
2520        drop(reloaded_acp_thread);
2521    }
2522
2523    #[gpui::test]
2524    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2525        init_test(cx);
2526        let fs = FakeFs::new(cx.executor());
2527        fs.insert_tree("/", json!({ "a": {} })).await;
2528        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2529        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2530        let agent = cx.update(|cx| {
2531            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2532        });
2533        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2534
2535        // Register a model where id() != name(), like real Anthropic models
2536        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2537        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2538            "fake-corp",
2539            "custom-model-id",
2540            "Custom Model Display Name",
2541            false,
2542        ));
2543        let provider = Arc::new(
2544            FakeLanguageModelProvider::new(
2545                LanguageModelProviderId::from("fake-corp".to_string()),
2546                LanguageModelProviderName::from("Fake Corp".to_string()),
2547            )
2548            .with_models(vec![model.clone()]),
2549        );
2550        cx.update(|cx| {
2551            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2552                registry.register_provider(provider, cx);
2553            });
2554        });
2555        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2556
2557        // Create a thread and select the model.
2558        let acp_thread = cx
2559            .update(|cx| {
2560                connection
2561                    .clone()
2562                    .new_session(project.clone(), Path::new("/a"), cx)
2563            })
2564            .await
2565            .unwrap();
2566        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2567
2568        let selector = connection.model_selector(&session_id).unwrap();
2569        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2570            .await
2571            .unwrap();
2572
2573        let thread = agent.read_with(cx, |agent, _| {
2574            agent.sessions.get(&session_id).unwrap().thread.clone()
2575        });
2576        thread.read_with(cx, |thread, _| {
2577            assert_eq!(
2578                thread.model().unwrap().id().0.as_ref(),
2579                "custom-model-id",
2580                "model should be set before persisting"
2581            );
2582        });
2583
2584        // Send a message so the thread gets persisted.
2585        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2586        let send = cx.foreground_executor().spawn(send);
2587        cx.run_until_parked();
2588
2589        model.send_last_completion_stream_text_chunk("Response.");
2590        model.end_last_completion_stream();
2591
2592        send.await.unwrap();
2593        cx.run_until_parked();
2594
2595        // Close the session so it can be reloaded from disk.
2596        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2597            .await
2598            .unwrap();
2599        drop(thread);
2600        drop(acp_thread);
2601        agent.read_with(cx, |agent, _| {
2602            assert!(agent.sessions.is_empty());
2603        });
2604
2605        // Reload the thread and verify the model was preserved.
2606        let reloaded_acp_thread = agent
2607            .update(cx, |agent, cx| {
2608                agent.open_thread(session_id.clone(), project.clone(), cx)
2609            })
2610            .await
2611            .unwrap();
2612        let reloaded_thread = agent.read_with(cx, |agent, _| {
2613            agent.sessions.get(&session_id).unwrap().thread.clone()
2614        });
2615        reloaded_thread.read_with(cx, |thread, _| {
2616            let reloaded_model = thread
2617                .model()
2618                .expect("model should be present after reload");
2619            assert_eq!(
2620                reloaded_model.id().0.as_ref(),
2621                "custom-model-id",
2622                "reloaded thread should have the same model, not fall back to the default"
2623            );
2624        });
2625
2626        drop(reloaded_acp_thread);
2627    }
2628
2629    #[gpui::test]
2630    async fn test_save_load_thread(cx: &mut TestAppContext) {
2631        init_test(cx);
2632        let fs = FakeFs::new(cx.executor());
2633        fs.insert_tree(
2634            "/",
2635            json!({
2636                "a": {
2637                    "b.md": "Lorem"
2638                }
2639            }),
2640        )
2641        .await;
2642        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2643        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2644        let agent = cx.update(|cx| {
2645            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2646        });
2647        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2648
2649        let acp_thread = cx
2650            .update(|cx| {
2651                connection
2652                    .clone()
2653                    .new_session(project.clone(), Path::new(""), cx)
2654            })
2655            .await
2656            .unwrap();
2657        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2658        let thread = agent.read_with(cx, |agent, _| {
2659            agent.sessions.get(&session_id).unwrap().thread.clone()
2660        });
2661
2662        // Ensure empty threads are not saved, even if they get mutated.
2663        let model = Arc::new(FakeLanguageModel::default());
2664        let summary_model = Arc::new(FakeLanguageModel::default());
2665        thread.update(cx, |thread, cx| {
2666            thread.set_model(model.clone(), cx);
2667            thread.set_summarization_model(Some(summary_model.clone()), cx);
2668        });
2669        cx.run_until_parked();
2670        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2671
2672        let send = acp_thread.update(cx, |thread, cx| {
2673            thread.send(
2674                vec![
2675                    "What does ".into(),
2676                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2677                        "b.md",
2678                        MentionUri::File {
2679                            abs_path: path!("/a/b.md").into(),
2680                        }
2681                        .to_uri()
2682                        .to_string(),
2683                    )),
2684                    " mean?".into(),
2685                ],
2686                cx,
2687            )
2688        });
2689        let send = cx.foreground_executor().spawn(send);
2690        cx.run_until_parked();
2691
2692        model.send_last_completion_stream_text_chunk("Lorem.");
2693        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2694            language_model::TokenUsage {
2695                input_tokens: 150,
2696                output_tokens: 75,
2697                ..Default::default()
2698            },
2699        ));
2700        model.end_last_completion_stream();
2701        cx.run_until_parked();
2702        summary_model
2703            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2704        summary_model.end_last_completion_stream();
2705
2706        send.await.unwrap();
2707        let uri = MentionUri::File {
2708            abs_path: path!("/a/b.md").into(),
2709        }
2710        .to_uri();
2711        acp_thread.read_with(cx, |thread, cx| {
2712            assert_eq!(
2713                thread.to_markdown(cx),
2714                formatdoc! {"
2715                    ## User
2716
2717                    What does [@b.md]({uri}) mean?
2718
2719                    ## Assistant
2720
2721                    Lorem.
2722
2723                "}
2724            )
2725        });
2726
2727        cx.run_until_parked();
2728
2729        // Set a draft prompt with rich content blocks before saving.
2730        let draft_blocks = vec![
2731            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2732            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2733            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2734        ];
2735        acp_thread.update(cx, |thread, _cx| {
2736            thread.set_draft_prompt(Some(draft_blocks.clone()));
2737        });
2738        thread.update(cx, |thread, _cx| {
2739            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2740                item_ix: 5,
2741                offset_in_item: gpui::px(12.5),
2742            }));
2743        });
2744        thread.update(cx, |_thread, cx| cx.notify());
2745        cx.run_until_parked();
2746
2747        // Close the session so it can be reloaded from disk.
2748        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2749            .await
2750            .unwrap();
2751        drop(thread);
2752        drop(acp_thread);
2753        agent.read_with(cx, |agent, _| {
2754            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2755        });
2756
2757        // Ensure the thread can be reloaded from disk.
2758        assert_eq!(
2759            thread_entries(&thread_store, cx),
2760            vec![(
2761                session_id.clone(),
2762                format!("Explaining {}", path!("/a/b.md"))
2763            )]
2764        );
2765        let acp_thread = agent
2766            .update(cx, |agent, cx| {
2767                agent.open_thread(session_id.clone(), project.clone(), cx)
2768            })
2769            .await
2770            .unwrap();
2771        acp_thread.read_with(cx, |thread, cx| {
2772            assert_eq!(
2773                thread.to_markdown(cx),
2774                formatdoc! {"
2775                    ## User
2776
2777                    What does [@b.md]({uri}) mean?
2778
2779                    ## Assistant
2780
2781                    Lorem.
2782
2783                "}
2784            )
2785        });
2786
2787        // Ensure the draft prompt with rich content blocks survived the round-trip.
2788        acp_thread.read_with(cx, |thread, _| {
2789            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2790        });
2791
2792        // Ensure token usage survived the round-trip.
2793        acp_thread.read_with(cx, |thread, _| {
2794            let usage = thread
2795                .token_usage()
2796                .expect("token usage should be restored after reload");
2797            assert_eq!(usage.input_tokens, 150);
2798            assert_eq!(usage.output_tokens, 75);
2799        });
2800
2801        // Ensure scroll position survived the round-trip.
2802        acp_thread.read_with(cx, |thread, _| {
2803            let scroll = thread
2804                .ui_scroll_position()
2805                .expect("scroll position should be restored after reload");
2806            assert_eq!(scroll.item_ix, 5);
2807            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2808        });
2809    }
2810
2811    fn thread_entries(
2812        thread_store: &Entity<ThreadStore>,
2813        cx: &mut TestAppContext,
2814    ) -> Vec<(acp::SessionId, String)> {
2815        thread_store.read_with(cx, |store, _| {
2816            store
2817                .entries()
2818                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2819                .collect::<Vec<_>>()
2820        })
2821    }
2822
2823    fn init_test(cx: &mut TestAppContext) {
2824        env_logger::try_init().ok();
2825        cx.update(|cx| {
2826            let settings_store = SettingsStore::test(cx);
2827            cx.set_global(settings_store);
2828
2829            LanguageModelRegistry::test(cx);
2830        });
2831    }
2832}
2833
2834fn mcp_message_content_to_acp_content_block(
2835    content: context_server::types::MessageContent,
2836) -> acp::ContentBlock {
2837    match content {
2838        context_server::types::MessageContent::Text {
2839            text,
2840            annotations: _,
2841        } => text.into(),
2842        context_server::types::MessageContent::Image {
2843            data,
2844            mime_type,
2845            annotations: _,
2846        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2847        context_server::types::MessageContent::Audio {
2848            data,
2849            mime_type,
2850            annotations: _,
2851        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2852        context_server::types::MessageContent::Resource {
2853            resource,
2854            annotations: _,
2855        } => {
2856            let mut link =
2857                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2858            if let Some(mime_type) = resource.mime_type {
2859                link = link.mime_type(mime_type);
2860            }
2861            acp::ContentBlock::ResourceLink(link)
2862        }
2863    }
2864}