agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
  41    WeakEntity,
  42};
  43use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  44use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
  45use prompt_store::{
  46    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  47    WorktreeContext,
  48};
  49use serde::{Deserialize, Serialize};
  50use settings::{LanguageModelSelection, update_settings_file};
  51use std::any::Any;
  52use std::path::PathBuf;
  53use std::rc::Rc;
  54use std::sync::{Arc, LazyLock};
  55use util::ResultExt;
  56use util::path_list::PathList;
  57use util::rel_path::RelPath;
  58
  59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  60pub struct ProjectSnapshot {
  61    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  62    pub timestamp: DateTime<Utc>,
  63}
  64
  65pub struct RulesLoadingError {
  66    pub message: SharedString,
  67}
  68
  69struct ProjectState {
  70    project: Entity<Project>,
  71    project_context: Entity<ProjectContext>,
  72    project_context_needs_refresh: watch::Sender<()>,
  73    _maintain_project_context: Task<Result<()>>,
  74    context_server_registry: Entity<ContextServerRegistry>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78/// Holds both the internal Thread and the AcpThread for a session
  79struct Session {
  80    /// The internal thread that processes messages
  81    thread: Entity<Thread>,
  82    /// The ACP thread that handles protocol communication
  83    acp_thread: Entity<acp_thread::AcpThread>,
  84    project_id: EntityId,
  85    pending_save: Task<()>,
  86    _subscriptions: Vec<Subscription>,
  87}
  88
  89pub struct LanguageModels {
  90    /// Access language model by ID
  91    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  92    /// Cached list for returning language model information
  93    model_list: acp_thread::AgentModelList,
  94    refresh_models_rx: watch::Receiver<()>,
  95    refresh_models_tx: watch::Sender<()>,
  96    _authenticate_all_providers_task: Task<()>,
  97}
  98
  99impl LanguageModels {
 100    fn new(cx: &mut App) -> Self {
 101        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 102
 103        let mut this = Self {
 104            models: HashMap::default(),
 105            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 106            refresh_models_rx,
 107            refresh_models_tx,
 108            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 109        };
 110        this.refresh_list(cx);
 111        this
 112    }
 113
 114    fn refresh_list(&mut self, cx: &App) {
 115        let providers = LanguageModelRegistry::global(cx)
 116            .read(cx)
 117            .visible_providers()
 118            .into_iter()
 119            .filter(|provider| provider.is_authenticated(cx))
 120            .collect::<Vec<_>>();
 121
 122        let mut language_model_list = IndexMap::default();
 123        let mut recommended_models = HashSet::default();
 124
 125        let mut recommended = Vec::new();
 126        for provider in &providers {
 127            for model in provider.recommended_models(cx) {
 128                recommended_models.insert((model.provider_id(), model.id()));
 129                recommended.push(Self::map_language_model_to_info(&model, provider));
 130            }
 131        }
 132        if !recommended.is_empty() {
 133            language_model_list.insert(
 134                acp_thread::AgentModelGroupName("Recommended".into()),
 135                recommended,
 136            );
 137        }
 138
 139        let mut models = HashMap::default();
 140        for provider in providers {
 141            let mut provider_models = Vec::new();
 142            for model in provider.provided_models(cx) {
 143                let model_info = Self::map_language_model_to_info(&model, &provider);
 144                let model_id = model_info.id.clone();
 145                provider_models.push(model_info);
 146                models.insert(model_id, model);
 147            }
 148            if !provider_models.is_empty() {
 149                language_model_list.insert(
 150                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 151                    provider_models,
 152                );
 153            }
 154        }
 155
 156        self.models = models;
 157        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 158        self.refresh_models_tx.send(()).ok();
 159    }
 160
 161    fn watch(&self) -> watch::Receiver<()> {
 162        self.refresh_models_rx.clone()
 163    }
 164
 165    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 166        self.models.get(model_id).cloned()
 167    }
 168
 169    fn map_language_model_to_info(
 170        model: &Arc<dyn LanguageModel>,
 171        provider: &Arc<dyn LanguageModelProvider>,
 172    ) -> acp_thread::AgentModelInfo {
 173        acp_thread::AgentModelInfo {
 174            id: Self::model_id(model),
 175            name: model.name().0,
 176            description: None,
 177            icon: Some(match provider.icon() {
 178                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 179                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 180            }),
 181            is_latest: model.is_latest(),
 182            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 183        }
 184    }
 185
 186    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 187        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 188    }
 189
 190    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 191        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 192            .read(cx)
 193            .visible_providers()
 194            .iter()
 195            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 196            .collect::<Vec<_>>();
 197
 198        cx.background_spawn(async move {
 199            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 200                if let Err(err) = authenticate_task.await {
 201                    match err {
 202                        language_model::AuthenticateError::CredentialsNotFound => {
 203                            // Since we're authenticating these providers in the
 204                            // background for the purposes of populating the
 205                            // language selector, we don't care about providers
 206                            // where the credentials are not found.
 207                        }
 208                        language_model::AuthenticateError::ConnectionRefused => {
 209                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 210                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 211                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 212                        }
 213                        _ => {
 214                            // Some providers have noisy failure states that we
 215                            // don't want to spam the logs with every time the
 216                            // language model selector is initialized.
 217                            //
 218                            // Ideally these should have more clear failure modes
 219                            // that we know are safe to ignore here, like what we do
 220                            // with `CredentialsNotFound` above.
 221                            match provider_id.0.as_ref() {
 222                                "lmstudio" | "ollama" => {
 223                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 224                                    //
 225                                    // These fail noisily, so we don't log them.
 226                                }
 227                                "copilot_chat" => {
 228                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 229                                }
 230                                _ => {
 231                                    log::error!(
 232                                        "Failed to authenticate provider: {}: {err:#}",
 233                                        provider_name.0
 234                                    );
 235                                }
 236                            }
 237                        }
 238                    }
 239                }
 240            }
 241        })
 242    }
 243}
 244
 245pub struct NativeAgent {
 246    /// Session ID -> Session mapping
 247    sessions: HashMap<acp::SessionId, Session>,
 248    thread_store: Entity<ThreadStore>,
 249    /// Project-specific state keyed by project EntityId
 250    projects: HashMap<EntityId, ProjectState>,
 251    /// Shared templates for all threads
 252    templates: Arc<Templates>,
 253    /// Cached model information
 254    models: LanguageModels,
 255    prompt_store: Option<Entity<PromptStore>>,
 256    fs: Arc<dyn Fs>,
 257    _subscriptions: Vec<Subscription>,
 258}
 259
 260impl NativeAgent {
 261    pub fn new(
 262        thread_store: Entity<ThreadStore>,
 263        templates: Arc<Templates>,
 264        prompt_store: Option<Entity<PromptStore>>,
 265        fs: Arc<dyn Fs>,
 266        cx: &mut App,
 267    ) -> Entity<NativeAgent> {
 268        log::debug!("Creating new NativeAgent");
 269
 270        cx.new(|cx| {
 271            let mut subscriptions = vec![cx.subscribe(
 272                &LanguageModelRegistry::global(cx),
 273                Self::handle_models_updated_event,
 274            )];
 275            if let Some(prompt_store) = prompt_store.as_ref() {
 276                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 277            }
 278
 279            Self {
 280                sessions: HashMap::default(),
 281                thread_store,
 282                projects: HashMap::default(),
 283                templates,
 284                models: LanguageModels::new(cx),
 285                prompt_store,
 286                fs,
 287                _subscriptions: subscriptions,
 288            }
 289        })
 290    }
 291
 292    fn new_session(
 293        &mut self,
 294        project: Entity<Project>,
 295        cx: &mut Context<Self>,
 296    ) -> Entity<AcpThread> {
 297        let project_id = self.get_or_create_project_state(&project, cx);
 298        let project_state = &self.projects[&project_id];
 299
 300        let registry = LanguageModelRegistry::read_global(cx);
 301        let available_count = registry.available_models(cx).count();
 302        log::debug!("Total available models: {}", available_count);
 303
 304        let default_model = registry.default_model().and_then(|default_model| {
 305            self.models
 306                .model_from_id(&LanguageModels::model_id(&default_model.model))
 307        });
 308        let thread = cx.new(|cx| {
 309            Thread::new(
 310                project,
 311                project_state.project_context.clone(),
 312                project_state.context_server_registry.clone(),
 313                self.templates.clone(),
 314                default_model,
 315                cx,
 316            )
 317        });
 318
 319        self.register_session(thread, project_id, cx)
 320    }
 321
 322    fn register_session(
 323        &mut self,
 324        thread_handle: Entity<Thread>,
 325        project_id: EntityId,
 326        cx: &mut Context<Self>,
 327    ) -> Entity<AcpThread> {
 328        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 329
 330        let thread = thread_handle.read(cx);
 331        let session_id = thread.id().clone();
 332        let parent_session_id = thread.parent_thread_id();
 333        let title = thread.title();
 334        let draft_prompt = thread.draft_prompt().map(Vec::from);
 335        let scroll_position = thread.ui_scroll_position();
 336        let token_usage = thread.latest_token_usage();
 337        let project = thread.project.clone();
 338        let action_log = thread.action_log.clone();
 339        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 340        let acp_thread = cx.new(|cx| {
 341            let mut acp_thread = acp_thread::AcpThread::new(
 342                parent_session_id,
 343                title,
 344                None,
 345                connection,
 346                project.clone(),
 347                action_log.clone(),
 348                session_id.clone(),
 349                prompt_capabilities_rx,
 350                cx,
 351            );
 352            acp_thread.set_draft_prompt(draft_prompt);
 353            acp_thread.set_ui_scroll_position(scroll_position);
 354            acp_thread.update_token_usage(token_usage, cx);
 355            acp_thread
 356        });
 357
 358        let registry = LanguageModelRegistry::read_global(cx);
 359        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 360
 361        let weak = cx.weak_entity();
 362        let weak_thread = thread_handle.downgrade();
 363        thread_handle.update(cx, |thread, cx| {
 364            thread.set_summarization_model(summarization_model, cx);
 365            thread.add_default_tools(
 366                Rc::new(NativeThreadEnvironment {
 367                    acp_thread: acp_thread.downgrade(),
 368                    thread: weak_thread,
 369                    agent: weak,
 370                }) as _,
 371                cx,
 372            )
 373        });
 374
 375        let subscriptions = vec![
 376            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 377            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 378            cx.observe(&thread_handle, move |this, thread, cx| {
 379                this.save_thread(thread, cx)
 380            }),
 381        ];
 382
 383        self.sessions.insert(
 384            session_id,
 385            Session {
 386                thread: thread_handle,
 387                acp_thread: acp_thread.clone(),
 388                project_id,
 389                _subscriptions: subscriptions,
 390                pending_save: Task::ready(()),
 391            },
 392        );
 393
 394        self.update_available_commands_for_project(project_id, cx);
 395
 396        acp_thread
 397    }
 398
 399    pub fn models(&self) -> &LanguageModels {
 400        &self.models
 401    }
 402
 403    fn get_or_create_project_state(
 404        &mut self,
 405        project: &Entity<Project>,
 406        cx: &mut Context<Self>,
 407    ) -> EntityId {
 408        let project_id = project.entity_id();
 409        if self.projects.contains_key(&project_id) {
 410            return project_id;
 411        }
 412
 413        let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
 414        self.register_project_with_initial_context(project.clone(), project_context, cx);
 415        if let Some(state) = self.projects.get_mut(&project_id) {
 416            state.project_context_needs_refresh.send(()).ok();
 417        }
 418        project_id
 419    }
 420
 421    fn register_project_with_initial_context(
 422        &mut self,
 423        project: Entity<Project>,
 424        project_context: Entity<ProjectContext>,
 425        cx: &mut Context<Self>,
 426    ) {
 427        let project_id = project.entity_id();
 428
 429        let context_server_store = project.read(cx).context_server_store();
 430        let context_server_registry =
 431            cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 432
 433        let subscriptions = vec![
 434            cx.subscribe(&project, Self::handle_project_event),
 435            cx.subscribe(
 436                &context_server_store,
 437                Self::handle_context_server_store_updated,
 438            ),
 439            cx.subscribe(
 440                &context_server_registry,
 441                Self::handle_context_server_registry_event,
 442            ),
 443        ];
 444
 445        let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 446            watch::channel(());
 447
 448        self.projects.insert(
 449            project_id,
 450            ProjectState {
 451                project,
 452                project_context,
 453                project_context_needs_refresh: project_context_needs_refresh_tx,
 454                _maintain_project_context: cx.spawn(async move |this, cx| {
 455                    Self::maintain_project_context(
 456                        this,
 457                        project_id,
 458                        project_context_needs_refresh_rx,
 459                        cx,
 460                    )
 461                    .await
 462                }),
 463                context_server_registry,
 464                _subscriptions: subscriptions,
 465            },
 466        );
 467    }
 468
 469    fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
 470        self.sessions
 471            .get(session_id)
 472            .and_then(|session| self.projects.get(&session.project_id))
 473    }
 474
 475    async fn maintain_project_context(
 476        this: WeakEntity<Self>,
 477        project_id: EntityId,
 478        mut needs_refresh: watch::Receiver<()>,
 479        cx: &mut AsyncApp,
 480    ) -> Result<()> {
 481        while needs_refresh.changed().await.is_ok() {
 482            let project_context = this
 483                .update(cx, |this, cx| {
 484                    let state = this
 485                        .projects
 486                        .get(&project_id)
 487                        .context("project state not found")?;
 488                    anyhow::Ok(Self::build_project_context(
 489                        &state.project,
 490                        this.prompt_store.as_ref(),
 491                        cx,
 492                    ))
 493                })??
 494                .await;
 495            this.update(cx, |this, cx| {
 496                if let Some(state) = this.projects.get(&project_id) {
 497                    state
 498                        .project_context
 499                        .update(cx, |current_project_context, _cx| {
 500                            *current_project_context = project_context;
 501                        });
 502                }
 503            })?;
 504        }
 505
 506        Ok(())
 507    }
 508
 509    fn build_project_context(
 510        project: &Entity<Project>,
 511        prompt_store: Option<&Entity<PromptStore>>,
 512        cx: &mut App,
 513    ) -> Task<ProjectContext> {
 514        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 515        let worktree_tasks = worktrees
 516            .into_iter()
 517            .map(|worktree| {
 518                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 519            })
 520            .collect::<Vec<_>>();
 521        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 522            prompt_store.read_with(cx, |prompt_store, cx| {
 523                let prompts = prompt_store.default_prompt_metadata();
 524                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 525                    let contents = prompt_store.load(prompt_metadata.id, cx);
 526                    async move { (contents.await, prompt_metadata) }
 527                });
 528                cx.background_spawn(future::join_all(load_tasks))
 529            })
 530        } else {
 531            Task::ready(vec![])
 532        };
 533
 534        cx.spawn(async move |_cx| {
 535            let (worktrees, default_user_rules) =
 536                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 537
 538            let worktrees = worktrees
 539                .into_iter()
 540                .map(|(worktree, _rules_error)| {
 541                    // TODO: show error message
 542                    // if let Some(rules_error) = rules_error {
 543                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 544                    // }
 545                    worktree
 546                })
 547                .collect::<Vec<_>>();
 548
 549            let default_user_rules = default_user_rules
 550                .into_iter()
 551                .flat_map(|(contents, prompt_metadata)| match contents {
 552                    Ok(contents) => Some(UserRulesContext {
 553                        uuid: prompt_metadata.id.as_user()?,
 554                        title: prompt_metadata.title.map(|title| title.to_string()),
 555                        contents,
 556                    }),
 557                    Err(_err) => {
 558                        // TODO: show error message
 559                        // this.update(cx, |_, cx| {
 560                        //     cx.emit(RulesLoadingError {
 561                        //         message: format!("{err:?}").into(),
 562                        //     });
 563                        // })
 564                        // .ok();
 565                        None
 566                    }
 567                })
 568                .collect::<Vec<_>>();
 569
 570            ProjectContext::new(worktrees, default_user_rules)
 571        })
 572    }
 573
 574    fn load_worktree_info_for_system_prompt(
 575        worktree: Entity<Worktree>,
 576        project: Entity<Project>,
 577        cx: &mut App,
 578    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 579        let tree = worktree.read(cx);
 580        let root_name = tree.root_name_str().into();
 581        let abs_path = tree.abs_path();
 582
 583        let mut context = WorktreeContext {
 584            root_name,
 585            abs_path,
 586            rules_file: None,
 587        };
 588
 589        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 590        let Some(rules_task) = rules_task else {
 591            return Task::ready((context, None));
 592        };
 593
 594        cx.spawn(async move |_| {
 595            let (rules_file, rules_file_error) = match rules_task.await {
 596                Ok(rules_file) => (Some(rules_file), None),
 597                Err(err) => (
 598                    None,
 599                    Some(RulesLoadingError {
 600                        message: format!("{err}").into(),
 601                    }),
 602                ),
 603            };
 604            context.rules_file = rules_file;
 605            (context, rules_file_error)
 606        })
 607    }
 608
 609    fn load_worktree_rules_file(
 610        worktree: Entity<Worktree>,
 611        project: Entity<Project>,
 612        cx: &mut App,
 613    ) -> Option<Task<Result<RulesFileContext>>> {
 614        let worktree = worktree.read(cx);
 615        let worktree_id = worktree.id();
 616        let selected_rules_file = RULES_FILE_NAMES
 617            .into_iter()
 618            .filter_map(|name| {
 619                worktree
 620                    .entry_for_path(RelPath::unix(name).unwrap())
 621                    .filter(|entry| entry.is_file())
 622                    .map(|entry| entry.path.clone())
 623            })
 624            .next();
 625
 626        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 627        // supported. This doesn't seem to occur often in GitHub repositories.
 628        selected_rules_file.map(|path_in_worktree| {
 629            let project_path = ProjectPath {
 630                worktree_id,
 631                path: path_in_worktree.clone(),
 632            };
 633            let buffer_task =
 634                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 635            let rope_task = cx.spawn(async move |cx| {
 636                let buffer = buffer_task.await?;
 637                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 638                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 639                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 640                })?;
 641                anyhow::Ok((project_entry_id, rope))
 642            });
 643            // Build a string from the rope on a background thread.
 644            cx.background_spawn(async move {
 645                let (project_entry_id, rope) = rope_task.await?;
 646                anyhow::Ok(RulesFileContext {
 647                    path_in_worktree,
 648                    text: rope.to_string().trim().to_string(),
 649                    project_entry_id: project_entry_id.to_usize(),
 650                })
 651            })
 652        })
 653    }
 654
 655    fn handle_thread_title_updated(
 656        &mut self,
 657        thread: Entity<Thread>,
 658        _: &TitleUpdated,
 659        cx: &mut Context<Self>,
 660    ) {
 661        let session_id = thread.read(cx).id();
 662        let Some(session) = self.sessions.get(session_id) else {
 663            return;
 664        };
 665        let thread = thread.downgrade();
 666        let acp_thread = session.acp_thread.downgrade();
 667        cx.spawn(async move |_, cx| {
 668            let title = thread.read_with(cx, |thread, _| thread.title())?;
 669            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 670            task.await
 671        })
 672        .detach_and_log_err(cx);
 673    }
 674
 675    fn handle_thread_token_usage_updated(
 676        &mut self,
 677        thread: Entity<Thread>,
 678        usage: &TokenUsageUpdated,
 679        cx: &mut Context<Self>,
 680    ) {
 681        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 682            return;
 683        };
 684        session.acp_thread.update(cx, |acp_thread, cx| {
 685            acp_thread.update_token_usage(usage.0.clone(), cx);
 686        });
 687    }
 688
 689    fn handle_project_event(
 690        &mut self,
 691        project: Entity<Project>,
 692        event: &project::Event,
 693        _cx: &mut Context<Self>,
 694    ) {
 695        let project_id = project.entity_id();
 696        let Some(state) = self.projects.get_mut(&project_id) else {
 697            return;
 698        };
 699        match event {
 700            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 701                state.project_context_needs_refresh.send(()).ok();
 702            }
 703            project::Event::WorktreeUpdatedEntries(_, items) => {
 704                if items.iter().any(|(path, _, _)| {
 705                    RULES_FILE_NAMES
 706                        .iter()
 707                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 708                }) {
 709                    state.project_context_needs_refresh.send(()).ok();
 710                }
 711            }
 712            _ => {}
 713        }
 714    }
 715
 716    fn handle_prompts_updated_event(
 717        &mut self,
 718        _prompt_store: Entity<PromptStore>,
 719        _event: &prompt_store::PromptsUpdatedEvent,
 720        _cx: &mut Context<Self>,
 721    ) {
 722        for state in self.projects.values_mut() {
 723            state.project_context_needs_refresh.send(()).ok();
 724        }
 725    }
 726
 727    fn handle_models_updated_event(
 728        &mut self,
 729        _registry: Entity<LanguageModelRegistry>,
 730        _event: &language_model::Event,
 731        cx: &mut Context<Self>,
 732    ) {
 733        self.models.refresh_list(cx);
 734
 735        let registry = LanguageModelRegistry::read_global(cx);
 736        let default_model = registry.default_model().map(|m| m.model);
 737        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 738
 739        for session in self.sessions.values_mut() {
 740            session.thread.update(cx, |thread, cx| {
 741                if thread.model().is_none()
 742                    && let Some(model) = default_model.clone()
 743                {
 744                    thread.set_model(model, cx);
 745                    cx.notify();
 746                }
 747                thread.set_summarization_model(summarization_model.clone(), cx);
 748            });
 749        }
 750    }
 751
 752    fn handle_context_server_store_updated(
 753        &mut self,
 754        store: Entity<project::context_server_store::ContextServerStore>,
 755        _event: &project::context_server_store::ServerStatusChangedEvent,
 756        cx: &mut Context<Self>,
 757    ) {
 758        let project_id = self.projects.iter().find_map(|(id, state)| {
 759            if *state.context_server_registry.read(cx).server_store() == store {
 760                Some(*id)
 761            } else {
 762                None
 763            }
 764        });
 765        if let Some(project_id) = project_id {
 766            self.update_available_commands_for_project(project_id, cx);
 767        }
 768    }
 769
 770    fn handle_context_server_registry_event(
 771        &mut self,
 772        registry: Entity<ContextServerRegistry>,
 773        event: &ContextServerRegistryEvent,
 774        cx: &mut Context<Self>,
 775    ) {
 776        match event {
 777            ContextServerRegistryEvent::ToolsChanged => {}
 778            ContextServerRegistryEvent::PromptsChanged => {
 779                let project_id = self.projects.iter().find_map(|(id, state)| {
 780                    if state.context_server_registry == registry {
 781                        Some(*id)
 782                    } else {
 783                        None
 784                    }
 785                });
 786                if let Some(project_id) = project_id {
 787                    self.update_available_commands_for_project(project_id, cx);
 788                }
 789            }
 790        }
 791    }
 792
 793    fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
 794        let available_commands =
 795            Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
 796        for session in self.sessions.values() {
 797            if session.project_id != project_id {
 798                continue;
 799            }
 800            session.acp_thread.update(cx, |thread, cx| {
 801                thread
 802                    .handle_session_update(
 803                        acp::SessionUpdate::AvailableCommandsUpdate(
 804                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 805                        ),
 806                        cx,
 807                    )
 808                    .log_err();
 809            });
 810        }
 811    }
 812
 813    fn build_available_commands_for_project(
 814        project_state: Option<&ProjectState>,
 815        cx: &App,
 816    ) -> Vec<acp::AvailableCommand> {
 817        let Some(state) = project_state else {
 818            return vec![];
 819        };
 820        let registry = state.context_server_registry.read(cx);
 821
 822        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 823        for context_server_prompt in registry.prompts() {
 824            *prompt_name_counts
 825                .entry(context_server_prompt.prompt.name.as_str())
 826                .or_insert(0) += 1;
 827        }
 828
 829        registry
 830            .prompts()
 831            .flat_map(|context_server_prompt| {
 832                let prompt = &context_server_prompt.prompt;
 833
 834                let should_prefix = prompt_name_counts
 835                    .get(prompt.name.as_str())
 836                    .copied()
 837                    .unwrap_or(0)
 838                    > 1;
 839
 840                let name = if should_prefix {
 841                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 842                } else {
 843                    prompt.name.clone()
 844                };
 845
 846                let mut command = acp::AvailableCommand::new(
 847                    name,
 848                    prompt.description.clone().unwrap_or_default(),
 849                );
 850
 851                match prompt.arguments.as_deref() {
 852                    Some([arg]) => {
 853                        let hint = format!("<{}>", arg.name);
 854
 855                        command = command.input(acp::AvailableCommandInput::Unstructured(
 856                            acp::UnstructuredCommandInput::new(hint),
 857                        ));
 858                    }
 859                    Some([]) | None => {}
 860                    Some(_) => {
 861                        // skip >1 argument commands since we don't support them yet
 862                        return None;
 863                    }
 864                }
 865
 866                Some(command)
 867            })
 868            .collect()
 869    }
 870
 871    pub fn load_thread(
 872        &mut self,
 873        id: acp::SessionId,
 874        project: Entity<Project>,
 875        cx: &mut Context<Self>,
 876    ) -> Task<Result<Entity<Thread>>> {
 877        let database_future = ThreadsDatabase::connect(cx);
 878        cx.spawn(async move |this, cx| {
 879            let database = database_future.await.map_err(|err| anyhow!(err))?;
 880            let db_thread = database
 881                .load_thread(id.clone())
 882                .await?
 883                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 884
 885            this.update(cx, |this, cx| {
 886                let project_id = this.get_or_create_project_state(&project, cx);
 887                let project_state = this
 888                    .projects
 889                    .get(&project_id)
 890                    .context("project state not found")?;
 891                let summarization_model = LanguageModelRegistry::read_global(cx)
 892                    .thread_summary_model()
 893                    .map(|c| c.model);
 894
 895                Ok(cx.new(|cx| {
 896                    let mut thread = Thread::from_db(
 897                        id.clone(),
 898                        db_thread,
 899                        project_state.project.clone(),
 900                        project_state.project_context.clone(),
 901                        project_state.context_server_registry.clone(),
 902                        this.templates.clone(),
 903                        cx,
 904                    );
 905                    thread.set_summarization_model(summarization_model, cx);
 906                    thread
 907                }))
 908            })?
 909        })
 910    }
 911
 912    pub fn open_thread(
 913        &mut self,
 914        id: acp::SessionId,
 915        project: Entity<Project>,
 916        cx: &mut Context<Self>,
 917    ) -> Task<Result<Entity<AcpThread>>> {
 918        if let Some(session) = self.sessions.get(&id) {
 919            return Task::ready(Ok(session.acp_thread.clone()));
 920        }
 921
 922        let task = self.load_thread(id, project.clone(), cx);
 923        cx.spawn(async move |this, cx| {
 924            let thread = task.await?;
 925            let acp_thread = this.update(cx, |this, cx| {
 926                let project_id = this.get_or_create_project_state(&project, cx);
 927                this.register_session(thread.clone(), project_id, cx)
 928            })?;
 929            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 930            cx.update(|cx| {
 931                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 932            })
 933            .await?;
 934            Ok(acp_thread)
 935        })
 936    }
 937
 938    pub fn thread_summary(
 939        &mut self,
 940        id: acp::SessionId,
 941        project: Entity<Project>,
 942        cx: &mut Context<Self>,
 943    ) -> Task<Result<SharedString>> {
 944        let thread = self.open_thread(id.clone(), project, cx);
 945        cx.spawn(async move |this, cx| {
 946            let acp_thread = thread.await?;
 947            let result = this
 948                .update(cx, |this, cx| {
 949                    this.sessions
 950                        .get(&id)
 951                        .unwrap()
 952                        .thread
 953                        .update(cx, |thread, cx| thread.summary(cx))
 954                })?
 955                .await
 956                .context("Failed to generate summary")?;
 957            drop(acp_thread);
 958            Ok(result)
 959        })
 960    }
 961
 962    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 963        if thread.read(cx).is_empty() {
 964            return;
 965        }
 966
 967        let id = thread.read(cx).id().clone();
 968        let Some(session) = self.sessions.get_mut(&id) else {
 969            return;
 970        };
 971
 972        let project_id = session.project_id;
 973        let Some(state) = self.projects.get(&project_id) else {
 974            return;
 975        };
 976
 977        let folder_paths = PathList::new(
 978            &state
 979                .project
 980                .read(cx)
 981                .visible_worktrees(cx)
 982                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 983                .collect::<Vec<_>>(),
 984        );
 985
 986        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 987        let database_future = ThreadsDatabase::connect(cx);
 988        let db_thread = thread.update(cx, |thread, cx| {
 989            thread.set_draft_prompt(draft_prompt);
 990            thread.to_db(cx)
 991        });
 992        let thread_store = self.thread_store.clone();
 993        session.pending_save = cx.spawn(async move |_, cx| {
 994            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 995                return;
 996            };
 997            let db_thread = db_thread.await;
 998            database
 999                .save_thread(id, db_thread, folder_paths)
1000                .await
1001                .log_err();
1002            thread_store.update(cx, |store, cx| store.reload(cx));
1003        });
1004    }
1005
1006    fn send_mcp_prompt(
1007        &self,
1008        message_id: UserMessageId,
1009        session_id: acp::SessionId,
1010        prompt_name: String,
1011        server_id: ContextServerId,
1012        arguments: HashMap<String, String>,
1013        original_content: Vec<acp::ContentBlock>,
1014        cx: &mut Context<Self>,
1015    ) -> Task<Result<acp::PromptResponse>> {
1016        let Some(state) = self.session_project_state(&session_id) else {
1017            return Task::ready(Err(anyhow!("Project state not found for session")));
1018        };
1019        let server_store = state
1020            .context_server_registry
1021            .read(cx)
1022            .server_store()
1023            .clone();
1024        let path_style = state.project.read(cx).path_style(cx);
1025
1026        cx.spawn(async move |this, cx| {
1027            let prompt =
1028                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
1029
1030            let (acp_thread, thread) = this.update(cx, |this, _cx| {
1031                let session = this
1032                    .sessions
1033                    .get(&session_id)
1034                    .context("Failed to get session")?;
1035                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
1036            })??;
1037
1038            let mut last_is_user = true;
1039
1040            thread.update(cx, |thread, cx| {
1041                thread.push_acp_user_block(
1042                    message_id,
1043                    original_content.into_iter().skip(1),
1044                    path_style,
1045                    cx,
1046                );
1047            });
1048
1049            for message in prompt.messages {
1050                let context_server::types::PromptMessage { role, content } = message;
1051                let block = mcp_message_content_to_acp_content_block(content);
1052
1053                match role {
1054                    context_server::types::Role::User => {
1055                        let id = acp_thread::UserMessageId::new();
1056
1057                        acp_thread.update(cx, |acp_thread, cx| {
1058                            acp_thread.push_user_content_block_with_indent(
1059                                Some(id.clone()),
1060                                block.clone(),
1061                                true,
1062                                cx,
1063                            );
1064                        });
1065
1066                        thread.update(cx, |thread, cx| {
1067                            thread.push_acp_user_block(id, [block], path_style, cx);
1068                        });
1069                    }
1070                    context_server::types::Role::Assistant => {
1071                        acp_thread.update(cx, |acp_thread, cx| {
1072                            acp_thread.push_assistant_content_block_with_indent(
1073                                block.clone(),
1074                                false,
1075                                true,
1076                                cx,
1077                            );
1078                        });
1079
1080                        thread.update(cx, |thread, cx| {
1081                            thread.push_acp_agent_block(block, cx);
1082                        });
1083                    }
1084                }
1085
1086                last_is_user = role == context_server::types::Role::User;
1087            }
1088
1089            let response_stream = thread.update(cx, |thread, cx| {
1090                if last_is_user {
1091                    thread.send_existing(cx)
1092                } else {
1093                    // Resume if MCP prompt did not end with a user message
1094                    thread.resume(cx)
1095                }
1096            })?;
1097
1098            cx.update(|cx| {
1099                NativeAgentConnection::handle_thread_events(
1100                    response_stream,
1101                    acp_thread.downgrade(),
1102                    cx,
1103                )
1104            })
1105            .await
1106        })
1107    }
1108}
1109
1110/// Wrapper struct that implements the AgentConnection trait
1111#[derive(Clone)]
1112pub struct NativeAgentConnection(pub Entity<NativeAgent>);
1113
1114impl NativeAgentConnection {
1115    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
1116        self.0
1117            .read(cx)
1118            .sessions
1119            .get(session_id)
1120            .map(|session| session.thread.clone())
1121    }
1122
1123    pub fn load_thread(
1124        &self,
1125        id: acp::SessionId,
1126        project: Entity<Project>,
1127        cx: &mut App,
1128    ) -> Task<Result<Entity<Thread>>> {
1129        self.0
1130            .update(cx, |this, cx| this.load_thread(id, project, cx))
1131    }
1132
1133    fn run_turn(
1134        &self,
1135        session_id: acp::SessionId,
1136        cx: &mut App,
1137        f: impl 'static
1138        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1139    ) -> Task<Result<acp::PromptResponse>> {
1140        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1141            agent
1142                .sessions
1143                .get_mut(&session_id)
1144                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1145        }) else {
1146            return Task::ready(Err(anyhow!("Session not found")));
1147        };
1148        log::debug!("Found session for: {}", session_id);
1149
1150        let response_stream = match f(thread, cx) {
1151            Ok(stream) => stream,
1152            Err(err) => return Task::ready(Err(err)),
1153        };
1154        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1155    }
1156
1157    fn handle_thread_events(
1158        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1159        acp_thread: WeakEntity<AcpThread>,
1160        cx: &App,
1161    ) -> Task<Result<acp::PromptResponse>> {
1162        cx.spawn(async move |cx| {
1163            // Handle response stream and forward to session.acp_thread
1164            while let Some(result) = events.next().await {
1165                match result {
1166                    Ok(event) => {
1167                        log::trace!("Received completion event: {:?}", event);
1168
1169                        match event {
1170                            ThreadEvent::UserMessage(message) => {
1171                                acp_thread.update(cx, |thread, cx| {
1172                                    for content in message.content {
1173                                        thread.push_user_content_block(
1174                                            Some(message.id.clone()),
1175                                            content.into(),
1176                                            cx,
1177                                        );
1178                                    }
1179                                })?;
1180                            }
1181                            ThreadEvent::AgentText(text) => {
1182                                acp_thread.update(cx, |thread, cx| {
1183                                    thread.push_assistant_content_block(text.into(), false, cx)
1184                                })?;
1185                            }
1186                            ThreadEvent::AgentThinking(text) => {
1187                                acp_thread.update(cx, |thread, cx| {
1188                                    thread.push_assistant_content_block(text.into(), true, cx)
1189                                })?;
1190                            }
1191                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1192                                tool_call,
1193                                options,
1194                                response,
1195                                context: _,
1196                            }) => {
1197                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1198                                    thread.request_tool_call_authorization(tool_call, options, cx)
1199                                })??;
1200                                cx.background_spawn(async move {
1201                                    if let acp::RequestPermissionOutcome::Selected(
1202                                        acp::SelectedPermissionOutcome { option_id, .. },
1203                                    ) = outcome_task.await
1204                                    {
1205                                        response
1206                                            .send(option_id)
1207                                            .map(|_| anyhow!("authorization receiver was dropped"))
1208                                            .log_err();
1209                                    }
1210                                })
1211                                .detach();
1212                            }
1213                            ThreadEvent::ToolCall(tool_call) => {
1214                                acp_thread.update(cx, |thread, cx| {
1215                                    thread.upsert_tool_call(tool_call, cx)
1216                                })??;
1217                            }
1218                            ThreadEvent::ToolCallUpdate(update) => {
1219                                acp_thread.update(cx, |thread, cx| {
1220                                    thread.update_tool_call(update, cx)
1221                                })??;
1222                            }
1223                            ThreadEvent::SubagentSpawned(session_id) => {
1224                                acp_thread.update(cx, |thread, cx| {
1225                                    thread.subagent_spawned(session_id, cx);
1226                                })?;
1227                            }
1228                            ThreadEvent::Retry(status) => {
1229                                acp_thread.update(cx, |thread, cx| {
1230                                    thread.update_retry_status(status, cx)
1231                                })?;
1232                            }
1233                            ThreadEvent::Stop(stop_reason) => {
1234                                log::debug!("Assistant message complete: {:?}", stop_reason);
1235                                return Ok(acp::PromptResponse::new(stop_reason));
1236                            }
1237                        }
1238                    }
1239                    Err(e) => {
1240                        log::error!("Error in model response stream: {:?}", e);
1241                        return Err(e);
1242                    }
1243                }
1244            }
1245
1246            log::debug!("Response stream completed");
1247            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1248        })
1249    }
1250}
1251
1252struct Command<'a> {
1253    prompt_name: &'a str,
1254    arg_value: &'a str,
1255    explicit_server_id: Option<&'a str>,
1256}
1257
1258impl<'a> Command<'a> {
1259    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1260        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1261            return None;
1262        };
1263        let text = text_content.text.trim();
1264        let command = text.strip_prefix('/')?;
1265        let (command, arg_value) = command
1266            .split_once(char::is_whitespace)
1267            .unwrap_or((command, ""));
1268
1269        if let Some((server_id, prompt_name)) = command.split_once('.') {
1270            Some(Self {
1271                prompt_name,
1272                arg_value,
1273                explicit_server_id: Some(server_id),
1274            })
1275        } else {
1276            Some(Self {
1277                prompt_name: command,
1278                arg_value,
1279                explicit_server_id: None,
1280            })
1281        }
1282    }
1283}
1284
1285struct NativeAgentModelSelector {
1286    session_id: acp::SessionId,
1287    connection: NativeAgentConnection,
1288}
1289
1290impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1291    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1292        log::debug!("NativeAgentConnection::list_models called");
1293        let list = self.connection.0.read(cx).models.model_list.clone();
1294        Task::ready(if list.is_empty() {
1295            Err(anyhow::anyhow!("No models available"))
1296        } else {
1297            Ok(list)
1298        })
1299    }
1300
1301    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1302        log::debug!(
1303            "Setting model for session {}: {}",
1304            self.session_id,
1305            model_id
1306        );
1307        let Some(thread) = self
1308            .connection
1309            .0
1310            .read(cx)
1311            .sessions
1312            .get(&self.session_id)
1313            .map(|session| session.thread.clone())
1314        else {
1315            return Task::ready(Err(anyhow!("Session not found")));
1316        };
1317
1318        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1319            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1320        };
1321
1322        // We want to reset the effort level when switching models, as the currently-selected effort level may
1323        // not be compatible.
1324        let effort = model
1325            .default_effort_level()
1326            .map(|effort_level| effort_level.value.to_string());
1327
1328        thread.update(cx, |thread, cx| {
1329            thread.set_model(model.clone(), cx);
1330            thread.set_thinking_effort(effort.clone(), cx);
1331            thread.set_thinking_enabled(model.supports_thinking(), cx);
1332        });
1333
1334        update_settings_file(
1335            self.connection.0.read(cx).fs.clone(),
1336            cx,
1337            move |settings, cx| {
1338                let provider = model.provider_id().0.to_string();
1339                let model = model.id().0.to_string();
1340                let enable_thinking = thread.read(cx).thinking_enabled();
1341                settings
1342                    .agent
1343                    .get_or_insert_default()
1344                    .set_model(LanguageModelSelection {
1345                        provider: provider.into(),
1346                        model,
1347                        enable_thinking,
1348                        effort,
1349                    });
1350            },
1351        );
1352
1353        Task::ready(Ok(()))
1354    }
1355
1356    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1357        let Some(thread) = self
1358            .connection
1359            .0
1360            .read(cx)
1361            .sessions
1362            .get(&self.session_id)
1363            .map(|session| session.thread.clone())
1364        else {
1365            return Task::ready(Err(anyhow!("Session not found")));
1366        };
1367        let Some(model) = thread.read(cx).model() else {
1368            return Task::ready(Err(anyhow!("Model not found")));
1369        };
1370        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1371        else {
1372            return Task::ready(Err(anyhow!("Provider not found")));
1373        };
1374        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1375            model, &provider,
1376        )))
1377    }
1378
1379    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1380        Some(self.connection.0.read(cx).models.watch())
1381    }
1382
1383    fn should_render_footer(&self) -> bool {
1384        true
1385    }
1386}
1387
1388pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
1389
1390impl acp_thread::AgentConnection for NativeAgentConnection {
1391    fn agent_id(&self) -> AgentId {
1392        ZED_AGENT_ID.clone()
1393    }
1394
1395    fn telemetry_id(&self) -> SharedString {
1396        "zed".into()
1397    }
1398
1399    fn new_session(
1400        self: Rc<Self>,
1401        project: Entity<Project>,
1402        work_dirs: PathList,
1403        cx: &mut App,
1404    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1405        log::debug!("Creating new thread for project at: {work_dirs:?}");
1406        Task::ready(Ok(self
1407            .0
1408            .update(cx, |agent, cx| agent.new_session(project, cx))))
1409    }
1410
1411    fn supports_load_session(&self) -> bool {
1412        true
1413    }
1414
1415    fn load_session(
1416        self: Rc<Self>,
1417        session_id: acp::SessionId,
1418        project: Entity<Project>,
1419        _work_dirs: PathList,
1420        _title: Option<SharedString>,
1421        cx: &mut App,
1422    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1423        self.0
1424            .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
1425    }
1426
1427    fn supports_close_session(&self) -> bool {
1428        true
1429    }
1430
1431    fn close_session(
1432        self: Rc<Self>,
1433        session_id: &acp::SessionId,
1434        cx: &mut App,
1435    ) -> Task<Result<()>> {
1436        self.0.update(cx, |agent, cx| {
1437            let Some(session) = agent.sessions.remove(session_id) else {
1438                return;
1439            };
1440            let project_id = session.project_id;
1441            agent.save_thread(session.thread, cx);
1442
1443            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
1444            if !has_remaining {
1445                agent.projects.remove(&project_id);
1446            }
1447        });
1448        Task::ready(Ok(()))
1449    }
1450
1451    fn auth_methods(&self) -> &[acp::AuthMethod] {
1452        &[] // No auth for in-process
1453    }
1454
1455    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1456        Task::ready(Ok(()))
1457    }
1458
1459    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1460        Some(Rc::new(NativeAgentModelSelector {
1461            session_id: session_id.clone(),
1462            connection: self.clone(),
1463        }) as Rc<dyn AgentModelSelector>)
1464    }
1465
1466    fn prompt(
1467        &self,
1468        id: Option<acp_thread::UserMessageId>,
1469        params: acp::PromptRequest,
1470        cx: &mut App,
1471    ) -> Task<Result<acp::PromptResponse>> {
1472        let id = id.expect("UserMessageId is required");
1473        let session_id = params.session_id.clone();
1474        log::info!("Received prompt request for session: {}", session_id);
1475        log::debug!("Prompt blocks count: {}", params.prompt.len());
1476
1477        let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
1478            return Task::ready(Err(anyhow::anyhow!("Session not found")));
1479        };
1480
1481        if let Some(parsed_command) = Command::parse(&params.prompt) {
1482            let registry = project_state.context_server_registry.read(cx);
1483
1484            let explicit_server_id = parsed_command
1485                .explicit_server_id
1486                .map(|server_id| ContextServerId(server_id.into()));
1487
1488            if let Some(prompt) =
1489                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1490            {
1491                let arguments = if !parsed_command.arg_value.is_empty()
1492                    && let Some(arg_name) = prompt
1493                        .prompt
1494                        .arguments
1495                        .as_ref()
1496                        .and_then(|args| args.first())
1497                        .map(|arg| arg.name.clone())
1498                {
1499                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1500                } else {
1501                    Default::default()
1502                };
1503
1504                let prompt_name = prompt.prompt.name.clone();
1505                let server_id = prompt.server_id.clone();
1506
1507                return self.0.update(cx, |agent, cx| {
1508                    agent.send_mcp_prompt(
1509                        id,
1510                        session_id.clone(),
1511                        prompt_name,
1512                        server_id,
1513                        arguments,
1514                        params.prompt,
1515                        cx,
1516                    )
1517                });
1518            }
1519        };
1520
1521        let path_style = project_state.project.read(cx).path_style(cx);
1522
1523        self.run_turn(session_id, cx, move |thread, cx| {
1524            let content: Vec<UserMessageContent> = params
1525                .prompt
1526                .into_iter()
1527                .map(|block| UserMessageContent::from_content_block(block, path_style))
1528                .collect::<Vec<_>>();
1529            log::debug!("Converted prompt to message: {} chars", content.len());
1530            log::debug!("Message id: {:?}", id);
1531            log::debug!("Message content: {:?}", content);
1532
1533            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1534        })
1535    }
1536
1537    fn retry(
1538        &self,
1539        session_id: &acp::SessionId,
1540        _cx: &App,
1541    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1542        Some(Rc::new(NativeAgentSessionRetry {
1543            connection: self.clone(),
1544            session_id: session_id.clone(),
1545        }) as _)
1546    }
1547
1548    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1549        log::info!("Cancelling on session: {}", session_id);
1550        self.0.update(cx, |agent, cx| {
1551            if let Some(session) = agent.sessions.get(session_id) {
1552                session
1553                    .thread
1554                    .update(cx, |thread, cx| thread.cancel(cx))
1555                    .detach();
1556            }
1557        });
1558    }
1559
1560    fn truncate(
1561        &self,
1562        session_id: &acp::SessionId,
1563        cx: &App,
1564    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1565        self.0.read_with(cx, |agent, _cx| {
1566            agent.sessions.get(session_id).map(|session| {
1567                Rc::new(NativeAgentSessionTruncate {
1568                    thread: session.thread.clone(),
1569                    acp_thread: session.acp_thread.downgrade(),
1570                }) as _
1571            })
1572        })
1573    }
1574
1575    fn set_title(
1576        &self,
1577        session_id: &acp::SessionId,
1578        cx: &App,
1579    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1580        self.0.read_with(cx, |agent, _cx| {
1581            agent
1582                .sessions
1583                .get(session_id)
1584                .filter(|s| !s.thread.read(cx).is_subagent())
1585                .map(|session| {
1586                    Rc::new(NativeAgentSessionSetTitle {
1587                        thread: session.thread.clone(),
1588                    }) as _
1589                })
1590        })
1591    }
1592
1593    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1594        let thread_store = self.0.read(cx).thread_store.clone();
1595        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1596    }
1597
1598    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1599        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1600    }
1601
1602    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1603        self
1604    }
1605}
1606
1607impl acp_thread::AgentTelemetry for NativeAgentConnection {
1608    fn thread_data(
1609        &self,
1610        session_id: &acp::SessionId,
1611        cx: &mut App,
1612    ) -> Task<Result<serde_json::Value>> {
1613        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1614            return Task::ready(Err(anyhow!("Session not found")));
1615        };
1616
1617        let task = session.thread.read(cx).to_db(cx);
1618        cx.background_spawn(async move {
1619            serde_json::to_value(task.await).context("Failed to serialize thread")
1620        })
1621    }
1622}
1623
1624pub struct NativeAgentSessionList {
1625    thread_store: Entity<ThreadStore>,
1626    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1627    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1628    _subscription: Subscription,
1629}
1630
1631impl NativeAgentSessionList {
1632    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1633        let (tx, rx) = smol::channel::unbounded();
1634        let this_tx = tx.clone();
1635        let subscription = cx.observe(&thread_store, move |_, _| {
1636            this_tx
1637                .try_send(acp_thread::SessionListUpdate::Refresh)
1638                .ok();
1639        });
1640        Self {
1641            thread_store,
1642            updates_tx: tx,
1643            updates_rx: rx,
1644            _subscription: subscription,
1645        }
1646    }
1647
1648    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1649        &self.thread_store
1650    }
1651}
1652
1653impl AgentSessionList for NativeAgentSessionList {
1654    fn list_sessions(
1655        &self,
1656        _request: AgentSessionListRequest,
1657        cx: &mut App,
1658    ) -> Task<Result<AgentSessionListResponse>> {
1659        let sessions = self
1660            .thread_store
1661            .read(cx)
1662            .entries()
1663            .map(|entry| AgentSessionInfo::from(&entry))
1664            .collect();
1665        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1666    }
1667
1668    fn supports_delete(&self) -> bool {
1669        true
1670    }
1671
1672    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1673        self.thread_store
1674            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1675    }
1676
1677    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1678        self.thread_store
1679            .update(cx, |store, cx| store.delete_threads(cx))
1680    }
1681
1682    fn watch(
1683        &self,
1684        _cx: &mut App,
1685    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1686        Some(self.updates_rx.clone())
1687    }
1688
1689    fn notify_refresh(&self) {
1690        self.updates_tx
1691            .try_send(acp_thread::SessionListUpdate::Refresh)
1692            .ok();
1693    }
1694
1695    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1696        self
1697    }
1698}
1699
1700struct NativeAgentSessionTruncate {
1701    thread: Entity<Thread>,
1702    acp_thread: WeakEntity<AcpThread>,
1703}
1704
1705impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1706    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1707        match self.thread.update(cx, |thread, cx| {
1708            thread.truncate(message_id.clone(), cx)?;
1709            Ok(thread.latest_token_usage())
1710        }) {
1711            Ok(usage) => {
1712                self.acp_thread
1713                    .update(cx, |thread, cx| {
1714                        thread.update_token_usage(usage, cx);
1715                    })
1716                    .ok();
1717                Task::ready(Ok(()))
1718            }
1719            Err(error) => Task::ready(Err(error)),
1720        }
1721    }
1722}
1723
1724struct NativeAgentSessionRetry {
1725    connection: NativeAgentConnection,
1726    session_id: acp::SessionId,
1727}
1728
1729impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1730    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1731        self.connection
1732            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1733                thread.update(cx, |thread, cx| thread.resume(cx))
1734            })
1735    }
1736}
1737
1738struct NativeAgentSessionSetTitle {
1739    thread: Entity<Thread>,
1740}
1741
1742impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1743    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1744        self.thread
1745            .update(cx, |thread, cx| thread.set_title(title, cx));
1746        Task::ready(Ok(()))
1747    }
1748}
1749
1750pub struct NativeThreadEnvironment {
1751    agent: WeakEntity<NativeAgent>,
1752    thread: WeakEntity<Thread>,
1753    acp_thread: WeakEntity<AcpThread>,
1754}
1755
1756impl NativeThreadEnvironment {
1757    pub(crate) fn create_subagent_thread(
1758        &self,
1759        label: String,
1760        cx: &mut App,
1761    ) -> Result<Rc<dyn SubagentHandle>> {
1762        let Some(parent_thread_entity) = self.thread.upgrade() else {
1763            anyhow::bail!("Parent thread no longer exists".to_string());
1764        };
1765        let parent_thread = parent_thread_entity.read(cx);
1766        let current_depth = parent_thread.depth();
1767        let parent_session_id = parent_thread.id().clone();
1768
1769        if current_depth >= MAX_SUBAGENT_DEPTH {
1770            return Err(anyhow!(
1771                "Maximum subagent depth ({}) reached",
1772                MAX_SUBAGENT_DEPTH
1773            ));
1774        }
1775
1776        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1777            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1778            thread.set_title(label.into(), cx);
1779            thread
1780        });
1781
1782        let session_id = subagent_thread.read(cx).id().clone();
1783
1784        let acp_thread = self
1785            .agent
1786            .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
1787                let project_id = agent
1788                    .sessions
1789                    .get(&parent_session_id)
1790                    .map(|s| s.project_id)
1791                    .context("parent session not found")?;
1792                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
1793            })??;
1794
1795        let depth = current_depth + 1;
1796
1797        telemetry::event!(
1798            "Subagent Started",
1799            session = parent_thread_entity.read(cx).id().to_string(),
1800            subagent_session = session_id.to_string(),
1801            depth,
1802            is_resumed = false,
1803        );
1804
1805        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1806    }
1807
1808    pub(crate) fn resume_subagent_thread(
1809        &self,
1810        session_id: acp::SessionId,
1811        cx: &mut App,
1812    ) -> Result<Rc<dyn SubagentHandle>> {
1813        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1814            let session = agent
1815                .sessions
1816                .get(&session_id)
1817                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1818            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1819        })??;
1820
1821        let depth = subagent_thread.read(cx).depth();
1822
1823        if let Some(parent_thread_entity) = self.thread.upgrade() {
1824            telemetry::event!(
1825                "Subagent Started",
1826                session = parent_thread_entity.read(cx).id().to_string(),
1827                subagent_session = session_id.to_string(),
1828                depth,
1829                is_resumed = true,
1830            );
1831        }
1832
1833        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1834    }
1835
1836    fn prompt_subagent(
1837        &self,
1838        session_id: acp::SessionId,
1839        subagent_thread: Entity<Thread>,
1840        acp_thread: Entity<acp_thread::AcpThread>,
1841    ) -> Result<Rc<dyn SubagentHandle>> {
1842        let Some(parent_thread_entity) = self.thread.upgrade() else {
1843            anyhow::bail!("Parent thread no longer exists".to_string());
1844        };
1845        Ok(Rc::new(NativeSubagentHandle::new(
1846            session_id,
1847            subagent_thread,
1848            acp_thread,
1849            parent_thread_entity,
1850        )) as _)
1851    }
1852}
1853
1854impl ThreadEnvironment for NativeThreadEnvironment {
1855    fn create_terminal(
1856        &self,
1857        command: String,
1858        cwd: Option<PathBuf>,
1859        output_byte_limit: Option<u64>,
1860        cx: &mut AsyncApp,
1861    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1862        let task = self.acp_thread.update(cx, |thread, cx| {
1863            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1864        });
1865
1866        let acp_thread = self.acp_thread.clone();
1867        cx.spawn(async move |cx| {
1868            let terminal = task?.await?;
1869
1870            let (drop_tx, drop_rx) = oneshot::channel();
1871            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1872
1873            cx.spawn(async move |cx| {
1874                drop_rx.await.ok();
1875                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1876            })
1877            .detach();
1878
1879            let handle = AcpTerminalHandle {
1880                terminal,
1881                _drop_tx: Some(drop_tx),
1882            };
1883
1884            Ok(Rc::new(handle) as _)
1885        })
1886    }
1887
1888    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1889        self.create_subagent_thread(label, cx)
1890    }
1891
1892    fn resume_subagent(
1893        &self,
1894        session_id: acp::SessionId,
1895        cx: &mut App,
1896    ) -> Result<Rc<dyn SubagentHandle>> {
1897        self.resume_subagent_thread(session_id, cx)
1898    }
1899}
1900
1901#[derive(Debug, Clone)]
1902enum SubagentPromptResult {
1903    Completed,
1904    Cancelled,
1905    ContextWindowWarning,
1906    Error(String),
1907}
1908
1909pub struct NativeSubagentHandle {
1910    session_id: acp::SessionId,
1911    parent_thread: WeakEntity<Thread>,
1912    subagent_thread: Entity<Thread>,
1913    acp_thread: Entity<acp_thread::AcpThread>,
1914}
1915
1916impl NativeSubagentHandle {
1917    fn new(
1918        session_id: acp::SessionId,
1919        subagent_thread: Entity<Thread>,
1920        acp_thread: Entity<acp_thread::AcpThread>,
1921        parent_thread_entity: Entity<Thread>,
1922    ) -> Self {
1923        NativeSubagentHandle {
1924            session_id,
1925            subagent_thread,
1926            parent_thread: parent_thread_entity.downgrade(),
1927            acp_thread,
1928        }
1929    }
1930}
1931
1932impl SubagentHandle for NativeSubagentHandle {
1933    fn id(&self) -> acp::SessionId {
1934        self.session_id.clone()
1935    }
1936
1937    fn num_entries(&self, cx: &App) -> usize {
1938        self.acp_thread.read(cx).entries().len()
1939    }
1940
1941    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1942        let thread = self.subagent_thread.clone();
1943        let acp_thread = self.acp_thread.clone();
1944        let subagent_session_id = self.session_id.clone();
1945        let parent_thread = self.parent_thread.clone();
1946
1947        cx.spawn(async move |cx| {
1948            let (task, _subscription) = cx.update(|cx| {
1949                let ratio_before_prompt = thread
1950                    .read(cx)
1951                    .latest_token_usage()
1952                    .map(|usage| usage.ratio());
1953
1954                parent_thread
1955                    .update(cx, |parent_thread, _cx| {
1956                        parent_thread.register_running_subagent(thread.downgrade())
1957                    })
1958                    .ok();
1959
1960                let task = acp_thread.update(cx, |acp_thread, cx| {
1961                    acp_thread.send(vec![message.into()], cx)
1962                });
1963
1964                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1965                let mut token_limit_tx = Some(token_limit_tx);
1966
1967                let subscription = cx.subscribe(
1968                    &thread,
1969                    move |_thread, event: &TokenUsageUpdated, _cx| {
1970                        if let Some(usage) = &event.0 {
1971                            let old_ratio = ratio_before_prompt
1972                                .clone()
1973                                .unwrap_or(TokenUsageRatio::Normal);
1974                            let new_ratio = usage.ratio();
1975                            if old_ratio == TokenUsageRatio::Normal
1976                                && new_ratio == TokenUsageRatio::Warning
1977                            {
1978                                if let Some(tx) = token_limit_tx.take() {
1979                                    tx.send(()).ok();
1980                                }
1981                            }
1982                        }
1983                    },
1984                );
1985
1986                let wait_for_prompt = cx
1987                    .background_spawn(async move {
1988                        futures::select! {
1989                            response = task.fuse() => match response {
1990                                Ok(Some(response)) => {
1991                                    match response.stop_reason {
1992                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1993                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1994                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1995                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1996                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1997                                    }
1998                                }
1999                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
2000                                Err(error) => SubagentPromptResult::Error(error.to_string()),
2001                            },
2002                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
2003                        }
2004                    });
2005
2006                (wait_for_prompt, subscription)
2007            });
2008
2009            let result = match task.await {
2010                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
2011                    thread
2012                        .last_message()
2013                        .and_then(|message| {
2014                            let content = message.as_agent_message()?
2015                                .content
2016                                .iter()
2017                                .filter_map(|c| match c {
2018                                    AgentMessageContent::Text(text) => Some(text.as_str()),
2019                                    _ => None,
2020                                })
2021                                .join("\n\n");
2022                            if content.is_empty() {
2023                                None
2024                            } else {
2025                                Some( content)
2026                            }
2027                        })
2028                        .context("No response from subagent")
2029                }),
2030                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
2031                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
2032                SubagentPromptResult::ContextWindowWarning => {
2033                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2034                    Err(anyhow!(
2035                        "The agent is nearing the end of its context window and has been \
2036                         stopped. You can prompt the thread again to have the agent wrap up \
2037                         or hand off its work."
2038                    ))
2039                }
2040            };
2041
2042            parent_thread
2043                .update(cx, |parent_thread, cx| {
2044                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
2045                })
2046                .ok();
2047
2048            result
2049        })
2050    }
2051}
2052
2053pub struct AcpTerminalHandle {
2054    terminal: Entity<acp_thread::Terminal>,
2055    _drop_tx: Option<oneshot::Sender<()>>,
2056}
2057
2058impl TerminalHandle for AcpTerminalHandle {
2059    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
2060        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
2061    }
2062
2063    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
2064        Ok(self
2065            .terminal
2066            .read_with(cx, |term, _cx| term.wait_for_exit()))
2067    }
2068
2069    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
2070        Ok(self
2071            .terminal
2072            .read_with(cx, |term, cx| term.current_output(cx)))
2073    }
2074
2075    fn kill(&self, cx: &AsyncApp) -> Result<()> {
2076        cx.update(|cx| {
2077            self.terminal.update(cx, |terminal, cx| {
2078                terminal.kill(cx);
2079            });
2080        });
2081        Ok(())
2082    }
2083
2084    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
2085        Ok(self
2086            .terminal
2087            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
2088    }
2089}
2090
2091#[cfg(test)]
2092mod internal_tests {
2093    use std::path::Path;
2094
2095    use super::*;
2096    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
2097    use fs::FakeFs;
2098    use gpui::TestAppContext;
2099    use indoc::formatdoc;
2100    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2101    use language_model::{
2102        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
2103    };
2104    use serde_json::json;
2105    use settings::SettingsStore;
2106    use util::{path, rel_path::rel_path};
2107
2108    #[gpui::test]
2109    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
2110        init_test(cx);
2111        let fs = FakeFs::new(cx.executor());
2112        fs.insert_tree(
2113            "/",
2114            json!({
2115                "a": {}
2116            }),
2117        )
2118        .await;
2119        let project = Project::test(fs.clone(), [], cx).await;
2120        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2121        let agent =
2122            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2123
2124        // Creating a session registers the project and triggers context building.
2125        let connection = NativeAgentConnection(agent.clone());
2126        let _acp_thread = cx
2127            .update(|cx| {
2128                Rc::new(connection).new_session(
2129                    project.clone(),
2130                    PathList::new(&[Path::new("/")]),
2131                    cx,
2132                )
2133            })
2134            .await
2135            .unwrap();
2136        cx.run_until_parked();
2137
2138        let thread = agent.read_with(cx, |agent, _cx| {
2139            agent.sessions.values().next().unwrap().thread.clone()
2140        });
2141
2142        agent.read_with(cx, |agent, cx| {
2143            let project_id = project.entity_id();
2144            let state = agent.projects.get(&project_id).unwrap();
2145            assert_eq!(state.project_context.read(cx).worktrees, vec![]);
2146            assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
2147        });
2148
2149        let worktree = project
2150            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
2151            .await
2152            .unwrap();
2153        cx.run_until_parked();
2154        agent.read_with(cx, |agent, cx| {
2155            let project_id = project.entity_id();
2156            let state = agent.projects.get(&project_id).unwrap();
2157            let expected_worktrees = vec![WorktreeContext {
2158                root_name: "a".into(),
2159                abs_path: Path::new("/a").into(),
2160                rules_file: None,
2161            }];
2162            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2163            assert_eq!(
2164                thread.read(cx).project_context().read(cx).worktrees,
2165                expected_worktrees
2166            );
2167        });
2168
2169        // Creating `/a/.rules` updates the project context.
2170        fs.insert_file("/a/.rules", Vec::new()).await;
2171        cx.run_until_parked();
2172        agent.read_with(cx, |agent, cx| {
2173            let project_id = project.entity_id();
2174            let state = agent.projects.get(&project_id).unwrap();
2175            let rules_entry = worktree
2176                .read(cx)
2177                .entry_for_path(rel_path(".rules"))
2178                .unwrap();
2179            let expected_worktrees = vec![WorktreeContext {
2180                root_name: "a".into(),
2181                abs_path: Path::new("/a").into(),
2182                rules_file: Some(RulesFileContext {
2183                    path_in_worktree: rel_path(".rules").into(),
2184                    text: "".into(),
2185                    project_entry_id: rules_entry.id.to_usize(),
2186                }),
2187            }];
2188            assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
2189            assert_eq!(
2190                thread.read(cx).project_context().read(cx).worktrees,
2191                expected_worktrees
2192            );
2193        });
2194    }
2195
2196    #[gpui::test]
2197    async fn test_listing_models(cx: &mut TestAppContext) {
2198        init_test(cx);
2199        let fs = FakeFs::new(cx.executor());
2200        fs.insert_tree("/", json!({ "a": {}  })).await;
2201        let project = Project::test(fs.clone(), [], cx).await;
2202        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2203        let connection =
2204            NativeAgentConnection(cx.update(|cx| {
2205                NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
2206            }));
2207
2208        // Create a thread/session
2209        let acp_thread = cx
2210            .update(|cx| {
2211                Rc::new(connection.clone()).new_session(
2212                    project.clone(),
2213                    PathList::new(&[Path::new("/a")]),
2214                    cx,
2215                )
2216            })
2217            .await
2218            .unwrap();
2219
2220        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2221
2222        let models = cx
2223            .update(|cx| {
2224                connection
2225                    .model_selector(&session_id)
2226                    .unwrap()
2227                    .list_models(cx)
2228            })
2229            .await
2230            .unwrap();
2231
2232        let acp_thread::AgentModelList::Grouped(models) = models else {
2233            panic!("Unexpected model group");
2234        };
2235        assert_eq!(
2236            models,
2237            IndexMap::from_iter([(
2238                AgentModelGroupName("Fake".into()),
2239                vec![AgentModelInfo {
2240                    id: acp::ModelId::new("fake/fake"),
2241                    name: "Fake".into(),
2242                    description: None,
2243                    icon: Some(acp_thread::AgentModelIcon::Named(
2244                        ui::IconName::ZedAssistant
2245                    )),
2246                    is_latest: false,
2247                    cost: None,
2248                }]
2249            )])
2250        );
2251    }
2252
2253    #[gpui::test]
2254    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2255        init_test(cx);
2256        let fs = FakeFs::new(cx.executor());
2257        fs.create_dir(paths::settings_file().parent().unwrap())
2258            .await
2259            .unwrap();
2260        fs.insert_file(
2261            paths::settings_file(),
2262            json!({
2263                "agent": {
2264                    "default_model": {
2265                        "provider": "foo",
2266                        "model": "bar"
2267                    }
2268                }
2269            })
2270            .to_string()
2271            .into_bytes(),
2272        )
2273        .await;
2274        let project = Project::test(fs.clone(), [], cx).await;
2275
2276        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2277
2278        // Create the agent and connection
2279        let agent =
2280            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2281        let connection = NativeAgentConnection(agent.clone());
2282
2283        // Create a thread/session
2284        let acp_thread = cx
2285            .update(|cx| {
2286                Rc::new(connection.clone()).new_session(
2287                    project.clone(),
2288                    PathList::new(&[Path::new("/a")]),
2289                    cx,
2290                )
2291            })
2292            .await
2293            .unwrap();
2294
2295        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2296
2297        // Select a model
2298        let selector = connection.model_selector(&session_id).unwrap();
2299        let model_id = acp::ModelId::new("fake/fake");
2300        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2301            .await
2302            .unwrap();
2303
2304        // Verify the thread has the selected model
2305        agent.read_with(cx, |agent, _| {
2306            let session = agent.sessions.get(&session_id).unwrap();
2307            session.thread.read_with(cx, |thread, _| {
2308                assert_eq!(thread.model().unwrap().id().0, "fake");
2309            });
2310        });
2311
2312        cx.run_until_parked();
2313
2314        // Verify settings file was updated
2315        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2316        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2317
2318        // Check that the agent settings contain the selected model
2319        assert_eq!(
2320            settings_json["agent"]["default_model"]["model"],
2321            json!("fake")
2322        );
2323        assert_eq!(
2324            settings_json["agent"]["default_model"]["provider"],
2325            json!("fake")
2326        );
2327
2328        // Register a thinking model and select it.
2329        cx.update(|cx| {
2330            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2331                "fake-corp",
2332                "fake-thinking",
2333                "Fake Thinking",
2334                true,
2335            ));
2336            let thinking_provider = Arc::new(
2337                FakeLanguageModelProvider::new(
2338                    LanguageModelProviderId::from("fake-corp".to_string()),
2339                    LanguageModelProviderName::from("Fake Corp".to_string()),
2340                )
2341                .with_models(vec![thinking_model]),
2342            );
2343            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2344                registry.register_provider(thinking_provider, cx);
2345            });
2346        });
2347        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2348
2349        let selector = connection.model_selector(&session_id).unwrap();
2350        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2351            .await
2352            .unwrap();
2353        cx.run_until_parked();
2354
2355        // Verify enable_thinking was written to settings as true.
2356        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2357        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2358        assert_eq!(
2359            settings_json["agent"]["default_model"]["enable_thinking"],
2360            json!(true),
2361            "selecting a thinking model should persist enable_thinking: true to settings"
2362        );
2363    }
2364
2365    #[gpui::test]
2366    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2367        init_test(cx);
2368        let fs = FakeFs::new(cx.executor());
2369        fs.create_dir(paths::settings_file().parent().unwrap())
2370            .await
2371            .unwrap();
2372        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2373        let project = Project::test(fs.clone(), [], cx).await;
2374
2375        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2376        let agent =
2377            cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
2378        let connection = NativeAgentConnection(agent.clone());
2379
2380        let acp_thread = cx
2381            .update(|cx| {
2382                Rc::new(connection.clone()).new_session(
2383                    project.clone(),
2384                    PathList::new(&[Path::new("/a")]),
2385                    cx,
2386                )
2387            })
2388            .await
2389            .unwrap();
2390        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2391
2392        // Register a second provider with a thinking model.
2393        cx.update(|cx| {
2394            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2395                "fake-corp",
2396                "fake-thinking",
2397                "Fake Thinking",
2398                true,
2399            ));
2400            let thinking_provider = Arc::new(
2401                FakeLanguageModelProvider::new(
2402                    LanguageModelProviderId::from("fake-corp".to_string()),
2403                    LanguageModelProviderName::from("Fake Corp".to_string()),
2404                )
2405                .with_models(vec![thinking_model]),
2406            );
2407            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2408                registry.register_provider(thinking_provider, cx);
2409            });
2410        });
2411        // Refresh the agent's model list so it picks up the new provider.
2412        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2413
2414        // Thread starts with thinking_enabled = false (the default).
2415        agent.read_with(cx, |agent, _| {
2416            let session = agent.sessions.get(&session_id).unwrap();
2417            session.thread.read_with(cx, |thread, _| {
2418                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2419            });
2420        });
2421
2422        // Select the thinking model via select_model.
2423        let selector = connection.model_selector(&session_id).unwrap();
2424        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2425            .await
2426            .unwrap();
2427
2428        // select_model should have enabled thinking based on the model's supports_thinking().
2429        agent.read_with(cx, |agent, _| {
2430            let session = agent.sessions.get(&session_id).unwrap();
2431            session.thread.read_with(cx, |thread, _| {
2432                assert!(
2433                    thread.thinking_enabled(),
2434                    "select_model should enable thinking when model supports it"
2435                );
2436            });
2437        });
2438
2439        // Switch back to the non-thinking model.
2440        let selector = connection.model_selector(&session_id).unwrap();
2441        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2442            .await
2443            .unwrap();
2444
2445        // select_model should have disabled thinking.
2446        agent.read_with(cx, |agent, _| {
2447            let session = agent.sessions.get(&session_id).unwrap();
2448            session.thread.read_with(cx, |thread, _| {
2449                assert!(
2450                    !thread.thinking_enabled(),
2451                    "select_model should disable thinking when model does not support it"
2452                );
2453            });
2454        });
2455    }
2456
2457    #[gpui::test]
2458    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2459        init_test(cx);
2460        let fs = FakeFs::new(cx.executor());
2461        fs.insert_tree("/", json!({ "a": {} })).await;
2462        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2463        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2464        let agent = cx.update(|cx| {
2465            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2466        });
2467        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2468
2469        // Register a thinking model.
2470        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2471            "fake-corp",
2472            "fake-thinking",
2473            "Fake Thinking",
2474            true,
2475        ));
2476        let thinking_provider = Arc::new(
2477            FakeLanguageModelProvider::new(
2478                LanguageModelProviderId::from("fake-corp".to_string()),
2479                LanguageModelProviderName::from("Fake Corp".to_string()),
2480            )
2481            .with_models(vec![thinking_model.clone()]),
2482        );
2483        cx.update(|cx| {
2484            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2485                registry.register_provider(thinking_provider, cx);
2486            });
2487        });
2488        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2489
2490        // Create a thread and select the thinking model.
2491        let acp_thread = cx
2492            .update(|cx| {
2493                connection.clone().new_session(
2494                    project.clone(),
2495                    PathList::new(&[Path::new("/a")]),
2496                    cx,
2497                )
2498            })
2499            .await
2500            .unwrap();
2501        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2502
2503        let selector = connection.model_selector(&session_id).unwrap();
2504        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2505            .await
2506            .unwrap();
2507
2508        // Verify thinking is enabled after selecting the thinking model.
2509        let thread = agent.read_with(cx, |agent, _| {
2510            agent.sessions.get(&session_id).unwrap().thread.clone()
2511        });
2512        thread.read_with(cx, |thread, _| {
2513            assert!(
2514                thread.thinking_enabled(),
2515                "thinking should be enabled after selecting thinking model"
2516            );
2517        });
2518
2519        // Send a message so the thread gets persisted.
2520        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2521        let send = cx.foreground_executor().spawn(send);
2522        cx.run_until_parked();
2523
2524        thinking_model.send_last_completion_stream_text_chunk("Response.");
2525        thinking_model.end_last_completion_stream();
2526
2527        send.await.unwrap();
2528        cx.run_until_parked();
2529
2530        // Close the session so it can be reloaded from disk.
2531        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2532            .await
2533            .unwrap();
2534        drop(thread);
2535        drop(acp_thread);
2536        agent.read_with(cx, |agent, _| {
2537            assert!(agent.sessions.is_empty());
2538        });
2539
2540        // Reload the thread and verify thinking_enabled is still true.
2541        let reloaded_acp_thread = agent
2542            .update(cx, |agent, cx| {
2543                agent.open_thread(session_id.clone(), project.clone(), cx)
2544            })
2545            .await
2546            .unwrap();
2547        let reloaded_thread = agent.read_with(cx, |agent, _| {
2548            agent.sessions.get(&session_id).unwrap().thread.clone()
2549        });
2550        reloaded_thread.read_with(cx, |thread, _| {
2551            assert!(
2552                thread.thinking_enabled(),
2553                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2554            );
2555        });
2556
2557        drop(reloaded_acp_thread);
2558    }
2559
2560    #[gpui::test]
2561    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2562        init_test(cx);
2563        let fs = FakeFs::new(cx.executor());
2564        fs.insert_tree("/", json!({ "a": {} })).await;
2565        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2566        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2567        let agent = cx.update(|cx| {
2568            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2569        });
2570        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2571
2572        // Register a model where id() != name(), like real Anthropic models
2573        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2574        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2575            "fake-corp",
2576            "custom-model-id",
2577            "Custom Model Display Name",
2578            false,
2579        ));
2580        let provider = Arc::new(
2581            FakeLanguageModelProvider::new(
2582                LanguageModelProviderId::from("fake-corp".to_string()),
2583                LanguageModelProviderName::from("Fake Corp".to_string()),
2584            )
2585            .with_models(vec![model.clone()]),
2586        );
2587        cx.update(|cx| {
2588            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2589                registry.register_provider(provider, cx);
2590            });
2591        });
2592        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2593
2594        // Create a thread and select the model.
2595        let acp_thread = cx
2596            .update(|cx| {
2597                connection.clone().new_session(
2598                    project.clone(),
2599                    PathList::new(&[Path::new("/a")]),
2600                    cx,
2601                )
2602            })
2603            .await
2604            .unwrap();
2605        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2606
2607        let selector = connection.model_selector(&session_id).unwrap();
2608        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2609            .await
2610            .unwrap();
2611
2612        let thread = agent.read_with(cx, |agent, _| {
2613            agent.sessions.get(&session_id).unwrap().thread.clone()
2614        });
2615        thread.read_with(cx, |thread, _| {
2616            assert_eq!(
2617                thread.model().unwrap().id().0.as_ref(),
2618                "custom-model-id",
2619                "model should be set before persisting"
2620            );
2621        });
2622
2623        // Send a message so the thread gets persisted.
2624        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2625        let send = cx.foreground_executor().spawn(send);
2626        cx.run_until_parked();
2627
2628        model.send_last_completion_stream_text_chunk("Response.");
2629        model.end_last_completion_stream();
2630
2631        send.await.unwrap();
2632        cx.run_until_parked();
2633
2634        // Close the session so it can be reloaded from disk.
2635        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2636            .await
2637            .unwrap();
2638        drop(thread);
2639        drop(acp_thread);
2640        agent.read_with(cx, |agent, _| {
2641            assert!(agent.sessions.is_empty());
2642        });
2643
2644        // Reload the thread and verify the model was preserved.
2645        let reloaded_acp_thread = agent
2646            .update(cx, |agent, cx| {
2647                agent.open_thread(session_id.clone(), project.clone(), cx)
2648            })
2649            .await
2650            .unwrap();
2651        let reloaded_thread = agent.read_with(cx, |agent, _| {
2652            agent.sessions.get(&session_id).unwrap().thread.clone()
2653        });
2654        reloaded_thread.read_with(cx, |thread, _| {
2655            let reloaded_model = thread
2656                .model()
2657                .expect("model should be present after reload");
2658            assert_eq!(
2659                reloaded_model.id().0.as_ref(),
2660                "custom-model-id",
2661                "reloaded thread should have the same model, not fall back to the default"
2662            );
2663        });
2664
2665        drop(reloaded_acp_thread);
2666    }
2667
2668    #[gpui::test]
2669    async fn test_save_load_thread(cx: &mut TestAppContext) {
2670        init_test(cx);
2671        let fs = FakeFs::new(cx.executor());
2672        fs.insert_tree(
2673            "/",
2674            json!({
2675                "a": {
2676                    "b.md": "Lorem"
2677                }
2678            }),
2679        )
2680        .await;
2681        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2682        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2683        let agent = cx.update(|cx| {
2684            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
2685        });
2686        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2687
2688        let acp_thread = cx
2689            .update(|cx| {
2690                connection
2691                    .clone()
2692                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
2693            })
2694            .await
2695            .unwrap();
2696        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2697        let thread = agent.read_with(cx, |agent, _| {
2698            agent.sessions.get(&session_id).unwrap().thread.clone()
2699        });
2700
2701        // Ensure empty threads are not saved, even if they get mutated.
2702        let model = Arc::new(FakeLanguageModel::default());
2703        let summary_model = Arc::new(FakeLanguageModel::default());
2704        thread.update(cx, |thread, cx| {
2705            thread.set_model(model.clone(), cx);
2706            thread.set_summarization_model(Some(summary_model.clone()), cx);
2707        });
2708        cx.run_until_parked();
2709        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2710
2711        let send = acp_thread.update(cx, |thread, cx| {
2712            thread.send(
2713                vec![
2714                    "What does ".into(),
2715                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2716                        "b.md",
2717                        MentionUri::File {
2718                            abs_path: path!("/a/b.md").into(),
2719                        }
2720                        .to_uri()
2721                        .to_string(),
2722                    )),
2723                    " mean?".into(),
2724                ],
2725                cx,
2726            )
2727        });
2728        let send = cx.foreground_executor().spawn(send);
2729        cx.run_until_parked();
2730
2731        model.send_last_completion_stream_text_chunk("Lorem.");
2732        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2733            language_model::TokenUsage {
2734                input_tokens: 150,
2735                output_tokens: 75,
2736                ..Default::default()
2737            },
2738        ));
2739        model.end_last_completion_stream();
2740        cx.run_until_parked();
2741        summary_model
2742            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2743        summary_model.end_last_completion_stream();
2744
2745        send.await.unwrap();
2746        let uri = MentionUri::File {
2747            abs_path: path!("/a/b.md").into(),
2748        }
2749        .to_uri();
2750        acp_thread.read_with(cx, |thread, cx| {
2751            assert_eq!(
2752                thread.to_markdown(cx),
2753                formatdoc! {"
2754                    ## User
2755
2756                    What does [@b.md]({uri}) mean?
2757
2758                    ## Assistant
2759
2760                    Lorem.
2761
2762                "}
2763            )
2764        });
2765
2766        cx.run_until_parked();
2767
2768        // Set a draft prompt with rich content blocks before saving.
2769        let draft_blocks = vec![
2770            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2771            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2772            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2773        ];
2774        acp_thread.update(cx, |thread, _cx| {
2775            thread.set_draft_prompt(Some(draft_blocks.clone()));
2776        });
2777        thread.update(cx, |thread, _cx| {
2778            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2779                item_ix: 5,
2780                offset_in_item: gpui::px(12.5),
2781            }));
2782        });
2783        thread.update(cx, |_thread, cx| cx.notify());
2784        cx.run_until_parked();
2785
2786        // Close the session so it can be reloaded from disk.
2787        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2788            .await
2789            .unwrap();
2790        drop(thread);
2791        drop(acp_thread);
2792        agent.read_with(cx, |agent, _| {
2793            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2794        });
2795
2796        // Ensure the thread can be reloaded from disk.
2797        assert_eq!(
2798            thread_entries(&thread_store, cx),
2799            vec![(
2800                session_id.clone(),
2801                format!("Explaining {}", path!("/a/b.md"))
2802            )]
2803        );
2804        let acp_thread = agent
2805            .update(cx, |agent, cx| {
2806                agent.open_thread(session_id.clone(), project.clone(), cx)
2807            })
2808            .await
2809            .unwrap();
2810        acp_thread.read_with(cx, |thread, cx| {
2811            assert_eq!(
2812                thread.to_markdown(cx),
2813                formatdoc! {"
2814                    ## User
2815
2816                    What does [@b.md]({uri}) mean?
2817
2818                    ## Assistant
2819
2820                    Lorem.
2821
2822                "}
2823            )
2824        });
2825
2826        // Ensure the draft prompt with rich content blocks survived the round-trip.
2827        acp_thread.read_with(cx, |thread, _| {
2828            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2829        });
2830
2831        // Ensure token usage survived the round-trip.
2832        acp_thread.read_with(cx, |thread, _| {
2833            let usage = thread
2834                .token_usage()
2835                .expect("token usage should be restored after reload");
2836            assert_eq!(usage.input_tokens, 150);
2837            assert_eq!(usage.output_tokens, 75);
2838        });
2839
2840        // Ensure scroll position survived the round-trip.
2841        acp_thread.read_with(cx, |thread, _| {
2842            let scroll = thread
2843                .ui_scroll_position()
2844                .expect("scroll position should be restored after reload");
2845            assert_eq!(scroll.item_ix, 5);
2846            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2847        });
2848    }
2849
2850    fn thread_entries(
2851        thread_store: &Entity<ThreadStore>,
2852        cx: &mut TestAppContext,
2853    ) -> Vec<(acp::SessionId, String)> {
2854        thread_store.read_with(cx, |store, _| {
2855            store
2856                .entries()
2857                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2858                .collect::<Vec<_>>()
2859        })
2860    }
2861
2862    fn init_test(cx: &mut TestAppContext) {
2863        env_logger::try_init().ok();
2864        cx.update(|cx| {
2865            let settings_store = SettingsStore::test(cx);
2866            cx.set_global(settings_store);
2867
2868            LanguageModelRegistry::test(cx);
2869        });
2870    }
2871}
2872
2873fn mcp_message_content_to_acp_content_block(
2874    content: context_server::types::MessageContent,
2875) -> acp::ContentBlock {
2876    match content {
2877        context_server::types::MessageContent::Text {
2878            text,
2879            annotations: _,
2880        } => text.into(),
2881        context_server::types::MessageContent::Image {
2882            data,
2883            mime_type,
2884            annotations: _,
2885        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2886        context_server::types::MessageContent::Audio {
2887            data,
2888            mime_type,
2889            annotations: _,
2890        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2891        context_server::types::MessageContent::Resource {
2892            resource,
2893            annotations: _,
2894        } => {
2895            let mut link =
2896                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2897            if let Some(mime_type) = resource.mime_type {
2898                link = link.mime_type(mime_type);
2899            }
2900            acp::ContentBlock::ResourceLink(link)
2901        }
2902    }
2903}