agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  41};
  42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  43use project::{Project, ProjectItem, ProjectPath, Worktree};
  44use prompt_store::{
  45    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  46    WorktreeContext,
  47};
  48use serde::{Deserialize, Serialize};
  49use settings::{LanguageModelSelection, update_settings_file};
  50use std::any::Any;
  51use std::path::{Path, PathBuf};
  52use std::rc::Rc;
  53use std::sync::Arc;
  54use util::ResultExt;
  55use util::path_list::PathList;
  56use util::rel_path::RelPath;
  57
  58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  59pub struct ProjectSnapshot {
  60    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  61    pub timestamp: DateTime<Utc>,
  62}
  63
  64pub struct RulesLoadingError {
  65    pub message: SharedString,
  66}
  67
  68/// Holds both the internal Thread and the AcpThread for a session
  69struct Session {
  70    /// The internal thread that processes messages
  71    thread: Entity<Thread>,
  72    /// The ACP thread that handles protocol communication
  73    acp_thread: Entity<acp_thread::AcpThread>,
  74    pending_save: Task<()>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78pub struct LanguageModels {
  79    /// Access language model by ID
  80    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  81    /// Cached list for returning language model information
  82    model_list: acp_thread::AgentModelList,
  83    refresh_models_rx: watch::Receiver<()>,
  84    refresh_models_tx: watch::Sender<()>,
  85    _authenticate_all_providers_task: Task<()>,
  86}
  87
  88impl LanguageModels {
  89    fn new(cx: &mut App) -> Self {
  90        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  91
  92        let mut this = Self {
  93            models: HashMap::default(),
  94            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  95            refresh_models_rx,
  96            refresh_models_tx,
  97            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  98        };
  99        this.refresh_list(cx);
 100        this
 101    }
 102
 103    fn refresh_list(&mut self, cx: &App) {
 104        let providers = LanguageModelRegistry::global(cx)
 105            .read(cx)
 106            .visible_providers()
 107            .into_iter()
 108            .filter(|provider| provider.is_authenticated(cx))
 109            .collect::<Vec<_>>();
 110
 111        let mut language_model_list = IndexMap::default();
 112        let mut recommended_models = HashSet::default();
 113
 114        let mut recommended = Vec::new();
 115        for provider in &providers {
 116            for model in provider.recommended_models(cx) {
 117                recommended_models.insert((model.provider_id(), model.id()));
 118                recommended.push(Self::map_language_model_to_info(&model, provider));
 119            }
 120        }
 121        if !recommended.is_empty() {
 122            language_model_list.insert(
 123                acp_thread::AgentModelGroupName("Recommended".into()),
 124                recommended,
 125            );
 126        }
 127
 128        let mut models = HashMap::default();
 129        for provider in providers {
 130            let mut provider_models = Vec::new();
 131            for model in provider.provided_models(cx) {
 132                let model_info = Self::map_language_model_to_info(&model, &provider);
 133                let model_id = model_info.id.clone();
 134                provider_models.push(model_info);
 135                models.insert(model_id, model);
 136            }
 137            if !provider_models.is_empty() {
 138                language_model_list.insert(
 139                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 140                    provider_models,
 141                );
 142            }
 143        }
 144
 145        self.models = models;
 146        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 147        self.refresh_models_tx.send(()).ok();
 148    }
 149
 150    fn watch(&self) -> watch::Receiver<()> {
 151        self.refresh_models_rx.clone()
 152    }
 153
 154    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 155        self.models.get(model_id).cloned()
 156    }
 157
 158    fn map_language_model_to_info(
 159        model: &Arc<dyn LanguageModel>,
 160        provider: &Arc<dyn LanguageModelProvider>,
 161    ) -> acp_thread::AgentModelInfo {
 162        acp_thread::AgentModelInfo {
 163            id: Self::model_id(model),
 164            name: model.name().0,
 165            description: None,
 166            icon: Some(match provider.icon() {
 167                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 168                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 169            }),
 170            is_latest: model.is_latest(),
 171            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 172        }
 173    }
 174
 175    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 176        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 177    }
 178
 179    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 181            .read(cx)
 182            .visible_providers()
 183            .iter()
 184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 185            .collect::<Vec<_>>();
 186
 187        cx.background_spawn(async move {
 188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 189                if let Err(err) = authenticate_task.await {
 190                    match err {
 191                        language_model::AuthenticateError::CredentialsNotFound => {
 192                            // Since we're authenticating these providers in the
 193                            // background for the purposes of populating the
 194                            // language selector, we don't care about providers
 195                            // where the credentials are not found.
 196                        }
 197                        language_model::AuthenticateError::ConnectionRefused => {
 198                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 199                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 200                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 201                        }
 202                        _ => {
 203                            // Some providers have noisy failure states that we
 204                            // don't want to spam the logs with every time the
 205                            // language model selector is initialized.
 206                            //
 207                            // Ideally these should have more clear failure modes
 208                            // that we know are safe to ignore here, like what we do
 209                            // with `CredentialsNotFound` above.
 210                            match provider_id.0.as_ref() {
 211                                "lmstudio" | "ollama" => {
 212                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 213                                    //
 214                                    // These fail noisily, so we don't log them.
 215                                }
 216                                "copilot_chat" => {
 217                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 218                                }
 219                                _ => {
 220                                    log::error!(
 221                                        "Failed to authenticate provider: {}: {err:#}",
 222                                        provider_name.0
 223                                    );
 224                                }
 225                            }
 226                        }
 227                    }
 228                }
 229            }
 230        })
 231    }
 232}
 233
 234pub struct NativeAgent {
 235    /// Session ID -> Session mapping
 236    sessions: HashMap<acp::SessionId, Session>,
 237    thread_store: Entity<ThreadStore>,
 238    /// Shared project context for all threads
 239    project_context: Entity<ProjectContext>,
 240    project_context_needs_refresh: watch::Sender<()>,
 241    _maintain_project_context: Task<Result<()>>,
 242    context_server_registry: Entity<ContextServerRegistry>,
 243    /// Shared templates for all threads
 244    templates: Arc<Templates>,
 245    /// Cached model information
 246    models: LanguageModels,
 247    project: Entity<Project>,
 248    prompt_store: Option<Entity<PromptStore>>,
 249    fs: Arc<dyn Fs>,
 250    _subscriptions: Vec<Subscription>,
 251}
 252
 253impl NativeAgent {
 254    pub async fn new(
 255        project: Entity<Project>,
 256        thread_store: Entity<ThreadStore>,
 257        templates: Arc<Templates>,
 258        prompt_store: Option<Entity<PromptStore>>,
 259        fs: Arc<dyn Fs>,
 260        cx: &mut AsyncApp,
 261    ) -> Result<Entity<NativeAgent>> {
 262        log::debug!("Creating new NativeAgent");
 263
 264        let project_context = cx
 265            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 266            .await;
 267
 268        Ok(cx.new(|cx| {
 269            let context_server_store = project.read(cx).context_server_store();
 270            let context_server_registry =
 271                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 272
 273            let mut subscriptions = vec![
 274                cx.subscribe(&project, Self::handle_project_event),
 275                cx.subscribe(
 276                    &LanguageModelRegistry::global(cx),
 277                    Self::handle_models_updated_event,
 278                ),
 279                cx.subscribe(
 280                    &context_server_store,
 281                    Self::handle_context_server_store_updated,
 282                ),
 283                cx.subscribe(
 284                    &context_server_registry,
 285                    Self::handle_context_server_registry_event,
 286                ),
 287            ];
 288            if let Some(prompt_store) = prompt_store.as_ref() {
 289                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 290            }
 291
 292            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 293                watch::channel(());
 294            Self {
 295                sessions: HashMap::default(),
 296                thread_store,
 297                project_context: cx.new(|_| project_context),
 298                project_context_needs_refresh: project_context_needs_refresh_tx,
 299                _maintain_project_context: cx.spawn(async move |this, cx| {
 300                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 301                }),
 302                context_server_registry,
 303                templates,
 304                models: LanguageModels::new(cx),
 305                project,
 306                prompt_store,
 307                fs,
 308                _subscriptions: subscriptions,
 309            }
 310        }))
 311    }
 312
 313    fn new_session(
 314        &mut self,
 315        project: Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> Entity<AcpThread> {
 318        // Create Thread
 319        // Fetch default model from registry settings
 320        let registry = LanguageModelRegistry::read_global(cx);
 321        // Log available models for debugging
 322        let available_count = registry.available_models(cx).count();
 323        log::debug!("Total available models: {}", available_count);
 324
 325        let default_model = registry.default_model().and_then(|default_model| {
 326            self.models
 327                .model_from_id(&LanguageModels::model_id(&default_model.model))
 328        });
 329        let thread = cx.new(|cx| {
 330            Thread::new(
 331                project.clone(),
 332                self.project_context.clone(),
 333                self.context_server_registry.clone(),
 334                self.templates.clone(),
 335                default_model,
 336                cx,
 337            )
 338        });
 339
 340        self.register_session(thread, cx)
 341    }
 342
 343    fn register_session(
 344        &mut self,
 345        thread_handle: Entity<Thread>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let draft_prompt = thread.draft_prompt().map(Vec::from);
 355        let scroll_position = thread.ui_scroll_position();
 356        let token_usage = thread.latest_token_usage();
 357        let project = thread.project.clone();
 358        let action_log = thread.action_log.clone();
 359        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 360        let acp_thread = cx.new(|cx| {
 361            let mut acp_thread = acp_thread::AcpThread::new(
 362                parent_session_id,
 363                title,
 364                connection,
 365                project.clone(),
 366                action_log.clone(),
 367                session_id.clone(),
 368                prompt_capabilities_rx,
 369                cx,
 370            );
 371            acp_thread.set_draft_prompt(draft_prompt);
 372            acp_thread.set_ui_scroll_position(scroll_position);
 373            acp_thread.update_token_usage(token_usage, cx);
 374            acp_thread
 375        });
 376
 377        let registry = LanguageModelRegistry::read_global(cx);
 378        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 379
 380        let weak = cx.weak_entity();
 381        let weak_thread = thread_handle.downgrade();
 382        thread_handle.update(cx, |thread, cx| {
 383            thread.set_summarization_model(summarization_model, cx);
 384            thread.add_default_tools(
 385                Rc::new(NativeThreadEnvironment {
 386                    acp_thread: acp_thread.downgrade(),
 387                    thread: weak_thread,
 388                    agent: weak,
 389                }) as _,
 390                cx,
 391            )
 392        });
 393
 394        let subscriptions = vec![
 395            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 396            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 397            cx.observe(&thread_handle, move |this, thread, cx| {
 398                this.save_thread(thread, cx)
 399            }),
 400        ];
 401
 402        self.sessions.insert(
 403            session_id,
 404            Session {
 405                thread: thread_handle,
 406                acp_thread: acp_thread.clone(),
 407                _subscriptions: subscriptions,
 408                pending_save: Task::ready(()),
 409            },
 410        );
 411
 412        self.update_available_commands(cx);
 413
 414        acp_thread
 415    }
 416
 417    pub fn models(&self) -> &LanguageModels {
 418        &self.models
 419    }
 420
 421    async fn maintain_project_context(
 422        this: WeakEntity<Self>,
 423        mut needs_refresh: watch::Receiver<()>,
 424        cx: &mut AsyncApp,
 425    ) -> Result<()> {
 426        while needs_refresh.changed().await.is_ok() {
 427            let project_context = this
 428                .update(cx, |this, cx| {
 429                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 430                })?
 431                .await;
 432            this.update(cx, |this, cx| {
 433                this.project_context = cx.new(|_| project_context);
 434            })?;
 435        }
 436
 437        Ok(())
 438    }
 439
 440    fn build_project_context(
 441        project: &Entity<Project>,
 442        prompt_store: Option<&Entity<PromptStore>>,
 443        cx: &mut App,
 444    ) -> Task<ProjectContext> {
 445        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 446        let worktree_tasks = worktrees
 447            .into_iter()
 448            .map(|worktree| {
 449                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 450            })
 451            .collect::<Vec<_>>();
 452        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 453            prompt_store.read_with(cx, |prompt_store, cx| {
 454                let prompts = prompt_store.default_prompt_metadata();
 455                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 456                    let contents = prompt_store.load(prompt_metadata.id, cx);
 457                    async move { (contents.await, prompt_metadata) }
 458                });
 459                cx.background_spawn(future::join_all(load_tasks))
 460            })
 461        } else {
 462            Task::ready(vec![])
 463        };
 464
 465        cx.spawn(async move |_cx| {
 466            let (worktrees, default_user_rules) =
 467                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 468
 469            let worktrees = worktrees
 470                .into_iter()
 471                .map(|(worktree, _rules_error)| {
 472                    // TODO: show error message
 473                    // if let Some(rules_error) = rules_error {
 474                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 475                    // }
 476                    worktree
 477                })
 478                .collect::<Vec<_>>();
 479
 480            let default_user_rules = default_user_rules
 481                .into_iter()
 482                .flat_map(|(contents, prompt_metadata)| match contents {
 483                    Ok(contents) => Some(UserRulesContext {
 484                        uuid: prompt_metadata.id.as_user()?,
 485                        title: prompt_metadata.title.map(|title| title.to_string()),
 486                        contents,
 487                    }),
 488                    Err(_err) => {
 489                        // TODO: show error message
 490                        // this.update(cx, |_, cx| {
 491                        //     cx.emit(RulesLoadingError {
 492                        //         message: format!("{err:?}").into(),
 493                        //     });
 494                        // })
 495                        // .ok();
 496                        None
 497                    }
 498                })
 499                .collect::<Vec<_>>();
 500
 501            ProjectContext::new(worktrees, default_user_rules)
 502        })
 503    }
 504
 505    fn load_worktree_info_for_system_prompt(
 506        worktree: Entity<Worktree>,
 507        project: Entity<Project>,
 508        cx: &mut App,
 509    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 510        let tree = worktree.read(cx);
 511        let root_name = tree.root_name_str().into();
 512        let abs_path = tree.abs_path();
 513
 514        let mut context = WorktreeContext {
 515            root_name,
 516            abs_path,
 517            rules_file: None,
 518        };
 519
 520        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 521        let Some(rules_task) = rules_task else {
 522            return Task::ready((context, None));
 523        };
 524
 525        cx.spawn(async move |_| {
 526            let (rules_file, rules_file_error) = match rules_task.await {
 527                Ok(rules_file) => (Some(rules_file), None),
 528                Err(err) => (
 529                    None,
 530                    Some(RulesLoadingError {
 531                        message: format!("{err}").into(),
 532                    }),
 533                ),
 534            };
 535            context.rules_file = rules_file;
 536            (context, rules_file_error)
 537        })
 538    }
 539
 540    fn load_worktree_rules_file(
 541        worktree: Entity<Worktree>,
 542        project: Entity<Project>,
 543        cx: &mut App,
 544    ) -> Option<Task<Result<RulesFileContext>>> {
 545        let worktree = worktree.read(cx);
 546        let worktree_id = worktree.id();
 547        let selected_rules_file = RULES_FILE_NAMES
 548            .into_iter()
 549            .filter_map(|name| {
 550                worktree
 551                    .entry_for_path(RelPath::unix(name).unwrap())
 552                    .filter(|entry| entry.is_file())
 553                    .map(|entry| entry.path.clone())
 554            })
 555            .next();
 556
 557        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 558        // supported. This doesn't seem to occur often in GitHub repositories.
 559        selected_rules_file.map(|path_in_worktree| {
 560            let project_path = ProjectPath {
 561                worktree_id,
 562                path: path_in_worktree.clone(),
 563            };
 564            let buffer_task =
 565                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 566            let rope_task = cx.spawn(async move |cx| {
 567                let buffer = buffer_task.await?;
 568                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 569                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 570                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 571                })?;
 572                anyhow::Ok((project_entry_id, rope))
 573            });
 574            // Build a string from the rope on a background thread.
 575            cx.background_spawn(async move {
 576                let (project_entry_id, rope) = rope_task.await?;
 577                anyhow::Ok(RulesFileContext {
 578                    path_in_worktree,
 579                    text: rope.to_string().trim().to_string(),
 580                    project_entry_id: project_entry_id.to_usize(),
 581                })
 582            })
 583        })
 584    }
 585
 586    fn handle_thread_title_updated(
 587        &mut self,
 588        thread: Entity<Thread>,
 589        _: &TitleUpdated,
 590        cx: &mut Context<Self>,
 591    ) {
 592        let session_id = thread.read(cx).id();
 593        let Some(session) = self.sessions.get(session_id) else {
 594            return;
 595        };
 596        let thread = thread.downgrade();
 597        let acp_thread = session.acp_thread.downgrade();
 598        cx.spawn(async move |_, cx| {
 599            let title = thread.read_with(cx, |thread, _| thread.title())?;
 600            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 601            task.await
 602        })
 603        .detach_and_log_err(cx);
 604    }
 605
 606    fn handle_thread_token_usage_updated(
 607        &mut self,
 608        thread: Entity<Thread>,
 609        usage: &TokenUsageUpdated,
 610        cx: &mut Context<Self>,
 611    ) {
 612        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 613            return;
 614        };
 615        session.acp_thread.update(cx, |acp_thread, cx| {
 616            acp_thread.update_token_usage(usage.0.clone(), cx);
 617        });
 618    }
 619
 620    fn handle_project_event(
 621        &mut self,
 622        _project: Entity<Project>,
 623        event: &project::Event,
 624        _cx: &mut Context<Self>,
 625    ) {
 626        match event {
 627            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 628                self.project_context_needs_refresh.send(()).ok();
 629            }
 630            project::Event::WorktreeUpdatedEntries(_, items) => {
 631                if items.iter().any(|(path, _, _)| {
 632                    RULES_FILE_NAMES
 633                        .iter()
 634                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 635                }) {
 636                    self.project_context_needs_refresh.send(()).ok();
 637                }
 638            }
 639            _ => {}
 640        }
 641    }
 642
 643    fn handle_prompts_updated_event(
 644        &mut self,
 645        _prompt_store: Entity<PromptStore>,
 646        _event: &prompt_store::PromptsUpdatedEvent,
 647        _cx: &mut Context<Self>,
 648    ) {
 649        self.project_context_needs_refresh.send(()).ok();
 650    }
 651
 652    fn handle_models_updated_event(
 653        &mut self,
 654        _registry: Entity<LanguageModelRegistry>,
 655        _event: &language_model::Event,
 656        cx: &mut Context<Self>,
 657    ) {
 658        self.models.refresh_list(cx);
 659
 660        let registry = LanguageModelRegistry::read_global(cx);
 661        let default_model = registry.default_model().map(|m| m.model);
 662        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 663
 664        for session in self.sessions.values_mut() {
 665            session.thread.update(cx, |thread, cx| {
 666                if thread.model().is_none()
 667                    && let Some(model) = default_model.clone()
 668                {
 669                    thread.set_model(model, cx);
 670                    cx.notify();
 671                }
 672                thread.set_summarization_model(summarization_model.clone(), cx);
 673            });
 674        }
 675    }
 676
 677    fn handle_context_server_store_updated(
 678        &mut self,
 679        _store: Entity<project::context_server_store::ContextServerStore>,
 680        _event: &project::context_server_store::ServerStatusChangedEvent,
 681        cx: &mut Context<Self>,
 682    ) {
 683        self.update_available_commands(cx);
 684    }
 685
 686    fn handle_context_server_registry_event(
 687        &mut self,
 688        _registry: Entity<ContextServerRegistry>,
 689        event: &ContextServerRegistryEvent,
 690        cx: &mut Context<Self>,
 691    ) {
 692        match event {
 693            ContextServerRegistryEvent::ToolsChanged => {}
 694            ContextServerRegistryEvent::PromptsChanged => {
 695                self.update_available_commands(cx);
 696            }
 697        }
 698    }
 699
 700    fn update_available_commands(&self, cx: &mut Context<Self>) {
 701        let available_commands = self.build_available_commands(cx);
 702        for session in self.sessions.values() {
 703            session.acp_thread.update(cx, |thread, cx| {
 704                thread
 705                    .handle_session_update(
 706                        acp::SessionUpdate::AvailableCommandsUpdate(
 707                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 708                        ),
 709                        cx,
 710                    )
 711                    .log_err();
 712            });
 713        }
 714    }
 715
 716    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 717        let registry = self.context_server_registry.read(cx);
 718
 719        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 720        for context_server_prompt in registry.prompts() {
 721            *prompt_name_counts
 722                .entry(context_server_prompt.prompt.name.as_str())
 723                .or_insert(0) += 1;
 724        }
 725
 726        registry
 727            .prompts()
 728            .flat_map(|context_server_prompt| {
 729                let prompt = &context_server_prompt.prompt;
 730
 731                let should_prefix = prompt_name_counts
 732                    .get(prompt.name.as_str())
 733                    .copied()
 734                    .unwrap_or(0)
 735                    > 1;
 736
 737                let name = if should_prefix {
 738                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 739                } else {
 740                    prompt.name.clone()
 741                };
 742
 743                let mut command = acp::AvailableCommand::new(
 744                    name,
 745                    prompt.description.clone().unwrap_or_default(),
 746                );
 747
 748                match prompt.arguments.as_deref() {
 749                    Some([arg]) => {
 750                        let hint = format!("<{}>", arg.name);
 751
 752                        command = command.input(acp::AvailableCommandInput::Unstructured(
 753                            acp::UnstructuredCommandInput::new(hint),
 754                        ));
 755                    }
 756                    Some([]) | None => {}
 757                    Some(_) => {
 758                        // skip >1 argument commands since we don't support them yet
 759                        return None;
 760                    }
 761                }
 762
 763                Some(command)
 764            })
 765            .collect()
 766    }
 767
 768    pub fn load_thread(
 769        &mut self,
 770        id: acp::SessionId,
 771        cx: &mut Context<Self>,
 772    ) -> Task<Result<Entity<Thread>>> {
 773        let database_future = ThreadsDatabase::connect(cx);
 774        cx.spawn(async move |this, cx| {
 775            let database = database_future.await.map_err(|err| anyhow!(err))?;
 776            let db_thread = database
 777                .load_thread(id.clone())
 778                .await?
 779                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 780
 781            this.update(cx, |this, cx| {
 782                let summarization_model = LanguageModelRegistry::read_global(cx)
 783                    .thread_summary_model()
 784                    .map(|c| c.model);
 785
 786                cx.new(|cx| {
 787                    let mut thread = Thread::from_db(
 788                        id.clone(),
 789                        db_thread,
 790                        this.project.clone(),
 791                        this.project_context.clone(),
 792                        this.context_server_registry.clone(),
 793                        this.templates.clone(),
 794                        cx,
 795                    );
 796                    thread.set_summarization_model(summarization_model, cx);
 797                    thread
 798                })
 799            })
 800        })
 801    }
 802
 803    pub fn open_thread(
 804        &mut self,
 805        id: acp::SessionId,
 806        cx: &mut Context<Self>,
 807    ) -> Task<Result<Entity<AcpThread>>> {
 808        if let Some(session) = self.sessions.get(&id) {
 809            return Task::ready(Ok(session.acp_thread.clone()));
 810        }
 811
 812        let task = self.load_thread(id, cx);
 813        cx.spawn(async move |this, cx| {
 814            let thread = task.await?;
 815            let acp_thread =
 816                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 817            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 818            cx.update(|cx| {
 819                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 820            })
 821            .await?;
 822            Ok(acp_thread)
 823        })
 824    }
 825
 826    pub fn thread_summary(
 827        &mut self,
 828        id: acp::SessionId,
 829        cx: &mut Context<Self>,
 830    ) -> Task<Result<SharedString>> {
 831        let thread = self.open_thread(id.clone(), cx);
 832        cx.spawn(async move |this, cx| {
 833            let acp_thread = thread.await?;
 834            let result = this
 835                .update(cx, |this, cx| {
 836                    this.sessions
 837                        .get(&id)
 838                        .unwrap()
 839                        .thread
 840                        .update(cx, |thread, cx| thread.summary(cx))
 841                })?
 842                .await
 843                .context("Failed to generate summary")?;
 844            drop(acp_thread);
 845            Ok(result)
 846        })
 847    }
 848
 849    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 850        if thread.read(cx).is_empty() {
 851            return;
 852        }
 853
 854        let id = thread.read(cx).id().clone();
 855        let Some(session) = self.sessions.get_mut(&id) else {
 856            return;
 857        };
 858
 859        let folder_paths = PathList::new(
 860            &self
 861                .project
 862                .read(cx)
 863                .visible_worktrees(cx)
 864                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 865                .collect::<Vec<_>>(),
 866        );
 867
 868        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 869        let database_future = ThreadsDatabase::connect(cx);
 870        let db_thread = thread.update(cx, |thread, cx| {
 871            thread.set_draft_prompt(draft_prompt);
 872            thread.to_db(cx)
 873        });
 874        let thread_store = self.thread_store.clone();
 875        session.pending_save = cx.spawn(async move |_, cx| {
 876            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 877                return;
 878            };
 879            let db_thread = db_thread.await;
 880            database
 881                .save_thread(id, db_thread, folder_paths)
 882                .await
 883                .log_err();
 884            thread_store.update(cx, |store, cx| store.reload(cx));
 885        });
 886    }
 887
 888    fn send_mcp_prompt(
 889        &self,
 890        message_id: UserMessageId,
 891        session_id: agent_client_protocol::SessionId,
 892        prompt_name: String,
 893        server_id: ContextServerId,
 894        arguments: HashMap<String, String>,
 895        original_content: Vec<acp::ContentBlock>,
 896        cx: &mut Context<Self>,
 897    ) -> Task<Result<acp::PromptResponse>> {
 898        let server_store = self.context_server_registry.read(cx).server_store().clone();
 899        let path_style = self.project.read(cx).path_style(cx);
 900
 901        cx.spawn(async move |this, cx| {
 902            let prompt =
 903                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 904
 905            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 906                let session = this
 907                    .sessions
 908                    .get(&session_id)
 909                    .context("Failed to get session")?;
 910                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 911            })??;
 912
 913            let mut last_is_user = true;
 914
 915            thread.update(cx, |thread, cx| {
 916                thread.push_acp_user_block(
 917                    message_id,
 918                    original_content.into_iter().skip(1),
 919                    path_style,
 920                    cx,
 921                );
 922            });
 923
 924            for message in prompt.messages {
 925                let context_server::types::PromptMessage { role, content } = message;
 926                let block = mcp_message_content_to_acp_content_block(content);
 927
 928                match role {
 929                    context_server::types::Role::User => {
 930                        let id = acp_thread::UserMessageId::new();
 931
 932                        acp_thread.update(cx, |acp_thread, cx| {
 933                            acp_thread.push_user_content_block_with_indent(
 934                                Some(id.clone()),
 935                                block.clone(),
 936                                true,
 937                                cx,
 938                            );
 939                        });
 940
 941                        thread.update(cx, |thread, cx| {
 942                            thread.push_acp_user_block(id, [block], path_style, cx);
 943                        });
 944                    }
 945                    context_server::types::Role::Assistant => {
 946                        acp_thread.update(cx, |acp_thread, cx| {
 947                            acp_thread.push_assistant_content_block_with_indent(
 948                                block.clone(),
 949                                false,
 950                                true,
 951                                cx,
 952                            );
 953                        });
 954
 955                        thread.update(cx, |thread, cx| {
 956                            thread.push_acp_agent_block(block, cx);
 957                        });
 958                    }
 959                }
 960
 961                last_is_user = role == context_server::types::Role::User;
 962            }
 963
 964            let response_stream = thread.update(cx, |thread, cx| {
 965                if last_is_user {
 966                    thread.send_existing(cx)
 967                } else {
 968                    // Resume if MCP prompt did not end with a user message
 969                    thread.resume(cx)
 970                }
 971            })?;
 972
 973            cx.update(|cx| {
 974                NativeAgentConnection::handle_thread_events(
 975                    response_stream,
 976                    acp_thread.downgrade(),
 977                    cx,
 978                )
 979            })
 980            .await
 981        })
 982    }
 983}
 984
 985/// Wrapper struct that implements the AgentConnection trait
 986#[derive(Clone)]
 987pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 988
 989impl NativeAgentConnection {
 990    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 991        self.0
 992            .read(cx)
 993            .sessions
 994            .get(session_id)
 995            .map(|session| session.thread.clone())
 996    }
 997
 998    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 999        self.0.update(cx, |this, cx| this.load_thread(id, cx))
