agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  41};
  42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  43use project::{Project, ProjectItem, ProjectPath, Worktree};
  44use prompt_store::{
  45    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  46    WorktreeContext,
  47};
  48use serde::{Deserialize, Serialize};
  49use settings::{LanguageModelSelection, update_settings_file};
  50use std::any::Any;
  51use std::path::{Path, PathBuf};
  52use std::rc::Rc;
  53use std::sync::Arc;
  54use util::ResultExt;
  55use util::path_list::PathList;
  56use util::rel_path::RelPath;
  57
  58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  59pub struct ProjectSnapshot {
  60    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  61    pub timestamp: DateTime<Utc>,
  62}
  63
  64pub struct RulesLoadingError {
  65    pub message: SharedString,
  66}
  67
  68/// Holds both the internal Thread and the AcpThread for a session
  69struct Session {
  70    /// The internal thread that processes messages
  71    thread: Entity<Thread>,
  72    /// The ACP thread that handles protocol communication
  73    acp_thread: Entity<acp_thread::AcpThread>,
  74    pending_save: Task<()>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78pub struct LanguageModels {
  79    /// Access language model by ID
  80    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  81    /// Cached list for returning language model information
  82    model_list: acp_thread::AgentModelList,
  83    refresh_models_rx: watch::Receiver<()>,
  84    refresh_models_tx: watch::Sender<()>,
  85    _authenticate_all_providers_task: Task<()>,
  86}
  87
  88impl LanguageModels {
  89    fn new(cx: &mut App) -> Self {
  90        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  91
  92        let mut this = Self {
  93            models: HashMap::default(),
  94            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  95            refresh_models_rx,
  96            refresh_models_tx,
  97            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  98        };
  99        this.refresh_list(cx);
 100        this
 101    }
 102
 103    fn refresh_list(&mut self, cx: &App) {
 104        let providers = LanguageModelRegistry::global(cx)
 105            .read(cx)
 106            .visible_providers()
 107            .into_iter()
 108            .filter(|provider| provider.is_authenticated(cx))
 109            .collect::<Vec<_>>();
 110
 111        let mut language_model_list = IndexMap::default();
 112        let mut recommended_models = HashSet::default();
 113
 114        let mut recommended = Vec::new();
 115        for provider in &providers {
 116            for model in provider.recommended_models(cx) {
 117                recommended_models.insert((model.provider_id(), model.id()));
 118                recommended.push(Self::map_language_model_to_info(&model, provider));
 119            }
 120        }
 121        if !recommended.is_empty() {
 122            language_model_list.insert(
 123                acp_thread::AgentModelGroupName("Recommended".into()),
 124                recommended,
 125            );
 126        }
 127
 128        let mut models = HashMap::default();
 129        for provider in providers {
 130            let mut provider_models = Vec::new();
 131            for model in provider.provided_models(cx) {
 132                let model_info = Self::map_language_model_to_info(&model, &provider);
 133                let model_id = model_info.id.clone();
 134                provider_models.push(model_info);
 135                models.insert(model_id, model);
 136            }
 137            if !provider_models.is_empty() {
 138                language_model_list.insert(
 139                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 140                    provider_models,
 141                );
 142            }
 143        }
 144
 145        self.models = models;
 146        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 147        self.refresh_models_tx.send(()).ok();
 148    }
 149
 150    fn watch(&self) -> watch::Receiver<()> {
 151        self.refresh_models_rx.clone()
 152    }
 153
 154    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 155        self.models.get(model_id).cloned()
 156    }
 157
 158    fn map_language_model_to_info(
 159        model: &Arc<dyn LanguageModel>,
 160        provider: &Arc<dyn LanguageModelProvider>,
 161    ) -> acp_thread::AgentModelInfo {
 162        acp_thread::AgentModelInfo {
 163            id: Self::model_id(model),
 164            name: model.name().0,
 165            description: None,
 166            icon: Some(match provider.icon() {
 167                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 168                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 169            }),
 170            is_latest: model.is_latest(),
 171            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 172        }
 173    }
 174
 175    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 176        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 177    }
 178
 179    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 181            .read(cx)
 182            .visible_providers()
 183            .iter()
 184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 185            .collect::<Vec<_>>();
 186
 187        cx.background_spawn(async move {
 188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 189                if let Err(err) = authenticate_task.await {
 190                    match err {
 191                        language_model::AuthenticateError::CredentialsNotFound => {
 192                            // Since we're authenticating these providers in the
 193                            // background for the purposes of populating the
 194                            // language selector, we don't care about providers
 195                            // where the credentials are not found.
 196                        }
 197                        language_model::AuthenticateError::ConnectionRefused => {
 198                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 199                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 200                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 201                        }
 202                        _ => {
 203                            // Some providers have noisy failure states that we
 204                            // don't want to spam the logs with every time the
 205                            // language model selector is initialized.
 206                            //
 207                            // Ideally these should have more clear failure modes
 208                            // that we know are safe to ignore here, like what we do
 209                            // with `CredentialsNotFound` above.
 210                            match provider_id.0.as_ref() {
 211                                "lmstudio" | "ollama" => {
 212                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 213                                    //
 214                                    // These fail noisily, so we don't log them.
 215                                }
 216                                "copilot_chat" => {
 217                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 218                                }
 219                                _ => {
 220                                    log::error!(
 221                                        "Failed to authenticate provider: {}: {err:#}",
 222                                        provider_name.0
 223                                    );
 224                                }
 225                            }
 226                        }
 227                    }
 228                }
 229            }
 230        })
 231    }
 232}
 233
 234pub struct NativeAgent {
 235    /// Session ID -> Session mapping
 236    sessions: HashMap<acp::SessionId, Session>,
 237    thread_store: Entity<ThreadStore>,
 238    /// Shared project context for all threads
 239    project_context: Entity<ProjectContext>,
 240    project_context_needs_refresh: watch::Sender<()>,
 241    _maintain_project_context: Task<Result<()>>,
 242    context_server_registry: Entity<ContextServerRegistry>,
 243    /// Shared templates for all threads
 244    templates: Arc<Templates>,
 245    /// Cached model information
 246    models: LanguageModels,
 247    project: Entity<Project>,
 248    prompt_store: Option<Entity<PromptStore>>,
 249    fs: Arc<dyn Fs>,
 250    _subscriptions: Vec<Subscription>,
 251}
 252
 253impl NativeAgent {
 254    pub async fn new(
 255        project: Entity<Project>,
 256        thread_store: Entity<ThreadStore>,
 257        templates: Arc<Templates>,
 258        prompt_store: Option<Entity<PromptStore>>,
 259        fs: Arc<dyn Fs>,
 260        cx: &mut AsyncApp,
 261    ) -> Result<Entity<NativeAgent>> {
 262        log::debug!("Creating new NativeAgent");
 263
 264        let project_context = cx
 265            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 266            .await;
 267
 268        Ok(cx.new(|cx| {
 269            let context_server_store = project.read(cx).context_server_store();
 270            let context_server_registry =
 271                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 272
 273            let mut subscriptions = vec![
 274                cx.subscribe(&project, Self::handle_project_event),
 275                cx.subscribe(
 276                    &LanguageModelRegistry::global(cx),
 277                    Self::handle_models_updated_event,
 278                ),
 279                cx.subscribe(
 280                    &context_server_store,
 281                    Self::handle_context_server_store_updated,
 282                ),
 283                cx.subscribe(
 284                    &context_server_registry,
 285                    Self::handle_context_server_registry_event,
 286                ),
 287            ];
 288            if let Some(prompt_store) = prompt_store.as_ref() {
 289                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 290            }
 291
 292            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 293                watch::channel(());
 294            Self {
 295                sessions: HashMap::default(),
 296                thread_store,
 297                project_context: cx.new(|_| project_context),
 298                project_context_needs_refresh: project_context_needs_refresh_tx,
 299                _maintain_project_context: cx.spawn(async move |this, cx| {
 300                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 301                }),
 302                context_server_registry,
 303                templates,
 304                models: LanguageModels::new(cx),
 305                project,
 306                prompt_store,
 307                fs,
 308                _subscriptions: subscriptions,
 309            }
 310        }))
 311    }
 312
 313    fn new_session(
 314        &mut self,
 315        project: Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> Entity<AcpThread> {
 318        // Create Thread
 319        // Fetch default model from registry settings
 320        let registry = LanguageModelRegistry::read_global(cx);
 321        // Log available models for debugging
 322        let available_count = registry.available_models(cx).count();
 323        log::debug!("Total available models: {}", available_count);
 324
 325        let default_model = registry.default_model().and_then(|default_model| {
 326            self.models
 327                .model_from_id(&LanguageModels::model_id(&default_model.model))
 328        });
 329        let thread = cx.new(|cx| {
 330            Thread::new(
 331                project.clone(),
 332                self.project_context.clone(),
 333                self.context_server_registry.clone(),
 334                self.templates.clone(),
 335                default_model,
 336                cx,
 337            )
 338        });
 339
 340        self.register_session(thread, cx)
 341    }
 342
 343    fn register_session(
 344        &mut self,
 345        thread_handle: Entity<Thread>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let draft_prompt = thread.draft_prompt().map(Vec::from);
 355        let scroll_position = thread.ui_scroll_position();
 356        let token_usage = thread.latest_token_usage();
 357        let project = thread.project.clone();
 358        let action_log = thread.action_log.clone();
 359        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 360        let acp_thread = cx.new(|cx| {
 361            let mut acp_thread = acp_thread::AcpThread::new(
 362                parent_session_id,
 363                title,
 364                connection,
 365                project.clone(),
 366                action_log.clone(),
 367                session_id.clone(),
 368                prompt_capabilities_rx,
 369                cx,
 370            );
 371            acp_thread.set_draft_prompt(draft_prompt);
 372            acp_thread.set_ui_scroll_position(scroll_position);
 373            acp_thread.update_token_usage(token_usage, cx);
 374            acp_thread
 375        });
 376
 377        let registry = LanguageModelRegistry::read_global(cx);
 378        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 379
 380        let weak = cx.weak_entity();
 381        let weak_thread = thread_handle.downgrade();
 382        thread_handle.update(cx, |thread, cx| {
 383            thread.set_summarization_model(summarization_model, cx);
 384            thread.add_default_tools(
 385                Rc::new(NativeThreadEnvironment {
 386                    acp_thread: acp_thread.downgrade(),
 387                    thread: weak_thread,
 388                    agent: weak,
 389                }) as _,
 390                cx,
 391            )
 392        });
 393
 394        let subscriptions = vec![
 395            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 396            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 397            cx.observe(&thread_handle, move |this, thread, cx| {
 398                this.save_thread(thread, cx)
 399            }),
 400        ];
 401
 402        self.sessions.insert(
 403            session_id,
 404            Session {
 405                thread: thread_handle,
 406                acp_thread: acp_thread.clone(),
 407                _subscriptions: subscriptions,
 408                pending_save: Task::ready(()),
 409            },
 410        );
 411
 412        self.update_available_commands(cx);
 413
 414        acp_thread
 415    }
 416
 417    pub fn models(&self) -> &LanguageModels {
 418        &self.models
 419    }
 420
 421    async fn maintain_project_context(
 422        this: WeakEntity<Self>,
 423        mut needs_refresh: watch::Receiver<()>,
 424        cx: &mut AsyncApp,
 425    ) -> Result<()> {
 426        while needs_refresh.changed().await.is_ok() {
 427            let project_context = this
 428                .update(cx, |this, cx| {
 429                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 430                })?
 431                .await;
 432            this.update(cx, |this, cx| {
 433                this.project_context = cx.new(|_| project_context);
 434            })?;
 435        }
 436
 437        Ok(())
 438    }
 439
 440    fn build_project_context(
 441        project: &Entity<Project>,
 442        prompt_store: Option<&Entity<PromptStore>>,
 443        cx: &mut App,
 444    ) -> Task<ProjectContext> {
 445        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 446        let worktree_tasks = worktrees
 447            .into_iter()
 448            .map(|worktree| {
 449                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 450            })
 451            .collect::<Vec<_>>();
 452        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 453            prompt_store.read_with(cx, |prompt_store, cx| {
 454                let prompts = prompt_store.default_prompt_metadata();
 455                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 456                    let contents = prompt_store.load(prompt_metadata.id, cx);
 457                    async move { (contents.await, prompt_metadata) }
 458                });
 459                cx.background_spawn(future::join_all(load_tasks))
 460            })
 461        } else {
 462            Task::ready(vec![])
 463        };
 464
 465        cx.spawn(async move |_cx| {
 466            let (worktrees, default_user_rules) =
 467                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 468
 469            let worktrees = worktrees
 470                .into_iter()
 471                .map(|(worktree, _rules_error)| {
 472                    // TODO: show error message
 473                    // if let Some(rules_error) = rules_error {
 474                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 475                    // }
 476                    worktree
 477                })
 478                .collect::<Vec<_>>();
 479
 480            let default_user_rules = default_user_rules
 481                .into_iter()
 482                .flat_map(|(contents, prompt_metadata)| match contents {
 483                    Ok(contents) => Some(UserRulesContext {
 484                        uuid: prompt_metadata.id.as_user()?,
 485                        title: prompt_metadata.title.map(|title| title.to_string()),
 486                        contents,
 487                    }),
 488                    Err(_err) => {
 489                        // TODO: show error message
 490                        // this.update(cx, |_, cx| {
 491                        //     cx.emit(RulesLoadingError {
 492                        //         message: format!("{err:?}").into(),
 493                        //     });
 494                        // })
 495                        // .ok();
 496                        None
 497                    }
 498                })
 499                .collect::<Vec<_>>();
 500
 501            ProjectContext::new(worktrees, default_user_rules)
 502        })
 503    }
 504
 505    fn load_worktree_info_for_system_prompt(
 506        worktree: Entity<Worktree>,
 507        project: Entity<Project>,
 508        cx: &mut App,
 509    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 510        let tree = worktree.read(cx);
 511        let root_name = tree.root_name_str().into();
 512        let abs_path = tree.abs_path();
 513
 514        let mut context = WorktreeContext {
 515            root_name,
 516            abs_path,
 517            rules_file: None,
 518        };
 519
 520        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 521        let Some(rules_task) = rules_task else {
 522            return Task::ready((context, None));
 523        };
 524
 525        cx.spawn(async move |_| {
 526            let (rules_file, rules_file_error) = match rules_task.await {
 527                Ok(rules_file) => (Some(rules_file), None),
 528                Err(err) => (
 529                    None,
 530                    Some(RulesLoadingError {
 531                        message: format!("{err}").into(),
 532                    }),
 533                ),
 534            };
 535            context.rules_file = rules_file;
 536            (context, rules_file_error)
 537        })
 538    }
 539
 540    fn load_worktree_rules_file(
 541        worktree: Entity<Worktree>,
 542        project: Entity<Project>,
 543        cx: &mut App,
 544    ) -> Option<Task<Result<RulesFileContext>>> {
 545        let worktree = worktree.read(cx);
 546        let worktree_id = worktree.id();
 547        let selected_rules_file = RULES_FILE_NAMES
 548            .into_iter()
 549            .filter_map(|name| {
 550                worktree
 551                    .entry_for_path(RelPath::unix(name).unwrap())
 552                    .filter(|entry| entry.is_file())
 553                    .map(|entry| entry.path.clone())
 554            })
 555            .next();
 556
 557        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 558        // supported. This doesn't seem to occur often in GitHub repositories.
 559        selected_rules_file.map(|path_in_worktree| {
 560            let project_path = ProjectPath {
 561                worktree_id,
 562                path: path_in_worktree.clone(),
 563            };
 564            let buffer_task =
 565                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 566            let rope_task = cx.spawn(async move |cx| {
 567                let buffer = buffer_task.await?;
 568                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 569                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 570                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 571                })?;
 572                anyhow::Ok((project_entry_id, rope))
 573            });
 574            // Build a string from the rope on a background thread.
 575            cx.background_spawn(async move {
 576                let (project_entry_id, rope) = rope_task.await?;
 577                anyhow::Ok(RulesFileContext {
 578                    path_in_worktree,
 579                    text: rope.to_string().trim().to_string(),
 580                    project_entry_id: project_entry_id.to_usize(),
 581                })
 582            })
 583        })
 584    }
 585
 586    fn handle_thread_title_updated(
 587        &mut self,
 588        thread: Entity<Thread>,
 589        _: &TitleUpdated,
 590        cx: &mut Context<Self>,
 591    ) {
 592        let session_id = thread.read(cx).id();
 593        let Some(session) = self.sessions.get(session_id) else {
 594            return;
 595        };
 596        let thread = thread.downgrade();
 597        let acp_thread = session.acp_thread.downgrade();
 598        cx.spawn(async move |_, cx| {
 599            let title = thread.read_with(cx, |thread, _| thread.title())?;
 600            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 601            task.await
 602        })
 603        .detach_and_log_err(cx);
 604    }
 605
 606    fn handle_thread_token_usage_updated(
 607        &mut self,
 608        thread: Entity<Thread>,
 609        usage: &TokenUsageUpdated,
 610        cx: &mut Context<Self>,
 611    ) {
 612        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 613            return;
 614        };
 615        session.acp_thread.update(cx, |acp_thread, cx| {
 616            acp_thread.update_token_usage(usage.0.clone(), cx);
 617        });
 618    }
 619
 620    fn handle_project_event(
 621        &mut self,
 622        _project: Entity<Project>,
 623        event: &project::Event,
 624        _cx: &mut Context<Self>,
 625    ) {
 626        match event {
 627            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 628                self.project_context_needs_refresh.send(()).ok();
 629            }
 630            project::Event::WorktreeUpdatedEntries(_, items) => {
 631                if items.iter().any(|(path, _, _)| {
 632                    RULES_FILE_NAMES
 633                        .iter()
 634                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 635                }) {
 636                    self.project_context_needs_refresh.send(()).ok();
 637                }
 638            }
 639            _ => {}
 640        }
 641    }
 642
 643    fn handle_prompts_updated_event(
 644        &mut self,
 645        _prompt_store: Entity<PromptStore>,
 646        _event: &prompt_store::PromptsUpdatedEvent,
 647        _cx: &mut Context<Self>,
 648    ) {
 649        self.project_context_needs_refresh.send(()).ok();
 650    }
 651
 652    fn handle_models_updated_event(
 653        &mut self,
 654        _registry: Entity<LanguageModelRegistry>,
 655        _event: &language_model::Event,
 656        cx: &mut Context<Self>,
 657    ) {
 658        self.models.refresh_list(cx);
 659
 660        let registry = LanguageModelRegistry::read_global(cx);
 661        let default_model = registry.default_model().map(|m| m.model);
 662        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 663
 664        for session in self.sessions.values_mut() {
 665            session.thread.update(cx, |thread, cx| {
 666                if thread.model().is_none()
 667                    && let Some(model) = default_model.clone()
 668                {
 669                    thread.set_model(model, cx);
 670                    cx.notify();
 671                }
 672                thread.set_summarization_model(summarization_model.clone(), cx);
 673            });
 674        }
 675    }
 676
 677    fn handle_context_server_store_updated(
 678        &mut self,
 679        _store: Entity<project::context_server_store::ContextServerStore>,
 680        _event: &project::context_server_store::ServerStatusChangedEvent,
 681        cx: &mut Context<Self>,
 682    ) {
 683        self.update_available_commands(cx);
 684    }
 685
 686    fn handle_context_server_registry_event(
 687        &mut self,
 688        _registry: Entity<ContextServerRegistry>,
 689        event: &ContextServerRegistryEvent,
 690        cx: &mut Context<Self>,
 691    ) {
 692        match event {
 693            ContextServerRegistryEvent::ToolsChanged => {}
 694            ContextServerRegistryEvent::PromptsChanged => {
 695                self.update_available_commands(cx);
 696            }
 697        }
 698    }
 699
 700    fn update_available_commands(&self, cx: &mut Context<Self>) {
 701        let available_commands = self.build_available_commands(cx);
 702        for session in self.sessions.values() {
 703            session.acp_thread.update(cx, |thread, cx| {
 704                thread
 705                    .handle_session_update(
 706                        acp::SessionUpdate::AvailableCommandsUpdate(
 707                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 708                        ),
 709                        cx,
 710                    )
 711                    .log_err();
 712            });
 713        }
 714    }
 715
 716    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 717        let registry = self.context_server_registry.read(cx);
 718
 719        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 720        for context_server_prompt in registry.prompts() {
 721            *prompt_name_counts
 722                .entry(context_server_prompt.prompt.name.as_str())
 723                .or_insert(0) += 1;
 724        }
 725
 726        registry
 727            .prompts()
 728            .flat_map(|context_server_prompt| {
 729                let prompt = &context_server_prompt.prompt;
 730
 731                let should_prefix = prompt_name_counts
 732                    .get(prompt.name.as_str())
 733                    .copied()
 734                    .unwrap_or(0)
 735                    > 1;
 736
 737                let name = if should_prefix {
 738                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 739                } else {
 740                    prompt.name.clone()
 741                };
 742
 743                let mut command = acp::AvailableCommand::new(
 744                    name,
 745                    prompt.description.clone().unwrap_or_default(),
 746                );
 747
 748                match prompt.arguments.as_deref() {
 749                    Some([arg]) => {
 750                        let hint = format!("<{}>", arg.name);
 751
 752                        command = command.input(acp::AvailableCommandInput::Unstructured(
 753                            acp::UnstructuredCommandInput::new(hint),
 754                        ));
 755                    }
 756                    Some([]) | None => {}
 757                    Some(_) => {
 758                        // skip >1 argument commands since we don't support them yet
 759                        return None;
 760                    }
 761                }
 762
 763                Some(command)
 764            })
 765            .collect()
 766    }
 767
 768    pub fn load_thread(
 769        &mut self,
 770        id: acp::SessionId,
 771        cx: &mut Context<Self>,
 772    ) -> Task<Result<Entity<Thread>>> {
 773        let database_future = ThreadsDatabase::connect(cx);
 774        cx.spawn(async move |this, cx| {
 775            let database = database_future.await.map_err(|err| anyhow!(err))?;
 776            let db_thread = database
 777                .load_thread(id.clone())
 778                .await?
 779                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 780
 781            this.update(cx, |this, cx| {
 782                let summarization_model = LanguageModelRegistry::read_global(cx)
 783                    .thread_summary_model()
 784                    .map(|c| c.model);
 785
 786                cx.new(|cx| {
 787                    let mut thread = Thread::from_db(
 788                        id.clone(),
 789                        db_thread,
 790                        this.project.clone(),
 791                        this.project_context.clone(),
 792                        this.context_server_registry.clone(),
 793                        this.templates.clone(),
 794                        cx,
 795                    );
 796                    thread.set_summarization_model(summarization_model, cx);
 797                    thread
 798                })
 799            })
 800        })
 801    }
 802
 803    pub fn open_thread(
 804        &mut self,
 805        id: acp::SessionId,
 806        cx: &mut Context<Self>,
 807    ) -> Task<Result<Entity<AcpThread>>> {
 808        if let Some(session) = self.sessions.get(&id) {
 809            return Task::ready(Ok(session.acp_thread.clone()));
 810        }
 811
 812        let task = self.load_thread(id, cx);
 813        cx.spawn(async move |this, cx| {
 814            let thread = task.await?;
 815            let acp_thread =
 816                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 817            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 818            cx.update(|cx| {
 819                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 820            })
 821            .await?;
 822            Ok(acp_thread)
 823        })
 824    }
 825
 826    pub fn thread_summary(
 827        &mut self,
 828        id: acp::SessionId,
 829        cx: &mut Context<Self>,
 830    ) -> Task<Result<SharedString>> {
 831        let thread = self.open_thread(id.clone(), cx);
 832        cx.spawn(async move |this, cx| {
 833            let acp_thread = thread.await?;
 834            let result = this
 835                .update(cx, |this, cx| {
 836                    this.sessions
 837                        .get(&id)
 838                        .unwrap()
 839                        .thread
 840                        .update(cx, |thread, cx| thread.summary(cx))
 841                })?
 842                .await
 843                .context("Failed to generate summary")?;
 844            drop(acp_thread);
 845            Ok(result)
 846        })
 847    }
 848
 849    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 850        if thread.read(cx).is_empty() {
 851            return;
 852        }
 853
 854        let id = thread.read(cx).id().clone();
 855        let Some(session) = self.sessions.get_mut(&id) else {
 856            return;
 857        };
 858
 859        let folder_paths = PathList::new(
 860            &self
 861                .project
 862                .read(cx)
 863                .visible_worktrees(cx)
 864                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 865                .collect::<Vec<_>>(),
 866        );
 867
 868        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 869        let database_future = ThreadsDatabase::connect(cx);
 870        let db_thread = thread.update(cx, |thread, cx| {
 871            thread.set_draft_prompt(draft_prompt);
 872            thread.to_db(cx)
 873        });
 874        let thread_store = self.thread_store.clone();
 875        session.pending_save = cx.spawn(async move |_, cx| {
 876            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 877                return;
 878            };
 879            let db_thread = db_thread.await;
 880            database
 881                .save_thread(id, db_thread, folder_paths)
 882                .await
 883                .log_err();
 884            thread_store.update(cx, |store, cx| store.reload(cx));
 885        });
 886    }
 887
 888    fn send_mcp_prompt(
 889        &self,
 890        message_id: UserMessageId,
 891        session_id: agent_client_protocol::SessionId,
 892        prompt_name: String,
 893        server_id: ContextServerId,
 894        arguments: HashMap<String, String>,
 895        original_content: Vec<acp::ContentBlock>,
 896        cx: &mut Context<Self>,
 897    ) -> Task<Result<acp::PromptResponse>> {
 898        let server_store = self.context_server_registry.read(cx).server_store().clone();
 899        let path_style = self.project.read(cx).path_style(cx);
 900
 901        cx.spawn(async move |this, cx| {
 902            let prompt =
 903                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 904
 905            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 906                let session = this
 907                    .sessions
 908                    .get(&session_id)
 909                    .context("Failed to get session")?;
 910                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 911            })??;
 912
 913            let mut last_is_user = true;
 914
 915            thread.update(cx, |thread, cx| {
 916                thread.push_acp_user_block(
 917                    message_id,
 918                    original_content.into_iter().skip(1),
 919                    path_style,
 920                    cx,
 921                );
 922            });
 923
 924            for message in prompt.messages {
 925                let context_server::types::PromptMessage { role, content } = message;
 926                let block = mcp_message_content_to_acp_content_block(content);
 927
 928                match role {
 929                    context_server::types::Role::User => {
 930                        let id = acp_thread::UserMessageId::new();
 931
 932                        acp_thread.update(cx, |acp_thread, cx| {
 933                            acp_thread.push_user_content_block_with_indent(
 934                                Some(id.clone()),
 935                                block.clone(),
 936                                true,
 937                                cx,
 938                            );
 939                        });
 940
 941                        thread.update(cx, |thread, cx| {
 942                            thread.push_acp_user_block(id, [block], path_style, cx);
 943                        });
 944                    }
 945                    context_server::types::Role::Assistant => {
 946                        acp_thread.update(cx, |acp_thread, cx| {
 947                            acp_thread.push_assistant_content_block_with_indent(
 948                                block.clone(),
 949                                false,
 950                                true,
 951                                cx,
 952                            );
 953                        });
 954
 955                        thread.update(cx, |thread, cx| {
 956                            thread.push_acp_agent_block(block, cx);
 957                        });
 958                    }
 959                }
 960
 961                last_is_user = role == context_server::types::Role::User;
 962            }
 963
 964            let response_stream = thread.update(cx, |thread, cx| {
 965                if last_is_user {
 966                    thread.send_existing(cx)
 967                } else {
 968                    // Resume if MCP prompt did not end with a user message
 969                    thread.resume(cx)
 970                }
 971            })?;
 972
 973            cx.update(|cx| {
 974                NativeAgentConnection::handle_thread_events(
 975                    response_stream,
 976                    acp_thread.downgrade(),
 977                    cx,
 978                )
 979            })
 980            .await
 981        })
 982    }
 983}
 984
 985/// Wrapper struct that implements the AgentConnection trait
 986#[derive(Clone)]
 987pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 988
 989impl NativeAgentConnection {
 990    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 991        self.0
 992            .read(cx)
 993            .sessions
 994            .get(session_id)
 995            .map(|session| session.thread.clone())
 996    }
 997
 998    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 999        self.0.update(cx, |this, cx| this.load_thread(id, cx))
