agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use shell_command_parser::extract_commands;
  20pub use templates::*;
  21pub use thread::*;
  22pub use thread_store::*;
  23pub use tool_permissions::*;
  24pub use tools::*;
  25
  26use acp_thread::{
  27    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  28    AgentSessionListResponse, UserMessageId,
  29};
  30use agent_client_protocol as acp;
  31use anyhow::{Context as _, Result, anyhow};
  32use chrono::{DateTime, Utc};
  33use collections::{HashMap, HashSet, IndexMap};
  34use fs::Fs;
  35use futures::channel::{mpsc, oneshot};
  36use futures::future::Shared;
  37use futures::{FutureExt as _, StreamExt as _, future};
  38use gpui::{
  39    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  40};
  41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  42use project::{Project, ProjectItem, ProjectPath, Worktree};
  43use prompt_store::{
  44    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  45    WorktreeContext,
  46};
  47use serde::{Deserialize, Serialize};
  48use settings::{LanguageModelSelection, update_settings_file};
  49use std::any::Any;
  50use std::path::{Path, PathBuf};
  51use std::rc::Rc;
  52use std::sync::Arc;
  53use std::time::Duration;
  54use util::ResultExt;
  55use util::rel_path::RelPath;
  56
  57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  58pub struct ProjectSnapshot {
  59    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  60    pub timestamp: DateTime<Utc>,
  61}
  62
  63pub struct RulesLoadingError {
  64    pub message: SharedString,
  65}
  66
  67/// Holds both the internal Thread and the AcpThread for a session
  68struct Session {
  69    /// The internal thread that processes messages
  70    thread: Entity<Thread>,
  71    /// The ACP thread that handles protocol communication
  72    acp_thread: Entity<acp_thread::AcpThread>,
  73    pending_save: Task<()>,
  74    _subscriptions: Vec<Subscription>,
  75}
  76
  77pub struct LanguageModels {
  78    /// Access language model by ID
  79    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  80    /// Cached list for returning language model information
  81    model_list: acp_thread::AgentModelList,
  82    refresh_models_rx: watch::Receiver<()>,
  83    refresh_models_tx: watch::Sender<()>,
  84    _authenticate_all_providers_task: Task<()>,
  85}
  86
  87impl LanguageModels {
  88    fn new(cx: &mut App) -> Self {
  89        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  90
  91        let mut this = Self {
  92            models: HashMap::default(),
  93            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  94            refresh_models_rx,
  95            refresh_models_tx,
  96            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  97        };
  98        this.refresh_list(cx);
  99        this
 100    }
 101
 102    fn refresh_list(&mut self, cx: &App) {
 103        let providers = LanguageModelRegistry::global(cx)
 104            .read(cx)
 105            .visible_providers()
 106            .into_iter()
 107            .filter(|provider| provider.is_authenticated(cx))
 108            .collect::<Vec<_>>();
 109
 110        let mut language_model_list = IndexMap::default();
 111        let mut recommended_models = HashSet::default();
 112
 113        let mut recommended = Vec::new();
 114        for provider in &providers {
 115            for model in provider.recommended_models(cx) {
 116                recommended_models.insert((model.provider_id(), model.id()));
 117                recommended.push(Self::map_language_model_to_info(&model, provider));
 118            }
 119        }
 120        if !recommended.is_empty() {
 121            language_model_list.insert(
 122                acp_thread::AgentModelGroupName("Recommended".into()),
 123                recommended,
 124            );
 125        }
 126
 127        let mut models = HashMap::default();
 128        for provider in providers {
 129            let mut provider_models = Vec::new();
 130            for model in provider.provided_models(cx) {
 131                let model_info = Self::map_language_model_to_info(&model, &provider);
 132                let model_id = model_info.id.clone();
 133                provider_models.push(model_info);
 134                models.insert(model_id, model);
 135            }
 136            if !provider_models.is_empty() {
 137                language_model_list.insert(
 138                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 139                    provider_models,
 140                );
 141            }
 142        }
 143
 144        self.models = models;
 145        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 146        self.refresh_models_tx.send(()).ok();
 147    }
 148
 149    fn watch(&self) -> watch::Receiver<()> {
 150        self.refresh_models_rx.clone()
 151    }
 152
 153    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 154        self.models.get(model_id).cloned()
 155    }
 156
 157    fn map_language_model_to_info(
 158        model: &Arc<dyn LanguageModel>,
 159        provider: &Arc<dyn LanguageModelProvider>,
 160    ) -> acp_thread::AgentModelInfo {
 161        acp_thread::AgentModelInfo {
 162            id: Self::model_id(model),
 163            name: model.name().0,
 164            description: None,
 165            icon: Some(match provider.icon() {
 166                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 167                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 168            }),
 169            is_latest: model.is_latest(),
 170            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 171        }
 172    }
 173
 174    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 175        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 176    }
 177
 178    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 179        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 180            .read(cx)
 181            .visible_providers()
 182            .iter()
 183            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 184            .collect::<Vec<_>>();
 185
 186        cx.background_spawn(async move {
 187            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 188                if let Err(err) = authenticate_task.await {
 189                    match err {
 190                        language_model::AuthenticateError::CredentialsNotFound => {
 191                            // Since we're authenticating these providers in the
 192                            // background for the purposes of populating the
 193                            // language selector, we don't care about providers
 194                            // where the credentials are not found.
 195                        }
 196                        language_model::AuthenticateError::ConnectionRefused => {
 197                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 198                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 199                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 200                        }
 201                        _ => {
 202                            // Some providers have noisy failure states that we
 203                            // don't want to spam the logs with every time the
 204                            // language model selector is initialized.
 205                            //
 206                            // Ideally these should have more clear failure modes
 207                            // that we know are safe to ignore here, like what we do
 208                            // with `CredentialsNotFound` above.
 209                            match provider_id.0.as_ref() {
 210                                "lmstudio" | "ollama" => {
 211                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 212                                    //
 213                                    // These fail noisily, so we don't log them.
 214                                }
 215                                "copilot_chat" => {
 216                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 217                                }
 218                                _ => {
 219                                    log::error!(
 220                                        "Failed to authenticate provider: {}: {err:#}",
 221                                        provider_name.0
 222                                    );
 223                                }
 224                            }
 225                        }
 226                    }
 227                }
 228            }
 229        })
 230    }
 231}
 232
 233pub struct NativeAgent {
 234    /// Session ID -> Session mapping
 235    sessions: HashMap<acp::SessionId, Session>,
 236    thread_store: Entity<ThreadStore>,
 237    /// Shared project context for all threads
 238    project_context: Entity<ProjectContext>,
 239    project_context_needs_refresh: watch::Sender<()>,
 240    _maintain_project_context: Task<Result<()>>,
 241    context_server_registry: Entity<ContextServerRegistry>,
 242    /// Shared templates for all threads
 243    templates: Arc<Templates>,
 244    /// Cached model information
 245    models: LanguageModels,
 246    project: Entity<Project>,
 247    prompt_store: Option<Entity<PromptStore>>,
 248    fs: Arc<dyn Fs>,
 249    _subscriptions: Vec<Subscription>,
 250}
 251
 252impl NativeAgent {
 253    pub async fn new(
 254        project: Entity<Project>,
 255        thread_store: Entity<ThreadStore>,
 256        templates: Arc<Templates>,
 257        prompt_store: Option<Entity<PromptStore>>,
 258        fs: Arc<dyn Fs>,
 259        cx: &mut AsyncApp,
 260    ) -> Result<Entity<NativeAgent>> {
 261        log::debug!("Creating new NativeAgent");
 262
 263        let project_context = cx
 264            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 265            .await;
 266
 267        Ok(cx.new(|cx| {
 268            let context_server_store = project.read(cx).context_server_store();
 269            let context_server_registry =
 270                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 271
 272            let mut subscriptions = vec![
 273                cx.subscribe(&project, Self::handle_project_event),
 274                cx.subscribe(
 275                    &LanguageModelRegistry::global(cx),
 276                    Self::handle_models_updated_event,
 277                ),
 278                cx.subscribe(
 279                    &context_server_store,
 280                    Self::handle_context_server_store_updated,
 281                ),
 282                cx.subscribe(
 283                    &context_server_registry,
 284                    Self::handle_context_server_registry_event,
 285                ),
 286            ];
 287            if let Some(prompt_store) = prompt_store.as_ref() {
 288                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 289            }
 290
 291            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 292                watch::channel(());
 293            Self {
 294                sessions: HashMap::default(),
 295                thread_store,
 296                project_context: cx.new(|_| project_context),
 297                project_context_needs_refresh: project_context_needs_refresh_tx,
 298                _maintain_project_context: cx.spawn(async move |this, cx| {
 299                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 300                }),
 301                context_server_registry,
 302                templates,
 303                models: LanguageModels::new(cx),
 304                project,
 305                prompt_store,
 306                fs,
 307                _subscriptions: subscriptions,
 308            }
 309        }))
 310    }
 311
 312    fn new_session(
 313        &mut self,
 314        project: Entity<Project>,
 315        cx: &mut Context<Self>,
 316    ) -> Entity<AcpThread> {
 317        // Create Thread
 318        // Fetch default model from registry settings
 319        let registry = LanguageModelRegistry::read_global(cx);
 320        // Log available models for debugging
 321        let available_count = registry.available_models(cx).count();
 322        log::debug!("Total available models: {}", available_count);
 323
 324        let default_model = registry.default_model().and_then(|default_model| {
 325            self.models
 326                .model_from_id(&LanguageModels::model_id(&default_model.model))
 327        });
 328        let thread = cx.new(|cx| {
 329            Thread::new(
 330                project.clone(),
 331                self.project_context.clone(),
 332                self.context_server_registry.clone(),
 333                self.templates.clone(),
 334                default_model,
 335                cx,
 336            )
 337        });
 338
 339        self.register_session(thread, None, cx)
 340    }
 341
 342    fn register_session(
 343        &mut self,
 344        thread_handle: Entity<Thread>,
 345        allowed_tool_names: Option<Vec<&str>>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let project = thread.project.clone();
 355        let action_log = thread.action_log.clone();
 356        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 357        let acp_thread = cx.new(|cx| {
 358            acp_thread::AcpThread::new(
 359                parent_session_id,
 360                title,
 361                connection,
 362                project.clone(),
 363                action_log.clone(),
 364                session_id.clone(),
 365                prompt_capabilities_rx,
 366                cx,
 367            )
 368        });
 369
 370        let registry = LanguageModelRegistry::read_global(cx);
 371        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 372
 373        let weak = cx.weak_entity();
 374        thread_handle.update(cx, |thread, cx| {
 375            thread.set_summarization_model(summarization_model, cx);
 376            thread.add_default_tools(
 377                allowed_tool_names,
 378                Rc::new(NativeThreadEnvironment {
 379                    acp_thread: acp_thread.downgrade(),
 380                    agent: weak,
 381                }) as _,
 382                cx,
 383            )
 384        });
 385
 386        let subscriptions = vec![
 387            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 388            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 389            cx.observe(&thread_handle, move |this, thread, cx| {
 390                this.save_thread(thread, cx)
 391            }),
 392        ];
 393
 394        self.sessions.insert(
 395            session_id,
 396            Session {
 397                thread: thread_handle,
 398                acp_thread: acp_thread.clone(),
 399                _subscriptions: subscriptions,
 400                pending_save: Task::ready(()),
 401            },
 402        );
 403
 404        self.update_available_commands(cx);
 405
 406        acp_thread
 407    }
 408
 409    pub fn models(&self) -> &LanguageModels {
 410        &self.models
 411    }
 412
 413    async fn maintain_project_context(
 414        this: WeakEntity<Self>,
 415        mut needs_refresh: watch::Receiver<()>,
 416        cx: &mut AsyncApp,
 417    ) -> Result<()> {
 418        while needs_refresh.changed().await.is_ok() {
 419            let project_context = this
 420                .update(cx, |this, cx| {
 421                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 422                })?
 423                .await;
 424            this.update(cx, |this, cx| {
 425                this.project_context = cx.new(|_| project_context);
 426            })?;
 427        }
 428
 429        Ok(())
 430    }
 431
 432    fn build_project_context(
 433        project: &Entity<Project>,
 434        prompt_store: Option<&Entity<PromptStore>>,
 435        cx: &mut App,
 436    ) -> Task<ProjectContext> {
 437        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 438        let worktree_tasks = worktrees
 439            .into_iter()
 440            .map(|worktree| {
 441                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 442            })
 443            .collect::<Vec<_>>();
 444        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 445            prompt_store.read_with(cx, |prompt_store, cx| {
 446                let prompts = prompt_store.default_prompt_metadata();
 447                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 448                    let contents = prompt_store.load(prompt_metadata.id, cx);
 449                    async move { (contents.await, prompt_metadata) }
 450                });
 451                cx.background_spawn(future::join_all(load_tasks))
 452            })
 453        } else {
 454            Task::ready(vec![])
 455        };
 456
 457        cx.spawn(async move |_cx| {
 458            let (worktrees, default_user_rules) =
 459                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 460
 461            let worktrees = worktrees
 462                .into_iter()
 463                .map(|(worktree, _rules_error)| {
 464                    // TODO: show error message
 465                    // if let Some(rules_error) = rules_error {
 466                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 467                    // }
 468                    worktree
 469                })
 470                .collect::<Vec<_>>();
 471
 472            let default_user_rules = default_user_rules
 473                .into_iter()
 474                .flat_map(|(contents, prompt_metadata)| match contents {
 475                    Ok(contents) => Some(UserRulesContext {
 476                        uuid: prompt_metadata.id.as_user()?,
 477                        title: prompt_metadata.title.map(|title| title.to_string()),
 478                        contents,
 479                    }),
 480                    Err(_err) => {
 481                        // TODO: show error message
 482                        // this.update(cx, |_, cx| {
 483                        //     cx.emit(RulesLoadingError {
 484                        //         message: format!("{err:?}").into(),
 485                        //     });
 486                        // })
 487                        // .ok();
 488                        None
 489                    }
 490                })
 491                .collect::<Vec<_>>();
 492
 493            ProjectContext::new(worktrees, default_user_rules)
 494        })
 495    }
 496
 497    fn load_worktree_info_for_system_prompt(
 498        worktree: Entity<Worktree>,
 499        project: Entity<Project>,
 500        cx: &mut App,
 501    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 502        let tree = worktree.read(cx);
 503        let root_name = tree.root_name_str().into();
 504        let abs_path = tree.abs_path();
 505
 506        let mut context = WorktreeContext {
 507            root_name,
 508            abs_path,
 509            rules_file: None,
 510        };
 511
 512        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 513        let Some(rules_task) = rules_task else {
 514            return Task::ready((context, None));
 515        };
 516
 517        cx.spawn(async move |_| {
 518            let (rules_file, rules_file_error) = match rules_task.await {
 519                Ok(rules_file) => (Some(rules_file), None),
 520                Err(err) => (
 521                    None,
 522                    Some(RulesLoadingError {
 523                        message: format!("{err}").into(),
 524                    }),
 525                ),
 526            };
 527            context.rules_file = rules_file;
 528            (context, rules_file_error)
 529        })
 530    }
 531
 532    fn load_worktree_rules_file(
 533        worktree: Entity<Worktree>,
 534        project: Entity<Project>,
 535        cx: &mut App,
 536    ) -> Option<Task<Result<RulesFileContext>>> {
 537        let worktree = worktree.read(cx);
 538        let worktree_id = worktree.id();
 539        let selected_rules_file = RULES_FILE_NAMES
 540            .into_iter()
 541            .filter_map(|name| {
 542                worktree
 543                    .entry_for_path(RelPath::unix(name).unwrap())
 544                    .filter(|entry| entry.is_file())
 545                    .map(|entry| entry.path.clone())
 546            })
 547            .next();
 548
 549        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 550        // supported. This doesn't seem to occur often in GitHub repositories.
 551        selected_rules_file.map(|path_in_worktree| {
 552            let project_path = ProjectPath {
 553                worktree_id,
 554                path: path_in_worktree.clone(),
 555            };
 556            let buffer_task =
 557                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 558            let rope_task = cx.spawn(async move |cx| {
 559                let buffer = buffer_task.await?;
 560                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 561                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 562                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 563                })?;
 564                anyhow::Ok((project_entry_id, rope))
 565            });
 566            // Build a string from the rope on a background thread.
 567            cx.background_spawn(async move {
 568                let (project_entry_id, rope) = rope_task.await?;
 569                anyhow::Ok(RulesFileContext {
 570                    path_in_worktree,
 571                    text: rope.to_string().trim().to_string(),
 572                    project_entry_id: project_entry_id.to_usize(),
 573                })
 574            })
 575        })
 576    }
 577
 578    fn handle_thread_title_updated(
 579        &mut self,
 580        thread: Entity<Thread>,
 581        _: &TitleUpdated,
 582        cx: &mut Context<Self>,
 583    ) {
 584        let session_id = thread.read(cx).id();
 585        let Some(session) = self.sessions.get(session_id) else {
 586            return;
 587        };
 588        let thread = thread.downgrade();
 589        let acp_thread = session.acp_thread.downgrade();
 590        cx.spawn(async move |_, cx| {
 591            let title = thread.read_with(cx, |thread, _| thread.title())?;
 592            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 593            task.await
 594        })
 595        .detach_and_log_err(cx);
 596    }
 597
 598    fn handle_thread_token_usage_updated(
 599        &mut self,
 600        thread: Entity<Thread>,
 601        usage: &TokenUsageUpdated,
 602        cx: &mut Context<Self>,
 603    ) {
 604        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 605            return;
 606        };
 607        session.acp_thread.update(cx, |acp_thread, cx| {
 608            acp_thread.update_token_usage(usage.0.clone(), cx);
 609        });
 610    }
 611
 612    fn handle_project_event(
 613        &mut self,
 614        _project: Entity<Project>,
 615        event: &project::Event,
 616        _cx: &mut Context<Self>,
 617    ) {
 618        match event {
 619            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 620                self.project_context_needs_refresh.send(()).ok();
 621            }
 622            project::Event::WorktreeUpdatedEntries(_, items) => {
 623                if items.iter().any(|(path, _, _)| {
 624                    RULES_FILE_NAMES
 625                        .iter()
 626                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 627                }) {
 628                    self.project_context_needs_refresh.send(()).ok();
 629                }
 630            }
 631            _ => {}
 632        }
 633    }
 634
 635    fn handle_prompts_updated_event(
 636        &mut self,
 637        _prompt_store: Entity<PromptStore>,
 638        _event: &prompt_store::PromptsUpdatedEvent,
 639        _cx: &mut Context<Self>,
 640    ) {
 641        self.project_context_needs_refresh.send(()).ok();
 642    }
 643
 644    fn handle_models_updated_event(
 645        &mut self,
 646        _registry: Entity<LanguageModelRegistry>,
 647        _event: &language_model::Event,
 648        cx: &mut Context<Self>,
 649    ) {
 650        self.models.refresh_list(cx);
 651
 652        let registry = LanguageModelRegistry::read_global(cx);
 653        let default_model = registry.default_model().map(|m| m.model);
 654        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 655
 656        for session in self.sessions.values_mut() {
 657            session.thread.update(cx, |thread, cx| {
 658                if thread.model().is_none()
 659                    && let Some(model) = default_model.clone()
 660                {
 661                    thread.set_model(model, cx);
 662                    cx.notify();
 663                }
 664                thread.set_summarization_model(summarization_model.clone(), cx);
 665            });
 666        }
 667    }
 668
 669    fn handle_context_server_store_updated(
 670        &mut self,
 671        _store: Entity<project::context_server_store::ContextServerStore>,
 672        _event: &project::context_server_store::ServerStatusChangedEvent,
 673        cx: &mut Context<Self>,
 674    ) {
 675        self.update_available_commands(cx);
 676    }
 677
 678    fn handle_context_server_registry_event(
 679        &mut self,
 680        _registry: Entity<ContextServerRegistry>,
 681        event: &ContextServerRegistryEvent,
 682        cx: &mut Context<Self>,
 683    ) {
 684        match event {
 685            ContextServerRegistryEvent::ToolsChanged => {}
 686            ContextServerRegistryEvent::PromptsChanged => {
 687                self.update_available_commands(cx);
 688            }
 689        }
 690    }
 691
 692    fn update_available_commands(&self, cx: &mut Context<Self>) {
 693        let available_commands = self.build_available_commands(cx);
 694        for session in self.sessions.values() {
 695            session.acp_thread.update(cx, |thread, cx| {
 696                thread
 697                    .handle_session_update(
 698                        acp::SessionUpdate::AvailableCommandsUpdate(
 699                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 700                        ),
 701                        cx,
 702                    )
 703                    .log_err();
 704            });
 705        }
 706    }
 707
 708    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 709        let registry = self.context_server_registry.read(cx);
 710
 711        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 712        for context_server_prompt in registry.prompts() {
 713            *prompt_name_counts
 714                .entry(context_server_prompt.prompt.name.as_str())
 715                .or_insert(0) += 1;
 716        }
 717
 718        registry
 719            .prompts()
 720            .flat_map(|context_server_prompt| {
 721                let prompt = &context_server_prompt.prompt;
 722
 723                let should_prefix = prompt_name_counts
 724                    .get(prompt.name.as_str())
 725                    .copied()
 726                    .unwrap_or(0)
 727                    > 1;
 728
 729                let name = if should_prefix {
 730                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 731                } else {
 732                    prompt.name.clone()
 733                };
 734
 735                let mut command = acp::AvailableCommand::new(
 736                    name,
 737                    prompt.description.clone().unwrap_or_default(),
 738                );
 739
 740                match prompt.arguments.as_deref() {
 741                    Some([arg]) => {
 742                        let hint = format!("<{}>", arg.name);
 743
 744                        command = command.input(acp::AvailableCommandInput::Unstructured(
 745                            acp::UnstructuredCommandInput::new(hint),
 746                        ));
 747                    }
 748                    Some([]) | None => {}
 749                    Some(_) => {
 750                        // skip >1 argument commands since we don't support them yet
 751                        return None;
 752                    }
 753                }
 754
 755                Some(command)
 756            })
 757            .collect()
 758    }
 759
 760    pub fn load_thread(
 761        &mut self,
 762        id: acp::SessionId,
 763        cx: &mut Context<Self>,
 764    ) -> Task<Result<Entity<Thread>>> {
 765        let database_future = ThreadsDatabase::connect(cx);
 766        cx.spawn(async move |this, cx| {
 767            let database = database_future.await.map_err(|err| anyhow!(err))?;
 768            let db_thread = database
 769                .load_thread(id.clone())
 770                .await?
 771                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 772
 773            this.update(cx, |this, cx| {
 774                let summarization_model = LanguageModelRegistry::read_global(cx)
 775                    .thread_summary_model()
 776                    .map(|c| c.model);
 777
 778                cx.new(|cx| {
 779                    let mut thread = Thread::from_db(
 780                        id.clone(),
 781                        db_thread,
 782                        this.project.clone(),
 783                        this.project_context.clone(),
 784                        this.context_server_registry.clone(),
 785                        this.templates.clone(),
 786                        cx,
 787                    );
 788                    thread.set_summarization_model(summarization_model, cx);
 789                    thread
 790                })
 791            })
 792        })
 793    }
 794
 795    pub fn open_thread(
 796        &mut self,
 797        id: acp::SessionId,
 798        cx: &mut Context<Self>,
 799    ) -> Task<Result<Entity<AcpThread>>> {
 800        if let Some(session) = self.sessions.get(&id) {
 801            return Task::ready(Ok(session.acp_thread.clone()));
 802        }
 803
 804        let task = self.load_thread(id, cx);
 805        cx.spawn(async move |this, cx| {
 806            let thread = task.await?;
 807            let acp_thread = this.update(cx, |this, cx| {
 808                this.register_session(thread.clone(), None, cx)
 809            })?;
 810            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 811            cx.update(|cx| {
 812                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 813            })
 814            .await?;
 815            Ok(acp_thread)
 816        })
 817    }
 818
 819    pub fn thread_summary(
 820        &mut self,
 821        id: acp::SessionId,
 822        cx: &mut Context<Self>,
 823    ) -> Task<Result<SharedString>> {
 824        let thread = self.open_thread(id.clone(), cx);
 825        cx.spawn(async move |this, cx| {
 826            let acp_thread = thread.await?;
 827            let result = this
 828                .update(cx, |this, cx| {
 829                    this.sessions
 830                        .get(&id)
 831                        .unwrap()
 832                        .thread
 833                        .update(cx, |thread, cx| thread.summary(cx))
 834                })?
 835                .await
 836                .context("Failed to generate summary")?;
 837            drop(acp_thread);
 838            Ok(result)
 839        })
 840    }
 841
 842    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 843        if thread.read(cx).is_empty() {
 844            return;
 845        }
 846
 847        let database_future = ThreadsDatabase::connect(cx);
 848        let (id, db_thread) =
 849            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 850        let Some(session) = self.sessions.get_mut(&id) else {
 851            return;
 852        };
 853        let thread_store = self.thread_store.clone();
 854        session.pending_save = cx.spawn(async move |_, cx| {
 855            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 856                return;
 857            };
 858            let db_thread = db_thread.await;
 859            database.save_thread(id, db_thread).await.log_err();
 860            thread_store.update(cx, |store, cx| store.reload(cx));
 861        });
 862    }
 863
 864    fn send_mcp_prompt(
 865        &self,
 866        message_id: UserMessageId,
 867        session_id: agent_client_protocol::SessionId,
 868        prompt_name: String,
 869        server_id: ContextServerId,
 870        arguments: HashMap<String, String>,
 871        original_content: Vec<acp::ContentBlock>,
 872        cx: &mut Context<Self>,
 873    ) -> Task<Result<acp::PromptResponse>> {
 874        let server_store = self.context_server_registry.read(cx).server_store().clone();
 875        let path_style = self.project.read(cx).path_style(cx);
 876
 877        cx.spawn(async move |this, cx| {
 878            let prompt =
 879                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 880
 881            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 882                let session = this
 883                    .sessions
 884                    .get(&session_id)
 885                    .context("Failed to get session")?;
 886                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 887            })??;
 888
 889            let mut last_is_user = true;
 890
 891            thread.update(cx, |thread, cx| {
 892                thread.push_acp_user_block(
 893                    message_id,
 894                    original_content.into_iter().skip(1),
 895                    path_style,
 896                    cx,
 897                );
 898            });
 899
 900            for message in prompt.messages {
 901                let context_server::types::PromptMessage { role, content } = message;
 902                let block = mcp_message_content_to_acp_content_block(content);
 903
 904                match role {
 905                    context_server::types::Role::User => {
 906                        let id = acp_thread::UserMessageId::new();
 907
 908                        acp_thread.update(cx, |acp_thread, cx| {
 909                            acp_thread.push_user_content_block_with_indent(
 910                                Some(id.clone()),
 911                                block.clone(),
 912                                true,
 913                                cx,
 914                            );
 915                        });
 916
 917                        thread.update(cx, |thread, cx| {
 918                            thread.push_acp_user_block(id, [block], path_style, cx);
 919                        });
 920                    }
 921                    context_server::types::Role::Assistant => {
 922                        acp_thread.update(cx, |acp_thread, cx| {
 923                            acp_thread.push_assistant_content_block_with_indent(
 924                                block.clone(),
 925                                false,
 926                                true,
 927                                cx,
 928                            );
 929                        });
 930
 931                        thread.update(cx, |thread, cx| {
 932                            thread.push_acp_agent_block(block, cx);
 933                        });
 934                    }
 935                }
 936
 937                last_is_user = role == context_server::types::Role::User;
 938            }
 939
 940            let response_stream = thread.update(cx, |thread, cx| {
 941                if last_is_user {
 942                    thread.send_existing(cx)
 943                } else {
 944                    // Resume if MCP prompt did not end with a user message
 945                    thread.resume(cx)
 946                }
 947            })?;
 948
 949            cx.update(|cx| {
 950                NativeAgentConnection::handle_thread_events(
 951                    response_stream,
 952                    acp_thread.downgrade(),
 953                    cx,
 954                )
 955            })
 956            .await
 957        })
 958    }
 959}
 960
 961/// Wrapper struct that implements the AgentConnection trait
 962#[derive(Clone)]
 963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 964
 965impl NativeAgentConnection {
 966    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 967        self.0
 968            .read(cx)
 969            .sessions
 970            .get(session_id)
 971            .map(|session| session.thread.clone())
 972    }
 973
 974    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 975        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 976    }
 977
 978    fn run_turn(
 979        &self,
 980        session_id: acp::SessionId,
 981        cx: &mut App,
 982        f: impl 'static
 983        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 984    ) -> Task<Result<acp::PromptResponse>> {
 985        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 986            agent
 987                .sessions
 988                .get_mut(&session_id)
 989                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 990        }) else {
 991            return Task::ready(Err(anyhow!("Session not found")));
 992        };
 993        log::debug!("Found session for: {}", session_id);
 994
 995        let response_stream = match f(thread, cx) {
 996            Ok(stream) => stream,
 997            Err(err) => return Task::ready(Err(err)),
 998        };
 999        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000    }