1000    }
1001
1002    fn run_turn(
1003        &self,
1004        session_id: acp::SessionId,
1005        cx: &mut App,
1006        f: impl 'static
1007        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1008    ) -> Task<Result<acp::PromptResponse>> {
1009        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1010            agent
1011                .sessions
1012                .get_mut(&session_id)
1013                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1014        }) else {
1015            return Task::ready(Err(anyhow!("Session not found")));
1016        };
1017        log::debug!("Found session for: {}", session_id);
1018
1019        let response_stream = match f(thread, cx) {
1020            Ok(stream) => stream,
1021            Err(err) => return Task::ready(Err(err)),
1022        };
1023        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1024    }
1025
1026    fn handle_thread_events(
1027        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1028        acp_thread: WeakEntity<AcpThread>,
1029        cx: &App,
1030    ) -> Task<Result<acp::PromptResponse>> {
1031        cx.spawn(async move |cx| {
1032            // Handle response stream and forward to session.acp_thread
1033            while let Some(result) = events.next().await {
1034                match result {
1035                    Ok(event) => {
1036                        log::trace!("Received completion event: {:?}", event);
1037
1038                        match event {
1039                            ThreadEvent::UserMessage(message) => {
1040                                acp_thread.update(cx, |thread, cx| {
1041                                    for content in message.content {
1042                                        thread.push_user_content_block(
1043                                            Some(message.id.clone()),
1044                                            content.into(),
1045                                            cx,
1046                                        );
1047                                    }
1048                                })?;
1049                            }
1050                            ThreadEvent::AgentText(text) => {
1051                                acp_thread.update(cx, |thread, cx| {
1052                                    thread.push_assistant_content_block(text.into(), false, cx)
1053                                })?;
1054                            }
1055                            ThreadEvent::AgentThinking(text) => {
1056                                acp_thread.update(cx, |thread, cx| {
1057                                    thread.push_assistant_content_block(text.into(), true, cx)
1058                                })?;
1059                            }
1060                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1061                                tool_call,
1062                                options,
1063                                response,
1064                                context: _,
1065                            }) => {
1066                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1067                                    thread.request_tool_call_authorization(tool_call, options, cx)
1068                                })??;
1069                                cx.background_spawn(async move {
1070                                    if let acp::RequestPermissionOutcome::Selected(
1071                                        acp::SelectedPermissionOutcome { option_id, .. },
1072                                    ) = outcome_task.await
1073                                    {
1074                                        response
1075                                            .send(option_id)
1076                                            .map(|_| anyhow!("authorization receiver was dropped"))
1077                                            .log_err();
1078                                    }
1079                                })
1080                                .detach();
1081                            }
1082                            ThreadEvent::ToolCall(tool_call) => {
1083                                acp_thread.update(cx, |thread, cx| {
1084                                    thread.upsert_tool_call(tool_call, cx)
1085                                })??;
1086                            }
1087                            ThreadEvent::ToolCallUpdate(update) => {
1088                                acp_thread.update(cx, |thread, cx| {
1089                                    thread.update_tool_call(update, cx)
1090                                })??;
1091                            }
1092                            ThreadEvent::SubagentSpawned(session_id) => {
1093                                acp_thread.update(cx, |thread, cx| {
1094                                    thread.subagent_spawned(session_id, cx);
1095                                })?;
1096                            }
1097                            ThreadEvent::Retry(status) => {
1098                                acp_thread.update(cx, |thread, cx| {
1099                                    thread.update_retry_status(status, cx)
1100                                })?;
1101                            }
1102                            ThreadEvent::Stop(stop_reason) => {
1103                                log::debug!("Assistant message complete: {:?}", stop_reason);
1104                                return Ok(acp::PromptResponse::new(stop_reason));
1105                            }
1106                        }
1107                    }
1108                    Err(e) => {
1109                        log::error!("Error in model response stream: {:?}", e);
1110                        return Err(e);
1111                    }
1112                }
1113            }
1114
1115            log::debug!("Response stream completed");
1116            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1117        })
1118    }
1119}
1120
1121struct Command<'a> {
1122    prompt_name: &'a str,
1123    arg_value: &'a str,
1124    explicit_server_id: Option<&'a str>,
1125}
1126
1127impl<'a> Command<'a> {
1128    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1129        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1130            return None;
1131        };
1132        let text = text_content.text.trim();
1133        let command = text.strip_prefix('/')?;
1134        let (command, arg_value) = command
1135            .split_once(char::is_whitespace)
1136            .unwrap_or((command, ""));
1137
1138        if let Some((server_id, prompt_name)) = command.split_once('.') {
1139            Some(Self {
1140                prompt_name,
1141                arg_value,
1142                explicit_server_id: Some(server_id),
1143            })
1144        } else {
1145            Some(Self {
1146                prompt_name: command,
1147                arg_value,
1148                explicit_server_id: None,
1149            })
1150        }
1151    }
1152}
1153
1154struct NativeAgentModelSelector {
1155    session_id: acp::SessionId,
1156    connection: NativeAgentConnection,
1157}
1158
1159impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1160    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1161        log::debug!("NativeAgentConnection::list_models called");
1162        let list = self.connection.0.read(cx).models.model_list.clone();
1163        Task::ready(if list.is_empty() {
1164            Err(anyhow::anyhow!("No models available"))
1165        } else {
1166            Ok(list)
1167        })
1168    }
1169
1170    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1171        log::debug!(
1172            "Setting model for session {}: {}",
1173            self.session_id,
1174            model_id
1175        );
1176        let Some(thread) = self
1177            .connection
1178            .0
1179            .read(cx)
1180            .sessions
1181            .get(&self.session_id)
1182            .map(|session| session.thread.clone())
1183        else {
1184            return Task::ready(Err(anyhow!("Session not found")));
1185        };
1186
1187        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1188            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1189        };
1190
1191        // We want to reset the effort level when switching models, as the currently-selected effort level may
1192        // not be compatible.
1193        let effort = model
1194            .default_effort_level()
1195            .map(|effort_level| effort_level.value.to_string());
1196
1197        thread.update(cx, |thread, cx| {
1198            thread.set_model(model.clone(), cx);
1199            thread.set_thinking_effort(effort.clone(), cx);
1200            thread.set_thinking_enabled(model.supports_thinking(), cx);
1201        });
1202
1203        update_settings_file(
1204            self.connection.0.read(cx).fs.clone(),
1205            cx,
1206            move |settings, cx| {
1207                let provider = model.provider_id().0.to_string();
1208                let model = model.id().0.to_string();
1209                let enable_thinking = thread.read(cx).thinking_enabled();
1210                settings
1211                    .agent
1212                    .get_or_insert_default()
1213                    .set_model(LanguageModelSelection {
1214                        provider: provider.into(),
1215                        model,
1216                        enable_thinking,
1217                        effort,
1218                    });
1219            },
1220        );
1221
1222        Task::ready(Ok(()))
1223    }
1224
1225    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1226        let Some(thread) = self
1227            .connection
1228            .0
1229            .read(cx)
1230            .sessions
1231            .get(&self.session_id)
1232            .map(|session| session.thread.clone())
1233        else {
1234            return Task::ready(Err(anyhow!("Session not found")));
1235        };
1236        let Some(model) = thread.read(cx).model() else {
1237            return Task::ready(Err(anyhow!("Model not found")));
1238        };
1239        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1240        else {
1241            return Task::ready(Err(anyhow!("Provider not found")));
1242        };
1243        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1244            model, &provider,
1245        )))
1246    }
1247
1248    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1249        Some(self.connection.0.read(cx).models.watch())
1250    }
1251
1252    fn should_render_footer(&self) -> bool {
1253        true
1254    }
1255}
1256
1257impl acp_thread::AgentConnection for NativeAgentConnection {
1258    fn telemetry_id(&self) -> SharedString {
1259        "zed".into()
1260    }
1261
1262    fn new_session(
1263        self: Rc<Self>,
1264        project: Entity<Project>,
1265        cwd: &Path,
1266        cx: &mut App,
1267    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1268        log::debug!("Creating new thread for project at: {cwd:?}");
1269        Task::ready(Ok(self
1270            .0
1271            .update(cx, |agent, cx| agent.new_session(project, cx))))
1272    }
1273
1274    fn supports_load_session(&self) -> bool {
1275        true
1276    }
1277
1278    fn load_session(
1279        self: Rc<Self>,
1280        session: AgentSessionInfo,
1281        _project: Entity<Project>,
1282        _cwd: &Path,
1283        cx: &mut App,
1284    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1285        self.0
1286            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1287    }
1288
1289    fn supports_close_session(&self) -> bool {
1290        true
1291    }
1292
1293    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1294        self.0.update(cx, |agent, _cx| {
1295            agent.sessions.remove(session_id);
1296        });
1297        Task::ready(Ok(()))
1298    }
1299
1300    fn auth_methods(&self) -> &[acp::AuthMethod] {
1301        &[] // No auth for in-process
1302    }
1303
1304    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1305        Task::ready(Ok(()))
1306    }
1307
1308    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1309        Some(Rc::new(NativeAgentModelSelector {
1310            session_id: session_id.clone(),
1311            connection: self.clone(),
1312        }) as Rc<dyn AgentModelSelector>)
1313    }
1314
1315    fn prompt(
1316        &self,
1317        id: Option<acp_thread::UserMessageId>,
1318        params: acp::PromptRequest,
1319        cx: &mut App,
1320    ) -> Task<Result<acp::PromptResponse>> {
1321        let id = id.expect("UserMessageId is required");
1322        let session_id = params.session_id.clone();
1323        log::info!("Received prompt request for session: {}", session_id);
1324        log::debug!("Prompt blocks count: {}", params.prompt.len());
1325
1326        if let Some(parsed_command) = Command::parse(&params.prompt) {
1327            let registry = self.0.read(cx).context_server_registry.read(cx);
1328
1329            let explicit_server_id = parsed_command
1330                .explicit_server_id
1331                .map(|server_id| ContextServerId(server_id.into()));
1332
1333            if let Some(prompt) =
1334                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1335            {
1336                let arguments = if !parsed_command.arg_value.is_empty()
1337                    && let Some(arg_name) = prompt
1338                        .prompt
1339                        .arguments
1340                        .as_ref()
1341                        .and_then(|args| args.first())
1342                        .map(|arg| arg.name.clone())
1343                {
1344                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1345                } else {
1346                    Default::default()
1347                };
1348
1349                let prompt_name = prompt.prompt.name.clone();
1350                let server_id = prompt.server_id.clone();
1351
1352                return self.0.update(cx, |agent, cx| {
1353                    agent.send_mcp_prompt(
1354                        id,
1355                        session_id.clone(),
1356                        prompt_name,
1357                        server_id,
1358                        arguments,
1359                        params.prompt,
1360                        cx,
1361                    )
1362                });
1363            };
1364        };
1365
1366        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1367
1368        self.run_turn(session_id, cx, move |thread, cx| {
1369            let content: Vec<UserMessageContent> = params
1370                .prompt
1371                .into_iter()
1372                .map(|block| UserMessageContent::from_content_block(block, path_style))
1373                .collect::<Vec<_>>();
1374            log::debug!("Converted prompt to message: {} chars", content.len());
1375            log::debug!("Message id: {:?}", id);
1376            log::debug!("Message content: {:?}", content);
1377
1378            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1379        })
1380    }
1381
1382    fn retry(
1383        &self,
1384        session_id: &acp::SessionId,
1385        _cx: &App,
1386    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1387        Some(Rc::new(NativeAgentSessionRetry {
1388            connection: self.clone(),
1389            session_id: session_id.clone(),
1390        }) as _)
1391    }
1392
1393    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1394        log::info!("Cancelling on session: {}", session_id);
1395        self.0.update(cx, |agent, cx| {
1396            if let Some(session) = agent.sessions.get(session_id) {
1397                session
1398                    .thread
1399                    .update(cx, |thread, cx| thread.cancel(cx))
1400                    .detach();
1401            }
1402        });
1403    }
1404
1405    fn truncate(
1406        &self,
1407        session_id: &agent_client_protocol::SessionId,
1408        cx: &App,
1409    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1410        self.0.read_with(cx, |agent, _cx| {
1411            agent.sessions.get(session_id).map(|session| {
1412                Rc::new(NativeAgentSessionTruncate {
1413                    thread: session.thread.clone(),
1414                    acp_thread: session.acp_thread.downgrade(),
1415                }) as _
1416            })
1417        })
1418    }
1419
1420    fn set_title(
1421        &self,
1422        session_id: &acp::SessionId,
1423        cx: &App,
1424    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1425        self.0.read_with(cx, |agent, _cx| {
1426            agent
1427                .sessions
1428                .get(session_id)
1429                .filter(|s| !s.thread.read(cx).is_subagent())
1430                .map(|session| {
1431                    Rc::new(NativeAgentSessionSetTitle {
1432                        thread: session.thread.clone(),
1433                    }) as _
1434                })
1435        })
1436    }
1437
1438    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1439        let thread_store = self.0.read(cx).thread_store.clone();
1440        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1441    }
1442
1443    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1444        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1445    }
1446
1447    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1448        self
1449    }
1450}
1451
1452impl acp_thread::AgentTelemetry for NativeAgentConnection {
1453    fn thread_data(
1454        &self,
1455        session_id: &acp::SessionId,
1456        cx: &mut App,
1457    ) -> Task<Result<serde_json::Value>> {
1458        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1459            return Task::ready(Err(anyhow!("Session not found")));
1460        };
1461
1462        let task = session.thread.read(cx).to_db(cx);
1463        cx.background_spawn(async move {
1464            serde_json::to_value(task.await).context("Failed to serialize thread")
1465        })
1466    }
1467}
1468
1469pub struct NativeAgentSessionList {
1470    thread_store: Entity<ThreadStore>,
1471    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1472    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1473    _subscription: Subscription,
1474}
1475
1476impl NativeAgentSessionList {
1477    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1478        let (tx, rx) = smol::channel::unbounded();
1479        let this_tx = tx.clone();
1480        let subscription = cx.observe(&thread_store, move |_, _| {
1481            this_tx
1482                .try_send(acp_thread::SessionListUpdate::Refresh)
1483                .ok();
1484        });
1485        Self {
1486            thread_store,
1487            updates_tx: tx,
1488            updates_rx: rx,
1489            _subscription: subscription,
1490        }
1491    }
1492
1493    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1494        AgentSessionInfo {
1495            session_id: entry.id,
1496            cwd: None,
1497            title: Some(entry.title),
1498            updated_at: Some(entry.updated_at),
1499            meta: None,
1500        }
1501    }
1502
1503    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1504        &self.thread_store
1505    }
1506}
1507
1508impl AgentSessionList for NativeAgentSessionList {
1509    fn list_sessions(
1510        &self,
1511        _request: AgentSessionListRequest,
1512        cx: &mut App,
1513    ) -> Task<Result<AgentSessionListResponse>> {
1514        let sessions = self
1515            .thread_store
1516            .read(cx)
1517            .entries()
1518            .map(Self::to_session_info)
1519            .collect();
1520        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1521    }
1522
1523    fn supports_delete(&self) -> bool {
1524        true
1525    }
1526
1527    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1528        self.thread_store
1529            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1530    }
1531
1532    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1533        self.thread_store
1534            .update(cx, |store, cx| store.delete_threads(cx))
1535    }
1536
1537    fn watch(
1538        &self,
1539        _cx: &mut App,
1540    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1541        Some(self.updates_rx.clone())
1542    }
1543
1544    fn notify_refresh(&self) {
1545        self.updates_tx
1546            .try_send(acp_thread::SessionListUpdate::Refresh)
1547            .ok();
1548    }
1549
1550    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1551        self
1552    }
1553}
1554
1555struct NativeAgentSessionTruncate {
1556    thread: Entity<Thread>,
1557    acp_thread: WeakEntity<AcpThread>,
1558}
1559
1560impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1561    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1562        match self.thread.update(cx, |thread, cx| {
1563            thread.truncate(message_id.clone(), cx)?;
1564            Ok(thread.latest_token_usage())
1565        }) {
1566            Ok(usage) => {
1567                self.acp_thread
1568                    .update(cx, |thread, cx| {
1569                        thread.update_token_usage(usage, cx);
1570                    })
1571                    .ok();
1572                Task::ready(Ok(()))
1573            }
1574            Err(error) => Task::ready(Err(error)),
1575        }
1576    }
1577}
1578
1579struct NativeAgentSessionRetry {
1580    connection: NativeAgentConnection,
1581    session_id: acp::SessionId,
1582}
1583
1584impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1585    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1586        self.connection
1587            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1588                thread.update(cx, |thread, cx| thread.resume(cx))
1589            })
1590    }
1591}
1592
1593struct NativeAgentSessionSetTitle {
1594    thread: Entity<Thread>,
1595}
1596
1597impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1598    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1599        self.thread
1600            .update(cx, |thread, cx| thread.set_title(title, cx));
1601        Task::ready(Ok(()))
1602    }
1603}
1604
1605pub struct NativeThreadEnvironment {
1606    agent: WeakEntity<NativeAgent>,
1607    thread: WeakEntity<Thread>,
1608    acp_thread: WeakEntity<AcpThread>,
1609}
1610
1611impl NativeThreadEnvironment {
1612    pub(crate) fn create_subagent_thread(
1613        &self,
1614        label: String,
1615        cx: &mut App,
1616    ) -> Result<Rc<dyn SubagentHandle>> {
1617        let Some(parent_thread_entity) = self.thread.upgrade() else {
1618            anyhow::bail!("Parent thread no longer exists".to_string());
1619        };
1620        let parent_thread = parent_thread_entity.read(cx);
1621        let current_depth = parent_thread.depth();
1622
1623        if current_depth >= MAX_SUBAGENT_DEPTH {
1624            return Err(anyhow!(
1625                "Maximum subagent depth ({}) reached",
1626                MAX_SUBAGENT_DEPTH
1627            ));
1628        }
1629
1630        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1631            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1632            thread.set_title(label.into(), cx);
1633            thread
1634        });
1635
1636        let session_id = subagent_thread.read(cx).id().clone();
1637
1638        let acp_thread = self.agent.update(cx, |agent, cx| {
1639            agent.register_session(subagent_thread.clone(), cx)
1640        })?;
1641
1642        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1643    }
1644
1645    pub(crate) fn resume_subagent_thread(
1646        &self,
1647        session_id: acp::SessionId,
1648        cx: &mut App,
1649    ) -> Result<Rc<dyn SubagentHandle>> {
1650        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1651            let session = agent
1652                .sessions
1653                .get(&session_id)
1654                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1655            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1656        })??;
1657
1658        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1659    }
1660
1661    fn prompt_subagent(
1662        &self,
1663        session_id: acp::SessionId,
1664        subagent_thread: Entity<Thread>,
1665        acp_thread: Entity<acp_thread::AcpThread>,
1666    ) -> Result<Rc<dyn SubagentHandle>> {
1667        let Some(parent_thread_entity) = self.thread.upgrade() else {
1668            anyhow::bail!("Parent thread no longer exists".to_string());
1669        };
1670        Ok(Rc::new(NativeSubagentHandle::new(
1671            session_id,
1672            subagent_thread,
1673            acp_thread,
1674            parent_thread_entity,
1675        )) as _)
1676    }
1677}
1678
1679impl ThreadEnvironment for NativeThreadEnvironment {
1680    fn create_terminal(
1681        &self,
1682        command: String,
1683        cwd: Option<PathBuf>,
1684        output_byte_limit: Option<u64>,
1685        cx: &mut AsyncApp,
1686    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1687        let task = self.acp_thread.update(cx, |thread, cx| {
1688            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1689        });
1690
1691        let acp_thread = self.acp_thread.clone();
1692        cx.spawn(async move |cx| {
1693            let terminal = task?.await?;
1694
1695            let (drop_tx, drop_rx) = oneshot::channel();
1696            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1697
1698            cx.spawn(async move |cx| {
1699                drop_rx.await.ok();
1700                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1701            })
1702            .detach();
1703
1704            let handle = AcpTerminalHandle {
1705                terminal,
1706                _drop_tx: Some(drop_tx),
1707            };
1708
1709            Ok(Rc::new(handle) as _)
1710        })
1711    }
1712
1713    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1714        self.create_subagent_thread(label, cx)
1715    }
1716
1717    fn resume_subagent(
1718        &self,
1719        session_id: acp::SessionId,
1720        cx: &mut App,
1721    ) -> Result<Rc<dyn SubagentHandle>> {
1722        self.resume_subagent_thread(session_id, cx)
1723    }
1724}
1725
1726#[derive(Debug, Clone)]
1727enum SubagentPromptResult {
1728    Completed,
1729    Cancelled,
1730    ContextWindowWarning,
1731    Error(String),
1732}
1733
1734pub struct NativeSubagentHandle {
1735    session_id: acp::SessionId,
1736    parent_thread: WeakEntity<Thread>,
1737    subagent_thread: Entity<Thread>,
1738    acp_thread: Entity<acp_thread::AcpThread>,
1739}
1740
1741impl NativeSubagentHandle {
1742    fn new(
1743        session_id: acp::SessionId,
1744        subagent_thread: Entity<Thread>,
1745        acp_thread: Entity<acp_thread::AcpThread>,
1746        parent_thread_entity: Entity<Thread>,
1747    ) -> Self {
1748        NativeSubagentHandle {
1749            session_id,
1750            subagent_thread,
1751            parent_thread: parent_thread_entity.downgrade(),
1752            acp_thread,
1753        }
1754    }
1755}
1756
1757impl SubagentHandle for NativeSubagentHandle {
1758    fn id(&self) -> acp::SessionId {
1759        self.session_id.clone()
1760    }
1761
1762    fn num_entries(&self, cx: &App) -> usize {
1763        self.acp_thread.read(cx).entries().len()
1764    }
1765
1766    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1767        let thread = self.subagent_thread.clone();
1768        let acp_thread = self.acp_thread.clone();
1769        let subagent_session_id = self.session_id.clone();
1770        let parent_thread = self.parent_thread.clone();
1771
1772        cx.spawn(async move |cx| {
1773            let (task, _subscription) = cx.update(|cx| {
1774                let ratio_before_prompt = thread
1775                    .read(cx)
1776                    .latest_token_usage()
1777                    .map(|usage| usage.ratio());
1778
1779                parent_thread
1780                    .update(cx, |parent_thread, _cx| {
1781                        parent_thread.register_running_subagent(thread.downgrade())
1782                    })
1783                    .ok();
1784
1785                let task = acp_thread.update(cx, |acp_thread, cx| {
1786                    acp_thread.send(vec![message.into()], cx)
1787                });
1788
1789                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1790                let mut token_limit_tx = Some(token_limit_tx);
1791
1792                let subscription = cx.subscribe(
1793                    &thread,
1794                    move |_thread, event: &TokenUsageUpdated, _cx| {
1795                        if let Some(usage) = &event.0 {
1796                            let old_ratio = ratio_before_prompt
1797                                .clone()
1798                                .unwrap_or(TokenUsageRatio::Normal);
1799                            let new_ratio = usage.ratio();
1800                            if old_ratio == TokenUsageRatio::Normal
1801                                && new_ratio == TokenUsageRatio::Warning
1802                            {
1803                                if let Some(tx) = token_limit_tx.take() {
1804                                    tx.send(()).ok();
1805                                }
1806                            }
1807                        }
1808                    },
1809                );
1810
1811                let wait_for_prompt = cx
1812                    .background_spawn(async move {
1813                        futures::select! {
1814                            response = task.fuse() => match response {
1815                                Ok(Some(response)) => {
1816                                    match response.stop_reason {
1817                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1818                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1819                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1820                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1821                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1822                                    }
1823                                }
1824                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1825                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1826                            },
1827                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1828                        }
1829                    });
1830
1831                (wait_for_prompt, subscription)
1832            });
1833
1834            let result = match task.await {
1835                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1836                    thread
1837                        .last_message()
1838                        .and_then(|message| {
1839                            let content = message.as_agent_message()?
1840                                .content
1841                                .iter()
1842                                .filter_map(|c| match c {
1843                                    AgentMessageContent::Text(text) => Some(text.as_str()),
1844                                    _ => None,
1845                                })
1846                                .join("\n\n");
1847                            if content.is_empty() {
1848                                None
1849                            } else {
1850                                Some( content)
1851                            }
1852                        })
1853                        .context("No response from subagent")
1854                }),
1855                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1856                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1857                SubagentPromptResult::ContextWindowWarning => {
1858                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1859                    Err(anyhow!(
1860                        "The agent is nearing the end of its context window and has been \
1861                         stopped. You can prompt the thread again to have the agent wrap up \
1862                         or hand off its work."
1863                    ))
1864                }
1865            };
1866
1867            parent_thread
1868                .update(cx, |parent_thread, cx| {
1869                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1870                })
1871                .ok();
1872
1873            result
1874        })
1875    }
1876}
1877
1878pub struct AcpTerminalHandle {
1879    terminal: Entity<acp_thread::Terminal>,
1880    _drop_tx: Option<oneshot::Sender<()>>,
1881}
1882
1883impl TerminalHandle for AcpTerminalHandle {
1884    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1885        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1886    }
1887
1888    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1889        Ok(self
1890            .terminal
1891            .read_with(cx, |term, _cx| term.wait_for_exit()))
1892    }
1893
1894    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1895        Ok(self
1896            .terminal
1897            .read_with(cx, |term, cx| term.current_output(cx)))
1898    }
1899
1900    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1901        cx.update(|cx| {
1902            self.terminal.update(cx, |terminal, cx| {
1903                terminal.kill(cx);
1904            });
1905        });
1906        Ok(())
1907    }
1908
1909    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1910        Ok(self
1911            .terminal
1912            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1913    }
1914}
1915
1916#[cfg(test)]
1917mod internal_tests {
1918    use super::*;
1919    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1920    use fs::FakeFs;
1921    use gpui::TestAppContext;
1922    use indoc::formatdoc;
1923    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1924    use language_model::{
1925        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1926    };
1927    use serde_json::json;
1928    use settings::SettingsStore;
1929    use util::{path, rel_path::rel_path};
1930
1931    #[gpui::test]
1932    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1933        init_test(cx);
1934        let fs = FakeFs::new(cx.executor());
1935        fs.insert_tree(
1936            "/",
1937            json!({
1938                "a": {}
1939            }),
1940        )
1941        .await;
1942        let project = Project::test(fs.clone(), [], cx).await;
1943        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1944        let agent = NativeAgent::new(
1945            project.clone(),
1946            thread_store,
1947            Templates::new(),
1948            None,
1949            fs.clone(),
1950            &mut cx.to_async(),
1951        )
1952        .await
1953        .unwrap();
1954        agent.read_with(cx, |agent, cx| {
1955            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1956        });
1957
1958        let worktree = project
1959            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1960            .await
1961            .unwrap();
1962        cx.run_until_parked();
1963        agent.read_with(cx, |agent, cx| {
1964            assert_eq!(
1965                agent.project_context.read(cx).worktrees,
1966                vec![WorktreeContext {
1967                    root_name: "a".into(),
1968                    abs_path: Path::new("/a").into(),
1969                    rules_file: None
1970                }]
1971            )
1972        });
1973
1974        // Creating `/a/.rules` updates the project context.
1975        fs.insert_file("/a/.rules", Vec::new()).await;
1976        cx.run_until_parked();
1977        agent.read_with(cx, |agent, cx| {
1978            let rules_entry = worktree
1979                .read(cx)
1980                .entry_for_path(rel_path(".rules"))
1981                .unwrap();
1982            assert_eq!(
1983                agent.project_context.read(cx).worktrees,
1984                vec![WorktreeContext {
1985                    root_name: "a".into(),
1986                    abs_path: Path::new("/a").into(),
1987                    rules_file: Some(RulesFileContext {
1988                        path_in_worktree: rel_path(".rules").into(),
1989                        text: "".into(),
1990                        project_entry_id: rules_entry.id.to_usize()
1991                    })
1992                }]
1993            )
1994        });
1995    }
1996
1997    #[gpui::test]
1998    async fn test_listing_models(cx: &mut TestAppContext) {
1999        init_test(cx);
2000        let fs = FakeFs::new(cx.executor());
2001        fs.insert_tree("/", json!({ "a": {}  })).await;
2002        let project = Project::test(fs.clone(), [], cx).await;
2003        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2004        let connection = NativeAgentConnection(
2005            NativeAgent::new(
2006                project.clone(),
2007                thread_store,
2008                Templates::new(),
2009                None,
2010                fs.clone(),
2011                &mut cx.to_async(),
2012            )
2013            .await
2014            .unwrap(),
2015        );
2016
2017        // Create a thread/session
2018        let acp_thread = cx
2019            .update(|cx| {
2020                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2021            })
2022            .await
2023            .unwrap();
2024
2025        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2026
2027        let models = cx
2028            .update(|cx| {
2029                connection
2030                    .model_selector(&session_id)
2031                    .unwrap()
2032                    .list_models(cx)
2033            })
2034            .await
2035            .unwrap();
2036
2037        let acp_thread::AgentModelList::Grouped(models) = models else {
2038            panic!("Unexpected model group");
2039        };
2040        assert_eq!(
2041            models,
2042            IndexMap::from_iter([(
2043                AgentModelGroupName("Fake".into()),
2044                vec![AgentModelInfo {
2045                    id: acp::ModelId::new("fake/fake"),
2046                    name: "Fake".into(),
2047                    description: None,
2048                    icon: Some(acp_thread::AgentModelIcon::Named(
2049                        ui::IconName::ZedAssistant
2050                    )),
2051                    is_latest: false,
2052                    cost: None,
2053                }]
2054            )])
2055        );
2056    }
2057
2058    #[gpui::test]
2059    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2060        init_test(cx);
2061        let fs = FakeFs::new(cx.executor());
2062        fs.create_dir(paths::settings_file().parent().unwrap())
2063            .await
2064            .unwrap();
2065        fs.insert_file(
2066            paths::settings_file(),
2067            json!({
2068                "agent": {
2069                    "default_model": {
2070                        "provider": "foo",
2071                        "model": "bar"
2072                    }
2073                }
2074            })
2075            .to_string()
2076            .into_bytes(),
2077        )
2078        .await;
2079        let project = Project::test(fs.clone(), [], cx).await;
2080
2081        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2082
2083        // Create the agent and connection
2084        let agent = NativeAgent::new(
2085            project.clone(),
2086            thread_store,
2087            Templates::new(),
2088            None,
2089            fs.clone(),
2090            &mut cx.to_async(),
2091        )
2092        .await
2093        .unwrap();
2094        let connection = NativeAgentConnection(agent.clone());
2095
2096        // Create a thread/session
2097        let acp_thread = cx
2098            .update(|cx| {
2099                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2100            })
2101            .await
2102            .unwrap();
2103
2104        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2105
2106        // Select a model
2107        let selector = connection.model_selector(&session_id).unwrap();
2108        let model_id = acp::ModelId::new("fake/fake");
2109        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2110            .await
2111            .unwrap();
2112
2113        // Verify the thread has the selected model
2114        agent.read_with(cx, |agent, _| {
2115            let session = agent.sessions.get(&session_id).unwrap();
2116            session.thread.read_with(cx, |thread, _| {
2117                assert_eq!(thread.model().unwrap().id().0, "fake");
2118            });
2119        });
2120
2121        cx.run_until_parked();
2122
2123        // Verify settings file was updated
2124        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2125        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2126
2127        // Check that the agent settings contain the selected model
2128        assert_eq!(
2129            settings_json["agent"]["default_model"]["model"],
2130            json!("fake")
2131        );
2132        assert_eq!(
2133            settings_json["agent"]["default_model"]["provider"],
2134            json!("fake")
2135        );
2136
2137        // Register a thinking model and select it.
2138        cx.update(|cx| {
2139            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2140                "fake-corp",
2141                "fake-thinking",
2142                "Fake Thinking",
2143                true,
2144            ));
2145            let thinking_provider = Arc::new(
2146                FakeLanguageModelProvider::new(
2147                    LanguageModelProviderId::from("fake-corp".to_string()),
2148                    LanguageModelProviderName::from("Fake Corp".to_string()),
2149                )
2150                .with_models(vec![thinking_model]),
2151            );
2152            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2153                registry.register_provider(thinking_provider, cx);
2154            });
2155        });
2156        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2157
2158        let selector = connection.model_selector(&session_id).unwrap();
2159        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2160            .await
2161            .unwrap();
2162        cx.run_until_parked();
2163
2164        // Verify enable_thinking was written to settings as true.
2165        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2166        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2167        assert_eq!(
2168            settings_json["agent"]["default_model"]["enable_thinking"],
2169            json!(true),
2170            "selecting a thinking model should persist enable_thinking: true to settings"
2171        );
2172    }
2173
2174    #[gpui::test]
2175    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2176        init_test(cx);
2177        let fs = FakeFs::new(cx.executor());
2178        fs.create_dir(paths::settings_file().parent().unwrap())
2179            .await
2180            .unwrap();
2181        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2182        let project = Project::test(fs.clone(), [], cx).await;
2183
2184        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2185        let agent = NativeAgent::new(
2186            project.clone(),
2187            thread_store,
2188            Templates::new(),
2189            None,
2190            fs.clone(),
2191            &mut cx.to_async(),
2192        )
2193        .await
2194        .unwrap();
2195        let connection = NativeAgentConnection(agent.clone());
2196
2197        let acp_thread = cx
2198            .update(|cx| {
2199                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2200            })
2201            .await
2202            .unwrap();
2203        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2204
2205        // Register a second provider with a thinking model.
2206        cx.update(|cx| {
2207            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2208                "fake-corp",
2209                "fake-thinking",
2210                "Fake Thinking",
2211                true,
2212            ));
2213            let thinking_provider = Arc::new(
2214                FakeLanguageModelProvider::new(
2215                    LanguageModelProviderId::from("fake-corp".to_string()),
2216                    LanguageModelProviderName::from("Fake Corp".to_string()),
2217                )
2218                .with_models(vec![thinking_model]),
2219            );
2220            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2221                registry.register_provider(thinking_provider, cx);
2222            });
2223        });
2224        // Refresh the agent's model list so it picks up the new provider.
2225        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2226
2227        // Thread starts with thinking_enabled = false (the default).
2228        agent.read_with(cx, |agent, _| {
2229            let session = agent.sessions.get(&session_id).unwrap();
2230            session.thread.read_with(cx, |thread, _| {
2231                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2232            });
2233        });
2234
2235        // Select the thinking model via select_model.
2236        let selector = connection.model_selector(&session_id).unwrap();
2237        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2238            .await
2239            .unwrap();
2240
2241        // select_model should have enabled thinking based on the model's supports_thinking().
2242        agent.read_with(cx, |agent, _| {
2243            let session = agent.sessions.get(&session_id).unwrap();
2244            session.thread.read_with(cx, |thread, _| {
2245                assert!(
2246                    thread.thinking_enabled(),
2247                    "select_model should enable thinking when model supports it"
2248                );
2249            });
2250        });
2251
2252        // Switch back to the non-thinking model.
2253        let selector = connection.model_selector(&session_id).unwrap();
2254        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2255            .await
2256            .unwrap();
2257
2258        // select_model should have disabled thinking.
2259        agent.read_with(cx, |agent, _| {
2260            let session = agent.sessions.get(&session_id).unwrap();
2261            session.thread.read_with(cx, |thread, _| {
2262                assert!(
2263                    !thread.thinking_enabled(),
2264                    "select_model should disable thinking when model does not support it"
2265                );
2266            });
2267        });
2268    }
2269
2270    #[gpui::test]
2271    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2272        init_test(cx);
2273        let fs = FakeFs::new(cx.executor());
2274        fs.insert_tree("/", json!({ "a": {} })).await;
2275        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2276        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2277        let agent = NativeAgent::new(
2278            project.clone(),
2279            thread_store.clone(),
2280            Templates::new(),
2281            None,
2282            fs.clone(),
2283            &mut cx.to_async(),
2284        )
2285        .await
2286        .unwrap();
2287        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2288
2289        // Register a thinking model.
2290        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2291            "fake-corp",
2292            "fake-thinking",
2293            "Fake Thinking",
2294            true,
2295        ));
2296        let thinking_provider = Arc::new(
2297            FakeLanguageModelProvider::new(
2298                LanguageModelProviderId::from("fake-corp".to_string()),
2299                LanguageModelProviderName::from("Fake Corp".to_string()),
2300            )
2301            .with_models(vec![thinking_model.clone()]),
2302        );
2303        cx.update(|cx| {
2304            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2305                registry.register_provider(thinking_provider, cx);
2306            });
2307        });
2308        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2309
2310        // Create a thread and select the thinking model.
2311        let acp_thread = cx
2312            .update(|cx| {
2313                connection
2314                    .clone()
2315                    .new_session(project.clone(), Path::new("/a"), cx)
2316            })
2317            .await
2318            .unwrap();
2319        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2320
2321        let selector = connection.model_selector(&session_id).unwrap();
2322        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2323            .await
2324            .unwrap();
2325
2326        // Verify thinking is enabled after selecting the thinking model.
2327        let thread = agent.read_with(cx, |agent, _| {
2328            agent.sessions.get(&session_id).unwrap().thread.clone()
2329        });
2330        thread.read_with(cx, |thread, _| {
2331            assert!(
2332                thread.thinking_enabled(),
2333                "thinking should be enabled after selecting thinking model"
2334            );
2335        });
2336
2337        // Send a message so the thread gets persisted.
2338        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2339        let send = cx.foreground_executor().spawn(send);
2340        cx.run_until_parked();
2341
2342        thinking_model.send_last_completion_stream_text_chunk("Response.");
2343        thinking_model.end_last_completion_stream();
2344
2345        send.await.unwrap();
2346        cx.run_until_parked();
2347
2348        // Close the session so it can be reloaded from disk.
2349        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2350            .await
2351            .unwrap();
2352        drop(thread);
2353        drop(acp_thread);
2354        agent.read_with(cx, |agent, _| {
2355            assert!(agent.sessions.is_empty());
2356        });
2357
2358        // Reload the thread and verify thinking_enabled is still true.
2359        let reloaded_acp_thread = agent
2360            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2361            .await
2362            .unwrap();
2363        let reloaded_thread = agent.read_with(cx, |agent, _| {
2364            agent.sessions.get(&session_id).unwrap().thread.clone()
2365        });
2366        reloaded_thread.read_with(cx, |thread, _| {
2367            assert!(
2368                thread.thinking_enabled(),
2369                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2370            );
2371        });
2372
2373        drop(reloaded_acp_thread);
2374    }
2375
2376    #[gpui::test]
2377    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2378        init_test(cx);
2379        let fs = FakeFs::new(cx.executor());
2380        fs.insert_tree("/", json!({ "a": {} })).await;
2381        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2382        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2383        let agent = NativeAgent::new(
2384            project.clone(),
2385            thread_store.clone(),
2386            Templates::new(),
2387            None,
2388            fs.clone(),
2389            &mut cx.to_async(),
2390        )
2391        .await
2392        .unwrap();
2393        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2394
2395        // Register a model where id() != name(), like real Anthropic models
2396        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2397        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2398            "fake-corp",
2399            "custom-model-id",
2400            "Custom Model Display Name",
2401            false,
2402        ));
2403        let provider = Arc::new(
2404            FakeLanguageModelProvider::new(
2405                LanguageModelProviderId::from("fake-corp".to_string()),
2406                LanguageModelProviderName::from("Fake Corp".to_string()),
2407            )
2408            .with_models(vec![model.clone()]),
2409        );
2410        cx.update(|cx| {
2411            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2412                registry.register_provider(provider, cx);
2413            });
2414        });
2415        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2416
2417        // Create a thread and select the model.
2418        let acp_thread = cx
2419            .update(|cx| {
2420                connection
2421                    .clone()
2422                    .new_session(project.clone(), Path::new("/a"), cx)
2423            })
2424            .await
2425            .unwrap();
2426        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2427
2428        let selector = connection.model_selector(&session_id).unwrap();
2429        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2430            .await
2431            .unwrap();
2432
2433        let thread = agent.read_with(cx, |agent, _| {
2434            agent.sessions.get(&session_id).unwrap().thread.clone()
2435        });
2436        thread.read_with(cx, |thread, _| {
2437            assert_eq!(
2438                thread.model().unwrap().id().0.as_ref(),
2439                "custom-model-id",
2440                "model should be set before persisting"
2441            );
2442        });
2443
2444        // Send a message so the thread gets persisted.
2445        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2446        let send = cx.foreground_executor().spawn(send);
2447        cx.run_until_parked();
2448
2449        model.send_last_completion_stream_text_chunk("Response.");
2450        model.end_last_completion_stream();
2451
2452        send.await.unwrap();
2453        cx.run_until_parked();
2454
2455        // Close the session so it can be reloaded from disk.
2456        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2457            .await
2458            .unwrap();
2459        drop(thread);
2460        drop(acp_thread);
2461        agent.read_with(cx, |agent, _| {
2462            assert!(agent.sessions.is_empty());
2463        });
2464
2465        // Reload the thread and verify the model was preserved.
2466        let reloaded_acp_thread = agent
2467            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2468            .await
2469            .unwrap();
2470        let reloaded_thread = agent.read_with(cx, |agent, _| {
2471            agent.sessions.get(&session_id).unwrap().thread.clone()
2472        });
2473        reloaded_thread.read_with(cx, |thread, _| {
2474            let reloaded_model = thread
2475                .model()
2476                .expect("model should be present after reload");
2477            assert_eq!(
2478                reloaded_model.id().0.as_ref(),
2479                "custom-model-id",
2480                "reloaded thread should have the same model, not fall back to the default"
2481            );
2482        });
2483
2484        drop(reloaded_acp_thread);
2485    }
2486
2487    #[gpui::test]
2488    async fn test_save_load_thread(cx: &mut TestAppContext) {
2489        init_test(cx);
2490        let fs = FakeFs::new(cx.executor());
2491        fs.insert_tree(
2492            "/",
2493            json!({
2494                "a": {
2495                    "b.md": "Lorem"
2496                }
2497            }),
2498        )
2499        .await;
2500        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2501        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2502        let agent = NativeAgent::new(
2503            project.clone(),
2504            thread_store.clone(),
2505            Templates::new(),
2506            None,
2507            fs.clone(),
2508            &mut cx.to_async(),
2509        )
2510        .await
2511        .unwrap();
2512        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2513
2514        let acp_thread = cx
2515            .update(|cx| {
2516                connection
2517                    .clone()
2518                    .new_session(project.clone(), Path::new(""), cx)
2519            })
2520            .await
2521            .unwrap();
2522        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2523        let thread = agent.read_with(cx, |agent, _| {
2524            agent.sessions.get(&session_id).unwrap().thread.clone()
2525        });
2526
2527        // Ensure empty threads are not saved, even if they get mutated.
2528        let model = Arc::new(FakeLanguageModel::default());
2529        let summary_model = Arc::new(FakeLanguageModel::default());
2530        thread.update(cx, |thread, cx| {
2531            thread.set_model(model.clone(), cx);
2532            thread.set_summarization_model(Some(summary_model.clone()), cx);
2533        });
2534        cx.run_until_parked();
2535        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2536
2537        let send = acp_thread.update(cx, |thread, cx| {
2538            thread.send(
2539                vec![
2540                    "What does ".into(),
2541                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2542                        "b.md",
2543                        MentionUri::File {
2544                            abs_path: path!("/a/b.md").into(),
2545                        }
2546                        .to_uri()
2547                        .to_string(),
2548                    )),
2549                    " mean?".into(),
2550                ],
2551                cx,
2552            )
2553        });
2554        let send = cx.foreground_executor().spawn(send);
2555        cx.run_until_parked();
2556
2557        model.send_last_completion_stream_text_chunk("Lorem.");
2558        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2559            language_model::TokenUsage {
2560                input_tokens: 150,
2561                output_tokens: 75,
2562                ..Default::default()
2563            },
2564        ));
2565        model.end_last_completion_stream();
2566        cx.run_until_parked();
2567        summary_model
2568            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2569        summary_model.end_last_completion_stream();
2570
2571        send.await.unwrap();
2572        let uri = MentionUri::File {
2573            abs_path: path!("/a/b.md").into(),
2574        }
2575        .to_uri();
2576        acp_thread.read_with(cx, |thread, cx| {
2577            assert_eq!(
2578                thread.to_markdown(cx),
2579                formatdoc! {"
2580                    ## User
2581
2582                    What does [@b.md]({uri}) mean?
2583
2584                    ## Assistant
2585
2586                    Lorem.
2587
2588                "}
2589            )
2590        });
2591
2592        cx.run_until_parked();
2593
2594        // Set a draft prompt with rich content blocks before saving.
2595        let draft_blocks = vec![
2596            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2597            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2598            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2599        ];
2600        acp_thread.update(cx, |thread, _cx| {
2601            thread.set_draft_prompt(Some(draft_blocks.clone()));
2602        });
2603        thread.update(cx, |thread, _cx| {
2604            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2605                item_ix: 5,
2606                offset_in_item: gpui::px(12.5),
2607            }));
2608        });
2609        thread.update(cx, |_thread, cx| cx.notify());
2610        cx.run_until_parked();
2611
2612        // Close the session so it can be reloaded from disk.
2613        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2614            .await
2615            .unwrap();
2616        drop(thread);
2617        drop(acp_thread);
2618        agent.read_with(cx, |agent, _| {
2619            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2620        });
2621
2622        // Ensure the thread can be reloaded from disk.
2623        assert_eq!(
2624            thread_entries(&thread_store, cx),
2625            vec![(
2626                session_id.clone(),
2627                format!("Explaining {}", path!("/a/b.md"))
2628            )]
2629        );
2630        let acp_thread = agent
2631            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2632            .await
2633            .unwrap();
2634        acp_thread.read_with(cx, |thread, cx| {
2635            assert_eq!(
2636                thread.to_markdown(cx),
2637                formatdoc! {"
2638                    ## User
2639
2640                    What does [@b.md]({uri}) mean?
2641
2642                    ## Assistant
2643
2644                    Lorem.
2645
2646                "}
2647            )
2648        });
2649
2650        // Ensure the draft prompt with rich content blocks survived the round-trip.
2651        acp_thread.read_with(cx, |thread, _| {
2652            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2653        });
2654
2655        // Ensure token usage survived the round-trip.
2656        acp_thread.read_with(cx, |thread, _| {
2657            let usage = thread
2658                .token_usage()
2659                .expect("token usage should be restored after reload");
2660            assert_eq!(usage.input_tokens, 150);
2661            assert_eq!(usage.output_tokens, 75);
2662        });
2663
2664        // Ensure scroll position survived the round-trip.
2665        acp_thread.read_with(cx, |thread, _| {
2666            let scroll = thread
2667                .ui_scroll_position()
2668                .expect("scroll position should be restored after reload");
2669            assert_eq!(scroll.item_ix, 5);
2670            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2671        });
2672    }
2673
2674    fn thread_entries(
2675        thread_store: &Entity<ThreadStore>,
2676        cx: &mut TestAppContext,
2677    ) -> Vec<(acp::SessionId, String)> {
2678        thread_store.read_with(cx, |store, _| {
2679            store
2680                .entries()
2681                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2682                .collect::<Vec<_>>()
2683        })
2684    }
2685
2686    fn init_test(cx: &mut TestAppContext) {
2687        env_logger::try_init().ok();
2688        cx.update(|cx| {
2689            let settings_store = SettingsStore::test(cx);
2690            cx.set_global(settings_store);
2691
2692            LanguageModelRegistry::test(cx);
2693        });
2694    }
2695}
2696
2697fn mcp_message_content_to_acp_content_block(
2698    content: context_server::types::MessageContent,
2699) -> acp::ContentBlock {
2700    match content {
2701        context_server::types::MessageContent::Text {
2702            text,
2703            annotations: _,
2704        } => text.into(),
2705        context_server::types::MessageContent::Image {
2706            data,
2707            mime_type,
2708            annotations: _,
2709        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2710        context_server::types::MessageContent::Audio {
2711            data,
2712            mime_type,
2713            annotations: _,
2714        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2715        context_server::types::MessageContent::Resource {
2716            resource,
2717            annotations: _,
2718        } => {
2719            let mut link =
2720                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2721            if let Some(mime_type) = resource.mime_type {
2722                link = link.mime_type(mime_type);
2723            }
2724            acp::ContentBlock::ResourceLink(link)
2725        }
2726    }
2727}