1000    }
1001
1002    fn run_turn(
1003        &self,
1004        session_id: acp::SessionId,
1005        cx: &mut App,
1006        f: impl 'static
1007        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1008    ) -> Task<Result<acp::PromptResponse>> {
1009        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1010            agent
1011                .sessions
1012                .get_mut(&session_id)
1013                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1014        }) else {
1015            return Task::ready(Err(anyhow!("Session not found")));
1016        };
1017        log::debug!("Found session for: {}", session_id);
1018
1019        let response_stream = match f(thread, cx) {
1020            Ok(stream) => stream,
1021            Err(err) => return Task::ready(Err(err)),
1022        };
1023        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1024    }
1025
1026    fn handle_thread_events(
1027        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1028        acp_thread: WeakEntity<AcpThread>,
1029        cx: &App,
1030    ) -> Task<Result<acp::PromptResponse>> {
1031        cx.spawn(async move |cx| {
1032            // Handle response stream and forward to session.acp_thread
1033            while let Some(result) = events.next().await {
1034                match result {
1035                    Ok(event) => {
1036                        log::trace!("Received completion event: {:?}", event);
1037
1038                        match event {
1039                            ThreadEvent::UserMessage(message) => {
1040                                acp_thread.update(cx, |thread, cx| {
1041                                    for content in message.content {
1042                                        thread.push_user_content_block(
1043                                            Some(message.id.clone()),
1044                                            content.into(),
1045                                            cx,
1046                                        );
1047                                    }
1048                                })?;
1049                            }
1050                            ThreadEvent::AgentText(text) => {
1051                                acp_thread.update(cx, |thread, cx| {
1052                                    thread.push_assistant_content_block(text.into(), false, cx)
1053                                })?;
1054                            }
1055                            ThreadEvent::AgentThinking(text) => {
1056                                acp_thread.update(cx, |thread, cx| {
1057                                    thread.push_assistant_content_block(text.into(), true, cx)
1058                                })?;
1059                            }
1060                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1061                                tool_call,
1062                                options,
1063                                response,
1064                                context: _,
1065                            }) => {
1066                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1067                                    thread.request_tool_call_authorization(tool_call, options, cx)
1068                                })??;
1069                                cx.background_spawn(async move {
1070                                    if let acp::RequestPermissionOutcome::Selected(
1071                                        acp::SelectedPermissionOutcome { option_id, .. },
1072                                    ) = outcome_task.await
1073                                    {
1074                                        response
1075                                            .send(option_id)
1076                                            .map(|_| anyhow!("authorization receiver was dropped"))
1077                                            .log_err();
1078                                    }
1079                                })
1080                                .detach();
1081                            }
1082                            ThreadEvent::ToolCall(tool_call) => {
1083                                acp_thread.update(cx, |thread, cx| {
1084                                    thread.upsert_tool_call(tool_call, cx)
1085                                })??;
1086                            }
1087                            ThreadEvent::ToolCallUpdate(update) => {
1088                                acp_thread.update(cx, |thread, cx| {
1089                                    thread.update_tool_call(update, cx)
1090                                })??;
1091                            }
1092                            ThreadEvent::SubagentSpawned(session_id) => {
1093                                acp_thread.update(cx, |thread, cx| {
1094                                    thread.subagent_spawned(session_id, cx);
1095                                })?;
1096                            }
1097                            ThreadEvent::Retry(status) => {
1098                                acp_thread.update(cx, |thread, cx| {
1099                                    thread.update_retry_status(status, cx)
1100                                })?;
1101                            }
1102                            ThreadEvent::Stop(stop_reason) => {
1103                                log::debug!("Assistant message complete: {:?}", stop_reason);
1104                                return Ok(acp::PromptResponse::new(stop_reason));
1105                            }
1106                        }
1107                    }
1108                    Err(e) => {
1109                        log::error!("Error in model response stream: {:?}", e);
1110                        return Err(e);
1111                    }
1112                }
1113            }
1114
1115            log::debug!("Response stream completed");
1116            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1117        })
1118    }
1119}
1120
1121struct Command<'a> {
1122    prompt_name: &'a str,
1123    arg_value: &'a str,
1124    explicit_server_id: Option<&'a str>,
1125}
1126
1127impl<'a> Command<'a> {
1128    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1129        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1130            return None;
1131        };
1132        let text = text_content.text.trim();
1133        let command = text.strip_prefix('/')?;
1134        let (command, arg_value) = command
1135            .split_once(char::is_whitespace)
1136            .unwrap_or((command, ""));
1137
1138        if let Some((server_id, prompt_name)) = command.split_once('.') {
1139            Some(Self {
1140                prompt_name,
1141                arg_value,
1142                explicit_server_id: Some(server_id),
1143            })
1144        } else {
1145            Some(Self {
1146                prompt_name: command,
1147                arg_value,
1148                explicit_server_id: None,
1149            })
1150        }
1151    }
1152}
1153
1154struct NativeAgentModelSelector {
1155    session_id: acp::SessionId,
1156    connection: NativeAgentConnection,
1157}
1158
1159impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1160    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1161        log::debug!("NativeAgentConnection::list_models called");
1162        let list = self.connection.0.read(cx).models.model_list.clone();
1163        Task::ready(if list.is_empty() {
1164            Err(anyhow::anyhow!("No models available"))
1165        } else {
1166            Ok(list)
1167        })
1168    }
1169
1170    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1171        log::debug!(
1172            "Setting model for session {}: {}",
1173            self.session_id,
1174            model_id
1175        );
1176        let Some(thread) = self
1177            .connection
1178            .0
1179            .read(cx)
1180            .sessions
1181            .get(&self.session_id)
1182            .map(|session| session.thread.clone())
1183        else {
1184            return Task::ready(Err(anyhow!("Session not found")));
1185        };
1186
1187        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1188            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1189        };
1190
1191        // We want to reset the effort level when switching models, as the currently-selected effort level may
1192        // not be compatible.
1193        let effort = model
1194            .default_effort_level()
1195            .map(|effort_level| effort_level.value.to_string());
1196
1197        thread.update(cx, |thread, cx| {
1198            thread.set_model(model.clone(), cx);
1199            thread.set_thinking_effort(effort.clone(), cx);
1200            thread.set_thinking_enabled(model.supports_thinking(), cx);
1201        });
1202
1203        update_settings_file(
1204            self.connection.0.read(cx).fs.clone(),
1205            cx,
1206            move |settings, cx| {
1207                let provider = model.provider_id().0.to_string();
1208                let model = model.id().0.to_string();
1209                let enable_thinking = thread.read(cx).thinking_enabled();
1210                settings
1211                    .agent
1212                    .get_or_insert_default()
1213                    .set_model(LanguageModelSelection {
1214                        provider: provider.into(),
1215                        model,
1216                        enable_thinking,
1217                        effort,
1218                    });
1219            },
1220        );
1221
1222        Task::ready(Ok(()))
1223    }
1224
1225    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1226        let Some(thread) = self
1227            .connection
1228            .0
1229            .read(cx)
1230            .sessions
1231            .get(&self.session_id)
1232            .map(|session| session.thread.clone())
1233        else {
1234            return Task::ready(Err(anyhow!("Session not found")));
1235        };
1236        let Some(model) = thread.read(cx).model() else {
1237            return Task::ready(Err(anyhow!("Model not found")));
1238        };
1239        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1240        else {
1241            return Task::ready(Err(anyhow!("Provider not found")));
1242        };
1243        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1244            model, &provider,
1245        )))
1246    }
1247
1248    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1249        Some(self.connection.0.read(cx).models.watch())
1250    }
1251
1252    fn should_render_footer(&self) -> bool {
1253        true
1254    }
1255}
1256
1257impl acp_thread::AgentConnection for NativeAgentConnection {
1258    fn telemetry_id(&self) -> SharedString {
1259        "zed".into()
1260    }
1261
1262    fn new_session(
1263        self: Rc<Self>,
1264        project: Entity<Project>,
1265        cwd: &Path,
1266        cx: &mut App,
1267    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1268        log::debug!("Creating new thread for project at: {cwd:?}");
1269        Task::ready(Ok(self
1270            .0
1271            .update(cx, |agent, cx| agent.new_session(project, cx))))
1272    }
1273
1274    fn supports_load_session(&self) -> bool {
1275        true
1276    }
1277
1278    fn load_session(
1279        self: Rc<Self>,
1280        session: AgentSessionInfo,
1281        _project: Entity<Project>,
1282        _cwd: &Path,
1283        cx: &mut App,
1284    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1285        self.0
1286            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1287    }
1288
1289    fn supports_close_session(&self) -> bool {
1290        true
1291    }
1292
1293    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1294        self.0.update(cx, |agent, _cx| {
1295            agent.sessions.remove(session_id);
1296        });
1297        Task::ready(Ok(()))
1298    }
1299
1300    fn auth_methods(&self) -> &[acp::AuthMethod] {
1301        &[] // No auth for in-process
1302    }
1303
1304    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1305        Task::ready(Ok(()))
1306    }
1307
1308    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1309        Some(Rc::new(NativeAgentModelSelector {
1310            session_id: session_id.clone(),
1311            connection: self.clone(),
1312        }) as Rc<dyn AgentModelSelector>)
1313    }
1314
1315    fn prompt(
1316        &self,
1317        id: Option<acp_thread::UserMessageId>,
1318        params: acp::PromptRequest,
1319        cx: &mut App,
1320    ) -> Task<Result<acp::PromptResponse>> {
1321        let id = id.expect("UserMessageId is required");
1322        let session_id = params.session_id.clone();
1323        log::info!("Received prompt request for session: {}", session_id);
1324        log::debug!("Prompt blocks count: {}", params.prompt.len());
1325
1326        if let Some(parsed_command) = Command::parse(&params.prompt) {
1327            let registry = self.0.read(cx).context_server_registry.read(cx);
1328
1329            let explicit_server_id = parsed_command
1330                .explicit_server_id
1331                .map(|server_id| ContextServerId(server_id.into()));
1332
1333            if let Some(prompt) =
1334                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1335            {
1336                let arguments = if !parsed_command.arg_value.is_empty()
1337                    && let Some(arg_name) = prompt
1338                        .prompt
1339                        .arguments
1340                        .as_ref()
1341                        .and_then(|args| args.first())
1342                        .map(|arg| arg.name.clone())
1343                {
1344                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1345                } else {
1346                    Default::default()
1347                };
1348
1349                let prompt_name = prompt.prompt.name.clone();
1350                let server_id = prompt.server_id.clone();
1351
1352                return self.0.update(cx, |agent, cx| {
1353                    agent.send_mcp_prompt(
1354                        id,
1355                        session_id.clone(),
1356                        prompt_name,
1357                        server_id,
1358                        arguments,
1359                        params.prompt,
1360                        cx,
1361                    )
1362                });
1363            };
1364        };
1365
1366        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1367
1368        self.run_turn(session_id, cx, move |thread, cx| {
1369            let content: Vec<UserMessageContent> = params
1370                .prompt
1371                .into_iter()
1372                .map(|block| UserMessageContent::from_content_block(block, path_style))
1373                .collect::<Vec<_>>();
1374            log::debug!("Converted prompt to message: {} chars", content.len());
1375            log::debug!("Message id: {:?}", id);
1376            log::debug!("Message content: {:?}", content);
1377
1378            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1379        })
1380    }
1381
1382    fn retry(
1383        &self,
1384        session_id: &acp::SessionId,
1385        _cx: &App,
1386    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1387        Some(Rc::new(NativeAgentSessionRetry {
1388            connection: self.clone(),
1389            session_id: session_id.clone(),
1390        }) as _)
1391    }
1392
1393    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1394        log::info!("Cancelling on session: {}", session_id);
1395        self.0.update(cx, |agent, cx| {
1396            if let Some(session) = agent.sessions.get(session_id) {
1397                session
1398                    .thread
1399                    .update(cx, |thread, cx| thread.cancel(cx))
1400                    .detach();
1401            }
1402        });
1403    }
1404
1405    fn truncate(
1406        &self,
1407        session_id: &agent_client_protocol::SessionId,
1408        cx: &App,
1409    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1410        self.0.read_with(cx, |agent, _cx| {
1411            agent.sessions.get(session_id).map(|session| {
1412                Rc::new(NativeAgentSessionTruncate {
1413                    thread: session.thread.clone(),
1414                    acp_thread: session.acp_thread.downgrade(),
1415                }) as _
1416            })
1417        })
1418    }
1419
1420    fn set_title(
1421        &self,
1422        session_id: &acp::SessionId,
1423        cx: &App,
1424    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1425        self.0.read_with(cx, |agent, _cx| {
1426            agent
1427                .sessions
1428                .get(session_id)
1429                .filter(|s| !s.thread.read(cx).is_subagent())
1430                .map(|session| {
1431                    Rc::new(NativeAgentSessionSetTitle {
1432                        thread: session.thread.clone(),
1433                    }) as _
1434                })
1435        })
1436    }
1437
1438    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1439        let thread_store = self.0.read(cx).thread_store.clone();
1440        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1441    }
1442
1443    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1444        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1445    }
1446
1447    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1448        self
1449    }
1450}
1451
1452impl acp_thread::AgentTelemetry for NativeAgentConnection {
1453    fn thread_data(
1454        &self,
1455        session_id: &acp::SessionId,
1456        cx: &mut App,
1457    ) -> Task<Result<serde_json::Value>> {
1458        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1459            return Task::ready(Err(anyhow!("Session not found")));
1460        };
1461
1462        let task = session.thread.read(cx).to_db(cx);
1463        cx.background_spawn(async move {
1464            serde_json::to_value(task.await).context("Failed to serialize thread")
1465        })
1466    }
1467}
1468
1469pub struct NativeAgentSessionList {
1470    thread_store: Entity<ThreadStore>,
1471    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1472    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1473    _subscription: Subscription,
1474}
1475
1476impl NativeAgentSessionList {
1477    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1478        let (tx, rx) = smol::channel::unbounded();
1479        let this_tx = tx.clone();
1480        let subscription = cx.observe(&thread_store, move |_, _| {
1481            this_tx
1482                .try_send(acp_thread::SessionListUpdate::Refresh)
1483                .ok();
1484        });
1485        Self {
1486            thread_store,
1487            updates_tx: tx,
1488            updates_rx: rx,
1489            _subscription: subscription,
1490        }
1491    }
1492
1493    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1494        &self.thread_store
1495    }
1496}
1497
1498impl AgentSessionList for NativeAgentSessionList {
1499    fn list_sessions(
1500        &self,
1501        _request: AgentSessionListRequest,
1502        cx: &mut App,
1503    ) -> Task<Result<AgentSessionListResponse>> {
1504        let sessions = self
1505            .thread_store
1506            .read(cx)
1507            .entries()
1508            .map(|entry| AgentSessionInfo::from(&entry))
1509            .collect();
1510        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1511    }
1512
1513    fn supports_delete(&self) -> bool {
1514        true
1515    }
1516
1517    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1518        self.thread_store
1519            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1520    }
1521
1522    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1523        self.thread_store
1524            .update(cx, |store, cx| store.delete_threads(cx))
1525    }
1526
1527    fn watch(
1528        &self,
1529        _cx: &mut App,
1530    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1531        Some(self.updates_rx.clone())
1532    }
1533
1534    fn notify_refresh(&self) {
1535        self.updates_tx
1536            .try_send(acp_thread::SessionListUpdate::Refresh)
1537            .ok();
1538    }
1539
1540    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1541        self
1542    }
1543}
1544
1545struct NativeAgentSessionTruncate {
1546    thread: Entity<Thread>,
1547    acp_thread: WeakEntity<AcpThread>,
1548}
1549
1550impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1551    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1552        match self.thread.update(cx, |thread, cx| {
1553            thread.truncate(message_id.clone(), cx)?;
1554            Ok(thread.latest_token_usage())
1555        }) {
1556            Ok(usage) => {
1557                self.acp_thread
1558                    .update(cx, |thread, cx| {
1559                        thread.update_token_usage(usage, cx);
1560                    })
1561                    .ok();
1562                Task::ready(Ok(()))
1563            }
1564            Err(error) => Task::ready(Err(error)),
1565        }
1566    }
1567}
1568
1569struct NativeAgentSessionRetry {
1570    connection: NativeAgentConnection,
1571    session_id: acp::SessionId,
1572}
1573
1574impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1575    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1576        self.connection
1577            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1578                thread.update(cx, |thread, cx| thread.resume(cx))
1579            })
1580    }
1581}
1582
1583struct NativeAgentSessionSetTitle {
1584    thread: Entity<Thread>,
1585}
1586
1587impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1588    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1589        self.thread
1590            .update(cx, |thread, cx| thread.set_title(title, cx));
1591        Task::ready(Ok(()))
1592    }
1593}
1594
1595pub struct NativeThreadEnvironment {
1596    agent: WeakEntity<NativeAgent>,
1597    thread: WeakEntity<Thread>,
1598    acp_thread: WeakEntity<AcpThread>,
1599}
1600
1601impl NativeThreadEnvironment {
1602    pub(crate) fn create_subagent_thread(
1603        &self,
1604        label: String,
1605        cx: &mut App,
1606    ) -> Result<Rc<dyn SubagentHandle>> {
1607        let Some(parent_thread_entity) = self.thread.upgrade() else {
1608            anyhow::bail!("Parent thread no longer exists".to_string());
1609        };
1610        let parent_thread = parent_thread_entity.read(cx);
1611        let current_depth = parent_thread.depth();
1612
1613        if current_depth >= MAX_SUBAGENT_DEPTH {
1614            return Err(anyhow!(
1615                "Maximum subagent depth ({}) reached",
1616                MAX_SUBAGENT_DEPTH
1617            ));
1618        }
1619
1620        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1621            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1622            thread.set_title(label.into(), cx);
1623            thread
1624        });
1625
1626        let session_id = subagent_thread.read(cx).id().clone();
1627
1628        let acp_thread = self.agent.update(cx, |agent, cx| {
1629            agent.register_session(subagent_thread.clone(), cx)
1630        })?;
1631
1632        let depth = current_depth + 1;
1633
1634        telemetry::event!(
1635            "Subagent Started",
1636            session = parent_thread_entity.read(cx).id().to_string(),
1637            subagent_session = session_id.to_string(),
1638            depth,
1639            is_resumed = false,
1640        );
1641
1642        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1643    }
1644
1645    pub(crate) fn resume_subagent_thread(
1646        &self,
1647        session_id: acp::SessionId,
1648        cx: &mut App,
1649    ) -> Result<Rc<dyn SubagentHandle>> {
1650        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1651            let session = agent
1652                .sessions
1653                .get(&session_id)
1654                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1655            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1656        })??;
1657
1658        let depth = subagent_thread.read(cx).depth();
1659
1660        if let Some(parent_thread_entity) = self.thread.upgrade() {
1661            telemetry::event!(
1662                "Subagent Started",
1663                session = parent_thread_entity.read(cx).id().to_string(),
1664                subagent_session = session_id.to_string(),
1665                depth,
1666                is_resumed = true,
1667            );
1668        }
1669
1670        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1671    }
1672
1673    fn prompt_subagent(
1674        &self,
1675        session_id: acp::SessionId,
1676        subagent_thread: Entity<Thread>,
1677        acp_thread: Entity<acp_thread::AcpThread>,
1678    ) -> Result<Rc<dyn SubagentHandle>> {
1679        let Some(parent_thread_entity) = self.thread.upgrade() else {
1680            anyhow::bail!("Parent thread no longer exists".to_string());
1681        };
1682        Ok(Rc::new(NativeSubagentHandle::new(
1683            session_id,
1684            subagent_thread,
1685            acp_thread,
1686            parent_thread_entity,
1687        )) as _)
1688    }
1689}
1690
1691impl ThreadEnvironment for NativeThreadEnvironment {
1692    fn create_terminal(
1693        &self,
1694        command: String,
1695        cwd: Option<PathBuf>,
1696        output_byte_limit: Option<u64>,
1697        cx: &mut AsyncApp,
1698    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1699        let task = self.acp_thread.update(cx, |thread, cx| {
1700            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1701        });
1702
1703        let acp_thread = self.acp_thread.clone();
1704        cx.spawn(async move |cx| {
1705            let terminal = task?.await?;
1706
1707            let (drop_tx, drop_rx) = oneshot::channel();
1708            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1709
1710            cx.spawn(async move |cx| {
1711                drop_rx.await.ok();
1712                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1713            })
1714            .detach();
1715
1716            let handle = AcpTerminalHandle {
1717                terminal,
1718                _drop_tx: Some(drop_tx),
1719            };
1720
1721            Ok(Rc::new(handle) as _)
1722        })
1723    }
1724
1725    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1726        self.create_subagent_thread(label, cx)
1727    }
1728
1729    fn resume_subagent(
1730        &self,
1731        session_id: acp::SessionId,
1732        cx: &mut App,
1733    ) -> Result<Rc<dyn SubagentHandle>> {
1734        self.resume_subagent_thread(session_id, cx)
1735    }
1736}
1737
1738#[derive(Debug, Clone)]
1739enum SubagentPromptResult {
1740    Completed,
1741    Cancelled,
1742    ContextWindowWarning,
1743    Error(String),
1744}
1745
1746pub struct NativeSubagentHandle {
1747    session_id: acp::SessionId,
1748    parent_thread: WeakEntity<Thread>,
1749    subagent_thread: Entity<Thread>,
1750    acp_thread: Entity<acp_thread::AcpThread>,
1751}
1752
1753impl NativeSubagentHandle {
1754    fn new(
1755        session_id: acp::SessionId,
1756        subagent_thread: Entity<Thread>,
1757        acp_thread: Entity<acp_thread::AcpThread>,
1758        parent_thread_entity: Entity<Thread>,
1759    ) -> Self {
1760        NativeSubagentHandle {
1761            session_id,
1762            subagent_thread,
1763            parent_thread: parent_thread_entity.downgrade(),
1764            acp_thread,
1765        }
1766    }
1767}
1768
1769impl SubagentHandle for NativeSubagentHandle {
1770    fn id(&self) -> acp::SessionId {
1771        self.session_id.clone()
1772    }
1773
1774    fn num_entries(&self, cx: &App) -> usize {
1775        self.acp_thread.read(cx).entries().len()
1776    }
1777
1778    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1779        let thread = self.subagent_thread.clone();
1780        let acp_thread = self.acp_thread.clone();
1781        let subagent_session_id = self.session_id.clone();
1782        let parent_thread = self.parent_thread.clone();
1783
1784        cx.spawn(async move |cx| {
1785            let (task, _subscription) = cx.update(|cx| {
1786                let ratio_before_prompt = thread
1787                    .read(cx)
1788                    .latest_token_usage()
1789                    .map(|usage| usage.ratio());
1790
1791                parent_thread
1792                    .update(cx, |parent_thread, _cx| {
1793                        parent_thread.register_running_subagent(thread.downgrade())
1794                    })
1795                    .ok();
1796
1797                let task = acp_thread.update(cx, |acp_thread, cx| {
1798                    acp_thread.send(vec![message.into()], cx)
1799                });
1800
1801                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1802                let mut token_limit_tx = Some(token_limit_tx);
1803
1804                let subscription = cx.subscribe(
1805                    &thread,
1806                    move |_thread, event: &TokenUsageUpdated, _cx| {
1807                        if let Some(usage) = &event.0 {
1808                            let old_ratio = ratio_before_prompt
1809                                .clone()
1810                                .unwrap_or(TokenUsageRatio::Normal);
1811                            let new_ratio = usage.ratio();
1812                            if old_ratio == TokenUsageRatio::Normal
1813                                && new_ratio == TokenUsageRatio::Warning
1814                            {
1815                                if let Some(tx) = token_limit_tx.take() {
1816                                    tx.send(()).ok();
1817                                }
1818                            }
1819                        }
1820                    },
1821                );
1822
1823                let wait_for_prompt = cx
1824                    .background_spawn(async move {
1825                        futures::select! {
1826                            response = task.fuse() => match response {
1827                                Ok(Some(response)) => {
1828                                    match response.stop_reason {
1829                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1830                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1831                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1832                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1833                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1834                                    }
1835                                }
1836                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1837                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1838                            },
1839                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1840                        }
1841                    });
1842
1843                (wait_for_prompt, subscription)
1844            });
1845
1846            let result = match task.await {
1847                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1848                    thread
1849                        .last_message()
1850                        .and_then(|message| {
1851                            let content = message.as_agent_message()?
1852                                .content
1853                                .iter()
1854                                .filter_map(|c| match c {
1855                                    AgentMessageContent::Text(text) => Some(text.as_str()),
1856                                    _ => None,
1857                                })
1858                                .join("\n\n");
1859                            if content.is_empty() {
1860                                None
1861                            } else {
1862                                Some( content)
1863                            }
1864                        })
1865                        .context("No response from subagent")
1866                }),
1867                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1868                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1869                SubagentPromptResult::ContextWindowWarning => {
1870                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1871                    Err(anyhow!(
1872                        "The agent is nearing the end of its context window and has been \
1873                         stopped. You can prompt the thread again to have the agent wrap up \
1874                         or hand off its work."
1875                    ))
1876                }
1877            };
1878
1879            parent_thread
1880                .update(cx, |parent_thread, cx| {
1881                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1882                })
1883                .ok();
1884
1885            result
1886        })
1887    }
1888}
1889
1890pub struct AcpTerminalHandle {
1891    terminal: Entity<acp_thread::Terminal>,
1892    _drop_tx: Option<oneshot::Sender<()>>,
1893}
1894
1895impl TerminalHandle for AcpTerminalHandle {
1896    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1897        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1898    }
1899
1900    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1901        Ok(self
1902            .terminal
1903            .read_with(cx, |term, _cx| term.wait_for_exit()))
1904    }
1905
1906    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1907        Ok(self
1908            .terminal
1909            .read_with(cx, |term, cx| term.current_output(cx)))
1910    }
1911
1912    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1913        cx.update(|cx| {
1914            self.terminal.update(cx, |terminal, cx| {
1915                terminal.kill(cx);
1916            });
1917        });
1918        Ok(())
1919    }
1920
1921    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1922        Ok(self
1923            .terminal
1924            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1925    }
1926}
1927
1928#[cfg(test)]
1929mod internal_tests {
1930    use super::*;
1931    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1932    use fs::FakeFs;
1933    use gpui::TestAppContext;
1934    use indoc::formatdoc;
1935    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1936    use language_model::{
1937        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1938    };
1939    use serde_json::json;
1940    use settings::SettingsStore;
1941    use util::{path, rel_path::rel_path};
1942
1943    #[gpui::test]
1944    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1945        init_test(cx);
1946        let fs = FakeFs::new(cx.executor());
1947        fs.insert_tree(
1948            "/",
1949            json!({
1950                "a": {}
1951            }),
1952        )
1953        .await;
1954        let project = Project::test(fs.clone(), [], cx).await;
1955        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1956        let agent = NativeAgent::new(
1957            project.clone(),
1958            thread_store,
1959            Templates::new(),
1960            None,
1961            fs.clone(),
1962            &mut cx.to_async(),
1963        )
1964        .await
1965        .unwrap();
1966        agent.read_with(cx, |agent, cx| {
1967            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1968        });
1969
1970        let worktree = project
1971            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1972            .await
1973            .unwrap();
1974        cx.run_until_parked();
1975        agent.read_with(cx, |agent, cx| {
1976            assert_eq!(
1977                agent.project_context.read(cx).worktrees,
1978                vec![WorktreeContext {
1979                    root_name: "a".into(),
1980                    abs_path: Path::new("/a").into(),
1981                    rules_file: None
1982                }]
1983            )
1984        });
1985
1986        // Creating `/a/.rules` updates the project context.
1987        fs.insert_file("/a/.rules", Vec::new()).await;
1988        cx.run_until_parked();
1989        agent.read_with(cx, |agent, cx| {
1990            let rules_entry = worktree
1991                .read(cx)
1992                .entry_for_path(rel_path(".rules"))
1993                .unwrap();
1994            assert_eq!(
1995                agent.project_context.read(cx).worktrees,
1996                vec![WorktreeContext {
1997                    root_name: "a".into(),
1998                    abs_path: Path::new("/a").into(),
1999                    rules_file: Some(RulesFileContext {
2000                        path_in_worktree: rel_path(".rules").into(),
2001                        text: "".into(),
2002                        project_entry_id: rules_entry.id.to_usize()
2003                    })
2004                }]
2005            )
2006        });
2007    }
2008
2009    #[gpui::test]
2010    async fn test_listing_models(cx: &mut TestAppContext) {
2011        init_test(cx);
2012        let fs = FakeFs::new(cx.executor());
2013        fs.insert_tree("/", json!({ "a": {}  })).await;
2014        let project = Project::test(fs.clone(), [], cx).await;
2015        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2016        let connection = NativeAgentConnection(
2017            NativeAgent::new(
2018                project.clone(),
2019                thread_store,
2020                Templates::new(),
2021                None,
2022                fs.clone(),
2023                &mut cx.to_async(),
2024            )
2025            .await
2026            .unwrap(),
2027        );
2028
2029        // Create a thread/session
2030        let acp_thread = cx
2031            .update(|cx| {
2032                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2033            })
2034            .await
2035            .unwrap();
2036
2037        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2038
2039        let models = cx
2040            .update(|cx| {
2041                connection
2042                    .model_selector(&session_id)
2043                    .unwrap()
2044                    .list_models(cx)
2045            })
2046            .await
2047            .unwrap();
2048
2049        let acp_thread::AgentModelList::Grouped(models) = models else {
2050            panic!("Unexpected model group");
2051        };
2052        assert_eq!(
2053            models,
2054            IndexMap::from_iter([(
2055                AgentModelGroupName("Fake".into()),
2056                vec![AgentModelInfo {
2057                    id: acp::ModelId::new("fake/fake"),
2058                    name: "Fake".into(),
2059                    description: None,
2060                    icon: Some(acp_thread::AgentModelIcon::Named(
2061                        ui::IconName::ZedAssistant
2062                    )),
2063                    is_latest: false,
2064                    cost: None,
2065                }]
2066            )])
2067        );
2068    }
2069
2070    #[gpui::test]
2071    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2072        init_test(cx);
2073        let fs = FakeFs::new(cx.executor());
2074        fs.create_dir(paths::settings_file().parent().unwrap())
2075            .await
2076            .unwrap();
2077        fs.insert_file(
2078            paths::settings_file(),
2079            json!({
2080                "agent": {
2081                    "default_model": {
2082                        "provider": "foo",
2083                        "model": "bar"
2084                    }
2085                }
2086            })
2087            .to_string()
2088            .into_bytes(),
2089        )
2090        .await;
2091        let project = Project::test(fs.clone(), [], cx).await;
2092
2093        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2094
2095        // Create the agent and connection
2096        let agent = NativeAgent::new(
2097            project.clone(),
2098            thread_store,
2099            Templates::new(),
2100            None,
2101            fs.clone(),
2102            &mut cx.to_async(),
2103        )
2104        .await
2105        .unwrap();
2106        let connection = NativeAgentConnection(agent.clone());
2107
2108        // Create a thread/session
2109        let acp_thread = cx
2110            .update(|cx| {
2111                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2112            })
2113            .await
2114            .unwrap();
2115
2116        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2117
2118        // Select a model
2119        let selector = connection.model_selector(&session_id).unwrap();
2120        let model_id = acp::ModelId::new("fake/fake");
2121        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2122            .await
2123            .unwrap();
2124
2125        // Verify the thread has the selected model
2126        agent.read_with(cx, |agent, _| {
2127            let session = agent.sessions.get(&session_id).unwrap();
2128            session.thread.read_with(cx, |thread, _| {
2129                assert_eq!(thread.model().unwrap().id().0, "fake");
2130            });
2131        });
2132
2133        cx.run_until_parked();
2134
2135        // Verify settings file was updated
2136        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2137        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2138
2139        // Check that the agent settings contain the selected model
2140        assert_eq!(
2141            settings_json["agent"]["default_model"]["model"],
2142            json!("fake")
2143        );
2144        assert_eq!(
2145            settings_json["agent"]["default_model"]["provider"],
2146            json!("fake")
2147        );
2148
2149        // Register a thinking model and select it.
2150        cx.update(|cx| {
2151            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2152                "fake-corp",
2153                "fake-thinking",
2154                "Fake Thinking",
2155                true,
2156            ));
2157            let thinking_provider = Arc::new(
2158                FakeLanguageModelProvider::new(
2159                    LanguageModelProviderId::from("fake-corp".to_string()),
2160                    LanguageModelProviderName::from("Fake Corp".to_string()),
2161                )
2162                .with_models(vec![thinking_model]),
2163            );
2164            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2165                registry.register_provider(thinking_provider, cx);
2166            });
2167        });
2168        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2169
2170        let selector = connection.model_selector(&session_id).unwrap();
2171        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2172            .await
2173            .unwrap();
2174        cx.run_until_parked();
2175
2176        // Verify enable_thinking was written to settings as true.
2177        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2178        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2179        assert_eq!(
2180            settings_json["agent"]["default_model"]["enable_thinking"],
2181            json!(true),
2182            "selecting a thinking model should persist enable_thinking: true to settings"
2183        );
2184    }
2185
2186    #[gpui::test]
2187    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2188        init_test(cx);
2189        let fs = FakeFs::new(cx.executor());
2190        fs.create_dir(paths::settings_file().parent().unwrap())
2191            .await
2192            .unwrap();
2193        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2194        let project = Project::test(fs.clone(), [], cx).await;
2195
2196        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2197        let agent = NativeAgent::new(
2198            project.clone(),
2199            thread_store,
2200            Templates::new(),
2201            None,
2202            fs.clone(),
2203            &mut cx.to_async(),
2204        )
2205        .await
2206        .unwrap();
2207        let connection = NativeAgentConnection(agent.clone());
2208
2209        let acp_thread = cx
2210            .update(|cx| {
2211                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2212            })
2213            .await
2214            .unwrap();
2215        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2216
2217        // Register a second provider with a thinking model.
2218        cx.update(|cx| {
2219            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2220                "fake-corp",
2221                "fake-thinking",
2222                "Fake Thinking",
2223                true,
2224            ));
2225            let thinking_provider = Arc::new(
2226                FakeLanguageModelProvider::new(
2227                    LanguageModelProviderId::from("fake-corp".to_string()),
2228                    LanguageModelProviderName::from("Fake Corp".to_string()),
2229                )
2230                .with_models(vec![thinking_model]),
2231            );
2232            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2233                registry.register_provider(thinking_provider, cx);
2234            });
2235        });
2236        // Refresh the agent's model list so it picks up the new provider.
2237        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2238
2239        // Thread starts with thinking_enabled = false (the default).
2240        agent.read_with(cx, |agent, _| {
2241            let session = agent.sessions.get(&session_id).unwrap();
2242            session.thread.read_with(cx, |thread, _| {
2243                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2244            });
2245        });
2246
2247        // Select the thinking model via select_model.
2248        let selector = connection.model_selector(&session_id).unwrap();
2249        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2250            .await
2251            .unwrap();
2252
2253        // select_model should have enabled thinking based on the model's supports_thinking().
2254        agent.read_with(cx, |agent, _| {
2255            let session = agent.sessions.get(&session_id).unwrap();
2256            session.thread.read_with(cx, |thread, _| {
2257                assert!(
2258                    thread.thinking_enabled(),
2259                    "select_model should enable thinking when model supports it"
2260                );
2261            });
2262        });
2263
2264        // Switch back to the non-thinking model.
2265        let selector = connection.model_selector(&session_id).unwrap();
2266        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2267            .await
2268            .unwrap();
2269
2270        // select_model should have disabled thinking.
2271        agent.read_with(cx, |agent, _| {
2272            let session = agent.sessions.get(&session_id).unwrap();
2273            session.thread.read_with(cx, |thread, _| {
2274                assert!(
2275                    !thread.thinking_enabled(),
2276                    "select_model should disable thinking when model does not support it"
2277                );
2278            });
2279        });
2280    }
2281
2282    #[gpui::test]
2283    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2284        init_test(cx);
2285        let fs = FakeFs::new(cx.executor());
2286        fs.insert_tree("/", json!({ "a": {} })).await;
2287        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2288        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2289        let agent = NativeAgent::new(
2290            project.clone(),
2291            thread_store.clone(),
2292            Templates::new(),
2293            None,
2294            fs.clone(),
2295            &mut cx.to_async(),
2296        )
2297        .await
2298        .unwrap();
2299        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2300
2301        // Register a thinking model.
2302        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2303            "fake-corp",
2304            "fake-thinking",
2305            "Fake Thinking",
2306            true,
2307        ));
2308        let thinking_provider = Arc::new(
2309            FakeLanguageModelProvider::new(
2310                LanguageModelProviderId::from("fake-corp".to_string()),
2311                LanguageModelProviderName::from("Fake Corp".to_string()),
2312            )
2313            .with_models(vec![thinking_model.clone()]),
2314        );
2315        cx.update(|cx| {
2316            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2317                registry.register_provider(thinking_provider, cx);
2318            });
2319        });
2320        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2321
2322        // Create a thread and select the thinking model.
2323        let acp_thread = cx
2324            .update(|cx| {
2325                connection
2326                    .clone()
2327                    .new_session(project.clone(), Path::new("/a"), cx)
2328            })
2329            .await
2330            .unwrap();
2331        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2332
2333        let selector = connection.model_selector(&session_id).unwrap();
2334        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2335            .await
2336            .unwrap();
2337
2338        // Verify thinking is enabled after selecting the thinking model.
2339        let thread = agent.read_with(cx, |agent, _| {
2340            agent.sessions.get(&session_id).unwrap().thread.clone()
2341        });
2342        thread.read_with(cx, |thread, _| {
2343            assert!(
2344                thread.thinking_enabled(),
2345                "thinking should be enabled after selecting thinking model"
2346            );
2347        });
2348
2349        // Send a message so the thread gets persisted.
2350        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2351        let send = cx.foreground_executor().spawn(send);
2352        cx.run_until_parked();
2353
2354        thinking_model.send_last_completion_stream_text_chunk("Response.");
2355        thinking_model.end_last_completion_stream();
2356
2357        send.await.unwrap();
2358        cx.run_until_parked();
2359
2360        // Close the session so it can be reloaded from disk.
2361        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2362            .await
2363            .unwrap();
2364        drop(thread);
2365        drop(acp_thread);
2366        agent.read_with(cx, |agent, _| {
2367            assert!(agent.sessions.is_empty());
2368        });
2369
2370        // Reload the thread and verify thinking_enabled is still true.
2371        let reloaded_acp_thread = agent
2372            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2373            .await
2374            .unwrap();
2375        let reloaded_thread = agent.read_with(cx, |agent, _| {
2376            agent.sessions.get(&session_id).unwrap().thread.clone()
2377        });
2378        reloaded_thread.read_with(cx, |thread, _| {
2379            assert!(
2380                thread.thinking_enabled(),
2381                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2382            );
2383        });
2384
2385        drop(reloaded_acp_thread);
2386    }
2387
2388    #[gpui::test]
2389    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2390        init_test(cx);
2391        let fs = FakeFs::new(cx.executor());
2392        fs.insert_tree("/", json!({ "a": {} })).await;
2393        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2394        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2395        let agent = NativeAgent::new(
2396            project.clone(),
2397            thread_store.clone(),
2398            Templates::new(),
2399            None,
2400            fs.clone(),
2401            &mut cx.to_async(),
2402        )
2403        .await
2404        .unwrap();
2405        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2406
2407        // Register a model where id() != name(), like real Anthropic models
2408        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2409        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2410            "fake-corp",
2411            "custom-model-id",
2412            "Custom Model Display Name",
2413            false,
2414        ));
2415        let provider = Arc::new(
2416            FakeLanguageModelProvider::new(
2417                LanguageModelProviderId::from("fake-corp".to_string()),
2418                LanguageModelProviderName::from("Fake Corp".to_string()),
2419            )
2420            .with_models(vec![model.clone()]),
2421        );
2422        cx.update(|cx| {
2423            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2424                registry.register_provider(provider, cx);
2425            });
2426        });
2427        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2428
2429        // Create a thread and select the model.
2430        let acp_thread = cx
2431            .update(|cx| {
2432                connection
2433                    .clone()
2434                    .new_session(project.clone(), Path::new("/a"), cx)
2435            })
2436            .await
2437            .unwrap();
2438        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2439
2440        let selector = connection.model_selector(&session_id).unwrap();
2441        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2442            .await
2443            .unwrap();
2444
2445        let thread = agent.read_with(cx, |agent, _| {
2446            agent.sessions.get(&session_id).unwrap().thread.clone()
2447        });
2448        thread.read_with(cx, |thread, _| {
2449            assert_eq!(
2450                thread.model().unwrap().id().0.as_ref(),
2451                "custom-model-id",
2452                "model should be set before persisting"
2453            );
2454        });
2455
2456        // Send a message so the thread gets persisted.
2457        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2458        let send = cx.foreground_executor().spawn(send);
2459        cx.run_until_parked();
2460
2461        model.send_last_completion_stream_text_chunk("Response.");
2462        model.end_last_completion_stream();
2463
2464        send.await.unwrap();
2465        cx.run_until_parked();
2466
2467        // Close the session so it can be reloaded from disk.
2468        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2469            .await
2470            .unwrap();
2471        drop(thread);
2472        drop(acp_thread);
2473        agent.read_with(cx, |agent, _| {
2474            assert!(agent.sessions.is_empty());
2475        });
2476
2477        // Reload the thread and verify the model was preserved.
2478        let reloaded_acp_thread = agent
2479            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2480            .await
2481            .unwrap();
2482        let reloaded_thread = agent.read_with(cx, |agent, _| {
2483            agent.sessions.get(&session_id).unwrap().thread.clone()
2484        });
2485        reloaded_thread.read_with(cx, |thread, _| {
2486            let reloaded_model = thread
2487                .model()
2488                .expect("model should be present after reload");
2489            assert_eq!(
2490                reloaded_model.id().0.as_ref(),
2491                "custom-model-id",
2492                "reloaded thread should have the same model, not fall back to the default"
2493            );
2494        });
2495
2496        drop(reloaded_acp_thread);
2497    }
2498
2499    #[gpui::test]
2500    async fn test_save_load_thread(cx: &mut TestAppContext) {
2501        init_test(cx);
2502        let fs = FakeFs::new(cx.executor());
2503        fs.insert_tree(
2504            "/",
2505            json!({
2506                "a": {
2507                    "b.md": "Lorem"
2508                }
2509            }),
2510        )
2511        .await;
2512        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2513        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2514        let agent = NativeAgent::new(
2515            project.clone(),
2516            thread_store.clone(),
2517            Templates::new(),
2518            None,
2519            fs.clone(),
2520            &mut cx.to_async(),
2521        )
2522        .await
2523        .unwrap();
2524        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2525
2526        let acp_thread = cx
2527            .update(|cx| {
2528                connection
2529                    .clone()
2530                    .new_session(project.clone(), Path::new(""), cx)
2531            })
2532            .await
2533            .unwrap();
2534        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2535        let thread = agent.read_with(cx, |agent, _| {
2536            agent.sessions.get(&session_id).unwrap().thread.clone()
2537        });
2538
2539        // Ensure empty threads are not saved, even if they get mutated.
2540        let model = Arc::new(FakeLanguageModel::default());
2541        let summary_model = Arc::new(FakeLanguageModel::default());
2542        thread.update(cx, |thread, cx| {
2543            thread.set_model(model.clone(), cx);
2544            thread.set_summarization_model(Some(summary_model.clone()), cx);
2545        });
2546        cx.run_until_parked();
2547        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2548
2549        let send = acp_thread.update(cx, |thread, cx| {
2550            thread.send(
2551                vec![
2552                    "What does ".into(),
2553                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2554                        "b.md",
2555                        MentionUri::File {
2556                            abs_path: path!("/a/b.md").into(),
2557                        }
2558                        .to_uri()
2559                        .to_string(),
2560                    )),
2561                    " mean?".into(),
2562                ],
2563                cx,
2564            )
2565        });
2566        let send = cx.foreground_executor().spawn(send);
2567        cx.run_until_parked();
2568
2569        model.send_last_completion_stream_text_chunk("Lorem.");
2570        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2571            language_model::TokenUsage {
2572                input_tokens: 150,
2573                output_tokens: 75,
2574                ..Default::default()
2575            },
2576        ));
2577        model.end_last_completion_stream();
2578        cx.run_until_parked();
2579        summary_model
2580            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2581        summary_model.end_last_completion_stream();
2582
2583        send.await.unwrap();
2584        let uri = MentionUri::File {
2585            abs_path: path!("/a/b.md").into(),
2586        }
2587        .to_uri();
2588        acp_thread.read_with(cx, |thread, cx| {
2589            assert_eq!(
2590                thread.to_markdown(cx),
2591                formatdoc! {"
2592                    ## User
2593
2594                    What does [@b.md]({uri}) mean?
2595
2596                    ## Assistant
2597
2598                    Lorem.
2599
2600                "}
2601            )
2602        });
2603
2604        cx.run_until_parked();
2605
2606        // Set a draft prompt with rich content blocks before saving.
2607        let draft_blocks = vec![
2608            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2609            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2610            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2611        ];
2612        acp_thread.update(cx, |thread, _cx| {
2613            thread.set_draft_prompt(Some(draft_blocks.clone()));
2614        });
2615        thread.update(cx, |thread, _cx| {
2616            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2617                item_ix: 5,
2618                offset_in_item: gpui::px(12.5),
2619            }));
2620        });
2621        thread.update(cx, |_thread, cx| cx.notify());
2622        cx.run_until_parked();
2623
2624        // Close the session so it can be reloaded from disk.
2625        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2626            .await
2627            .unwrap();
2628        drop(thread);
2629        drop(acp_thread);
2630        agent.read_with(cx, |agent, _| {
2631            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2632        });
2633
2634        // Ensure the thread can be reloaded from disk.
2635        assert_eq!(
2636            thread_entries(&thread_store, cx),
2637            vec![(
2638                session_id.clone(),
2639                format!("Explaining {}", path!("/a/b.md"))
2640            )]
2641        );
2642        let acp_thread = agent
2643            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2644            .await
2645            .unwrap();
2646        acp_thread.read_with(cx, |thread, cx| {
2647            assert_eq!(
2648                thread.to_markdown(cx),
2649                formatdoc! {"
2650                    ## User
2651
2652                    What does [@b.md]({uri}) mean?
2653
2654                    ## Assistant
2655
2656                    Lorem.
2657
2658                "}
2659            )
2660        });
2661
2662        // Ensure the draft prompt with rich content blocks survived the round-trip.
2663        acp_thread.read_with(cx, |thread, _| {
2664            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2665        });
2666
2667        // Ensure token usage survived the round-trip.
2668        acp_thread.read_with(cx, |thread, _| {
2669            let usage = thread
2670                .token_usage()
2671                .expect("token usage should be restored after reload");
2672            assert_eq!(usage.input_tokens, 150);
2673            assert_eq!(usage.output_tokens, 75);
2674        });
2675
2676        // Ensure scroll position survived the round-trip.
2677        acp_thread.read_with(cx, |thread, _| {
2678            let scroll = thread
2679                .ui_scroll_position()
2680                .expect("scroll position should be restored after reload");
2681            assert_eq!(scroll.item_ix, 5);
2682            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2683        });
2684    }
2685
2686    fn thread_entries(
2687        thread_store: &Entity<ThreadStore>,
2688        cx: &mut TestAppContext,
2689    ) -> Vec<(acp::SessionId, String)> {
2690        thread_store.read_with(cx, |store, _| {
2691            store
2692                .entries()
2693                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2694                .collect::<Vec<_>>()
2695        })
2696    }
2697
2698    fn init_test(cx: &mut TestAppContext) {
2699        env_logger::try_init().ok();
2700        cx.update(|cx| {
2701            let settings_store = SettingsStore::test(cx);
2702            cx.set_global(settings_store);
2703
2704            LanguageModelRegistry::test(cx);
2705        });
2706    }
2707}
2708
2709fn mcp_message_content_to_acp_content_block(
2710    content: context_server::types::MessageContent,
2711) -> acp::ContentBlock {
2712    match content {
2713        context_server::types::MessageContent::Text {
2714            text,
2715            annotations: _,
2716        } => text.into(),
2717        context_server::types::MessageContent::Image {
2718            data,
2719            mime_type,
2720            annotations: _,
2721        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2722        context_server::types::MessageContent::Audio {
2723            data,
2724            mime_type,
2725            annotations: _,
2726        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2727        context_server::types::MessageContent::Resource {
2728            resource,
2729            annotations: _,
2730        } => {
2731            let mut link =
2732                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2733            if let Some(mime_type) = resource.mime_type {
2734                link = link.mime_type(mime_type);
2735            }
2736            acp::ContentBlock::ResourceLink(link)
2737        }
2738    }
2739}