agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<()>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let session_id = thread.id().clone();
 332        let parent_session_id = thread.parent_thread_id();
 333        let title = thread.title();
 334        let draft_prompt = thread.draft_prompt().map(Vec::from);
 335        let scroll_position = thread.ui_scroll_position();
 336        let token_usage = thread.latest_token_usage();
 337        let project = thread.project.clone();
 338        let action_log = thread.action_log.clone();
 339        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 340        let acp_thread = cx.new(|cx| {
 341            let mut acp_thread = acp_thread::AcpThread::new(
 342                parent_session_id,
 343                title,
 344                None,
 345                connection,
 346                project.clone(),
 347                action_log.clone(),
 348                session_id.clone(),
 349                prompt_capabilities_rx,
 350                cx,
 351            );
 352            acp_thread.set_draft_prompt(draft_prompt);
 353            acp_thread.set_ui_scroll_position(scroll_position);
 354            acp_thread.update_token_usage(token_usage, cx);
 355            acp_thread
 356        });
 357
 358        let registry = LanguageModelRegistry::read_global(cx);
 359        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 360
 361        let weak = cx.weak_entity();
 362        let weak_thread = thread_handle.downgrade();
 363        thread_handle.update(cx, |thread, cx| {
 364            thread.set_summarization_model(summarization_model, cx);
 365            thread.add_default_tools(
 366                Rc::new(NativeThreadEnvironment {
 367                    acp_thread: acp_thread.downgrade(),
 368                    thread: weak_thread,
 369                    agent: weak,
 370                }) as _,
 371                cx,
 372            )
 373        });
 374
 375        let subscriptions = vec![
 376            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 377            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 378            cx.observe(&thread_handle, move |this, thread, cx| {
 379                this.save_thread(thread, cx)
 380            }),
 381        ];
 382
 383        self.sessions.insert(
 384            session_id,
 385            Session {
 386                thread: thread_handle,
 387                acp_thread: acp_thread.clone(),
 388                project_id,
 389                _subscriptions: subscriptions,
 390                pending_save: Task::ready(()),
 391            },
 392        );
 393
 394        self.update_available_commands_for_project(project_id, cx);
 395
 396        acp_thread
 397    }
 398
 399    pub fn models(&self) -> &LanguageModels {
 400        &self.models
 401    }
 402
 403    fn get_or_create_project_state(
 404        &mut self,
 405        project: &Entity<Project>,
 406        cx: &mut Context<Self>,
 407    ) -> EntityId {
 408        let project_id = project.entity_id();
 409        if self.projects.contains_key(&project_id) {
 410            return project_id;
 411        }
 412
 413        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 414        self.register_project_with_initial_context(project.clone(), project_context, cx);
 415        if let Some(state) = self.projects.get_mut(&project_id) {
 416            state.project_context_needs_refresh.send(()).ok();
 417        }
 418        project_id
 419    }
 420
 421    fn register_project_with_initial_context(
 422        &mut self,
 423        project: Entity<Project>,
 424        project_context: Entity<ProjectContext>,
 425        cx: &mut Context<Self>,
 426    ) {
 427        let project_id = project.entity_id();
 428
 429        let context_server_store = project.read(cx).context_server_store();
 430        let context_server_registry =
 431            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 432
 433        let subscriptions = vec![
 434            cx.subscribe(&project, Self::handle_project_event),
 435            cx.subscribe(
 436                &context_server_store,
 437                Self::handle_context_server_store_updated,
 438            ),
 439            cx.subscribe(
 440                &context_server_registry,
 441                Self::handle_context_server_registry_event,
 442            ),
 443        ];
 444
 445        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 446            watch::channel(());
 447
 448        self.projects.insert(
 449            project_id,
 450            ProjectState {
 451                project,
 452                project_context,
 453                project_context_needs_refresh: project_context_needs_refresh_tx,
 454                _maintain_project_context: cx.spawn(async move |this, cx| {
 455                    Self::maintain_project_context(
 456                        this,
 457                        project_id,
 458                        project_context_needs_refresh_rx,
 459                        cx,
 460                    )
 461                    .await
 462                }),
 463                context_server_registry,
 464                _subscriptions: subscriptions,
 465            },
 466        );
 467    }
 468
 469    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 470        self.sessions
 471            .get(session_id)
 472            .and_then(|session| self.projects.get(&session.project_id))
 473    }
 474
 475    async fn maintain_project_context(
 476        this: WeakEntity<Self>,
 477        project_id: EntityId,
 478        mut needs_refresh: watch::Receiver<()>,
 479        cx: &mut AsyncApp,
 480    ) -> Result<()> {
 481        while needs_refresh.changed().await.is_ok() {
 482            let project_context = this
 483                .update(cx, |this, cx| {
 484                    let state = this
 485                        .projects
 486                        .get(&project_id)
 487                        .context("project state not found")?;
 488                    anyhow::Ok(Self::build_project_context(
 489                        &state.project,
 490                        this.prompt_store.as_ref(),
 491                        cx,
 492                    ))
 493                })??
 494                .await;
 495            this.update(cx, |this, cx| {
 496                if let Some(state) = this.projects.get_mut(&project_id) {
 497                    state.project_context = cx.new(|_| project_context);
 498                }
 499            })?;
 500        }
 501
 502        Ok(())
 503    }
 504
 505    fn build_project_context(
 506        project: &Entity<Project>,
 507        prompt_store: Option<&Entity<PromptStore>>,
 508        cx: &mut App,
 509    ) -> Task<ProjectContext> {
 510        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 511        let worktree_tasks = worktrees
 512            .into_iter()
 513            .map(|worktree| {
 514                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 515            })
 516            .collect::<Vec<_>>();
 517        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 518            prompt_store.read_with(cx, |prompt_store, cx| {
 519                let prompts = prompt_store.default_prompt_metadata();
 520                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 521                    let contents = prompt_store.load(prompt_metadata.id, cx);
 522                    async move { (contents.await, prompt_metadata) }
 523                });
 524                cx.background_spawn(future::join_all(load_tasks))
 525            })
 526        } else {
 527            Task::ready(vec![])
 528        };
 529
 530        cx.spawn(async move |_cx| {
 531            let (worktrees, default_user_rules) =
 532                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 533
 534            let worktrees = worktrees
 535                .into_iter()
 536                .map(|(worktree, _rules_error)| {
 537                    // TODO: show error message
 538                    // if let Some(rules_error) = rules_error {
 539                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 540                    // }
 541                    worktree
 542                })
 543                .collect::<Vec<_>>();
 544
 545            let default_user_rules = default_user_rules
 546                .into_iter()
 547                .flat_map(|(contents, prompt_metadata)| match contents {
 548                    Ok(contents) => Some(UserRulesContext {
 549                        uuid: prompt_metadata.id.as_user()?,
 550                        title: prompt_metadata.title.map(|title| title.to_string()),
 551                        contents,
 552                    }),
 553                    Err(_err) => {
 554                        // TODO: show error message
 555                        // this.update(cx, |_, cx| {
 556                        //     cx.emit(RulesLoadingError {
 557                        //         message: format!("{err:?}").into(),
 558                        //     });
 559                        // })
 560                        // .ok();
 561                        None
 562                    }
 563                })
 564                .collect::<Vec<_>>();
 565
 566            ProjectContext::new(worktrees, default_user_rules)
 567        })
 568    }
 569
 570    fn load_worktree_info_for_system_prompt(
 571        worktree: Entity<Worktree>,
 572        project: Entity<Project>,
 573        cx: &mut App,
 574    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 575        let tree = worktree.read(cx);
 576        let root_name = tree.root_name_str().into();
 577        let abs_path = tree.abs_path();
 578
 579        let mut context = WorktreeContext {
 580            root_name,
 581            abs_path,
 582            rules_file: None,
 583        };
 584
 585        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 586        let Some(rules_task) = rules_task else {
 587            return Task::ready((context, None));
 588        };
 589
 590        cx.spawn(async move |_| {
 591            let (rules_file, rules_file_error) = match rules_task.await {
 592                Ok(rules_file) => (Some(rules_file), None),
 593                Err(err) => (
 594                    None,
 595                    Some(RulesLoadingError {
 596                        message: format!("{err}").into(),
 597                    }),
 598                ),
 599            };
 600            context.rules_file = rules_file;
 601            (context, rules_file_error)
 602        })
 603    }
 604
 605    fn load_worktree_rules_file(
 606        worktree: Entity<Worktree>,
 607        project: Entity<Project>,
 608        cx: &mut App,
 609    ) -> Option<Task<Result<RulesFileContext>>> {
 610        let worktree = worktree.read(cx);
 611        let worktree_id = worktree.id();
 612        let selected_rules_file = RULES_FILE_NAMES
 613            .into_iter()
 614            .filter_map(|name| {
 615                worktree
 616                    .entry_for_path(RelPath::unix(name).unwrap())
 617                    .filter(|entry| entry.is_file())
 618                    .map(|entry| entry.path.clone())
 619            })
 620            .next();
 621
 622        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 623        // supported. This doesn't seem to occur often in GitHub repositories.
 624        selected_rules_file.map(|path_in_worktree| {
 625            let project_path = ProjectPath {
 626                worktree_id,
 627                path: path_in_worktree.clone(),
 628            };
 629            let buffer_task =
 630                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 631            let rope_task = cx.spawn(async move |cx| {
 632                let buffer = buffer_task.await?;
 633                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 634                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 635                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 636                })?;
 637                anyhow::Ok((project_entry_id, rope))
 638            });
 639            // Build a string from the rope on a background thread.
 640            cx.background_spawn(async move {
 641                let (project_entry_id, rope) = rope_task.await?;
 642                anyhow::Ok(RulesFileContext {
 643                    path_in_worktree,
 644                    text: rope.to_string().trim().to_string(),
 645                    project_entry_id: project_entry_id.to_usize(),
 646                })
 647            })
 648        })
 649    }
 650
 651    fn handle_thread_title_updated(
 652        &mut self,
 653        thread: Entity<Thread>,
 654        _: &TitleUpdated,
 655        cx: &mut Context<Self>,
 656    ) {
 657        let session_id = thread.read(cx).id();
 658        let Some(session) = self.sessions.get(session_id) else {
 659            return;
 660        };
 661        let thread = thread.downgrade();
 662        let acp_thread = session.acp_thread.downgrade();
 663        cx.spawn(async move |_, cx| {
 664            let title = thread.read_with(cx, |thread, _| thread.title())?;
 665            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 666            task.await
 667        })
 668        .detach_and_log_err(cx);
 669    }
 670
 671    fn handle_thread_token_usage_updated(
 672        &mut self,
 673        thread: Entity<Thread>,
 674        usage: &TokenUsageUpdated,
 675        cx: &mut Context<Self>,
 676    ) {
 677        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 678            return;
 679        };
 680        session.acp_thread.update(cx, |acp_thread, cx| {
 681            acp_thread.update_token_usage(usage.0.clone(), cx);
 682        });
 683    }
 684
 685    fn handle_project_event(
 686        &mut self,
 687        project: Entity<Project>,
 688        event: &project::Event,
 689        _cx: &mut Context<Self>,
 690    ) {
 691        let project_id = project.entity_id();
 692        let Some(state) = self.projects.get_mut(&project_id) else {
 693            return;
 694        };
 695        match event {
 696            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 697                state.project_context_needs_refresh.send(()).ok();
 698            }
 699            project::Event::WorktreeUpdatedEntries(_, items) => {
 700                if items.iter().any(|(path, _, _)| {
 701                    RULES_FILE_NAMES
 702                        .iter()
 703                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 704                }) {
 705                    state.project_context_needs_refresh.send(()).ok();
 706                }
 707            }
 708            _ => {}
 709        }
 710    }
 711
 712    fn handle_prompts_updated_event(
 713        &mut self,
 714        _prompt_store: Entity<PromptStore>,
 715        _event: &prompt_store::PromptsUpdatedEvent,
 716        _cx: &mut Context<Self>,
 717    ) {
 718        for state in self.projects.values_mut() {
 719            state.project_context_needs_refresh.send(()).ok();
 720        }
 721    }
 722
 723    fn handle_models_updated_event(
 724        &mut self,
 725        _registry: Entity<LanguageModelRegistry>,
 726        _event: &language_model::Event,
 727        cx: &mut Context<Self>,
 728    ) {
 729        self.models.refresh_list(cx);
 730
 731        let registry = LanguageModelRegistry::read_global(cx);
 732        let default_model = registry.default_model().map(|m| m.model);
 733        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 734
 735        for session in self.sessions.values_mut() {
 736            session.thread.update(cx, |thread, cx| {
 737                if thread.model().is_none()
 738                    && let Some(model) = default_model.clone()
 739                {
 740                    thread.set_model(model, cx);
 741                    cx.notify();
 742                }
 743                thread.set_summarization_model(summarization_model.clone(), cx);
 744            });
 745        }
 746    }
 747
 748    fn handle_context_server_store_updated(
 749        &mut self,
 750        store: Entity<project::context_server_store::ContextServerStore>,
 751        _event: &project::context_server_store::ServerStatusChangedEvent,
 752        cx: &mut Context<Self>,
 753    ) {
 754        let project_id = self.projects.iter().find_map(|(id, state)| {
 755            if *state.context_server_registry.read(cx).server_store() == store {
 756                Some(*id)
 757            } else {
 758                None
 759            }
 760        });
 761        if let Some(project_id) = project_id {
 762            self.update_available_commands_for_project(project_id, cx);
 763        }
 764    }
 765
 766    fn handle_context_server_registry_event(
 767        &mut self,
 768        registry: Entity<ContextServerRegistry>,
 769        event: &ContextServerRegistryEvent,
 770        cx: &mut Context<Self>,
 771    ) {
 772        match event {
 773            ContextServerRegistryEvent::ToolsChanged => {}
 774            ContextServerRegistryEvent::PromptsChanged => {
 775                let project_id = self.projects.iter().find_map(|(id, state)| {
 776                    if state.context_server_registry == registry {
 777                        Some(*id)
 778                    } else {
 779                        None
 780                    }
 781                });
 782                if let Some(project_id) = project_id {
 783                    self.update_available_commands_for_project(project_id, cx);
 784                }
 785            }
 786        }
 787    }
 788
 789    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 790        let available_commands =
 791            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 792        for session in self.sessions.values() {
 793            if session.project_id != project_id {
 794                continue;
 795            }
 796            session.acp_thread.update(cx, |thread, cx| {
 797                thread
 798                    .handle_session_update(
 799                        acp::SessionUpdate::AvailableCommandsUpdate(
 800                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 801                        ),
 802                        cx,
 803                    )
 804                    .log_err();
 805            });
 806        }
 807    }
 808
 809    fn build_available_commands_for_project(
 810        project_state: Option<&ProjectState>,
 811        cx: &App,
 812    ) -> Vec<acp::AvailableCommand> {
 813        let Some(state) = project_state else {
 814            return vec![];
 815        };
 816        let registry = state.context_server_registry.read(cx);
 817
 818        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 819        for context_server_prompt in registry.prompts() {
 820            *prompt_name_counts
 821                .entry(context_server_prompt.prompt.name.as_str())
 822                .or_insert(0) += 1;
 823        }
 824
 825        registry
 826            .prompts()
 827            .flat_map(|context_server_prompt| {
 828                let prompt = &context_server_prompt.prompt;
 829
 830                let should_prefix = prompt_name_counts
 831                    .get(prompt.name.as_str())
 832                    .copied()
 833                    .unwrap_or(0)
 834                    > 1;
 835
 836                let name = if should_prefix {
 837                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 838                } else {
 839                    prompt.name.clone()
 840                };
 841
 842                let mut command = acp::AvailableCommand::new(
 843                    name,
 844                    prompt.description.clone().unwrap_or_default(),
 845                );
 846
 847                match prompt.arguments.as_deref() {
 848                    Some([arg]) => {
 849                        let hint = format!("<{}>", arg.name);
 850
 851                        command = command.input(acp::AvailableCommandInput::Unstructured(
 852                            acp::UnstructuredCommandInput::new(hint),
 853                        ));
 854                    }
 855                    Some([]) | None => {}
 856                    Some(_) => {
 857                        // skip >1 argument commands since we don't support them yet
 858                        return None;
 859                    }
 860                }
 861
 862                Some(command)
 863            })
 864            .collect()
 865    }
 866
 867    pub fn load_thread(
 868        &mut self,
 869        id: acp::SessionId,
 870        project: Entity<Project>,
 871        cx: &mut Context<Self>,
 872    ) -> Task<Result<Entity<Thread>>> {
 873        let database_future = ThreadsDatabase::connect(cx);
 874        cx.spawn(async move |this, cx| {
 875            let database = database_future.await.map_err(|err| anyhow!(err))?;
 876            let db_thread = database
 877                .load_thread(id.clone())
 878                .await?
 879                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 880
 881            this.update(cx, |this, cx| {
 882                let project_id = this.get_or_create_project_state(&project, cx);
 883                let project_state = this
 884                    .projects
 885                    .get(&project_id)
 886                    .context("project state not found")?;
 887                let summarization_model = LanguageModelRegistry::read_global(cx)
 888                    .thread_summary_model()
 889                    .map(|c| c.model);
 890
 891                Ok(cx.new(|cx| {
 892                    let mut thread = Thread::from_db(
 893                        id.clone(),
 894                        db_thread,
 895                        project_state.project.clone(),
 896                        project_state.project_context.clone(),
 897                        project_state.context_server_registry.clone(),
 898                        this.templates.clone(),
 899                        cx,
 900                    );
 901                    thread.set_summarization_model(summarization_model, cx);
 902                    thread
 903                }))
 904            })?
 905        })
 906    }
 907
 908    pub fn open_thread(
 909        &mut self,
 910        id: acp::SessionId,
 911        project: Entity<Project>,
 912        cx: &mut Context<Self>,
 913    ) -> Task<Result<Entity<AcpThread>>> {
 914        if let Some(session) = self.sessions.get(&id) {
 915            return Task::ready(Ok(session.acp_thread.clone()));
 916        }
 917
 918        let task = self.load_thread(id, project.clone(), cx);
 919        cx.spawn(async move |this, cx| {
 920            let thread = task.await?;
 921            let acp_thread = this.update(cx, |this, cx| {
 922                let project_id = this.get_or_create_project_state(&project, cx);
 923                this.register_session(thread.clone(), project_id, cx)
 924            })?;
 925            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 926            cx.update(|cx| {
 927                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 928            })
 929            .await?;
 930            Ok(acp_thread)
 931        })
 932    }
 933
 934    pub fn thread_summary(
 935        &mut self,
 936        id: acp::SessionId,
 937        project: Entity<Project>,
 938        cx: &mut Context<Self>,
 939    ) -> Task<Result<SharedString>> {
 940        let thread = self.open_thread(id.clone(), project, cx);
 941        cx.spawn(async move |this, cx| {
 942            let acp_thread = thread.await?;
 943            let result = this
 944                .update(cx, |this, cx| {
 945                    this.sessions
 946                        .get(&id)
 947                        .unwrap()
 948                        .thread
 949                        .update(cx, |thread, cx| thread.summary(cx))
 950                })?
 951                .await
 952                .context("Failed to generate summary")?;
 953            drop(acp_thread);
 954            Ok(result)
 955        })
 956    }
 957
 958    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 959        if thread.read(cx).is_empty() {
 960            return;
 961        }
 962
 963        let id = thread.read(cx).id().clone();
 964        let Some(session) = self.sessions.get_mut(&id) else {
 965            return;
 966        };
 967
 968        let project_id = session.project_id;
 969        let Some(state) = self.projects.get(&project_id) else {
 970            return;
 971        };
 972
 973        let folder_paths = PathList::new(
 974            &state
 975                .project
 976                .read(cx)
 977                .visible_worktrees(cx)
 978                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 979                .collect::<Vec<_>>(),
 980        );
 981
 982        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 983        let database_future = ThreadsDatabase::connect(cx);
 984        let db_thread = thread.update(cx, |thread, cx| {
 985            thread.set_draft_prompt(draft_prompt);
 986            thread.to_db(cx)
 987        });
 988        let thread_store = self.thread_store.clone();
 989        session.pending_save = cx.spawn(async move |_, cx| {
 990            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 991                return;
 992            };
 993            let db_thread = db_thread.await;
 994            database
 995                .save_thread(id, db_thread, folder_paths)
 996                .await
 997                .log_err();
 998            thread_store.update(cx, |store, cx| store.reload(cx));
 999        });