1001
1002    fn handle_thread_events(
1003        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004        acp_thread: WeakEntity<AcpThread>,
1005        cx: &App,
1006    ) -> Task<Result<acp::PromptResponse>> {
1007        cx.spawn(async move |cx| {
1008            // Handle response stream and forward to session.acp_thread
1009            while let Some(result) = events.next().await {
1010                match result {
1011                    Ok(event) => {
1012                        log::trace!("Received completion event: {:?}", event);
1013
1014                        match event {
1015                            ThreadEvent::UserMessage(message) => {
1016                                acp_thread.update(cx, |thread, cx| {
1017                                    for content in message.content {
1018                                        thread.push_user_content_block(
1019                                            Some(message.id.clone()),
1020                                            content.into(),
1021                                            cx,
1022                                        );
1023                                    }
1024                                })?;
1025                            }
1026                            ThreadEvent::AgentText(text) => {
1027                                acp_thread.update(cx, |thread, cx| {
1028                                    thread.push_assistant_content_block(text.into(), false, cx)
1029                                })?;
1030                            }
1031                            ThreadEvent::AgentThinking(text) => {
1032                                acp_thread.update(cx, |thread, cx| {
1033                                    thread.push_assistant_content_block(text.into(), true, cx)
1034                                })?;
1035                            }
1036                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037                                tool_call,
1038                                options,
1039                                response,
1040                                context: _,
1041                            }) => {
1042                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1043                                    thread.request_tool_call_authorization(tool_call, options, cx)
1044                                })??;
1045                                cx.background_spawn(async move {
1046                                    if let acp::RequestPermissionOutcome::Selected(
1047                                        acp::SelectedPermissionOutcome { option_id, .. },
1048                                    ) = outcome_task.await
1049                                    {
1050                                        response
1051                                            .send(option_id)
1052                                            .map(|_| anyhow!("authorization receiver was dropped"))
1053                                            .log_err();
1054                                    }
1055                                })
1056                                .detach();
1057                            }
1058                            ThreadEvent::ToolCall(tool_call) => {
1059                                acp_thread.update(cx, |thread, cx| {
1060                                    thread.upsert_tool_call(tool_call, cx)
1061                                })??;
1062                            }
1063                            ThreadEvent::ToolCallUpdate(update) => {
1064                                acp_thread.update(cx, |thread, cx| {
1065                                    thread.update_tool_call(update, cx)
1066                                })??;
1067                            }
1068                            ThreadEvent::SubagentSpawned(session_id) => {
1069                                acp_thread.update(cx, |thread, cx| {
1070                                    thread.subagent_spawned(session_id, cx);
1071                                })?;
1072                            }
1073                            ThreadEvent::Retry(status) => {
1074                                acp_thread.update(cx, |thread, cx| {
1075                                    thread.update_retry_status(status, cx)
1076                                })?;
1077                            }
1078                            ThreadEvent::Stop(stop_reason) => {
1079                                log::debug!("Assistant message complete: {:?}", stop_reason);
1080                                return Ok(acp::PromptResponse::new(stop_reason));
1081                            }
1082                        }
1083                    }
1084                    Err(e) => {
1085                        log::error!("Error in model response stream: {:?}", e);
1086                        return Err(e);
1087                    }
1088                }
1089            }
1090
1091            log::debug!("Response stream completed");
1092            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093        })
1094    }
1095}
1096
1097struct Command<'a> {
1098    prompt_name: &'a str,
1099    arg_value: &'a str,
1100    explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106            return None;
1107        };
1108        let text = text_content.text.trim();
1109        let command = text.strip_prefix('/')?;
1110        let (command, arg_value) = command
1111            .split_once(char::is_whitespace)
1112            .unwrap_or((command, ""));
1113
1114        if let Some((server_id, prompt_name)) = command.split_once('.') {
1115            Some(Self {
1116                prompt_name,
1117                arg_value,
1118                explicit_server_id: Some(server_id),
1119            })
1120        } else {
1121            Some(Self {
1122                prompt_name: command,
1123                arg_value,
1124                explicit_server_id: None,
1125            })
1126        }
1127    }
1128}
1129
1130struct NativeAgentModelSelector {
1131    session_id: acp::SessionId,
1132    connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137        log::debug!("NativeAgentConnection::list_models called");
1138        let list = self.connection.0.read(cx).models.model_list.clone();
1139        Task::ready(if list.is_empty() {
1140            Err(anyhow::anyhow!("No models available"))
1141        } else {
1142            Ok(list)
1143        })
1144    }
1145
1146    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147        log::debug!(
1148            "Setting model for session {}: {}",
1149            self.session_id,
1150            model_id
1151        );
1152        let Some(thread) = self
1153            .connection
1154            .0
1155            .read(cx)
1156            .sessions
1157            .get(&self.session_id)
1158            .map(|session| session.thread.clone())
1159        else {
1160            return Task::ready(Err(anyhow!("Session not found")));
1161        };
1162
1163        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165        };
1166
1167        // We want to reset the effort level when switching models, as the currently-selected effort level may
1168        // not be compatible.
1169        let effort = model
1170            .default_effort_level()
1171            .map(|effort_level| effort_level.value.to_string());
1172
1173        thread.update(cx, |thread, cx| {
1174            thread.set_model(model.clone(), cx);
1175            thread.set_thinking_effort(effort.clone(), cx);
1176            thread.set_thinking_enabled(model.supports_thinking(), cx);
1177        });
1178
1179        update_settings_file(
1180            self.connection.0.read(cx).fs.clone(),
1181            cx,
1182            move |settings, cx| {
1183                let provider = model.provider_id().0.to_string();
1184                let model = model.id().0.to_string();
1185                let enable_thinking = thread.read(cx).thinking_enabled();
1186                settings
1187                    .agent
1188                    .get_or_insert_default()
1189                    .set_model(LanguageModelSelection {
1190                        provider: provider.into(),
1191                        model,
1192                        enable_thinking,
1193                        effort,
1194                    });
1195            },
1196        );
1197
1198        Task::ready(Ok(()))
1199    }
1200
1201    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202        let Some(thread) = self
1203            .connection
1204            .0
1205            .read(cx)
1206            .sessions
1207            .get(&self.session_id)
1208            .map(|session| session.thread.clone())
1209        else {
1210            return Task::ready(Err(anyhow!("Session not found")));
1211        };
1212        let Some(model) = thread.read(cx).model() else {
1213            return Task::ready(Err(anyhow!("Model not found")));
1214        };
1215        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216        else {
1217            return Task::ready(Err(anyhow!("Provider not found")));
1218        };
1219        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220            model, &provider,
1221        )))
1222    }
1223
1224    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225        Some(self.connection.0.read(cx).models.watch())
1226    }
1227
1228    fn should_render_footer(&self) -> bool {
1229        true
1230    }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234    fn telemetry_id(&self) -> SharedString {
1235        "zed".into()
1236    }
1237
1238    fn new_session(
1239        self: Rc<Self>,
1240        project: Entity<Project>,
1241        cwd: &Path,
1242        cx: &mut App,
1243    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244        log::debug!("Creating new thread for project at: {cwd:?}");
1245        Task::ready(Ok(self
1246            .0
1247            .update(cx, |agent, cx| agent.new_session(project, cx))))
1248    }
1249
1250    fn supports_load_session(&self) -> bool {
1251        true
1252    }
1253
1254    fn load_session(
1255        self: Rc<Self>,
1256        session: AgentSessionInfo,
1257        _project: Entity<Project>,
1258        _cwd: &Path,
1259        cx: &mut App,
1260    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261        self.0
1262            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263    }
1264
1265    fn supports_close_session(&self) -> bool {
1266        true
1267    }
1268
1269    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270        self.0.update(cx, |agent, _cx| {
1271            agent.sessions.remove(session_id);
1272        });
1273        Task::ready(Ok(()))
1274    }
1275
1276    fn auth_methods(&self) -> &[acp::AuthMethod] {
1277        &[] // No auth for in-process
1278    }
1279
1280    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281        Task::ready(Ok(()))
1282    }
1283
1284    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285        Some(Rc::new(NativeAgentModelSelector {
1286            session_id: session_id.clone(),
1287            connection: self.clone(),
1288        }) as Rc<dyn AgentModelSelector>)
1289    }
1290
1291    fn prompt(
1292        &self,
1293        id: Option<acp_thread::UserMessageId>,
1294        params: acp::PromptRequest,
1295        cx: &mut App,
1296    ) -> Task<Result<acp::PromptResponse>> {
1297        let id = id.expect("UserMessageId is required");
1298        let session_id = params.session_id.clone();
1299        log::info!("Received prompt request for session: {}", session_id);
1300        log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302        if let Some(parsed_command) = Command::parse(&params.prompt) {
1303            let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305            let explicit_server_id = parsed_command
1306                .explicit_server_id
1307                .map(|server_id| ContextServerId(server_id.into()));
1308
1309            if let Some(prompt) =
1310                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311            {
1312                let arguments = if !parsed_command.arg_value.is_empty()
1313                    && let Some(arg_name) = prompt
1314                        .prompt
1315                        .arguments
1316                        .as_ref()
1317                        .and_then(|args| args.first())
1318                        .map(|arg| arg.name.clone())
1319                {
1320                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321                } else {
1322                    Default::default()
1323                };
1324
1325                let prompt_name = prompt.prompt.name.clone();
1326                let server_id = prompt.server_id.clone();
1327
1328                return self.0.update(cx, |agent, cx| {
1329                    agent.send_mcp_prompt(
1330                        id,
1331                        session_id.clone(),
1332                        prompt_name,
1333                        server_id,
1334                        arguments,
1335                        params.prompt,
1336                        cx,
1337                    )
1338                });
1339            };
1340        };
1341
1342        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344        self.run_turn(session_id, cx, move |thread, cx| {
1345            let content: Vec<UserMessageContent> = params
1346                .prompt
1347                .into_iter()
1348                .map(|block| UserMessageContent::from_content_block(block, path_style))
1349                .collect::<Vec<_>>();
1350            log::debug!("Converted prompt to message: {} chars", content.len());
1351            log::debug!("Message id: {:?}", id);
1352            log::debug!("Message content: {:?}", content);
1353
1354            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355        })
1356    }
1357
1358    fn retry(
1359        &self,
1360        session_id: &acp::SessionId,
1361        _cx: &App,
1362    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363        Some(Rc::new(NativeAgentSessionRetry {
1364            connection: self.clone(),
1365            session_id: session_id.clone(),
1366        }) as _)
1367    }
1368
1369    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370        log::info!("Cancelling on session: {}", session_id);
1371        self.0.update(cx, |agent, cx| {
1372            if let Some(agent) = agent.sessions.get(session_id) {
1373                agent
1374                    .thread
1375                    .update(cx, |thread, cx| thread.cancel(cx))
1376                    .detach();
1377            }
1378        });
1379    }
1380
1381    fn truncate(
1382        &self,
1383        session_id: &agent_client_protocol::SessionId,
1384        cx: &App,
1385    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386        self.0.read_with(cx, |agent, _cx| {
1387            agent.sessions.get(session_id).map(|session| {
1388                Rc::new(NativeAgentSessionTruncate {
1389                    thread: session.thread.clone(),
1390                    acp_thread: session.acp_thread.downgrade(),
1391                }) as _
1392            })
1393        })
1394    }
1395
1396    fn set_title(
1397        &self,
1398        session_id: &acp::SessionId,
1399        cx: &App,
1400    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401        self.0.read_with(cx, |agent, _cx| {
1402            agent
1403                .sessions
1404                .get(session_id)
1405                .filter(|s| !s.thread.read(cx).is_subagent())
1406                .map(|session| {
1407                    Rc::new(NativeAgentSessionSetTitle {
1408                        thread: session.thread.clone(),
1409                    }) as _
1410                })
1411        })
1412    }
1413
1414    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415        let thread_store = self.0.read(cx).thread_store.clone();
1416        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417    }
1418
1419    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421    }
1422
1423    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424        self
1425    }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429    fn thread_data(
1430        &self,
1431        session_id: &acp::SessionId,
1432        cx: &mut App,
1433    ) -> Task<Result<serde_json::Value>> {
1434        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435            return Task::ready(Err(anyhow!("Session not found")));
1436        };
1437
1438        let task = session.thread.read(cx).to_db(cx);
1439        cx.background_spawn(async move {
1440            serde_json::to_value(task.await).context("Failed to serialize thread")
1441        })
1442    }
1443}
1444
1445pub struct NativeAgentSessionList {
1446    thread_store: Entity<ThreadStore>,
1447    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449    _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454        let (tx, rx) = smol::channel::unbounded();
1455        let this_tx = tx.clone();
1456        let subscription = cx.observe(&thread_store, move |_, _| {
1457            this_tx
1458                .try_send(acp_thread::SessionListUpdate::Refresh)
1459                .ok();
1460        });
1461        Self {
1462            thread_store,
1463            updates_tx: tx,
1464            updates_rx: rx,
1465            _subscription: subscription,
1466        }
1467    }
1468
1469    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470        AgentSessionInfo {
1471            session_id: entry.id,
1472            cwd: None,
1473            title: Some(entry.title),
1474            updated_at: Some(entry.updated_at),
1475            meta: None,
1476        }
1477    }
1478
1479    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480        &self.thread_store
1481    }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485    fn list_sessions(
1486        &self,
1487        _request: AgentSessionListRequest,
1488        cx: &mut App,
1489    ) -> Task<Result<AgentSessionListResponse>> {
1490        let sessions = self
1491            .thread_store
1492            .read(cx)
1493            .entries()
1494            .map(Self::to_session_info)
1495            .collect();
1496        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497    }
1498
1499    fn supports_delete(&self) -> bool {
1500        true
1501    }
1502
1503    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504        self.thread_store
1505            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506    }
1507
1508    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509        self.thread_store
1510            .update(cx, |store, cx| store.delete_threads(cx))
1511    }
1512
1513    fn watch(
1514        &self,
1515        _cx: &mut App,
1516    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517        Some(self.updates_rx.clone())
1518    }
1519
1520    fn notify_refresh(&self) {
1521        self.updates_tx
1522            .try_send(acp_thread::SessionListUpdate::Refresh)
1523            .ok();
1524    }
1525
1526    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527        self
1528    }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532    thread: Entity<Thread>,
1533    acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538        match self.thread.update(cx, |thread, cx| {
1539            thread.truncate(message_id.clone(), cx)?;
1540            Ok(thread.latest_token_usage())
1541        }) {
1542            Ok(usage) => {
1543                self.acp_thread
1544                    .update(cx, |thread, cx| {
1545                        thread.update_token_usage(usage, cx);
1546                    })
1547                    .ok();
1548                Task::ready(Ok(()))
1549            }
1550            Err(error) => Task::ready(Err(error)),
1551        }
1552    }
1553}
1554
1555struct NativeAgentSessionRetry {
1556    connection: NativeAgentConnection,
1557    session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562        self.connection
1563            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564                thread.update(cx, |thread, cx| thread.resume(cx))
1565            })
1566    }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570    thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575        self.thread
1576            .update(cx, |thread, cx| thread.set_title(title, cx));
1577        Task::ready(Ok(()))
1578    }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582    agent: WeakEntity<NativeAgent>,
1583    acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587    pub(crate) fn create_subagent_thread(
1588        agent: WeakEntity<NativeAgent>,
1589        parent_thread_entity: Entity<Thread>,
1590        label: String,
1591        initial_prompt: String,
1592        timeout: Option<Duration>,
1593        allowed_tools: Option<Vec<String>>,
1594        cx: &mut App,
1595    ) -> Result<Rc<dyn SubagentHandle>> {
1596        let parent_thread = parent_thread_entity.read(cx);
1597        let current_depth = parent_thread.depth();
1598
1599        if current_depth >= MAX_SUBAGENT_DEPTH {
1600            return Err(anyhow!(
1601                "Maximum subagent depth ({}) reached",
1602                MAX_SUBAGENT_DEPTH
1603            ));
1604        }
1605
1606        let running_count = parent_thread.running_subagent_count();
1607        if running_count >= MAX_PARALLEL_SUBAGENTS {
1608            return Err(anyhow!(
1609                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
1610                MAX_PARALLEL_SUBAGENTS
1611            ));
1612        }
1613
1614        let allowed_tools = match allowed_tools {
1615            Some(tools) => {
1616                let parent_tool_names: std::collections::HashSet<&str> =
1617                    parent_thread.tools.keys().map(|s| s.as_str()).collect();
1618                Some(
1619                    tools
1620                        .into_iter()
1621                        .filter(|t| parent_tool_names.contains(t.as_str()))
1622                        .collect::<Vec<_>>(),
1623                )
1624            }
1625            None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1626        };
1627
1628        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1629            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1630            thread.set_title(label.into(), cx);
1631            thread
1632        });
1633
1634        let session_id = subagent_thread.read(cx).id().clone();
1635
1636        let acp_thread = agent.update(cx, |agent, cx| {
1637            agent.register_session(
1638                subagent_thread.clone(),
1639                allowed_tools
1640                    .as_ref()
1641                    .map(|v| v.iter().map(|s| s.as_str()).collect()),
1642                cx,
1643            )
1644        })?;
1645
1646        parent_thread_entity.update(cx, |parent_thread, _cx| {
1647            parent_thread.register_running_subagent(subagent_thread.downgrade())
1648        });
1649
1650        let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1651
1652        let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1653        let wait_for_prompt_to_complete = cx
1654            .background_spawn(async move {
1655                if let Some(timer) = timeout_timer {
1656                    futures::select! {
1657                        _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1658                        _ = task.fuse() => SubagentInitialPromptResult::Completed,
1659                    }
1660                } else {
1661                    task.await.log_err();
1662                    SubagentInitialPromptResult::Completed
1663                }
1664            })
1665            .shared();
1666
1667        let mut user_stop_rx: watch::Receiver<bool> =
1668            acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
1669
1670        let user_cancelled = cx
1671            .background_spawn(async move {
1672                loop {
1673                    if *user_stop_rx.borrow() {
1674                        return;
1675                    }
1676                    if user_stop_rx.changed().await.is_err() {
1677                        std::future::pending::<()>().await;
1678                    }
1679                }
1680            })
1681            .shared();
1682
1683        Ok(Rc::new(NativeSubagentHandle {
1684            session_id,
1685            subagent_thread,
1686            parent_thread: parent_thread_entity.downgrade(),
1687            acp_thread,
1688            wait_for_prompt_to_complete,
1689            user_cancelled,
1690        }) as _)
1691    }
1692}
1693
1694impl ThreadEnvironment for NativeThreadEnvironment {
1695    fn create_terminal(
1696        &self,
1697        command: String,
1698        cwd: Option<PathBuf>,
1699        output_byte_limit: Option<u64>,
1700        cx: &mut AsyncApp,
1701    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1702        let task = self.acp_thread.update(cx, |thread, cx| {
1703            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1704        });
1705
1706        let acp_thread = self.acp_thread.clone();
1707        cx.spawn(async move |cx| {
1708            let terminal = task?.await?;
1709
1710            let (drop_tx, drop_rx) = oneshot::channel();
1711            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1712
1713            cx.spawn(async move |cx| {
1714                drop_rx.await.ok();
1715                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1716            })
1717            .detach();
1718
1719            let handle = AcpTerminalHandle {
1720                terminal,
1721                _drop_tx: Some(drop_tx),
1722            };
1723
1724            Ok(Rc::new(handle) as _)
1725        })
1726    }
1727
1728    fn create_subagent(
1729        &self,
1730        parent_thread_entity: Entity<Thread>,
1731        label: String,
1732        initial_prompt: String,
1733        timeout: Option<Duration>,
1734        allowed_tools: Option<Vec<String>>,
1735        cx: &mut App,
1736    ) -> Result<Rc<dyn SubagentHandle>> {
1737        Self::create_subagent_thread(
1738            self.agent.clone(),
1739            parent_thread_entity,
1740            label,
1741            initial_prompt,
1742            timeout,
1743            allowed_tools,
1744            cx,
1745        )
1746    }
1747}
1748
1749#[derive(Debug, Clone, Copy)]
1750enum SubagentInitialPromptResult {
1751    Completed,
1752    Timeout,
1753}
1754
1755pub struct NativeSubagentHandle {
1756    session_id: acp::SessionId,
1757    parent_thread: WeakEntity<Thread>,
1758    subagent_thread: Entity<Thread>,
1759    acp_thread: Entity<AcpThread>,
1760    wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1761    user_cancelled: Shared<Task<()>>,
1762}
1763
1764impl SubagentHandle for NativeSubagentHandle {
1765    fn id(&self) -> acp::SessionId {
1766        self.session_id.clone()
1767    }
1768
1769    fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1770        let thread = self.subagent_thread.clone();
1771        let acp_thread = self.acp_thread.clone();
1772        let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1773
1774        let wait_for_summary_task = cx.spawn(async move |cx| {
1775            let timed_out = match wait_for_prompt.await {
1776                SubagentInitialPromptResult::Completed => false,
1777                SubagentInitialPromptResult::Timeout => true,
1778            };
1779
1780            let summary_prompt = if timed_out {
1781                thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1782                format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1783            } else {
1784                summary_prompt
1785            };
1786
1787            acp_thread
1788                .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1789                .await?;
1790
1791            thread.read_with(cx, |thread, _cx| {
1792                thread
1793                    .last_message()
1794                    .map(|m| m.to_markdown())
1795                    .context("No response from subagent")
1796            })
1797        });
1798
1799        let user_cancelled = self.user_cancelled.clone();
1800        let thread = self.subagent_thread.clone();
1801        let subagent_session_id = self.session_id.clone();
1802        let parent_thread = self.parent_thread.clone();
1803        cx.spawn(async move |cx| {
1804            let result = futures::select! {
1805                result = wait_for_summary_task.fuse() => result,
1806                _ = user_cancelled.fuse() => {
1807                    thread.update(cx, |thread, cx| thread.cancel(cx).detach());
1808                    Err(anyhow!("User cancelled"))
1809                },
1810            };
1811            parent_thread
1812                .update(cx, |parent_thread, cx| {
1813                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1814                })
1815                .ok();
1816            result
1817        })
1818    }
1819}
1820
1821pub struct AcpTerminalHandle {
1822    terminal: Entity<acp_thread::Terminal>,
1823    _drop_tx: Option<oneshot::Sender<()>>,
1824}
1825
1826impl TerminalHandle for AcpTerminalHandle {
1827    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1828        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1829    }
1830
1831    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1832        Ok(self
1833            .terminal
1834            .read_with(cx, |term, _cx| term.wait_for_exit()))
1835    }
1836
1837    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1838        Ok(self
1839            .terminal
1840            .read_with(cx, |term, cx| term.current_output(cx)))
1841    }
1842
1843    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1844        cx.update(|cx| {
1845            self.terminal.update(cx, |terminal, cx| {
1846                terminal.kill(cx);
1847            });
1848        });
1849        Ok(())
1850    }
1851
1852    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1853        Ok(self
1854            .terminal
1855            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1856    }
1857}
1858
1859#[cfg(test)]
1860mod internal_tests {
1861    use super::*;
1862    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1863    use fs::FakeFs;
1864    use gpui::TestAppContext;
1865    use indoc::formatdoc;
1866    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1867    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1868    use serde_json::json;
1869    use settings::SettingsStore;
1870    use util::{path, rel_path::rel_path};
1871
1872    #[gpui::test]
1873    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1874        init_test(cx);
1875        let fs = FakeFs::new(cx.executor());
1876        fs.insert_tree(
1877            "/",
1878            json!({
1879                "a": {}
1880            }),
1881        )
1882        .await;
1883        let project = Project::test(fs.clone(), [], cx).await;
1884        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1885        let agent = NativeAgent::new(
1886            project.clone(),
1887            thread_store,
1888            Templates::new(),
1889            None,
1890            fs.clone(),
1891            &mut cx.to_async(),
1892        )
1893        .await
1894        .unwrap();
1895        agent.read_with(cx, |agent, cx| {
1896            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1897        });
1898
1899        let worktree = project
1900            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1901            .await
1902            .unwrap();
1903        cx.run_until_parked();
1904        agent.read_with(cx, |agent, cx| {
1905            assert_eq!(
1906                agent.project_context.read(cx).worktrees,
1907                vec![WorktreeContext {
1908                    root_name: "a".into(),
1909                    abs_path: Path::new("/a").into(),
1910                    rules_file: None
1911                }]
1912            )
1913        });
1914
1915        // Creating `/a/.rules` updates the project context.
1916        fs.insert_file("/a/.rules", Vec::new()).await;
1917        cx.run_until_parked();
1918        agent.read_with(cx, |agent, cx| {
1919            let rules_entry = worktree
1920                .read(cx)
1921                .entry_for_path(rel_path(".rules"))
1922                .unwrap();
1923            assert_eq!(
1924                agent.project_context.read(cx).worktrees,
1925                vec![WorktreeContext {
1926                    root_name: "a".into(),
1927                    abs_path: Path::new("/a").into(),
1928                    rules_file: Some(RulesFileContext {
1929                        path_in_worktree: rel_path(".rules").into(),
1930                        text: "".into(),
1931                        project_entry_id: rules_entry.id.to_usize()
1932                    })
1933                }]
1934            )
1935        });
1936    }
1937
1938    #[gpui::test]
1939    async fn test_listing_models(cx: &mut TestAppContext) {
1940        init_test(cx);
1941        let fs = FakeFs::new(cx.executor());
1942        fs.insert_tree("/", json!({ "a": {}  })).await;
1943        let project = Project::test(fs.clone(), [], cx).await;
1944        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1945        let connection = NativeAgentConnection(
1946            NativeAgent::new(
1947                project.clone(),
1948                thread_store,
1949                Templates::new(),
1950                None,
1951                fs.clone(),
1952                &mut cx.to_async(),
1953            )
1954            .await
1955            .unwrap(),
1956        );
1957
1958        // Create a thread/session
1959        let acp_thread = cx
1960            .update(|cx| {
1961                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1962            })
1963            .await
1964            .unwrap();
1965
1966        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1967
1968        let models = cx
1969            .update(|cx| {
1970                connection
1971                    .model_selector(&session_id)
1972                    .unwrap()
1973                    .list_models(cx)
1974            })
1975            .await
1976            .unwrap();
1977
1978        let acp_thread::AgentModelList::Grouped(models) = models else {
1979            panic!("Unexpected model group");
1980        };
1981        assert_eq!(
1982            models,
1983            IndexMap::from_iter([(
1984                AgentModelGroupName("Fake".into()),
1985                vec![AgentModelInfo {
1986                    id: acp::ModelId::new("fake/fake"),
1987                    name: "Fake".into(),
1988                    description: None,
1989                    icon: Some(acp_thread::AgentModelIcon::Named(
1990                        ui::IconName::ZedAssistant
1991                    )),
1992                    is_latest: false,
1993                    cost: None,
1994                }]
1995            )])
1996        );
1997    }
1998
1999    #[gpui::test]
2000    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2001        init_test(cx);
2002        let fs = FakeFs::new(cx.executor());
2003        fs.create_dir(paths::settings_file().parent().unwrap())
2004            .await
2005            .unwrap();
2006        fs.insert_file(
2007            paths::settings_file(),
2008            json!({
2009                "agent": {
2010                    "default_model": {
2011                        "provider": "foo",
2012                        "model": "bar"
2013                    }
2014                }
2015            })
2016            .to_string()
2017            .into_bytes(),
2018        )
2019        .await;
2020        let project = Project::test(fs.clone(), [], cx).await;
2021
2022        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2023
2024        // Create the agent and connection
2025        let agent = NativeAgent::new(
2026            project.clone(),
2027            thread_store,
2028            Templates::new(),
2029            None,
2030            fs.clone(),
2031            &mut cx.to_async(),
2032        )
2033        .await
2034        .unwrap();
2035        let connection = NativeAgentConnection(agent.clone());
2036
2037        // Create a thread/session
2038        let acp_thread = cx
2039            .update(|cx| {
2040                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2041            })
2042            .await
2043            .unwrap();
2044
2045        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2046
2047        // Select a model
2048        let selector = connection.model_selector(&session_id).unwrap();
2049        let model_id = acp::ModelId::new("fake/fake");
2050        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2051            .await
2052            .unwrap();
2053
2054        // Verify the thread has the selected model
2055        agent.read_with(cx, |agent, _| {
2056            let session = agent.sessions.get(&session_id).unwrap();
2057            session.thread.read_with(cx, |thread, _| {
2058                assert_eq!(thread.model().unwrap().id().0, "fake");
2059            });
2060        });
2061
2062        cx.run_until_parked();
2063
2064        // Verify settings file was updated
2065        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2066        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2067
2068        // Check that the agent settings contain the selected model
2069        assert_eq!(
2070            settings_json["agent"]["default_model"]["model"],
2071            json!("fake")
2072        );
2073        assert_eq!(
2074            settings_json["agent"]["default_model"]["provider"],
2075            json!("fake")
2076        );
2077
2078        // Register a thinking model and select it.
2079        cx.update(|cx| {
2080            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2081                "fake-corp",
2082                "fake-thinking",
2083                "Fake Thinking",
2084                true,
2085            ));
2086            let thinking_provider = Arc::new(
2087                FakeLanguageModelProvider::new(
2088                    LanguageModelProviderId::from("fake-corp".to_string()),
2089                    LanguageModelProviderName::from("Fake Corp".to_string()),
2090                )
2091                .with_models(vec![thinking_model]),
2092            );
2093            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2094                registry.register_provider(thinking_provider, cx);
2095            });
2096        });
2097        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2098
2099        let selector = connection.model_selector(&session_id).unwrap();
2100        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2101            .await
2102            .unwrap();
2103        cx.run_until_parked();
2104
2105        // Verify enable_thinking was written to settings as true.
2106        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2107        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2108        assert_eq!(
2109            settings_json["agent"]["default_model"]["enable_thinking"],
2110            json!(true),
2111            "selecting a thinking model should persist enable_thinking: true to settings"
2112        );
2113    }
2114
2115    #[gpui::test]
2116    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2117        init_test(cx);
2118        let fs = FakeFs::new(cx.executor());
2119        fs.create_dir(paths::settings_file().parent().unwrap())
2120            .await
2121            .unwrap();
2122        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2123        let project = Project::test(fs.clone(), [], cx).await;
2124
2125        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2126        let agent = NativeAgent::new(
2127            project.clone(),
2128            thread_store,
2129            Templates::new(),
2130            None,
2131            fs.clone(),
2132            &mut cx.to_async(),
2133        )
2134        .await
2135        .unwrap();
2136        let connection = NativeAgentConnection(agent.clone());
2137
2138        let acp_thread = cx
2139            .update(|cx| {
2140                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2141            })
2142            .await
2143            .unwrap();
2144        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2145
2146        // Register a second provider with a thinking model.
2147        cx.update(|cx| {
2148            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2149                "fake-corp",
2150                "fake-thinking",
2151                "Fake Thinking",
2152                true,
2153            ));
2154            let thinking_provider = Arc::new(
2155                FakeLanguageModelProvider::new(
2156                    LanguageModelProviderId::from("fake-corp".to_string()),
2157                    LanguageModelProviderName::from("Fake Corp".to_string()),
2158                )
2159                .with_models(vec![thinking_model]),
2160            );
2161            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2162                registry.register_provider(thinking_provider, cx);
2163            });
2164        });
2165        // Refresh the agent's model list so it picks up the new provider.
2166        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2167
2168        // Thread starts with thinking_enabled = false (the default).
2169        agent.read_with(cx, |agent, _| {
2170            let session = agent.sessions.get(&session_id).unwrap();
2171            session.thread.read_with(cx, |thread, _| {
2172                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2173            });
2174        });
2175
2176        // Select the thinking model via select_model.
2177        let selector = connection.model_selector(&session_id).unwrap();
2178        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2179            .await
2180            .unwrap();
2181
2182        // select_model should have enabled thinking based on the model's supports_thinking().
2183        agent.read_with(cx, |agent, _| {
2184            let session = agent.sessions.get(&session_id).unwrap();
2185            session.thread.read_with(cx, |thread, _| {
2186                assert!(
2187                    thread.thinking_enabled(),
2188                    "select_model should enable thinking when model supports it"
2189                );
2190            });
2191        });
2192
2193        // Switch back to the non-thinking model.
2194        let selector = connection.model_selector(&session_id).unwrap();
2195        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2196            .await
2197            .unwrap();
2198
2199        // select_model should have disabled thinking.
2200        agent.read_with(cx, |agent, _| {
2201            let session = agent.sessions.get(&session_id).unwrap();
2202            session.thread.read_with(cx, |thread, _| {
2203                assert!(
2204                    !thread.thinking_enabled(),
2205                    "select_model should disable thinking when model does not support it"
2206                );
2207            });
2208        });
2209    }
2210
2211    #[gpui::test]
2212    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2213        init_test(cx);
2214        let fs = FakeFs::new(cx.executor());
2215        fs.insert_tree("/", json!({ "a": {} })).await;
2216        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2217        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2218        let agent = NativeAgent::new(
2219            project.clone(),
2220            thread_store.clone(),
2221            Templates::new(),
2222            None,
2223            fs.clone(),
2224            &mut cx.to_async(),
2225        )
2226        .await
2227        .unwrap();
2228        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2229
2230        // Register a thinking model.
2231        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2232            "fake-corp",
2233            "fake-thinking",
2234            "Fake Thinking",
2235            true,
2236        ));
2237        let thinking_provider = Arc::new(
2238            FakeLanguageModelProvider::new(
2239                LanguageModelProviderId::from("fake-corp".to_string()),
2240                LanguageModelProviderName::from("Fake Corp".to_string()),
2241            )
2242            .with_models(vec![thinking_model.clone()]),
2243        );
2244        cx.update(|cx| {
2245            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2246                registry.register_provider(thinking_provider, cx);
2247            });
2248        });
2249        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2250
2251        // Create a thread and select the thinking model.
2252        let acp_thread = cx
2253            .update(|cx| {
2254                connection
2255                    .clone()
2256                    .new_session(project.clone(), Path::new("/a"), cx)
2257            })
2258            .await
2259            .unwrap();
2260        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2261
2262        let selector = connection.model_selector(&session_id).unwrap();
2263        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2264            .await
2265            .unwrap();
2266
2267        // Verify thinking is enabled after selecting the thinking model.
2268        let thread = agent.read_with(cx, |agent, _| {
2269            agent.sessions.get(&session_id).unwrap().thread.clone()
2270        });
2271        thread.read_with(cx, |thread, _| {
2272            assert!(
2273                thread.thinking_enabled(),
2274                "thinking should be enabled after selecting thinking model"
2275            );
2276        });
2277
2278        // Send a message so the thread gets persisted.
2279        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2280        let send = cx.foreground_executor().spawn(send);
2281        cx.run_until_parked();
2282
2283        thinking_model.send_last_completion_stream_text_chunk("Response.");
2284        thinking_model.end_last_completion_stream();
2285
2286        send.await.unwrap();
2287        cx.run_until_parked();
2288
2289        // Close the session so it can be reloaded from disk.
2290        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2291            .await
2292            .unwrap();
2293        drop(thread);
2294        drop(acp_thread);
2295        agent.read_with(cx, |agent, _| {
2296            assert!(agent.sessions.is_empty());
2297        });
2298
2299        // Reload the thread and verify thinking_enabled is still true.
2300        let reloaded_acp_thread = agent
2301            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2302            .await
2303            .unwrap();
2304        let reloaded_thread = agent.read_with(cx, |agent, _| {
2305            agent.sessions.get(&session_id).unwrap().thread.clone()
2306        });
2307        reloaded_thread.read_with(cx, |thread, _| {
2308            assert!(
2309                thread.thinking_enabled(),
2310                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2311            );
2312        });
2313
2314        drop(reloaded_acp_thread);
2315    }
2316
2317    #[gpui::test]
2318    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2319        init_test(cx);
2320        let fs = FakeFs::new(cx.executor());
2321        fs.insert_tree("/", json!({ "a": {} })).await;
2322        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2323        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2324        let agent = NativeAgent::new(
2325            project.clone(),
2326            thread_store.clone(),
2327            Templates::new(),
2328            None,
2329            fs.clone(),
2330            &mut cx.to_async(),
2331        )
2332        .await
2333        .unwrap();
2334        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2335
2336        // Register a model where id() != name(), like real Anthropic models
2337        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2338        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2339            "fake-corp",
2340            "custom-model-id",
2341            "Custom Model Display Name",
2342            false,
2343        ));
2344        let provider = Arc::new(
2345            FakeLanguageModelProvider::new(
2346                LanguageModelProviderId::from("fake-corp".to_string()),
2347                LanguageModelProviderName::from("Fake Corp".to_string()),
2348            )
2349            .with_models(vec![model.clone()]),
2350        );
2351        cx.update(|cx| {
2352            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2353                registry.register_provider(provider, cx);
2354            });
2355        });
2356        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2357
2358        // Create a thread and select the model.
2359        let acp_thread = cx
2360            .update(|cx| {
2361                connection
2362                    .clone()
2363                    .new_session(project.clone(), Path::new("/a"), cx)
2364            })
2365            .await
2366            .unwrap();
2367        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2368
2369        let selector = connection.model_selector(&session_id).unwrap();
2370        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2371            .await
2372            .unwrap();
2373
2374        let thread = agent.read_with(cx, |agent, _| {
2375            agent.sessions.get(&session_id).unwrap().thread.clone()
2376        });
2377        thread.read_with(cx, |thread, _| {
2378            assert_eq!(
2379                thread.model().unwrap().id().0.as_ref(),
2380                "custom-model-id",
2381                "model should be set before persisting"
2382            );
2383        });
2384
2385        // Send a message so the thread gets persisted.
2386        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2387        let send = cx.foreground_executor().spawn(send);
2388        cx.run_until_parked();
2389
2390        model.send_last_completion_stream_text_chunk("Response.");
2391        model.end_last_completion_stream();
2392
2393        send.await.unwrap();
2394        cx.run_until_parked();
2395
2396        // Close the session so it can be reloaded from disk.
2397        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2398            .await
2399            .unwrap();
2400        drop(thread);
2401        drop(acp_thread);
2402        agent.read_with(cx, |agent, _| {
2403            assert!(agent.sessions.is_empty());
2404        });
2405
2406        // Reload the thread and verify the model was preserved.
2407        let reloaded_acp_thread = agent
2408            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2409            .await
2410            .unwrap();
2411        let reloaded_thread = agent.read_with(cx, |agent, _| {
2412            agent.sessions.get(&session_id).unwrap().thread.clone()
2413        });
2414        reloaded_thread.read_with(cx, |thread, _| {
2415            let reloaded_model = thread
2416                .model()
2417                .expect("model should be present after reload");
2418            assert_eq!(
2419                reloaded_model.id().0.as_ref(),
2420                "custom-model-id",
2421                "reloaded thread should have the same model, not fall back to the default"
2422            );
2423        });
2424
2425        drop(reloaded_acp_thread);
2426    }
2427
2428    #[gpui::test]
2429    async fn test_save_load_thread(cx: &mut TestAppContext) {
2430        init_test(cx);
2431        let fs = FakeFs::new(cx.executor());
2432        fs.insert_tree(
2433            "/",
2434            json!({
2435                "a": {
2436                    "b.md": "Lorem"
2437                }
2438            }),
2439        )
2440        .await;
2441        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2442        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2443        let agent = NativeAgent::new(
2444            project.clone(),
2445            thread_store.clone(),
2446            Templates::new(),
2447            None,
2448            fs.clone(),
2449            &mut cx.to_async(),
2450        )
2451        .await
2452        .unwrap();
2453        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2454
2455        let acp_thread = cx
2456            .update(|cx| {
2457                connection
2458                    .clone()
2459                    .new_session(project.clone(), Path::new(""), cx)
2460            })
2461            .await
2462            .unwrap();
2463        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2464        let thread = agent.read_with(cx, |agent, _| {
2465            agent.sessions.get(&session_id).unwrap().thread.clone()
2466        });
2467
2468        // Ensure empty threads are not saved, even if they get mutated.
2469        let model = Arc::new(FakeLanguageModel::default());
2470        let summary_model = Arc::new(FakeLanguageModel::default());
2471        thread.update(cx, |thread, cx| {
2472            thread.set_model(model.clone(), cx);
2473            thread.set_summarization_model(Some(summary_model.clone()), cx);
2474        });
2475        cx.run_until_parked();
2476        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2477
2478        let send = acp_thread.update(cx, |thread, cx| {
2479            thread.send(
2480                vec![
2481                    "What does ".into(),
2482                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2483                        "b.md",
2484                        MentionUri::File {
2485                            abs_path: path!("/a/b.md").into(),
2486                        }
2487                        .to_uri()
2488                        .to_string(),
2489                    )),
2490                    " mean?".into(),
2491                ],
2492                cx,
2493            )
2494        });
2495        let send = cx.foreground_executor().spawn(send);
2496        cx.run_until_parked();
2497
2498        model.send_last_completion_stream_text_chunk("Lorem.");
2499        model.end_last_completion_stream();
2500        cx.run_until_parked();
2501        summary_model
2502            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2503        summary_model.end_last_completion_stream();
2504
2505        send.await.unwrap();
2506        let uri = MentionUri::File {
2507            abs_path: path!("/a/b.md").into(),
2508        }
2509        .to_uri();
2510        acp_thread.read_with(cx, |thread, cx| {
2511            assert_eq!(
2512                thread.to_markdown(cx),
2513                formatdoc! {"
2514                    ## User
2515
2516                    What does [@b.md]({uri}) mean?
2517
2518                    ## Assistant
2519
2520                    Lorem.
2521
2522                "}
2523            )
2524        });
2525
2526        cx.run_until_parked();
2527
2528        // Close the session so it can be reloaded from disk.
2529        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2530            .await
2531            .unwrap();
2532        drop(thread);
2533        drop(acp_thread);
2534        agent.read_with(cx, |agent, _| {
2535            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2536        });
2537
2538        // Ensure the thread can be reloaded from disk.
2539        assert_eq!(
2540            thread_entries(&thread_store, cx),
2541            vec![(
2542                session_id.clone(),
2543                format!("Explaining {}", path!("/a/b.md"))
2544            )]
2545        );
2546        let acp_thread = agent
2547            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2548            .await
2549            .unwrap();
2550        acp_thread.read_with(cx, |thread, cx| {
2551            assert_eq!(
2552                thread.to_markdown(cx),
2553                formatdoc! {"
2554                    ## User
2555
2556                    What does [@b.md]({uri}) mean?
2557
2558                    ## Assistant
2559
2560                    Lorem.
2561
2562                "}
2563            )
2564        });
2565    }
2566
2567    fn thread_entries(
2568        thread_store: &Entity<ThreadStore>,
2569        cx: &mut TestAppContext,
2570    ) -> Vec<(acp::SessionId, String)> {
2571        thread_store.read_with(cx, |store, _| {
2572            store
2573                .entries()
2574                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2575                .collect::<Vec<_>>()
2576        })
2577    }
2578
2579    fn init_test(cx: &mut TestAppContext) {
2580        env_logger::try_init().ok();
2581        cx.update(|cx| {
2582            let settings_store = SettingsStore::test(cx);
2583            cx.set_global(settings_store);
2584
2585            LanguageModelRegistry::test(cx);
2586        });
2587    }
2588}
2589
2590fn mcp_message_content_to_acp_content_block(
2591    content: context_server::types::MessageContent,
2592) -> acp::ContentBlock {
2593    match content {
2594        context_server::types::MessageContent::Text {
2595            text,
2596            annotations: _,
2597        } => text.into(),
2598        context_server::types::MessageContent::Image {
2599            data,
2600            mime_type,
2601            annotations: _,
2602        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2603        context_server::types::MessageContent::Audio {
2604            data,
2605            mime_type,
2606            annotations: _,
2607        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2608        context_server::types::MessageContent::Resource {
2609            resource,
2610            annotations: _,
2611        } => {
2612            let mut link =
2613                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2614            if let Some(mime_type) = resource.mime_type {
2615                link = link.mime_type(mime_type);
2616            }
2617            acp::ContentBlock::ResourceLink(link)
2618        }
2619    }
2620}