1000    }
1001
1002    fn send_mcp_prompt(
1003        &self,
1004        message_id: UserMessageId,
1005        session_id: acp::SessionId,
1006        prompt_name: String,
1007        server_id: ContextServerId,
1008        arguments: HashMap<String, String>,
1009        original_content: Vec<acp::ContentBlock>,
1010        cx: &mut Context<Self>,
1011    ) -> Task<Result<acp::PromptResponse>> {
1012        let Some(state) = self.session_project_state(&session_id) else {
1013            return Task::ready(Err(anyhow!("Project state not found for session")));
1014        };
1015        let server_store = state
1016            .context_server_registry
1017            .read(cx)
1018            .server_store()
1019            .clone();
1020        let path_style = state.project.read(cx).path_style(cx);
1021
1022        cx.spawn(async move |this, cx| {
1023            let prompt =
1024                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1025
1026            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1027                let session = this
1028                    .sessions
1029                    .get(&session_id)
1030                    .context("Failed to get session")?;
1031                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1032            })??;
1033
1034            let mut last_is_user = true;
1035
1036            thread.update(cx, |thread, cx| {
1037                thread.push_acp_user_block(
1038                    message_id,
1039                    original_content.into_iter().skip(1),
1040                    path_style,
1041                    cx,
1042                );
1043            });
1044
1045            for message in prompt.messages {
1046                let context_server::types::PromptMessage { role, content } = message;
1047                let block = mcp_message_content_to_acp_content_block(content);
1048
1049                match role {
1050                    context_server::types::Role::User => {
1051                        let id = acp_thread::UserMessageId::new();
1052
1053                        acp_thread.update(cx, |acp_thread, cx| {
1054                            acp_thread.push_user_content_block_with_indent(
1055                                Some(id.clone()),
1056                                block.clone(),
1057                                true,
1058                                cx,
1059                            );
1060                        });
1061
1062                        thread.update(cx, |thread, cx| {
1063                            thread.push_acp_user_block(id, [block], path_style, cx);
1064                        });
1065                    }
1066                    context_server::types::Role::Assistant => {
1067                        acp_thread.update(cx, |acp_thread, cx| {
1068                            acp_thread.push_assistant_content_block_with_indent(
1069                                block.clone(),
1070                                false,
1071                                true,
1072                                cx,
1073                            );
1074                        });
1075
1076                        thread.update(cx, |thread, cx| {
1077                            thread.push_acp_agent_block(block, cx);
1078                        });
1079                    }
1080                }
1081
1082                last_is_user = role == context_server::types::Role::User;
1083            }
1084
1085            let response_stream = thread.update(cx, |thread, cx| {
1086                if last_is_user {
1087                    thread.send_existing(cx)
1088                } else {
1089                    // Resume if MCP prompt did not end with a user message
1090                    thread.resume(cx)
1091                }
1092            })?;
1093
1094            cx.update(|cx| {
1095                NativeAgentConnection::handle_thread_events(
1096                    response_stream,
1097                    acp_thread.downgrade(),
1098                    cx,
1099                )
1100            })
1101            .await
1102        })
1103    }
1104}
1105
1106/// Wrapper struct that implements the AgentConnection trait
1107#[derive(Clone)]
1108pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1109
1110impl NativeAgentConnection {
1111    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1112        self.0
1113            .read(cx)
1114            .sessions
1115            .get(session_id)
1116            .map(|session| session.thread.clone())
1117    }
1118
1119    pub fn load_thread(
1120        &self,
1121        id: acp::SessionId,
1122        project: Entity<Project>,
1123        cx: &mut App,
1124    ) -> Task<Result<Entity<Thread>>> {
1125        self.0
1126            .update(cx, |this, cx| this.load_thread(id, project, cx))
1127    }
1128
1129    fn run_turn(
1130        &self,
1131        session_id: acp::SessionId,
1132        cx: &mut App,
1133        f: impl 'static
1134        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1135    ) -> Task<Result<acp::PromptResponse>> {
1136        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1137            agent
1138                .sessions
1139                .get_mut(&session_id)
1140                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1141        }) else {
1142            return Task::ready(Err(anyhow!("Session not found")));
1143        };
1144        log::debug!("Found session for: {}", session_id);
1145
1146        let response_stream = match f(thread, cx) {
1147            Ok(stream) => stream,
1148            Err(err) => return Task::ready(Err(err)),
1149        };
1150        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1151    }
1152
1153    fn handle_thread_events(
1154        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1155        acp_thread: WeakEntity<AcpThread>,
1156        cx: &App,
1157    ) -> Task<Result<acp::PromptResponse>> {
1158        cx.spawn(async move |cx| {
1159            // Handle response stream and forward to session.acp_thread
1160            while let Some(result) = events.next().await {
1161                match result {
1162                    Ok(event) => {
1163                        log::trace!("Received completion event: {:?}", event);
1164
1165                        match event {
1166                            ThreadEvent::UserMessage(message) => {
1167                                acp_thread.update(cx, |thread, cx| {
1168                                    for content in message.content {
1169                                        thread.push_user_content_block(
1170                                            Some(message.id.clone()),
1171                                            content.into(),
1172                                            cx,
1173                                        );
1174                                    }
1175                                })?;
1176                            }
1177                            ThreadEvent::AgentText(text) => {
1178                                acp_thread.update(cx, |thread, cx| {
1179                                    thread.push_assistant_content_block(text.into(), false, cx)
1180                                })?;
1181                            }
1182                            ThreadEvent::AgentThinking(text) => {
1183                                acp_thread.update(cx, |thread, cx| {
1184                                    thread.push_assistant_content_block(text.into(), true, cx)
1185                                })?;
1186                            }
1187                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1188                                tool_call,
1189                                options,
1190                                response,
1191                                context: _,
1192                            }) => {
1193                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1194                                    thread.request_tool_call_authorization(tool_call, options, cx)
1195                                })??;
1196                                cx.background_spawn(async move {
1197                                    if let acp::RequestPermissionOutcome::Selected(
1198                                        acp::SelectedPermissionOutcome { option_id, .. },
1199                                    ) = outcome_task.await
1200                                    {
1201                                        response
1202                                            .send(option_id)
1203                                            .map(|_| anyhow!("authorization receiver was dropped"))
1204                                            .log_err();
1205                                    }
1206                                })
1207                                .detach();
1208                            }
1209                            ThreadEvent::ToolCall(tool_call) => {
1210                                acp_thread.update(cx, |thread, cx| {
1211                                    thread.upsert_tool_call(tool_call, cx)
1212                                })??;
1213                            }
1214                            ThreadEvent::ToolCallUpdate(update) => {
1215                                acp_thread.update(cx, |thread, cx| {
1216                                    thread.update_tool_call(update, cx)
1217                                })??;
1218                            }
1219                            ThreadEvent::SubagentSpawned(session_id) => {
1220                                acp_thread.update(cx, |thread, cx| {
1221                                    thread.subagent_spawned(session_id, cx);
1222                                })?;
1223                            }
1224                            ThreadEvent::Retry(status) => {
1225                                acp_thread.update(cx, |thread, cx| {
1226                                    thread.update_retry_status(status, cx)
1227                                })?;
1228                            }
1229                            ThreadEvent::Stop(stop_reason) => {
1230                                log::debug!("Assistant message complete: {:?}", stop_reason);
1231                                return Ok(acp::PromptResponse::new(stop_reason));
1232                            }
1233                        }
1234                    }
1235                    Err(e) => {
1236                        log::error!("Error in model response stream: {:?}", e);
1237                        return Err(e);
1238                    }
1239                }
1240            }
1241
1242            log::debug!("Response stream completed");
1243            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1244        })
1245    }
1246}
1247
1248struct Command<'a> {
1249    prompt_name: &'a str,
1250    arg_value: &'a str,
1251    explicit_server_id: Option<&'a str>,
1252}
1253
1254impl<'a> Command<'a> {
1255    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1256        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1257            return None;
1258        };
1259        let text = text_content.text.trim();
1260        let command = text.strip_prefix('/')?;
1261        let (command, arg_value) = command
1262            .split_once(char::is_whitespace)
1263            .unwrap_or((command, ""));
1264
1265        if let Some((server_id, prompt_name)) = command.split_once('.') {
1266            Some(Self {
1267                prompt_name,
1268                arg_value,
1269                explicit_server_id: Some(server_id),
1270            })
1271        } else {
1272            Some(Self {
1273                prompt_name: command,
1274                arg_value,
1275                explicit_server_id: None,
1276            })
1277        }
1278    }
1279}
1280
1281struct NativeAgentModelSelector {
1282    session_id: acp::SessionId,
1283    connection: NativeAgentConnection,
1284}
1285
1286impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1287    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1288        log::debug!("NativeAgentConnection::list_models called");
1289        let list = self.connection.0.read(cx).models.model_list.clone();
1290        Task::ready(if list.is_empty() {
1291            Err(anyhow::anyhow!("No models available"))
1292        } else {
1293            Ok(list)
1294        })
1295    }
1296
1297    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1298        log::debug!(
1299            "Setting model for session {}: {}",
1300            self.session_id,
1301            model_id
1302        );
1303        let Some(thread) = self
1304            .connection
1305            .0
1306            .read(cx)
1307            .sessions
1308            .get(&self.session_id)
1309            .map(|session| session.thread.clone())
1310        else {
1311            return Task::ready(Err(anyhow!("Session not found")));
1312        };
1313
1314        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1315            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1316        };
1317
1318        // We want to reset the effort level when switching models, as the currently-selected effort level may
1319        // not be compatible.
1320        let effort = model
1321            .default_effort_level()
1322            .map(|effort_level| effort_level.value.to_string());
1323
1324        thread.update(cx, |thread, cx| {
1325            thread.set_model(model.clone(), cx);
1326            thread.set_thinking_effort(effort.clone(), cx);
1327            thread.set_thinking_enabled(model.supports_thinking(), cx);
1328        });
1329
1330        update_settings_file(
1331            self.connection.0.read(cx).fs.clone(),
1332            cx,
1333            move |settings, cx| {
1334                let provider = model.provider_id().0.to_string();
1335                let model = model.id().0.to_string();
1336                let enable_thinking = thread.read(cx).thinking_enabled();
1337                settings
1338                    .agent
1339                    .get_or_insert_default()
1340                    .set_model(LanguageModelSelection {
1341                        provider: provider.into(),
1342                        model,
1343                        enable_thinking,
1344                        effort,
1345                    });
1346            },
1347        );
1348
1349        Task::ready(Ok(()))
1350    }
1351
1352    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1353        let Some(thread) = self
1354            .connection
1355            .0
1356            .read(cx)
1357            .sessions
1358            .get(&self.session_id)
1359            .map(|session| session.thread.clone())
1360        else {
1361            return Task::ready(Err(anyhow!("Session not found")));
1362        };
1363        let Some(model) = thread.read(cx).model() else {
1364            return Task::ready(Err(anyhow!("Model not found")));
1365        };
1366        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1367        else {
1368            return Task::ready(Err(anyhow!("Provider not found")));
1369        };
1370        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1371            model, &provider,
1372        )))
1373    }
1374
1375    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1376        Some(self.connection.0.read(cx).models.watch())
1377    }
1378
1379    fn should_render_footer(&self) -> bool {
1380        true
1381    }
1382}
1383
1384pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1385
1386impl acp_thread::AgentConnection for NativeAgentConnection {
1387    fn agent_id(&self) -> AgentId {
1388        ZED_AGENT_ID.clone()
1389    }
1390
1391    fn telemetry_id(&self) -> SharedString {
1392        "zed".into()
1393    }
1394
1395    fn new_session(
1396        self: Rc<Self>,
1397        project: Entity<Project>,
1398        work_dirs: PathList,
1399        cx: &mut App,
1400    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1401        log::debug!("Creating new thread for project at: {work_dirs:?}");
1402        Task::ready(Ok(self
1403            .0
1404            .update(cx, |agent, cx| agent.new_session(project, cx))))
1405    }
1406
1407    fn supports_load_session(&self) -> bool {
1408        true
1409    }
1410
1411    fn load_session(
1412        self: Rc<Self>,
1413        session_id: acp::SessionId,
1414        project: Entity<Project>,
1415        _work_dirs: PathList,
1416        _title: Option<SharedString>,
1417        cx: &mut App,
1418    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1419        self.0
1420            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1421    }
1422
1423    fn supports_close_session(&self) -> bool {
1424        true
1425    }
1426
1427    fn close_session(
1428        self: Rc<Self>,
1429        session_id: &acp::SessionId,
1430        cx: &mut App,
1431    ) -> Task<Result<()>> {
1432        self.0.update(cx, |agent, _cx| {
1433            let project_id = agent.sessions.get(session_id).map(|s| s.project_id);
1434            agent.sessions.remove(session_id);
1435
1436            if let Some(project_id) = project_id {
1437                let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1438                if !has_remaining {
1439                    agent.projects.remove(&project_id);
1440                }
1441            }
1442        });
1443        Task::ready(Ok(()))
1444    }
1445
1446    fn auth_methods(&self) -> &[acp::AuthMethod] {
1447        &[] // No auth for in-process
1448    }
1449
1450    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1451        Task::ready(Ok(()))
1452    }
1453
1454    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1455        Some(Rc::new(NativeAgentModelSelector {
1456            session_id: session_id.clone(),
1457            connection: self.clone(),
1458        }) as Rc<dyn AgentModelSelector>)
1459    }
1460
1461    fn prompt(
1462        &self,
1463        id: Option<acp_thread::UserMessageId>,
1464        params: acp::PromptRequest,
1465        cx: &mut App,
1466    ) -> Task<Result<acp::PromptResponse>> {
1467        let id = id.expect("UserMessageId is required");
1468        let session_id = params.session_id.clone();
1469        log::info!("Received prompt request for session: {}", session_id);
1470        log::debug!("Prompt blocks count: {}", params.prompt.len());
1471
1472        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1473            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1474        };
1475
1476        if let Some(parsed_command) = Command::parse(&params.prompt) {
1477            let registry = project_state.context_server_registry.read(cx);
1478
1479            let explicit_server_id = parsed_command
1480                .explicit_server_id
1481                .map(|server_id| ContextServerId(server_id.into()));
1482
1483            if let Some(prompt) =
1484                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1485            {
1486                let arguments = if !parsed_command.arg_value.is_empty()
1487                    && let Some(arg_name) = prompt
1488                        .prompt
1489                        .arguments
1490                        .as_ref()
1491                        .and_then(|args| args.first())
1492                        .map(|arg| arg.name.clone())
1493                {
1494                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1495                } else {
1496                    Default::default()
1497                };
1498
1499                let prompt_name = prompt.prompt.name.clone();
1500                let server_id = prompt.server_id.clone();
1501
1502                return self.0.update(cx, |agent, cx| {
1503                    agent.send_mcp_prompt(
1504                        id,
1505                        session_id.clone(),
1506                        prompt_name,
1507                        server_id,
1508                        arguments,
1509                        params.prompt,
1510                        cx,
1511                    )
1512                });
1513            }
1514        };
1515
1516        let path_style = project_state.project.read(cx).path_style(cx);
1517
1518        self.run_turn(session_id, cx, move |thread, cx| {
1519            let content: Vec<UserMessageContent> = params
1520                .prompt
1521                .into_iter()
1522                .map(|block| UserMessageContent::from_content_block(block, path_style))
1523                .collect::<Vec<_>>();
1524            log::debug!("Converted prompt to message: {} chars", content.len());
1525            log::debug!("Message id: {:?}", id);
1526            log::debug!("Message content: {:?}", content);
1527
1528            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1529        })
1530    }
1531
1532    fn retry(
1533        &self,
1534        session_id: &acp::SessionId,
1535        _cx: &App,
1536    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1537        Some(Rc::new(NativeAgentSessionRetry {
1538            connection: self.clone(),
1539            session_id: session_id.clone(),
1540        }) as _)
1541    }
1542
1543    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1544        log::info!("Cancelling on session: {}", session_id);
1545        self.0.update(cx, |agent, cx| {
1546            if let Some(session) = agent.sessions.get(session_id) {
1547                session
1548                    .thread
1549                    .update(cx, |thread, cx| thread.cancel(cx))
1550                    .detach();
1551            }
1552        });
1553    }
1554
1555    fn truncate(
1556        &self,
1557        session_id: &acp::SessionId,
1558        cx: &App,
1559    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1560        self.0.read_with(cx, |agent, _cx| {
1561            agent.sessions.get(session_id).map(|session| {
1562                Rc::new(NativeAgentSessionTruncate {
1563                    thread: session.thread.clone(),
1564                    acp_thread: session.acp_thread.downgrade(),
1565                }) as _
1566            })
1567        })
1568    }
1569
1570    fn set_title(
1571        &self,
1572        session_id: &acp::SessionId,
1573        cx: &App,
1574    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1575        self.0.read_with(cx, |agent, _cx| {
1576            agent
1577                .sessions
1578                .get(session_id)
1579                .filter(|s| !s.thread.read(cx).is_subagent())
1580                .map(|session| {
1581                    Rc::new(NativeAgentSessionSetTitle {
1582                        thread: session.thread.clone(),
1583                    }) as _
1584                })
1585        })
1586    }
1587
1588    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1589        let thread_store = self.0.read(cx).thread_store.clone();
1590        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1591    }
1592
1593    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1594        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1595    }
1596
1597    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1598        self
1599    }
1600}
1601
1602impl acp_thread::AgentTelemetry for NativeAgentConnection {
1603    fn thread_data(
1604        &self,
1605        session_id: &acp::SessionId,
1606        cx: &mut App,
1607    ) -> Task<Result<serde_json::Value>> {
1608        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1609            return Task::ready(Err(anyhow!("Session not found")));
1610        };
1611
1612        let task = session.thread.read(cx).to_db(cx);
1613        cx.background_spawn(async move {
1614            serde_json::to_value(task.await).context("Failed to serialize thread")
1615        })
1616    }
1617}
1618
1619pub struct NativeAgentSessionList {
1620    thread_store: Entity<ThreadStore>,
1621    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1622    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1623    _subscription: Subscription,
1624}
1625
1626impl NativeAgentSessionList {
1627    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1628        let (tx, rx) = smol::channel::unbounded();
1629        let this_tx = tx.clone();
1630        let subscription = cx.observe(&thread_store, move |_, _| {
1631            this_tx
1632                .try_send(acp_thread::SessionListUpdate::Refresh)
1633                .ok();
1634        });
1635        Self {
1636            thread_store,
1637            updates_tx: tx,
1638            updates_rx: rx,
1639            _subscription: subscription,
1640        }
1641    }
1642
1643    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1644        &self.thread_store
1645    }
1646}
1647
1648impl AgentSessionList for NativeAgentSessionList {
1649    fn list_sessions(
1650        &self,
1651        _request: AgentSessionListRequest,
1652        cx: &mut App,
1653    ) -> Task<Result<AgentSessionListResponse>> {
1654        let sessions = self
1655            .thread_store
1656            .read(cx)
1657            .entries()
1658            .map(|entry| AgentSessionInfo::from(&entry))
1659            .collect();
1660        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1661    }
1662
1663    fn supports_delete(&self) -> bool {
1664        true
1665    }
1666
1667    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1668        self.thread_store
1669            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1670    }
1671
1672    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1673        self.thread_store
1674            .update(cx, |store, cx| store.delete_threads(cx))
1675    }
1676
1677    fn watch(
1678        &self,
1679        _cx: &mut App,
1680    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1681        Some(self.updates_rx.clone())
1682    }
1683
1684    fn notify_refresh(&self) {
1685        self.updates_tx
1686            .try_send(acp_thread::SessionListUpdate::Refresh)
1687            .ok();
1688    }
1689
1690    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1691        self
1692    }
1693}
1694
1695struct NativeAgentSessionTruncate {
1696    thread: Entity<Thread>,
1697    acp_thread: WeakEntity<AcpThread>,
1698}
1699
1700impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1701    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1702        match self.thread.update(cx, |thread, cx| {
1703            thread.truncate(message_id.clone(), cx)?;
1704            Ok(thread.latest_token_usage())
1705        }) {
1706            Ok(usage) => {
1707                self.acp_thread
1708                    .update(cx, |thread, cx| {
1709                        thread.update_token_usage(usage, cx);
1710                    })
1711                    .ok();
1712                Task::ready(Ok(()))
1713            }
1714            Err(error) => Task::ready(Err(error)),
1715        }
1716    }
1717}
1718
1719struct NativeAgentSessionRetry {
1720    connection: NativeAgentConnection,
1721    session_id: acp::SessionId,
1722}
1723
1724impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1725    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1726        self.connection
1727            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1728                thread.update(cx, |thread, cx| thread.resume(cx))
1729            })
1730    }
1731}
1732
1733struct NativeAgentSessionSetTitle {
1734    thread: Entity<Thread>,
1735}
1736
1737impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1738    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1739        self.thread
1740            .update(cx, |thread, cx| thread.set_title(title, cx));
1741        Task::ready(Ok(()))
1742    }
1743}
1744
1745pub struct NativeThreadEnvironment {
1746    agent: WeakEntity<NativeAgent>,
1747    thread: WeakEntity<Thread>,
1748    acp_thread: WeakEntity<AcpThread>,
1749}
1750
1751impl NativeThreadEnvironment {
1752    pub(crate) fn create_subagent_thread(
1753        &self,
1754        label: String,
1755        cx: &mut App,
1756    ) -> Result<Rc<dyn SubagentHandle>> {
1757        let Some(parent_thread_entity) = self.thread.upgrade() else {
1758            anyhow::bail!("Parent thread no longer exists".to_string());
1759        };
1760        let parent_thread = parent_thread_entity.read(cx);
1761        let current_depth = parent_thread.depth();
1762        let parent_session_id = parent_thread.id().clone();
1763
1764        if current_depth >= MAX_SUBAGENT_DEPTH {
1765            return Err(anyhow!(
1766                "Maximum subagent depth ({}) reached",
1767                MAX_SUBAGENT_DEPTH
1768            ));
1769        }
1770
1771        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1772            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1773            thread.set_title(label.into(), cx);
1774            thread
1775        });
1776
1777        let session_id = subagent_thread.read(cx).id().clone();
1778
1779        let acp_thread = self
1780            .agent
1781            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1782                let project_id = agent
1783                    .sessions
1784                    .get(&parent_session_id)
1785                    .map(|s| s.project_id)
1786                    .context("parent session not found")?;
1787                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1788            })??;
1789
1790        let depth = current_depth + 1;
1791
1792        telemetry::event!(
1793            "Subagent Started",
1794            session = parent_thread_entity.read(cx).id().to_string(),
1795            subagent_session = session_id.to_string(),
1796            depth,
1797            is_resumed = false,
1798        );
1799
1800        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1801    }
1802
1803    pub(crate) fn resume_subagent_thread(
1804        &self,
1805        session_id: acp::SessionId,
1806        cx: &mut App,
1807    ) -> Result<Rc<dyn SubagentHandle>> {
1808        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1809            let session = agent
1810                .sessions
1811                .get(&session_id)
1812                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1813            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1814        })??;
1815
1816        let depth = subagent_thread.read(cx).depth();
1817
1818        if let Some(parent_thread_entity) = self.thread.upgrade() {
1819            telemetry::event!(
1820                "Subagent Started",
1821                session = parent_thread_entity.read(cx).id().to_string(),
1822                subagent_session = session_id.to_string(),
1823                depth,
1824                is_resumed = true,
1825            );
1826        }
1827
1828        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1829    }
1830
1831    fn prompt_subagent(
1832        &self,
1833        session_id: acp::SessionId,
1834        subagent_thread: Entity<Thread>,
1835        acp_thread: Entity<acp_thread::AcpThread>,
1836    ) -> Result<Rc<dyn SubagentHandle>> {
1837        let Some(parent_thread_entity) = self.thread.upgrade() else {
1838            anyhow::bail!("Parent thread no longer exists".to_string());
1839        };
1840        Ok(Rc::new(NativeSubagentHandle::new(
1841            session_id,
1842            subagent_thread,
1843            acp_thread,
1844            parent_thread_entity,
1845        )) as _)
1846    }
1847}
1848
1849impl ThreadEnvironment for NativeThreadEnvironment {
1850    fn create_terminal(
1851        &self,
1852        command: String,
1853        cwd: Option<PathBuf>,
1854        output_byte_limit: Option<u64>,
1855        cx: &mut AsyncApp,
1856    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1857        let task = self.acp_thread.update(cx, |thread, cx| {
1858            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1859        });
1860
1861        let acp_thread = self.acp_thread.clone();
1862        cx.spawn(async move |cx| {
1863            let terminal = task?.await?;
1864
1865            let (drop_tx, drop_rx) = oneshot::channel();
1866            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1867
1868            cx.spawn(async move |cx| {
1869                drop_rx.await.ok();
1870                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1871            })
1872            .detach();
1873
1874            let handle = AcpTerminalHandle {
1875                terminal,
1876                _drop_tx: Some(drop_tx),
1877            };
1878
1879            Ok(Rc::new(handle) as _)
1880        })
1881    }
1882
1883    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1884        self.create_subagent_thread(label, cx)
1885    }
1886
1887    fn resume_subagent(
1888        &self,
1889        session_id: acp::SessionId,
1890        cx: &mut App,
1891    ) -> Result<Rc<dyn SubagentHandle>> {
1892        self.resume_subagent_thread(session_id, cx)
1893    }
1894}
1895
1896#[derive(Debug, Clone)]
1897enum SubagentPromptResult {
1898    Completed,
1899    Cancelled,
1900    ContextWindowWarning,
1901    Error(String),
1902}
1903
1904pub struct NativeSubagentHandle {
1905    session_id: acp::SessionId,
1906    parent_thread: WeakEntity<Thread>,
1907    subagent_thread: Entity<Thread>,
1908    acp_thread: Entity<acp_thread::AcpThread>,
1909}
1910
1911impl NativeSubagentHandle {
1912    fn new(
1913        session_id: acp::SessionId,
1914        subagent_thread: Entity<Thread>,
1915        acp_thread: Entity<acp_thread::AcpThread>,
1916        parent_thread_entity: Entity<Thread>,
1917    ) -> Self {
1918        NativeSubagentHandle {
1919            session_id,
1920            subagent_thread,
1921            parent_thread: parent_thread_entity.downgrade(),
1922            acp_thread,
1923        }
1924    }
1925}
1926
1927impl SubagentHandle for NativeSubagentHandle {
1928    fn id(&self) -> acp::SessionId {
1929        self.session_id.clone()
1930    }
1931
1932    fn num_entries(&self, cx: &App) -> usize {
1933        self.acp_thread.read(cx).entries().len()
1934    }
1935
1936    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1937        let thread = self.subagent_thread.clone();
1938        let acp_thread = self.acp_thread.clone();
1939        let subagent_session_id = self.session_id.clone();
1940        let parent_thread = self.parent_thread.clone();
1941
1942        cx.spawn(async move |cx| {
1943            let (task, _subscription) = cx.update(|cx| {
1944                let ratio_before_prompt = thread
1945                    .read(cx)
1946                    .latest_token_usage()
1947                    .map(|usage| usage.ratio());
1948
1949                parent_thread
1950                    .update(cx, |parent_thread, _cx| {
1951                        parent_thread.register_running_subagent(thread.downgrade())
1952                    })
1953                    .ok();
1954
1955                let task = acp_thread.update(cx, |acp_thread, cx| {
1956                    acp_thread.send(vec![message.into()], cx)
1957                });
1958
1959                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1960                let mut token_limit_tx = Some(token_limit_tx);
1961
1962                let subscription = cx.subscribe(
1963                    &thread,
1964                    move |_thread, event: &TokenUsageUpdated, _cx| {
1965                        if let Some(usage) = &event.0 {
1966                            let old_ratio = ratio_before_prompt
1967                                .clone()
1968                                .unwrap_or(TokenUsageRatio::Normal);
1969                            let new_ratio = usage.ratio();
1970                            if old_ratio == TokenUsageRatio::Normal
1971                                && new_ratio == TokenUsageRatio::Warning
1972                            {
1973                                if let Some(tx) = token_limit_tx.take() {
1974                                    tx.send(()).ok();
1975                                }
1976                            }
1977                        }
1978                    },
1979                );
1980
1981                let wait_for_prompt = cx
1982                    .background_spawn(async move {
1983                        futures::select! {
1984                            response = task.fuse() => match response {
1985                                Ok(Some(response)) => {
1986                                    match response.stop_reason {
1987                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1988                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1989                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1990                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1991                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1992                                    }
1993                                }
1994                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1995                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1996                            },
1997                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1998                        }
1999                    });
2000
2001                (wait_for_prompt, subscription)
2002            });
2003
2004            let result = match task.await {
2005                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2006                    thread
2007                        .last_message()
2008                        .and_then(|message| {
2009                            let content = message.as_agent_message()?
2010                                .content
2011                                .iter()
2012                                .filter_map(|c| match c {
2013                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2014                                    _ => None,
2015                                })
2016                                .join("\n\n");
2017                            if content.is_empty() {
2018                                None
2019                            } else {
2020                                Some( content)
2021                            }
2022                        })
2023                        .context("No response from subagent")
2024                }),
2025                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2026                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2027                SubagentPromptResult::ContextWindowWarning => {
2028                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2029                    Err(anyhow!(
2030                        "The agent is nearing the end of its context window and has been \
2031                         stopped. You can prompt the thread again to have the agent wrap up \
2032                         or hand off its work."
2033                    ))
2034                }
2035            };
2036
2037            parent_thread
2038                .update(cx, |parent_thread, cx| {
2039                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2040                })
2041                .ok();
2042
2043            result
2044        })
2045    }
2046}
2047
2048pub struct AcpTerminalHandle {
2049    terminal: Entity<acp_thread::Terminal>,
2050    _drop_tx: Option<oneshot::Sender<()>>,
2051}
2052
2053impl TerminalHandle for AcpTerminalHandle {
2054    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2055        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2056    }
2057
2058    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2059        Ok(self
2060            .terminal
2061            .read_with(cx, |term, _cx| term.wait_for_exit()))
2062    }
2063
2064    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2065        Ok(self
2066            .terminal
2067            .read_with(cx, |term, cx| term.current_output(cx)))
2068    }
2069
2070    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2071        cx.update(|cx| {
2072            self.terminal.update(cx, |terminal, cx| {
2073                terminal.kill(cx);
2074            });
2075        });
2076        Ok(())
2077    }
2078
2079    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2080        Ok(self
2081            .terminal
2082            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2083    }
2084}
2085
2086#[cfg(test)]
2087mod internal_tests {
2088    use std::path::Path;
2089
2090    use super::*;
2091    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2092    use fs::FakeFs;
2093    use gpui::TestAppContext;
2094    use indoc::formatdoc;
2095    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2096    use language_model::{
2097        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2098    };
2099    use serde_json::json;
2100    use settings::SettingsStore;
2101    use util::{path, rel_path::rel_path};
2102
2103    #[gpui::test]
2104    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2105        init_test(cx);
2106        let fs = FakeFs::new(cx.executor());
2107        fs.insert_tree(
2108            "/",
2109            json!({
2110                "a": {}
2111            }),
2112        )
2113        .await;
2114        let project = Project::test(fs.clone(), [], cx).await;
2115        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2116        let agent =
2117            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2118
2119        // Creating a session registers the project and triggers context building.
2120        let connection = NativeAgentConnection(agent.clone());
2121        let _acp_thread = cx
2122            .update(|cx| {
2123                Rc::new(connection).new_session(
2124                    project.clone(),
2125                    PathList::new(&[Path::new("/")]),
2126                    cx,
2127                )
2128            })
2129            .await
2130            .unwrap();
2131        cx.run_until_parked();
2132
2133        agent.read_with(cx, |agent, cx| {
2134            let project_id = project.entity_id();
2135            let state = agent.projects.get(&project_id).unwrap();
2136            assert_eq!(state.project_context.read(cx).worktrees, vec![])
2137        });
2138
2139        let worktree = project
2140            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2141            .await
2142            .unwrap();
2143        cx.run_until_parked();
2144        agent.read_with(cx, |agent, cx| {
2145            let project_id = project.entity_id();
2146            let state = agent.projects.get(&project_id).unwrap();
2147            assert_eq!(
2148                state.project_context.read(cx).worktrees,
2149                vec![WorktreeContext {
2150                    root_name: "a".into(),
2151                    abs_path: Path::new("/a").into(),
2152                    rules_file: None
2153                }]
2154            )
2155        });
2156
2157        // Creating `/a/.rules` updates the project context.
2158        fs.insert_file("/a/.rules", Vec::new()).await;
2159        cx.run_until_parked();
2160        agent.read_with(cx, |agent, cx| {
2161            let project_id = project.entity_id();
2162            let state = agent.projects.get(&project_id).unwrap();
2163            let rules_entry = worktree
2164                .read(cx)
2165                .entry_for_path(rel_path(".rules"))
2166                .unwrap();
2167            assert_eq!(
2168                state.project_context.read(cx).worktrees,
2169                vec![WorktreeContext {
2170                    root_name: "a".into(),
2171                    abs_path: Path::new("/a").into(),
2172                    rules_file: Some(RulesFileContext {
2173                        path_in_worktree: rel_path(".rules").into(),
2174                        text: "".into(),
2175                        project_entry_id: rules_entry.id.to_usize()
2176                    })
2177                }]
2178            )
2179        });
2180    }
2181
2182    #[gpui::test]
2183    async fn test_listing_models(cx: &mut TestAppContext) {
2184        init_test(cx);
2185        let fs = FakeFs::new(cx.executor());
2186        fs.insert_tree("/", json!({ "a": {}  })).await;
2187        let project = Project::test(fs.clone(), [], cx).await;
2188        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2189        let connection =
2190            NativeAgentConnection(cx.update(|cx| {
2191                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2192            }));
2193
2194        // Create a thread/session
2195        let acp_thread = cx
2196            .update(|cx| {
2197                Rc::new(connection.clone()).new_session(
2198                    project.clone(),
2199                    PathList::new(&[Path::new("/a")]),
2200                    cx,
2201                )
2202            })
2203            .await
2204            .unwrap();
2205
2206        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2207
2208        let models = cx
2209            .update(|cx| {
2210                connection
2211                    .model_selector(&session_id)
2212                    .unwrap()
2213                    .list_models(cx)
2214            })
2215            .await
2216            .unwrap();
2217
2218        let acp_thread::AgentModelList::Grouped(models) = models else {
2219            panic!("Unexpected model group");
2220        };
2221        assert_eq!(
2222            models,
2223            IndexMap::from_iter([(
2224                AgentModelGroupName("Fake".into()),
2225                vec![AgentModelInfo {
2226                    id: acp::ModelId::new("fake/fake"),
2227                    name: "Fake".into(),
2228                    description: None,
2229                    icon: Some(acp_thread::AgentModelIcon::Named(
2230                        ui::IconName::ZedAssistant
2231                    )),
2232                    is_latest: false,
2233                    cost: None,
2234                }]
2235            )])
2236        );
2237    }
2238
2239    #[gpui::test]
2240    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2241        init_test(cx);
2242        let fs = FakeFs::new(cx.executor());
2243        fs.create_dir(paths::settings_file().parent().unwrap())
2244            .await
2245            .unwrap();
2246        fs.insert_file(
2247            paths::settings_file(),
2248            json!({
2249                "agent": {
2250                    "default_model": {
2251                        "provider": "foo",
2252                        "model": "bar"
2253                    }
2254                }
2255            })
2256            .to_string()
2257            .into_bytes(),
2258        )
2259        .await;
2260        let project = Project::test(fs.clone(), [], cx).await;
2261
2262        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2263
2264        // Create the agent and connection
2265        let agent =
2266            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2267        let connection = NativeAgentConnection(agent.clone());
2268
2269        // Create a thread/session
2270        let acp_thread = cx
2271            .update(|cx| {
2272                Rc::new(connection.clone()).new_session(
2273                    project.clone(),
2274                    PathList::new(&[Path::new("/a")]),
2275                    cx,
2276                )
2277            })
2278            .await
2279            .unwrap();
2280
2281        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2282
2283        // Select a model
2284        let selector = connection.model_selector(&session_id).unwrap();
2285        let model_id = acp::ModelId::new("fake/fake");
2286        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2287            .await
2288            .unwrap();
2289
2290        // Verify the thread has the selected model
2291        agent.read_with(cx, |agent, _| {
2292            let session = agent.sessions.get(&session_id).unwrap();
2293            session.thread.read_with(cx, |thread, _| {
2294                assert_eq!(thread.model().unwrap().id().0, "fake");
2295            });
2296        });
2297
2298        cx.run_until_parked();
2299
2300        // Verify settings file was updated
2301        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2302        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2303
2304        // Check that the agent settings contain the selected model
2305        assert_eq!(
2306            settings_json["agent"]["default_model"]["model"],
2307            json!("fake")
2308        );
2309        assert_eq!(
2310            settings_json["agent"]["default_model"]["provider"],
2311            json!("fake")
2312        );
2313
2314        // Register a thinking model and select it.
2315        cx.update(|cx| {
2316            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2317                "fake-corp",
2318                "fake-thinking",
2319                "Fake Thinking",
2320                true,
2321            ));
2322            let thinking_provider = Arc::new(
2323                FakeLanguageModelProvider::new(
2324                    LanguageModelProviderId::from("fake-corp".to_string()),
2325                    LanguageModelProviderName::from("Fake Corp".to_string()),
2326                )
2327                .with_models(vec![thinking_model]),
2328            );
2329            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2330                registry.register_provider(thinking_provider, cx);
2331            });
2332        });
2333        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2334
2335        let selector = connection.model_selector(&session_id).unwrap();
2336        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2337            .await
2338            .unwrap();
2339        cx.run_until_parked();
2340
2341        // Verify enable_thinking was written to settings as true.
2342        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2343        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2344        assert_eq!(
2345            settings_json["agent"]["default_model"]["enable_thinking"],
2346            json!(true),
2347            "selecting a thinking model should persist enable_thinking: true to settings"
2348        );
2349    }
2350
2351    #[gpui::test]
2352    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2353        init_test(cx);
2354        let fs = FakeFs::new(cx.executor());
2355        fs.create_dir(paths::settings_file().parent().unwrap())
2356            .await
2357            .unwrap();
2358        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2359        let project = Project::test(fs.clone(), [], cx).await;
2360
2361        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2362        let agent =
2363            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2364        let connection = NativeAgentConnection(agent.clone());
2365
2366        let acp_thread = cx
2367            .update(|cx| {
2368                Rc::new(connection.clone()).new_session(
2369                    project.clone(),
2370                    PathList::new(&[Path::new("/a")]),
2371                    cx,
2372                )
2373            })
2374            .await
2375            .unwrap();
2376        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2377
2378        // Register a second provider with a thinking model.
2379        cx.update(|cx| {
2380            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2381                "fake-corp",
2382                "fake-thinking",
2383                "Fake Thinking",
2384                true,
2385            ));
2386            let thinking_provider = Arc::new(
2387                FakeLanguageModelProvider::new(
2388                    LanguageModelProviderId::from("fake-corp".to_string()),
2389                    LanguageModelProviderName::from("Fake Corp".to_string()),
2390                )
2391                .with_models(vec![thinking_model]),
2392            );
2393            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2394                registry.register_provider(thinking_provider, cx);
2395            });
2396        });
2397        // Refresh the agent's model list so it picks up the new provider.
2398        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2399
2400        // Thread starts with thinking_enabled = false (the default).
2401        agent.read_with(cx, |agent, _| {
2402            let session = agent.sessions.get(&session_id).unwrap();
2403            session.thread.read_with(cx, |thread, _| {
2404                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2405            });
2406        });
2407
2408        // Select the thinking model via select_model.
2409        let selector = connection.model_selector(&session_id).unwrap();
2410        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2411            .await
2412            .unwrap();
2413
2414        // select_model should have enabled thinking based on the model's supports_thinking().
2415        agent.read_with(cx, |agent, _| {
2416            let session = agent.sessions.get(&session_id).unwrap();
2417            session.thread.read_with(cx, |thread, _| {
2418                assert!(
2419                    thread.thinking_enabled(),
2420                    "select_model should enable thinking when model supports it"
2421                );
2422            });
2423        });
2424
2425        // Switch back to the non-thinking model.
2426        let selector = connection.model_selector(&session_id).unwrap();
2427        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2428            .await
2429            .unwrap();
2430
2431        // select_model should have disabled thinking.
2432        agent.read_with(cx, |agent, _| {
2433            let session = agent.sessions.get(&session_id).unwrap();
2434            session.thread.read_with(cx, |thread, _| {
2435                assert!(
2436                    !thread.thinking_enabled(),
2437                    "select_model should disable thinking when model does not support it"
2438                );
2439            });
2440        });
2441    }
2442
2443    #[gpui::test]
2444    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2445        init_test(cx);
2446        let fs = FakeFs::new(cx.executor());
2447        fs.insert_tree("/", json!({ "a": {} })).await;
2448        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2449        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2450        let agent = cx.update(|cx| {
2451            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2452        });
2453        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2454
2455        // Register a thinking model.
2456        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2457            "fake-corp",
2458            "fake-thinking",
2459            "Fake Thinking",
2460            true,
2461        ));
2462        let thinking_provider = Arc::new(
2463            FakeLanguageModelProvider::new(
2464                LanguageModelProviderId::from("fake-corp".to_string()),
2465                LanguageModelProviderName::from("Fake Corp".to_string()),
2466            )
2467            .with_models(vec![thinking_model.clone()]),
2468        );
2469        cx.update(|cx| {
2470            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2471                registry.register_provider(thinking_provider, cx);
2472            });
2473        });
2474        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2475
2476        // Create a thread and select the thinking model.
2477        let acp_thread = cx
2478            .update(|cx| {
2479                connection.clone().new_session(
2480                    project.clone(),
2481                    PathList::new(&[Path::new("/a")]),
2482                    cx,
2483                )
2484            })
2485            .await
2486            .unwrap();
2487        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2488
2489        let selector = connection.model_selector(&session_id).unwrap();
2490        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2491            .await
2492            .unwrap();
2493
2494        // Verify thinking is enabled after selecting the thinking model.
2495        let thread = agent.read_with(cx, |agent, _| {
2496            agent.sessions.get(&session_id).unwrap().thread.clone()
2497        });
2498        thread.read_with(cx, |thread, _| {
2499            assert!(
2500                thread.thinking_enabled(),
2501                "thinking should be enabled after selecting thinking model"
2502            );
2503        });
2504
2505        // Send a message so the thread gets persisted.
2506        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2507        let send = cx.foreground_executor().spawn(send);
2508        cx.run_until_parked();
2509
2510        thinking_model.send_last_completion_stream_text_chunk("Response.");
2511        thinking_model.end_last_completion_stream();
2512
2513        send.await.unwrap();
2514        cx.run_until_parked();
2515
2516        // Close the session so it can be reloaded from disk.
2517        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2518            .await
2519            .unwrap();
2520        drop(thread);
2521        drop(acp_thread);
2522        agent.read_with(cx, |agent, _| {
2523            assert!(agent.sessions.is_empty());
2524        });
2525
2526        // Reload the thread and verify thinking_enabled is still true.
2527        let reloaded_acp_thread = agent
2528            .update(cx, |agent, cx| {
2529                agent.open_thread(session_id.clone(), project.clone(), cx)
2530            })
2531            .await
2532            .unwrap();
2533        let reloaded_thread = agent.read_with(cx, |agent, _| {
2534            agent.sessions.get(&session_id).unwrap().thread.clone()
2535        });
2536        reloaded_thread.read_with(cx, |thread, _| {
2537            assert!(
2538                thread.thinking_enabled(),
2539                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2540            );
2541        });
2542
2543        drop(reloaded_acp_thread);
2544    }
2545
2546    #[gpui::test]
2547    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2548        init_test(cx);
2549        let fs = FakeFs::new(cx.executor());
2550        fs.insert_tree("/", json!({ "a": {} })).await;
2551        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2552        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2553        let agent = cx.update(|cx| {
2554            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2555        });
2556        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2557
2558        // Register a model where id() != name(), like real Anthropic models
2559        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2560        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2561            "fake-corp",
2562            "custom-model-id",
2563            "Custom Model Display Name",
2564            false,
2565        ));
2566        let provider = Arc::new(
2567            FakeLanguageModelProvider::new(
2568                LanguageModelProviderId::from("fake-corp".to_string()),
2569                LanguageModelProviderName::from("Fake Corp".to_string()),
2570            )
2571            .with_models(vec![model.clone()]),
2572        );
2573        cx.update(|cx| {
2574            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2575                registry.register_provider(provider, cx);
2576            });
2577        });
2578        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2579
2580        // Create a thread and select the model.
2581        let acp_thread = cx
2582            .update(|cx| {
2583                connection.clone().new_session(
2584                    project.clone(),
2585                    PathList::new(&[Path::new("/a")]),
2586                    cx,
2587                )
2588            })
2589            .await
2590            .unwrap();
2591        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2592
2593        let selector = connection.model_selector(&session_id).unwrap();
2594        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2595            .await
2596            .unwrap();
2597
2598        let thread = agent.read_with(cx, |agent, _| {
2599            agent.sessions.get(&session_id).unwrap().thread.clone()
2600        });
2601        thread.read_with(cx, |thread, _| {
2602            assert_eq!(
2603                thread.model().unwrap().id().0.as_ref(),
2604                "custom-model-id",
2605                "model should be set before persisting"
2606            );
2607        });
2608
2609        // Send a message so the thread gets persisted.
2610        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2611        let send = cx.foreground_executor().spawn(send);
2612        cx.run_until_parked();
2613
2614        model.send_last_completion_stream_text_chunk("Response.");
2615        model.end_last_completion_stream();
2616
2617        send.await.unwrap();
2618        cx.run_until_parked();
2619
2620        // Close the session so it can be reloaded from disk.
2621        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2622            .await
2623            .unwrap();
2624        drop(thread);
2625        drop(acp_thread);
2626        agent.read_with(cx, |agent, _| {
2627            assert!(agent.sessions.is_empty());
2628        });
2629
2630        // Reload the thread and verify the model was preserved.
2631        let reloaded_acp_thread = agent
2632            .update(cx, |agent, cx| {
2633                agent.open_thread(session_id.clone(), project.clone(), cx)
2634            })
2635            .await
2636            .unwrap();
2637        let reloaded_thread = agent.read_with(cx, |agent, _| {
2638            agent.sessions.get(&session_id).unwrap().thread.clone()
2639        });
2640        reloaded_thread.read_with(cx, |thread, _| {
2641            let reloaded_model = thread
2642                .model()
2643                .expect("model should be present after reload");
2644            assert_eq!(
2645                reloaded_model.id().0.as_ref(),
2646                "custom-model-id",
2647                "reloaded thread should have the same model, not fall back to the default"
2648            );
2649        });
2650
2651        drop(reloaded_acp_thread);
2652    }
2653
2654    #[gpui::test]
2655    async fn test_save_load_thread(cx: &mut TestAppContext) {
2656        init_test(cx);
2657        let fs = FakeFs::new(cx.executor());
2658        fs.insert_tree(
2659            "/",
2660            json!({
2661                "a": {
2662                    "b.md": "Lorem"
2663                }
2664            }),
2665        )
2666        .await;
2667        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2668        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2669        let agent = cx.update(|cx| {
2670            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2671        });
2672        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2673
2674        let acp_thread = cx
2675            .update(|cx| {
2676                connection
2677                    .clone()
2678                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2679            })
2680            .await
2681            .unwrap();
2682        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2683        let thread = agent.read_with(cx, |agent, _| {
2684            agent.sessions.get(&session_id).unwrap().thread.clone()
2685        });
2686
2687        // Ensure empty threads are not saved, even if they get mutated.
2688        let model = Arc::new(FakeLanguageModel::default());
2689        let summary_model = Arc::new(FakeLanguageModel::default());
2690        thread.update(cx, |thread, cx| {
2691            thread.set_model(model.clone(), cx);
2692            thread.set_summarization_model(Some(summary_model.clone()), cx);
2693        });
2694        cx.run_until_parked();
2695        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2696
2697        let send = acp_thread.update(cx, |thread, cx| {
2698            thread.send(
2699                vec![
2700                    "What does ".into(),
2701                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2702                        "b.md",
2703                        MentionUri::File {
2704                            abs_path: path!("/a/b.md").into(),
2705                        }
2706                        .to_uri()
2707                        .to_string(),
2708                    )),
2709                    " mean?".into(),
2710                ],
2711                cx,
2712            )
2713        });
2714        let send = cx.foreground_executor().spawn(send);
2715        cx.run_until_parked();
2716
2717        model.send_last_completion_stream_text_chunk("Lorem.");
2718        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2719            language_model::TokenUsage {
2720                input_tokens: 150,
2721                output_tokens: 75,
2722                ..Default::default()
2723            },
2724        ));
2725        model.end_last_completion_stream();
2726        cx.run_until_parked();
2727        summary_model
2728            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2729        summary_model.end_last_completion_stream();
2730
2731        send.await.unwrap();
2732        let uri = MentionUri::File {
2733            abs_path: path!("/a/b.md").into(),
2734        }
2735        .to_uri();
2736        acp_thread.read_with(cx, |thread, cx| {
2737            assert_eq!(
2738                thread.to_markdown(cx),
2739                formatdoc! {"
2740                    ## User
2741
2742                    What does [@b.md]({uri}) mean?
2743
2744                    ## Assistant
2745
2746                    Lorem.
2747
2748                "}
2749            )
2750        });
2751
2752        cx.run_until_parked();
2753
2754        // Set a draft prompt with rich content blocks before saving.
2755        let draft_blocks = vec![
2756            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2757            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2758            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2759        ];
2760        acp_thread.update(cx, |thread, _cx| {
2761            thread.set_draft_prompt(Some(draft_blocks.clone()));
2762        });
2763        thread.update(cx, |thread, _cx| {
2764            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2765                item_ix: 5,
2766                offset_in_item: gpui::px(12.5),
2767            }));
2768        });
2769        thread.update(cx, |_thread, cx| cx.notify());
2770        cx.run_until_parked();
2771
2772        // Close the session so it can be reloaded from disk.
2773        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2774            .await
2775            .unwrap();
2776        drop(thread);
2777        drop(acp_thread);
2778        agent.read_with(cx, |agent, _| {
2779            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2780        });
2781
2782        // Ensure the thread can be reloaded from disk.
2783        assert_eq!(
2784            thread_entries(&thread_store, cx),
2785            vec![(
2786                session_id.clone(),
2787                format!("Explaining {}", path!("/a/b.md"))
2788            )]
2789        );
2790        let acp_thread = agent
2791            .update(cx, |agent, cx| {
2792                agent.open_thread(session_id.clone(), project.clone(), cx)
2793            })
2794            .await
2795            .unwrap();
2796        acp_thread.read_with(cx, |thread, cx| {
2797            assert_eq!(
2798                thread.to_markdown(cx),
2799                formatdoc! {"
2800                    ## User
2801
2802                    What does [@b.md]({uri}) mean?
2803
2804                    ## Assistant
2805
2806                    Lorem.
2807
2808                "}
2809            )
2810        });
2811
2812        // Ensure the draft prompt with rich content blocks survived the round-trip.
2813        acp_thread.read_with(cx, |thread, _| {
2814            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2815        });
2816
2817        // Ensure token usage survived the round-trip.
2818        acp_thread.read_with(cx, |thread, _| {
2819            let usage = thread
2820                .token_usage()
2821                .expect("token usage should be restored after reload");
2822            assert_eq!(usage.input_tokens, 150);
2823            assert_eq!(usage.output_tokens, 75);
2824        });
2825
2826        // Ensure scroll position survived the round-trip.
2827        acp_thread.read_with(cx, |thread, _| {
2828            let scroll = thread
2829                .ui_scroll_position()
2830                .expect("scroll position should be restored after reload");
2831            assert_eq!(scroll.item_ix, 5);
2832            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2833        });
2834    }
2835
2836    fn thread_entries(
2837        thread_store: &Entity<ThreadStore>,
2838        cx: &mut TestAppContext,
2839    ) -> Vec<(acp::SessionId, String)> {
2840        thread_store.read_with(cx, |store, _| {
2841            store
2842                .entries()
2843                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2844                .collect::<Vec<_>>()
2845        })
2846    }
2847
2848    fn init_test(cx: &mut TestAppContext) {
2849        env_logger::try_init().ok();
2850        cx.update(|cx| {
2851            let settings_store = SettingsStore::test(cx);
2852            cx.set_global(settings_store);
2853
2854            LanguageModelRegistry::test(cx);
2855        });
2856    }
2857}
2858
2859fn mcp_message_content_to_acp_content_block(
2860    content: context_server::types::MessageContent,
2861) -> acp::ContentBlock {
2862    match content {
2863        context_server::types::MessageContent::Text {
2864            text,
2865            annotations: _,
2866        } => text.into(),
2867        context_server::types::MessageContent::Image {
2868            data,
2869            mime_type,
2870            annotations: _,
2871        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2872        context_server::types::MessageContent::Audio {
2873            data,
2874            mime_type,
2875            annotations: _,
2876        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2877        context_server::types::MessageContent::Resource {
2878            resource,
2879            annotations: _,
2880        } => {
2881            let mut link =
2882                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2883            if let Some(mime_type) = resource.mime_type {
2884                link = link.mime_type(mime_type);
2885            }
2886            acp::ContentBlock::ResourceLink(link)
2887        }
2888    }
2889}