agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use shell_command_parser::extract_commands;
  20pub use templates::*;
  21pub use thread::*;
  22pub use thread_store::*;
  23pub use tool_permissions::*;
  24pub use tools::*;
  25
  26use acp_thread::{
  27    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  28    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  29};
  30use agent_client_protocol as acp;
  31use anyhow::{Context as _, Result, anyhow};
  32use chrono::{DateTime, Utc};
  33use collections::{HashMap, HashSet, IndexMap};
  34use fs::Fs;
  35use futures::channel::{mpsc, oneshot};
  36use futures::future::Shared;
  37use futures::{FutureExt as _, StreamExt as _, future};
  38use gpui::{
  39    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  40};
  41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  42use project::{Project, ProjectItem, ProjectPath, Worktree};
  43use prompt_store::{
  44    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  45    WorktreeContext,
  46};
  47use serde::{Deserialize, Serialize};
  48use settings::{LanguageModelSelection, update_settings_file};
  49use std::any::Any;
  50use std::path::{Path, PathBuf};
  51use std::rc::Rc;
  52use std::sync::Arc;
  53use util::ResultExt;
  54use util::path_list::PathList;
  55use util::rel_path::RelPath;
  56
  57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  58pub struct ProjectSnapshot {
  59    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  60    pub timestamp: DateTime<Utc>,
  61}
  62
  63pub struct RulesLoadingError {
  64    pub message: SharedString,
  65}
  66
  67/// Holds both the internal Thread and the AcpThread for a session
  68struct Session {
  69    /// The internal thread that processes messages
  70    thread: Entity<Thread>,
  71    /// The ACP thread that handles protocol communication
  72    acp_thread: Entity<acp_thread::AcpThread>,
  73    pending_save: Task<()>,
  74    _subscriptions: Vec<Subscription>,
  75}
  76
  77pub struct LanguageModels {
  78    /// Access language model by ID
  79    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  80    /// Cached list for returning language model information
  81    model_list: acp_thread::AgentModelList,
  82    refresh_models_rx: watch::Receiver<()>,
  83    refresh_models_tx: watch::Sender<()>,
  84    _authenticate_all_providers_task: Task<()>,
  85}
  86
  87impl LanguageModels {
  88    fn new(cx: &mut App) -> Self {
  89        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  90
  91        let mut this = Self {
  92            models: HashMap::default(),
  93            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  94            refresh_models_rx,
  95            refresh_models_tx,
  96            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  97        };
  98        this.refresh_list(cx);
  99        this
 100    }
 101
 102    fn refresh_list(&mut self, cx: &App) {
 103        let providers = LanguageModelRegistry::global(cx)
 104            .read(cx)
 105            .visible_providers()
 106            .into_iter()
 107            .filter(|provider| provider.is_authenticated(cx))
 108            .collect::<Vec<_>>();
 109
 110        let mut language_model_list = IndexMap::default();
 111        let mut recommended_models = HashSet::default();
 112
 113        let mut recommended = Vec::new();
 114        for provider in &providers {
 115            for model in provider.recommended_models(cx) {
 116                recommended_models.insert((model.provider_id(), model.id()));
 117                recommended.push(Self::map_language_model_to_info(&model, provider));
 118            }
 119        }
 120        if !recommended.is_empty() {
 121            language_model_list.insert(
 122                acp_thread::AgentModelGroupName("Recommended".into()),
 123                recommended,
 124            );
 125        }
 126
 127        let mut models = HashMap::default();
 128        for provider in providers {
 129            let mut provider_models = Vec::new();
 130            for model in provider.provided_models(cx) {
 131                let model_info = Self::map_language_model_to_info(&model, &provider);
 132                let model_id = model_info.id.clone();
 133                provider_models.push(model_info);
 134                models.insert(model_id, model);
 135            }
 136            if !provider_models.is_empty() {
 137                language_model_list.insert(
 138                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 139                    provider_models,
 140                );
 141            }
 142        }
 143
 144        self.models = models;
 145        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 146        self.refresh_models_tx.send(()).ok();
 147    }
 148
 149    fn watch(&self) -> watch::Receiver<()> {
 150        self.refresh_models_rx.clone()
 151    }
 152
 153    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 154        self.models.get(model_id).cloned()
 155    }
 156
 157    fn map_language_model_to_info(
 158        model: &Arc<dyn LanguageModel>,
 159        provider: &Arc<dyn LanguageModelProvider>,
 160    ) -> acp_thread::AgentModelInfo {
 161        acp_thread::AgentModelInfo {
 162            id: Self::model_id(model),
 163            name: model.name().0,
 164            description: None,
 165            icon: Some(match provider.icon() {
 166                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 167                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 168            }),
 169            is_latest: model.is_latest(),
 170            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 171        }
 172    }
 173
 174    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 175        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 176    }
 177
 178    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 179        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 180            .read(cx)
 181            .visible_providers()
 182            .iter()
 183            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 184            .collect::<Vec<_>>();
 185
 186        cx.background_spawn(async move {
 187            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 188                if let Err(err) = authenticate_task.await {
 189                    match err {
 190                        language_model::AuthenticateError::CredentialsNotFound => {
 191                            // Since we're authenticating these providers in the
 192                            // background for the purposes of populating the
 193                            // language selector, we don't care about providers
 194                            // where the credentials are not found.
 195                        }
 196                        language_model::AuthenticateError::ConnectionRefused => {
 197                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 198                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 199                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 200                        }
 201                        _ => {
 202                            // Some providers have noisy failure states that we
 203                            // don't want to spam the logs with every time the
 204                            // language model selector is initialized.
 205                            //
 206                            // Ideally these should have more clear failure modes
 207                            // that we know are safe to ignore here, like what we do
 208                            // with `CredentialsNotFound` above.
 209                            match provider_id.0.as_ref() {
 210                                "lmstudio" | "ollama" => {
 211                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 212                                    //
 213                                    // These fail noisily, so we don't log them.
 214                                }
 215                                "copilot_chat" => {
 216                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 217                                }
 218                                _ => {
 219                                    log::error!(
 220                                        "Failed to authenticate provider: {}: {err:#}",
 221                                        provider_name.0
 222                                    );
 223                                }
 224                            }
 225                        }
 226                    }
 227                }
 228            }
 229        })
 230    }
 231}
 232
 233pub struct NativeAgent {
 234    /// Session ID -> Session mapping
 235    sessions: HashMap<acp::SessionId, Session>,
 236    thread_store: Entity<ThreadStore>,
 237    /// Shared project context for all threads
 238    project_context: Entity<ProjectContext>,
 239    project_context_needs_refresh: watch::Sender<()>,
 240    _maintain_project_context: Task<Result<()>>,
 241    context_server_registry: Entity<ContextServerRegistry>,
 242    /// Shared templates for all threads
 243    templates: Arc<Templates>,
 244    /// Cached model information
 245    models: LanguageModels,
 246    project: Entity<Project>,
 247    prompt_store: Option<Entity<PromptStore>>,
 248    fs: Arc<dyn Fs>,
 249    _subscriptions: Vec<Subscription>,
 250}
 251
 252impl NativeAgent {
 253    pub async fn new(
 254        project: Entity<Project>,
 255        thread_store: Entity<ThreadStore>,
 256        templates: Arc<Templates>,
 257        prompt_store: Option<Entity<PromptStore>>,
 258        fs: Arc<dyn Fs>,
 259        cx: &mut AsyncApp,
 260    ) -> Result<Entity<NativeAgent>> {
 261        log::debug!("Creating new NativeAgent");
 262
 263        let project_context = cx
 264            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 265            .await;
 266
 267        Ok(cx.new(|cx| {
 268            let context_server_store = project.read(cx).context_server_store();
 269            let context_server_registry =
 270                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 271
 272            let mut subscriptions = vec![
 273                cx.subscribe(&project, Self::handle_project_event),
 274                cx.subscribe(
 275                    &LanguageModelRegistry::global(cx),
 276                    Self::handle_models_updated_event,
 277                ),
 278                cx.subscribe(
 279                    &context_server_store,
 280                    Self::handle_context_server_store_updated,
 281                ),
 282                cx.subscribe(
 283                    &context_server_registry,
 284                    Self::handle_context_server_registry_event,
 285                ),
 286            ];
 287            if let Some(prompt_store) = prompt_store.as_ref() {
 288                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 289            }
 290
 291            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 292                watch::channel(());
 293            Self {
 294                sessions: HashMap::default(),
 295                thread_store,
 296                project_context: cx.new(|_| project_context),
 297                project_context_needs_refresh: project_context_needs_refresh_tx,
 298                _maintain_project_context: cx.spawn(async move |this, cx| {
 299                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 300                }),
 301                context_server_registry,
 302                templates,
 303                models: LanguageModels::new(cx),
 304                project,
 305                prompt_store,
 306                fs,
 307                _subscriptions: subscriptions,
 308            }
 309        }))
 310    }
 311
 312    fn new_session(
 313        &mut self,
 314        project: Entity<Project>,
 315        cx: &mut Context<Self>,
 316    ) -> Entity<AcpThread> {
 317        // Create Thread
 318        // Fetch default model from registry settings
 319        let registry = LanguageModelRegistry::read_global(cx);
 320        // Log available models for debugging
 321        let available_count = registry.available_models(cx).count();
 322        log::debug!("Total available models: {}", available_count);
 323
 324        let default_model = registry.default_model().and_then(|default_model| {
 325            self.models
 326                .model_from_id(&LanguageModels::model_id(&default_model.model))
 327        });
 328        let thread = cx.new(|cx| {
 329            Thread::new(
 330                project.clone(),
 331                self.project_context.clone(),
 332                self.context_server_registry.clone(),
 333                self.templates.clone(),
 334                default_model,
 335                cx,
 336            )
 337        });
 338
 339        self.register_session(thread, cx)
 340    }
 341
 342    fn register_session(
 343        &mut self,
 344        thread_handle: Entity<Thread>,
 345        cx: &mut Context<Self>,
 346    ) -> Entity<AcpThread> {
 347        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 348
 349        let thread = thread_handle.read(cx);
 350        let session_id = thread.id().clone();
 351        let parent_session_id = thread.parent_thread_id();
 352        let title = thread.title();
 353        let project = thread.project.clone();
 354        let action_log = thread.action_log.clone();
 355        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 356        let acp_thread = cx.new(|cx| {
 357            acp_thread::AcpThread::new(
 358                parent_session_id,
 359                title,
 360                connection,
 361                project.clone(),
 362                action_log.clone(),
 363                session_id.clone(),
 364                prompt_capabilities_rx,
 365                cx,
 366            )
 367        });
 368
 369        let registry = LanguageModelRegistry::read_global(cx);
 370        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 371
 372        let weak = cx.weak_entity();
 373        let weak_thread = thread_handle.downgrade();
 374        thread_handle.update(cx, |thread, cx| {
 375            thread.set_summarization_model(summarization_model, cx);
 376            thread.add_default_tools(
 377                Rc::new(NativeThreadEnvironment {
 378                    acp_thread: acp_thread.downgrade(),
 379                    thread: weak_thread,
 380                    agent: weak,
 381                }) as _,
 382                cx,
 383            )
 384        });
 385
 386        let subscriptions = vec![
 387            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 388            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 389            cx.observe(&thread_handle, move |this, thread, cx| {
 390                this.save_thread(thread, cx)
 391            }),
 392        ];
 393
 394        self.sessions.insert(
 395            session_id,
 396            Session {
 397                thread: thread_handle,
 398                acp_thread: acp_thread.clone(),
 399                _subscriptions: subscriptions,
 400                pending_save: Task::ready(()),
 401            },
 402        );
 403
 404        self.update_available_commands(cx);
 405
 406        acp_thread
 407    }
 408
 409    pub fn models(&self) -> &LanguageModels {
 410        &self.models
 411    }
 412
 413    async fn maintain_project_context(
 414        this: WeakEntity<Self>,
 415        mut needs_refresh: watch::Receiver<()>,
 416        cx: &mut AsyncApp,
 417    ) -> Result<()> {
 418        while needs_refresh.changed().await.is_ok() {
 419            let project_context = this
 420                .update(cx, |this, cx| {
 421                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 422                })?
 423                .await;
 424            this.update(cx, |this, cx| {
 425                this.project_context = cx.new(|_| project_context);
 426            })?;
 427        }
 428
 429        Ok(())
 430    }
 431
 432    fn build_project_context(
 433        project: &Entity<Project>,
 434        prompt_store: Option<&Entity<PromptStore>>,
 435        cx: &mut App,
 436    ) -> Task<ProjectContext> {
 437        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 438        let worktree_tasks = worktrees
 439            .into_iter()
 440            .map(|worktree| {
 441                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 442            })
 443            .collect::<Vec<_>>();
 444        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 445            prompt_store.read_with(cx, |prompt_store, cx| {
 446                let prompts = prompt_store.default_prompt_metadata();
 447                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 448                    let contents = prompt_store.load(prompt_metadata.id, cx);
 449                    async move { (contents.await, prompt_metadata) }
 450                });
 451                cx.background_spawn(future::join_all(load_tasks))
 452            })
 453        } else {
 454            Task::ready(vec![])
 455        };
 456
 457        cx.spawn(async move |_cx| {
 458            let (worktrees, default_user_rules) =
 459                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 460
 461            let worktrees = worktrees
 462                .into_iter()
 463                .map(|(worktree, _rules_error)| {
 464                    // TODO: show error message
 465                    // if let Some(rules_error) = rules_error {
 466                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 467                    // }
 468                    worktree
 469                })
 470                .collect::<Vec<_>>();
 471
 472            let default_user_rules = default_user_rules
 473                .into_iter()
 474                .flat_map(|(contents, prompt_metadata)| match contents {
 475                    Ok(contents) => Some(UserRulesContext {
 476                        uuid: prompt_metadata.id.as_user()?,
 477                        title: prompt_metadata.title.map(|title| title.to_string()),
 478                        contents,
 479                    }),
 480                    Err(_err) => {
 481                        // TODO: show error message
 482                        // this.update(cx, |_, cx| {
 483                        //     cx.emit(RulesLoadingError {
 484                        //         message: format!("{err:?}").into(),
 485                        //     });
 486                        // })
 487                        // .ok();
 488                        None
 489                    }
 490                })
 491                .collect::<Vec<_>>();
 492
 493            ProjectContext::new(worktrees, default_user_rules)
 494        })
 495    }
 496
 497    fn load_worktree_info_for_system_prompt(
 498        worktree: Entity<Worktree>,
 499        project: Entity<Project>,
 500        cx: &mut App,
 501    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 502        let tree = worktree.read(cx);
 503        let root_name = tree.root_name_str().into();
 504        let abs_path = tree.abs_path();
 505
 506        let mut context = WorktreeContext {
 507            root_name,
 508            abs_path,
 509            rules_file: None,
 510        };
 511
 512        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 513        let Some(rules_task) = rules_task else {
 514            return Task::ready((context, None));
 515        };
 516
 517        cx.spawn(async move |_| {
 518            let (rules_file, rules_file_error) = match rules_task.await {
 519                Ok(rules_file) => (Some(rules_file), None),
 520                Err(err) => (
 521                    None,
 522                    Some(RulesLoadingError {
 523                        message: format!("{err}").into(),
 524                    }),
 525                ),
 526            };
 527            context.rules_file = rules_file;
 528            (context, rules_file_error)
 529        })
 530    }
 531
 532    fn load_worktree_rules_file(
 533        worktree: Entity<Worktree>,
 534        project: Entity<Project>,
 535        cx: &mut App,
 536    ) -> Option<Task<Result<RulesFileContext>>> {
 537        let worktree = worktree.read(cx);
 538        let worktree_id = worktree.id();
 539        let selected_rules_file = RULES_FILE_NAMES
 540            .into_iter()
 541            .filter_map(|name| {
 542                worktree
 543                    .entry_for_path(RelPath::unix(name).unwrap())
 544                    .filter(|entry| entry.is_file())
 545                    .map(|entry| entry.path.clone())
 546            })
 547            .next();
 548
 549        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 550        // supported. This doesn't seem to occur often in GitHub repositories.
 551        selected_rules_file.map(|path_in_worktree| {
 552            let project_path = ProjectPath {
 553                worktree_id,
 554                path: path_in_worktree.clone(),
 555            };
 556            let buffer_task =
 557                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 558            let rope_task = cx.spawn(async move |cx| {
 559                let buffer = buffer_task.await?;
 560                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 561                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 562                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 563                })?;
 564                anyhow::Ok((project_entry_id, rope))
 565            });
 566            // Build a string from the rope on a background thread.
 567            cx.background_spawn(async move {
 568                let (project_entry_id, rope) = rope_task.await?;
 569                anyhow::Ok(RulesFileContext {
 570                    path_in_worktree,
 571                    text: rope.to_string().trim().to_string(),
 572                    project_entry_id: project_entry_id.to_usize(),
 573                })
 574            })
 575        })
 576    }
 577
 578    fn handle_thread_title_updated(
 579        &mut self,
 580        thread: Entity<Thread>,
 581        _: &TitleUpdated,
 582        cx: &mut Context<Self>,
 583    ) {
 584        let session_id = thread.read(cx).id();
 585        let Some(session) = self.sessions.get(session_id) else {
 586            return;
 587        };
 588        let thread = thread.downgrade();
 589        let acp_thread = session.acp_thread.downgrade();
 590        cx.spawn(async move |_, cx| {
 591            let title = thread.read_with(cx, |thread, _| thread.title())?;
 592            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 593            task.await
 594        })
 595        .detach_and_log_err(cx);
 596    }
 597
 598    fn handle_thread_token_usage_updated(
 599        &mut self,
 600        thread: Entity<Thread>,
 601        usage: &TokenUsageUpdated,
 602        cx: &mut Context<Self>,
 603    ) {
 604        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 605            return;
 606        };
 607        session.acp_thread.update(cx, |acp_thread, cx| {
 608            acp_thread.update_token_usage(usage.0.clone(), cx);
 609        });
 610    }
 611
 612    fn handle_project_event(
 613        &mut self,
 614        _project: Entity<Project>,
 615        event: &project::Event,
 616        _cx: &mut Context<Self>,
 617    ) {
 618        match event {
 619            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 620                self.project_context_needs_refresh.send(()).ok();
 621            }
 622            project::Event::WorktreeUpdatedEntries(_, items) => {
 623                if items.iter().any(|(path, _, _)| {
 624                    RULES_FILE_NAMES
 625                        .iter()
 626                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 627                }) {
 628                    self.project_context_needs_refresh.send(()).ok();
 629                }
 630            }
 631            _ => {}
 632        }
 633    }
 634
 635    fn handle_prompts_updated_event(
 636        &mut self,
 637        _prompt_store: Entity<PromptStore>,
 638        _event: &prompt_store::PromptsUpdatedEvent,
 639        _cx: &mut Context<Self>,
 640    ) {
 641        self.project_context_needs_refresh.send(()).ok();
 642    }
 643
 644    fn handle_models_updated_event(
 645        &mut self,
 646        _registry: Entity<LanguageModelRegistry>,
 647        _event: &language_model::Event,
 648        cx: &mut Context<Self>,
 649    ) {
 650        self.models.refresh_list(cx);
 651
 652        let registry = LanguageModelRegistry::read_global(cx);
 653        let default_model = registry.default_model().map(|m| m.model);
 654        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 655
 656        for session in self.sessions.values_mut() {
 657            session.thread.update(cx, |thread, cx| {
 658                if thread.model().is_none()
 659                    && let Some(model) = default_model.clone()
 660                {
 661                    thread.set_model(model, cx);
 662                    cx.notify();
 663                }
 664                thread.set_summarization_model(summarization_model.clone(), cx);
 665            });
 666        }
 667    }
 668
 669    fn handle_context_server_store_updated(
 670        &mut self,
 671        _store: Entity<project::context_server_store::ContextServerStore>,
 672        _event: &project::context_server_store::ServerStatusChangedEvent,
 673        cx: &mut Context<Self>,
 674    ) {
 675        self.update_available_commands(cx);
 676    }
 677
 678    fn handle_context_server_registry_event(
 679        &mut self,
 680        _registry: Entity<ContextServerRegistry>,
 681        event: &ContextServerRegistryEvent,
 682        cx: &mut Context<Self>,
 683    ) {
 684        match event {
 685            ContextServerRegistryEvent::ToolsChanged => {}
 686            ContextServerRegistryEvent::PromptsChanged => {
 687                self.update_available_commands(cx);
 688            }
 689        }
 690    }
 691
 692    fn update_available_commands(&self, cx: &mut Context<Self>) {
 693        let available_commands = self.build_available_commands(cx);
 694        for session in self.sessions.values() {
 695            session.acp_thread.update(cx, |thread, cx| {
 696                thread
 697                    .handle_session_update(
 698                        acp::SessionUpdate::AvailableCommandsUpdate(
 699                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 700                        ),
 701                        cx,
 702                    )
 703                    .log_err();
 704            });
 705        }
 706    }
 707
 708    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 709        let registry = self.context_server_registry.read(cx);
 710
 711        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 712        for context_server_prompt in registry.prompts() {
 713            *prompt_name_counts
 714                .entry(context_server_prompt.prompt.name.as_str())
 715                .or_insert(0) += 1;
 716        }
 717
 718        registry
 719            .prompts()
 720            .flat_map(|context_server_prompt| {
 721                let prompt = &context_server_prompt.prompt;
 722
 723                let should_prefix = prompt_name_counts
 724                    .get(prompt.name.as_str())
 725                    .copied()
 726                    .unwrap_or(0)
 727                    > 1;
 728
 729                let name = if should_prefix {
 730                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 731                } else {
 732                    prompt.name.clone()
 733                };
 734
 735                let mut command = acp::AvailableCommand::new(
 736                    name,
 737                    prompt.description.clone().unwrap_or_default(),
 738                );
 739
 740                match prompt.arguments.as_deref() {
 741                    Some([arg]) => {
 742                        let hint = format!("<{}>", arg.name);
 743
 744                        command = command.input(acp::AvailableCommandInput::Unstructured(
 745                            acp::UnstructuredCommandInput::new(hint),
 746                        ));
 747                    }
 748                    Some([]) | None => {}
 749                    Some(_) => {
 750                        // skip >1 argument commands since we don't support them yet
 751                        return None;
 752                    }
 753                }
 754
 755                Some(command)
 756            })
 757            .collect()
 758    }
 759
 760    pub fn load_thread(
 761        &mut self,
 762        id: acp::SessionId,
 763        cx: &mut Context<Self>,
 764    ) -> Task<Result<Entity<Thread>>> {
 765        let database_future = ThreadsDatabase::connect(cx);
 766        cx.spawn(async move |this, cx| {
 767            let database = database_future.await.map_err(|err| anyhow!(err))?;
 768            let db_thread = database
 769                .load_thread(id.clone())
 770                .await?
 771                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 772
 773            this.update(cx, |this, cx| {
 774                let summarization_model = LanguageModelRegistry::read_global(cx)
 775                    .thread_summary_model()
 776                    .map(|c| c.model);
 777
 778                cx.new(|cx| {
 779                    let mut thread = Thread::from_db(
 780                        id.clone(),
 781                        db_thread,
 782                        this.project.clone(),
 783                        this.project_context.clone(),
 784                        this.context_server_registry.clone(),
 785                        this.templates.clone(),
 786                        cx,
 787                    );
 788                    thread.set_summarization_model(summarization_model, cx);
 789                    thread
 790                })
 791            })
 792        })
 793    }
 794
 795    pub fn open_thread(
 796        &mut self,
 797        id: acp::SessionId,
 798        cx: &mut Context<Self>,
 799    ) -> Task<Result<Entity<AcpThread>>> {
 800        if let Some(session) = self.sessions.get(&id) {
 801            return Task::ready(Ok(session.acp_thread.clone()));
 802        }
 803
 804        let task = self.load_thread(id, cx);
 805        cx.spawn(async move |this, cx| {
 806            let thread = task.await?;
 807            let acp_thread =
 808                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 809            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 810            cx.update(|cx| {
 811                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 812            })
 813            .await?;
 814            Ok(acp_thread)
 815        })
 816    }
 817
 818    pub fn thread_summary(
 819        &mut self,
 820        id: acp::SessionId,
 821        cx: &mut Context<Self>,
 822    ) -> Task<Result<SharedString>> {
 823        let thread = self.open_thread(id.clone(), cx);
 824        cx.spawn(async move |this, cx| {
 825            let acp_thread = thread.await?;
 826            let result = this
 827                .update(cx, |this, cx| {
 828                    this.sessions
 829                        .get(&id)
 830                        .unwrap()
 831                        .thread
 832                        .update(cx, |thread, cx| thread.summary(cx))
 833                })?
 834                .await
 835                .context("Failed to generate summary")?;
 836            drop(acp_thread);
 837            Ok(result)
 838        })
 839    }
 840
 841    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 842        if thread.read(cx).is_empty() {
 843            return;
 844        }
 845
 846        let database_future = ThreadsDatabase::connect(cx);
 847        let (id, db_thread) =
 848            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 849        let Some(session) = self.sessions.get_mut(&id) else {
 850            return;
 851        };
 852
 853        let folder_paths = PathList::new(
 854            &self
 855                .project
 856                .read(cx)
 857                .visible_worktrees(cx)
 858                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 859                .collect::<Vec<_>>(),
 860        );
 861
 862        let thread_store = self.thread_store.clone();
 863        session.pending_save = cx.spawn(async move |_, cx| {
 864            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 865                return;
 866            };
 867            let db_thread = db_thread.await;
 868            database
 869                .save_thread(id, db_thread, folder_paths)
 870                .await
 871                .log_err();
 872            thread_store.update(cx, |store, cx| store.reload(cx));
 873        });
 874    }
 875
 876    fn send_mcp_prompt(
 877        &self,
 878        message_id: UserMessageId,
 879        session_id: agent_client_protocol::SessionId,
 880        prompt_name: String,
 881        server_id: ContextServerId,
 882        arguments: HashMap<String, String>,
 883        original_content: Vec<acp::ContentBlock>,
 884        cx: &mut Context<Self>,
 885    ) -> Task<Result<acp::PromptResponse>> {
 886        let server_store = self.context_server_registry.read(cx).server_store().clone();
 887        let path_style = self.project.read(cx).path_style(cx);
 888
 889        cx.spawn(async move |this, cx| {
 890            let prompt =
 891                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 892
 893            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 894                let session = this
 895                    .sessions
 896                    .get(&session_id)
 897                    .context("Failed to get session")?;
 898                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 899            })??;
 900
 901            let mut last_is_user = true;
 902
 903            thread.update(cx, |thread, cx| {
 904                thread.push_acp_user_block(
 905                    message_id,
 906                    original_content.into_iter().skip(1),
 907                    path_style,
 908                    cx,
 909                );
 910            });
 911
 912            for message in prompt.messages {
 913                let context_server::types::PromptMessage { role, content } = message;
 914                let block = mcp_message_content_to_acp_content_block(content);
 915
 916                match role {
 917                    context_server::types::Role::User => {
 918                        let id = acp_thread::UserMessageId::new();
 919
 920                        acp_thread.update(cx, |acp_thread, cx| {
 921                            acp_thread.push_user_content_block_with_indent(
 922                                Some(id.clone()),
 923                                block.clone(),
 924                                true,
 925                                cx,
 926                            );
 927                        });
 928
 929                        thread.update(cx, |thread, cx| {
 930                            thread.push_acp_user_block(id, [block], path_style, cx);
 931                        });
 932                    }
 933                    context_server::types::Role::Assistant => {
 934                        acp_thread.update(cx, |acp_thread, cx| {
 935                            acp_thread.push_assistant_content_block_with_indent(
 936                                block.clone(),
 937                                false,
 938                                true,
 939                                cx,
 940                            );
 941                        });
 942
 943                        thread.update(cx, |thread, cx| {
 944                            thread.push_acp_agent_block(block, cx);
 945                        });
 946                    }
 947                }
 948
 949                last_is_user = role == context_server::types::Role::User;
 950            }
 951
 952            let response_stream = thread.update(cx, |thread, cx| {
 953                if last_is_user {
 954                    thread.send_existing(cx)
 955                } else {
 956                    // Resume if MCP prompt did not end with a user message
 957                    thread.resume(cx)
 958                }
 959            })?;
 960
 961            cx.update(|cx| {
 962                NativeAgentConnection::handle_thread_events(
 963                    response_stream,
 964                    acp_thread.downgrade(),
 965                    cx,
 966                )
 967            })
 968            .await
 969        })
 970    }
 971}
 972
 973/// Wrapper struct that implements the AgentConnection trait
 974#[derive(Clone)]
 975pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 976
 977impl NativeAgentConnection {
 978    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 979        self.0
 980            .read(cx)
 981            .sessions
 982            .get(session_id)
 983            .map(|session| session.thread.clone())
 984    }
 985
 986    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 987        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 988    }
 989
 990    fn run_turn(
 991        &self,
 992        session_id: acp::SessionId,
 993        cx: &mut App,
 994        f: impl 'static
 995        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 996    ) -> Task<Result<acp::PromptResponse>> {
 997        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 998            agent
 999                .sessions
1000                .get_mut(&session_id)
1001                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1002        }) else {
1003            return Task::ready(Err(anyhow!("Session not found")));
1004        };
1005        log::debug!("Found session for: {}", session_id);
1006
1007        let response_stream = match f(thread, cx) {
1008            Ok(stream) => stream,
1009            Err(err) => return Task::ready(Err(err)),
1010        };
1011        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1012    }
1013
1014    fn handle_thread_events(
1015        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1016        acp_thread: WeakEntity<AcpThread>,
1017        cx: &App,
1018    ) -> Task<Result<acp::PromptResponse>> {
1019        cx.spawn(async move |cx| {
1020            // Handle response stream and forward to session.acp_thread
1021            while let Some(result) = events.next().await {
1022                match result {
1023                    Ok(event) => {
1024                        log::trace!("Received completion event: {:?}", event);
1025
1026                        match event {
1027                            ThreadEvent::UserMessage(message) => {
1028                                acp_thread.update(cx, |thread, cx| {
1029                                    for content in message.content {
1030                                        thread.push_user_content_block(
1031                                            Some(message.id.clone()),
1032                                            content.into(),
1033                                            cx,
1034                                        );
1035                                    }
1036                                })?;
1037                            }
1038                            ThreadEvent::AgentText(text) => {
1039                                acp_thread.update(cx, |thread, cx| {
1040                                    thread.push_assistant_content_block(text.into(), false, cx)
1041                                })?;
1042                            }
1043                            ThreadEvent::AgentThinking(text) => {
1044                                acp_thread.update(cx, |thread, cx| {
1045                                    thread.push_assistant_content_block(text.into(), true, cx)
1046                                })?;
1047                            }
1048                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1049                                tool_call,
1050                                options,
1051                                response,
1052                                context: _,
1053                            }) => {
1054                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1055                                    thread.request_tool_call_authorization(tool_call, options, cx)
1056                                })??;
1057                                cx.background_spawn(async move {
1058                                    if let acp::RequestPermissionOutcome::Selected(
1059                                        acp::SelectedPermissionOutcome { option_id, .. },
1060                                    ) = outcome_task.await
1061                                    {
1062                                        response
1063                                            .send(option_id)
1064                                            .map(|_| anyhow!("authorization receiver was dropped"))
1065                                            .log_err();
1066                                    }
1067                                })
1068                                .detach();
1069                            }
1070                            ThreadEvent::ToolCall(tool_call) => {
1071                                acp_thread.update(cx, |thread, cx| {
1072                                    thread.upsert_tool_call(tool_call, cx)
1073                                })??;
1074                            }
1075                            ThreadEvent::ToolCallUpdate(update) => {
1076                                acp_thread.update(cx, |thread, cx| {
1077                                    thread.update_tool_call(update, cx)
1078                                })??;
1079                            }
1080                            ThreadEvent::SubagentSpawned(session_id) => {
1081                                acp_thread.update(cx, |thread, cx| {
1082                                    thread.subagent_spawned(session_id, cx);
1083                                })?;
1084                            }
1085                            ThreadEvent::Retry(status) => {
1086                                acp_thread.update(cx, |thread, cx| {
1087                                    thread.update_retry_status(status, cx)
1088                                })?;
1089                            }
1090                            ThreadEvent::Stop(stop_reason) => {
1091                                log::debug!("Assistant message complete: {:?}", stop_reason);
1092                                return Ok(acp::PromptResponse::new(stop_reason));
1093                            }
1094                        }
1095                    }
1096                    Err(e) => {
1097                        log::error!("Error in model response stream: {:?}", e);
1098                        return Err(e);
1099                    }
1100                }
1101            }
1102
1103            log::debug!("Response stream completed");
1104            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1105        })
1106    }
1107}
1108
1109struct Command<'a> {
1110    prompt_name: &'a str,
1111    arg_value: &'a str,
1112    explicit_server_id: Option<&'a str>,
1113}
1114
1115impl<'a> Command<'a> {
1116    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1117        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1118            return None;
1119        };
1120        let text = text_content.text.trim();
1121        let command = text.strip_prefix('/')?;
1122        let (command, arg_value) = command
1123            .split_once(char::is_whitespace)
1124            .unwrap_or((command, ""));
1125
1126        if let Some((server_id, prompt_name)) = command.split_once('.') {
1127            Some(Self {
1128                prompt_name,
1129                arg_value,
1130                explicit_server_id: Some(server_id),
1131            })
1132        } else {
1133            Some(Self {
1134                prompt_name: command,
1135                arg_value,
1136                explicit_server_id: None,
1137            })
1138        }
1139    }
1140}
1141
1142struct NativeAgentModelSelector {
1143    session_id: acp::SessionId,
1144    connection: NativeAgentConnection,
1145}
1146
1147impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1148    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1149        log::debug!("NativeAgentConnection::list_models called");
1150        let list = self.connection.0.read(cx).models.model_list.clone();
1151        Task::ready(if list.is_empty() {
1152            Err(anyhow::anyhow!("No models available"))
1153        } else {
1154            Ok(list)
1155        })
1156    }
1157
1158    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1159        log::debug!(
1160            "Setting model for session {}: {}",
1161            self.session_id,
1162            model_id
1163        );
1164        let Some(thread) = self
1165            .connection
1166            .0
1167            .read(cx)
1168            .sessions
1169            .get(&self.session_id)
1170            .map(|session| session.thread.clone())
1171        else {
1172            return Task::ready(Err(anyhow!("Session not found")));
1173        };
1174
1175        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1176            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1177        };
1178
1179        // We want to reset the effort level when switching models, as the currently-selected effort level may
1180        // not be compatible.
1181        let effort = model
1182            .default_effort_level()
1183            .map(|effort_level| effort_level.value.to_string());
1184
1185        thread.update(cx, |thread, cx| {
1186            thread.set_model(model.clone(), cx);
1187            thread.set_thinking_effort(effort.clone(), cx);
1188            thread.set_thinking_enabled(model.supports_thinking(), cx);
1189        });
1190
1191        update_settings_file(
1192            self.connection.0.read(cx).fs.clone(),
1193            cx,
1194            move |settings, cx| {
1195                let provider = model.provider_id().0.to_string();
1196                let model = model.id().0.to_string();
1197                let enable_thinking = thread.read(cx).thinking_enabled();
1198                settings
1199                    .agent
1200                    .get_or_insert_default()
1201                    .set_model(LanguageModelSelection {
1202                        provider: provider.into(),
1203                        model,
1204                        enable_thinking,
1205                        effort,
1206                    });
1207            },
1208        );
1209
1210        Task::ready(Ok(()))
1211    }
1212
1213    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1214        let Some(thread) = self
1215            .connection
1216            .0
1217            .read(cx)
1218            .sessions
1219            .get(&self.session_id)
1220            .map(|session| session.thread.clone())
1221        else {
1222            return Task::ready(Err(anyhow!("Session not found")));
1223        };
1224        let Some(model) = thread.read(cx).model() else {
1225            return Task::ready(Err(anyhow!("Model not found")));
1226        };
1227        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1228        else {
1229            return Task::ready(Err(anyhow!("Provider not found")));
1230        };
1231        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1232            model, &provider,
1233        )))
1234    }
1235
1236    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1237        Some(self.connection.0.read(cx).models.watch())
1238    }
1239
1240    fn should_render_footer(&self) -> bool {
1241        true
1242    }
1243}
1244
1245impl acp_thread::AgentConnection for NativeAgentConnection {
1246    fn telemetry_id(&self) -> SharedString {
1247        "zed".into()
1248    }
1249
1250    fn new_session(
1251        self: Rc<Self>,
1252        project: Entity<Project>,
1253        cwd: &Path,
1254        cx: &mut App,
1255    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1256        log::debug!("Creating new thread for project at: {cwd:?}");
1257        Task::ready(Ok(self
1258            .0
1259            .update(cx, |agent, cx| agent.new_session(project, cx))))
1260    }
1261
1262    fn supports_load_session(&self) -> bool {
1263        true
1264    }
1265
1266    fn load_session(
1267        self: Rc<Self>,
1268        session: AgentSessionInfo,
1269        _project: Entity<Project>,
1270        _cwd: &Path,
1271        cx: &mut App,
1272    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1273        self.0
1274            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1275    }
1276
1277    fn supports_close_session(&self) -> bool {
1278        true
1279    }
1280
1281    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1282        self.0.update(cx, |agent, _cx| {
1283            agent.sessions.remove(session_id);
1284        });
1285        Task::ready(Ok(()))
1286    }
1287
1288    fn auth_methods(&self) -> &[acp::AuthMethod] {
1289        &[] // No auth for in-process
1290    }
1291
1292    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1293        Task::ready(Ok(()))
1294    }
1295
1296    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1297        Some(Rc::new(NativeAgentModelSelector {
1298            session_id: session_id.clone(),
1299            connection: self.clone(),
1300        }) as Rc<dyn AgentModelSelector>)
1301    }
1302
1303    fn prompt(
1304        &self,
1305        id: Option<acp_thread::UserMessageId>,
1306        params: acp::PromptRequest,
1307        cx: &mut App,
1308    ) -> Task<Result<acp::PromptResponse>> {
1309        let id = id.expect("UserMessageId is required");
1310        let session_id = params.session_id.clone();
1311        log::info!("Received prompt request for session: {}", session_id);
1312        log::debug!("Prompt blocks count: {}", params.prompt.len());
1313
1314        if let Some(parsed_command) = Command::parse(&params.prompt) {
1315            let registry = self.0.read(cx).context_server_registry.read(cx);
1316
1317            let explicit_server_id = parsed_command
1318                .explicit_server_id
1319                .map(|server_id| ContextServerId(server_id.into()));
1320
1321            if let Some(prompt) =
1322                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1323            {
1324                let arguments = if !parsed_command.arg_value.is_empty()
1325                    && let Some(arg_name) = prompt
1326                        .prompt
1327                        .arguments
1328                        .as_ref()
1329                        .and_then(|args| args.first())
1330                        .map(|arg| arg.name.clone())
1331                {
1332                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1333                } else {
1334                    Default::default()
1335                };
1336
1337                let prompt_name = prompt.prompt.name.clone();
1338                let server_id = prompt.server_id.clone();
1339
1340                return self.0.update(cx, |agent, cx| {
1341                    agent.send_mcp_prompt(
1342                        id,
1343                        session_id.clone(),
1344                        prompt_name,
1345                        server_id,
1346                        arguments,
1347                        params.prompt,
1348                        cx,
1349                    )
1350                });
1351            };
1352        };
1353
1354        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1355
1356        self.run_turn(session_id, cx, move |thread, cx| {
1357            let content: Vec<UserMessageContent> = params
1358                .prompt
1359                .into_iter()
1360                .map(|block| UserMessageContent::from_content_block(block, path_style))
1361                .collect::<Vec<_>>();
1362            log::debug!("Converted prompt to message: {} chars", content.len());
1363            log::debug!("Message id: {:?}", id);
1364            log::debug!("Message content: {:?}", content);
1365
1366            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1367        })
1368    }
1369
1370    fn retry(
1371        &self,
1372        session_id: &acp::SessionId,
1373        _cx: &App,
1374    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1375        Some(Rc::new(NativeAgentSessionRetry {
1376            connection: self.clone(),
1377            session_id: session_id.clone(),
1378        }) as _)
1379    }
1380
1381    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1382        log::info!("Cancelling on session: {}", session_id);
1383        self.0.update(cx, |agent, cx| {
1384            if let Some(session) = agent.sessions.get(session_id) {
1385                session
1386                    .thread
1387                    .update(cx, |thread, cx| thread.cancel(cx))
1388                    .detach();
1389            }
1390        });
1391    }
1392
1393    fn truncate(
1394        &self,
1395        session_id: &agent_client_protocol::SessionId,
1396        cx: &App,
1397    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1398        self.0.read_with(cx, |agent, _cx| {
1399            agent.sessions.get(session_id).map(|session| {
1400                Rc::new(NativeAgentSessionTruncate {
1401                    thread: session.thread.clone(),
1402                    acp_thread: session.acp_thread.downgrade(),
1403                }) as _
1404            })
1405        })
1406    }
1407
1408    fn set_title(
1409        &self,
1410        session_id: &acp::SessionId,
1411        cx: &App,
1412    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1413        self.0.read_with(cx, |agent, _cx| {
1414            agent
1415                .sessions
1416                .get(session_id)
1417                .filter(|s| !s.thread.read(cx).is_subagent())
1418                .map(|session| {
1419                    Rc::new(NativeAgentSessionSetTitle {
1420                        thread: session.thread.clone(),
1421                    }) as _
1422                })
1423        })
1424    }
1425
1426    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1427        let thread_store = self.0.read(cx).thread_store.clone();
1428        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1429    }
1430
1431    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1432        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1433    }
1434
1435    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1436        self
1437    }
1438}
1439
1440impl acp_thread::AgentTelemetry for NativeAgentConnection {
1441    fn thread_data(
1442        &self,
1443        session_id: &acp::SessionId,
1444        cx: &mut App,
1445    ) -> Task<Result<serde_json::Value>> {
1446        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1447            return Task::ready(Err(anyhow!("Session not found")));
1448        };
1449
1450        let task = session.thread.read(cx).to_db(cx);
1451        cx.background_spawn(async move {
1452            serde_json::to_value(task.await).context("Failed to serialize thread")
1453        })
1454    }
1455}
1456
1457pub struct NativeAgentSessionList {
1458    thread_store: Entity<ThreadStore>,
1459    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1460    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1461    _subscription: Subscription,
1462}
1463
1464impl NativeAgentSessionList {
1465    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1466        let (tx, rx) = smol::channel::unbounded();
1467        let this_tx = tx.clone();
1468        let subscription = cx.observe(&thread_store, move |_, _| {
1469            this_tx
1470                .try_send(acp_thread::SessionListUpdate::Refresh)
1471                .ok();
1472        });
1473        Self {
1474            thread_store,
1475            updates_tx: tx,
1476            updates_rx: rx,
1477            _subscription: subscription,
1478        }
1479    }
1480
1481    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1482        AgentSessionInfo {
1483            session_id: entry.id,
1484            cwd: None,
1485            title: Some(entry.title),
1486            updated_at: Some(entry.updated_at),
1487            meta: None,
1488        }
1489    }
1490
1491    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1492        &self.thread_store
1493    }
1494}
1495
1496impl AgentSessionList for NativeAgentSessionList {
1497    fn list_sessions(
1498        &self,
1499        _request: AgentSessionListRequest,
1500        cx: &mut App,
1501    ) -> Task<Result<AgentSessionListResponse>> {
1502        let sessions = self
1503            .thread_store
1504            .read(cx)
1505            .entries()
1506            .map(Self::to_session_info)
1507            .collect();
1508        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1509    }
1510
1511    fn supports_delete(&self) -> bool {
1512        true
1513    }
1514
1515    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1516        self.thread_store
1517            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1518    }
1519
1520    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1521        self.thread_store
1522            .update(cx, |store, cx| store.delete_threads(cx))
1523    }
1524
1525    fn watch(
1526        &self,
1527        _cx: &mut App,
1528    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1529        Some(self.updates_rx.clone())
1530    }
1531
1532    fn notify_refresh(&self) {
1533        self.updates_tx
1534            .try_send(acp_thread::SessionListUpdate::Refresh)
1535            .ok();
1536    }
1537
1538    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1539        self
1540    }
1541}
1542
1543struct NativeAgentSessionTruncate {
1544    thread: Entity<Thread>,
1545    acp_thread: WeakEntity<AcpThread>,
1546}
1547
1548impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1549    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1550        match self.thread.update(cx, |thread, cx| {
1551            thread.truncate(message_id.clone(), cx)?;
1552            Ok(thread.latest_token_usage())
1553        }) {
1554            Ok(usage) => {
1555                self.acp_thread
1556                    .update(cx, |thread, cx| {
1557                        thread.update_token_usage(usage, cx);
1558                    })
1559                    .ok();
1560                Task::ready(Ok(()))
1561            }
1562            Err(error) => Task::ready(Err(error)),
1563        }
1564    }
1565}
1566
1567struct NativeAgentSessionRetry {
1568    connection: NativeAgentConnection,
1569    session_id: acp::SessionId,
1570}
1571
1572impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1573    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1574        self.connection
1575            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1576                thread.update(cx, |thread, cx| thread.resume(cx))
1577            })
1578    }
1579}
1580
1581struct NativeAgentSessionSetTitle {
1582    thread: Entity<Thread>,
1583}
1584
1585impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1586    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1587        self.thread
1588            .update(cx, |thread, cx| thread.set_title(title, cx));
1589        Task::ready(Ok(()))
1590    }
1591}
1592
1593pub struct NativeThreadEnvironment {
1594    agent: WeakEntity<NativeAgent>,
1595    thread: WeakEntity<Thread>,
1596    acp_thread: WeakEntity<AcpThread>,
1597}
1598
1599impl NativeThreadEnvironment {
1600    pub(crate) fn create_subagent_thread(
1601        &self,
1602        label: String,
1603        cx: &mut App,
1604    ) -> Result<Rc<dyn SubagentHandle>> {
1605        let Some(parent_thread_entity) = self.thread.upgrade() else {
1606            anyhow::bail!("Parent thread no longer exists".to_string());
1607        };
1608        let parent_thread = parent_thread_entity.read(cx);
1609        let current_depth = parent_thread.depth();
1610
1611        if current_depth >= MAX_SUBAGENT_DEPTH {
1612            return Err(anyhow!(
1613                "Maximum subagent depth ({}) reached",
1614                MAX_SUBAGENT_DEPTH
1615            ));
1616        }
1617
1618        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1619            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1620            thread.set_title(label.into(), cx);
1621            thread
1622        });
1623
1624        let session_id = subagent_thread.read(cx).id().clone();
1625
1626        let acp_thread = self.agent.update(cx, |agent, cx| {
1627            agent.register_session(subagent_thread.clone(), cx)
1628        })?;
1629
1630        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1631    }
1632
1633    pub(crate) fn resume_subagent_thread(
1634        &self,
1635        session_id: acp::SessionId,
1636        cx: &mut App,
1637    ) -> Result<Rc<dyn SubagentHandle>> {
1638        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1639            let session = agent
1640                .sessions
1641                .get(&session_id)
1642                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1643            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1644        })??;
1645
1646        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1647    }
1648
1649    fn prompt_subagent(
1650        &self,
1651        session_id: acp::SessionId,
1652        subagent_thread: Entity<Thread>,
1653        acp_thread: Entity<acp_thread::AcpThread>,
1654    ) -> Result<Rc<dyn SubagentHandle>> {
1655        let Some(parent_thread_entity) = self.thread.upgrade() else {
1656            anyhow::bail!("Parent thread no longer exists".to_string());
1657        };
1658        Ok(Rc::new(NativeSubagentHandle::new(
1659            session_id,
1660            subagent_thread,
1661            acp_thread,
1662            parent_thread_entity,
1663        )) as _)
1664    }
1665}
1666
1667impl ThreadEnvironment for NativeThreadEnvironment {
1668    fn create_terminal(
1669        &self,
1670        command: String,
1671        cwd: Option<PathBuf>,
1672        output_byte_limit: Option<u64>,
1673        cx: &mut AsyncApp,
1674    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1675        let task = self.acp_thread.update(cx, |thread, cx| {
1676            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1677        });
1678
1679        let acp_thread = self.acp_thread.clone();
1680        cx.spawn(async move |cx| {
1681            let terminal = task?.await?;
1682
1683            let (drop_tx, drop_rx) = oneshot::channel();
1684            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1685
1686            cx.spawn(async move |cx| {
1687                drop_rx.await.ok();
1688                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1689            })
1690            .detach();
1691
1692            let handle = AcpTerminalHandle {
1693                terminal,
1694                _drop_tx: Some(drop_tx),
1695            };
1696
1697            Ok(Rc::new(handle) as _)
1698        })
1699    }
1700
1701    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1702        self.create_subagent_thread(label, cx)
1703    }
1704
1705    fn resume_subagent(
1706        &self,
1707        session_id: acp::SessionId,
1708        cx: &mut App,
1709    ) -> Result<Rc<dyn SubagentHandle>> {
1710        self.resume_subagent_thread(session_id, cx)
1711    }
1712}
1713
1714#[derive(Debug, Clone)]
1715enum SubagentPromptResult {
1716    Completed,
1717    Cancelled,
1718    ContextWindowWarning,
1719    Error(String),
1720}
1721
1722pub struct NativeSubagentHandle {
1723    session_id: acp::SessionId,
1724    parent_thread: WeakEntity<Thread>,
1725    subagent_thread: Entity<Thread>,
1726    acp_thread: Entity<acp_thread::AcpThread>,
1727}
1728
1729impl NativeSubagentHandle {
1730    fn new(
1731        session_id: acp::SessionId,
1732        subagent_thread: Entity<Thread>,
1733        acp_thread: Entity<acp_thread::AcpThread>,
1734        parent_thread_entity: Entity<Thread>,
1735    ) -> Self {
1736        NativeSubagentHandle {
1737            session_id,
1738            subagent_thread,
1739            parent_thread: parent_thread_entity.downgrade(),
1740            acp_thread,
1741        }
1742    }
1743}
1744
1745impl SubagentHandle for NativeSubagentHandle {
1746    fn id(&self) -> acp::SessionId {
1747        self.session_id.clone()
1748    }
1749
1750    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1751        let thread = self.subagent_thread.clone();
1752        let acp_thread = self.acp_thread.clone();
1753        let subagent_session_id = self.session_id.clone();
1754        let parent_thread = self.parent_thread.clone();
1755
1756        cx.spawn(async move |cx| {
1757            let (task, _subscription) = cx.update(|cx| {
1758                let ratio_before_prompt = thread
1759                    .read(cx)
1760                    .latest_token_usage()
1761                    .map(|usage| usage.ratio());
1762
1763                parent_thread
1764                    .update(cx, |parent_thread, _cx| {
1765                        parent_thread.register_running_subagent(thread.downgrade())
1766                    })
1767                    .ok();
1768
1769                let task = acp_thread.update(cx, |acp_thread, cx| {
1770                    acp_thread.send(vec![message.into()], cx)
1771                });
1772
1773                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1774                let mut token_limit_tx = Some(token_limit_tx);
1775
1776                let subscription = cx.subscribe(
1777                    &thread,
1778                    move |_thread, event: &TokenUsageUpdated, _cx| {
1779                        if let Some(usage) = &event.0 {
1780                            let old_ratio = ratio_before_prompt
1781                                .clone()
1782                                .unwrap_or(TokenUsageRatio::Normal);
1783                            let new_ratio = usage.ratio();
1784                            if old_ratio == TokenUsageRatio::Normal
1785                                && new_ratio == TokenUsageRatio::Warning
1786                            {
1787                                if let Some(tx) = token_limit_tx.take() {
1788                                    tx.send(()).ok();
1789                                }
1790                            }
1791                        }
1792                    },
1793                );
1794
1795                let wait_for_prompt = cx
1796                    .background_spawn(async move {
1797                        futures::select! {
1798                            response = task.fuse() => match response {
1799                                Ok(Some(response)) => {
1800                                    match response.stop_reason {
1801                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1802                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1803                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1804                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1805                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1806                                    }
1807                                }
1808                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1809                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1810                            },
1811                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1812                        }
1813                    });
1814
1815                (wait_for_prompt, subscription)
1816            });
1817
1818            let result = match task.await {
1819                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1820                    thread
1821                        .last_message()
1822                        .map(|m| m.to_markdown())
1823                        .context("No response from subagent")
1824                }),
1825                SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1826                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1827                SubagentPromptResult::ContextWindowWarning => {
1828                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1829                    Err(anyhow!(
1830                        "The agent is nearing the end of its context window and has been \
1831                         stopped. You can prompt the thread again to have the agent wrap up \
1832                         or hand off its work."
1833                    ))
1834                }
1835            };
1836
1837            parent_thread
1838                .update(cx, |parent_thread, cx| {
1839                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1840                })
1841                .ok();
1842
1843            result
1844        })
1845    }
1846}
1847
1848pub struct AcpTerminalHandle {
1849    terminal: Entity<acp_thread::Terminal>,
1850    _drop_tx: Option<oneshot::Sender<()>>,
1851}
1852
1853impl TerminalHandle for AcpTerminalHandle {
1854    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1855        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1856    }
1857
1858    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1859        Ok(self
1860            .terminal
1861            .read_with(cx, |term, _cx| term.wait_for_exit()))
1862    }
1863
1864    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1865        Ok(self
1866            .terminal
1867            .read_with(cx, |term, cx| term.current_output(cx)))
1868    }
1869
1870    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1871        cx.update(|cx| {
1872            self.terminal.update(cx, |terminal, cx| {
1873                terminal.kill(cx);
1874            });
1875        });
1876        Ok(())
1877    }
1878
1879    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1880        Ok(self
1881            .terminal
1882            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1883    }
1884}
1885
1886#[cfg(test)]
1887mod internal_tests {
1888    use super::*;
1889    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1890    use fs::FakeFs;
1891    use gpui::TestAppContext;
1892    use indoc::formatdoc;
1893    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1894    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1895    use serde_json::json;
1896    use settings::SettingsStore;
1897    use util::{path, rel_path::rel_path};
1898
1899    #[gpui::test]
1900    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1901        init_test(cx);
1902        let fs = FakeFs::new(cx.executor());
1903        fs.insert_tree(
1904            "/",
1905            json!({
1906                "a": {}
1907            }),
1908        )
1909        .await;
1910        let project = Project::test(fs.clone(), [], cx).await;
1911        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1912        let agent = NativeAgent::new(
1913            project.clone(),
1914            thread_store,
1915            Templates::new(),
1916            None,
1917            fs.clone(),
1918            &mut cx.to_async(),
1919        )
1920        .await
1921        .unwrap();
1922        agent.read_with(cx, |agent, cx| {
1923            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1924        });
1925
1926        let worktree = project
1927            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1928            .await
1929            .unwrap();
1930        cx.run_until_parked();
1931        agent.read_with(cx, |agent, cx| {
1932            assert_eq!(
1933                agent.project_context.read(cx).worktrees,
1934                vec![WorktreeContext {
1935                    root_name: "a".into(),
1936                    abs_path: Path::new("/a").into(),
1937                    rules_file: None
1938                }]
1939            )
1940        });
1941
1942        // Creating `/a/.rules` updates the project context.
1943        fs.insert_file("/a/.rules", Vec::new()).await;
1944        cx.run_until_parked();
1945        agent.read_with(cx, |agent, cx| {
1946            let rules_entry = worktree
1947                .read(cx)
1948                .entry_for_path(rel_path(".rules"))
1949                .unwrap();
1950            assert_eq!(
1951                agent.project_context.read(cx).worktrees,
1952                vec![WorktreeContext {
1953                    root_name: "a".into(),
1954                    abs_path: Path::new("/a").into(),
1955                    rules_file: Some(RulesFileContext {
1956                        path_in_worktree: rel_path(".rules").into(),
1957                        text: "".into(),
1958                        project_entry_id: rules_entry.id.to_usize()
1959                    })
1960                }]
1961            )
1962        });
1963    }
1964
1965    #[gpui::test]
1966    async fn test_listing_models(cx: &mut TestAppContext) {
1967        init_test(cx);
1968        let fs = FakeFs::new(cx.executor());
1969        fs.insert_tree("/", json!({ "a": {}  })).await;
1970        let project = Project::test(fs.clone(), [], cx).await;
1971        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1972        let connection = NativeAgentConnection(
1973            NativeAgent::new(
1974                project.clone(),
1975                thread_store,
1976                Templates::new(),
1977                None,
1978                fs.clone(),
1979                &mut cx.to_async(),
1980            )
1981            .await
1982            .unwrap(),
1983        );
1984
1985        // Create a thread/session
1986        let acp_thread = cx
1987            .update(|cx| {
1988                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1989            })
1990            .await
1991            .unwrap();
1992
1993        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1994
1995        let models = cx
1996            .update(|cx| {
1997                connection
1998                    .model_selector(&session_id)
1999                    .unwrap()
2000                    .list_models(cx)
2001            })
2002            .await
2003            .unwrap();
2004
2005        let acp_thread::AgentModelList::Grouped(models) = models else {
2006            panic!("Unexpected model group");
2007        };
2008        assert_eq!(
2009            models,
2010            IndexMap::from_iter([(
2011                AgentModelGroupName("Fake".into()),
2012                vec![AgentModelInfo {
2013                    id: acp::ModelId::new("fake/fake"),
2014                    name: "Fake".into(),
2015                    description: None,
2016                    icon: Some(acp_thread::AgentModelIcon::Named(
2017                        ui::IconName::ZedAssistant
2018                    )),
2019                    is_latest: false,
2020                    cost: None,
2021                }]
2022            )])
2023        );
2024    }
2025
2026    #[gpui::test]
2027    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2028        init_test(cx);
2029        let fs = FakeFs::new(cx.executor());
2030        fs.create_dir(paths::settings_file().parent().unwrap())
2031            .await
2032            .unwrap();
2033        fs.insert_file(
2034            paths::settings_file(),
2035            json!({
2036                "agent": {
2037                    "default_model": {
2038                        "provider": "foo",
2039                        "model": "bar"
2040                    }
2041                }
2042            })
2043            .to_string()
2044            .into_bytes(),
2045        )
2046        .await;
2047        let project = Project::test(fs.clone(), [], cx).await;
2048
2049        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2050
2051        // Create the agent and connection
2052        let agent = NativeAgent::new(
2053            project.clone(),
2054            thread_store,
2055            Templates::new(),
2056            None,
2057            fs.clone(),
2058            &mut cx.to_async(),
2059        )
2060        .await
2061        .unwrap();
2062        let connection = NativeAgentConnection(agent.clone());
2063
2064        // Create a thread/session
2065        let acp_thread = cx
2066            .update(|cx| {
2067                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2068            })
2069            .await
2070            .unwrap();
2071
2072        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2073
2074        // Select a model
2075        let selector = connection.model_selector(&session_id).unwrap();
2076        let model_id = acp::ModelId::new("fake/fake");
2077        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2078            .await
2079            .unwrap();
2080
2081        // Verify the thread has the selected model
2082        agent.read_with(cx, |agent, _| {
2083            let session = agent.sessions.get(&session_id).unwrap();
2084            session.thread.read_with(cx, |thread, _| {
2085                assert_eq!(thread.model().unwrap().id().0, "fake");
2086            });
2087        });
2088
2089        cx.run_until_parked();
2090
2091        // Verify settings file was updated
2092        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2093        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2094
2095        // Check that the agent settings contain the selected model
2096        assert_eq!(
2097            settings_json["agent"]["default_model"]["model"],
2098            json!("fake")
2099        );
2100        assert_eq!(
2101            settings_json["agent"]["default_model"]["provider"],
2102            json!("fake")
2103        );
2104
2105        // Register a thinking model and select it.
2106        cx.update(|cx| {
2107            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2108                "fake-corp",
2109                "fake-thinking",
2110                "Fake Thinking",
2111                true,
2112            ));
2113            let thinking_provider = Arc::new(
2114                FakeLanguageModelProvider::new(
2115                    LanguageModelProviderId::from("fake-corp".to_string()),
2116                    LanguageModelProviderName::from("Fake Corp".to_string()),
2117                )
2118                .with_models(vec![thinking_model]),
2119            );
2120            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2121                registry.register_provider(thinking_provider, cx);
2122            });
2123        });
2124        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2125
2126        let selector = connection.model_selector(&session_id).unwrap();
2127        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2128            .await
2129            .unwrap();
2130        cx.run_until_parked();
2131
2132        // Verify enable_thinking was written to settings as true.
2133        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2134        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2135        assert_eq!(
2136            settings_json["agent"]["default_model"]["enable_thinking"],
2137            json!(true),
2138            "selecting a thinking model should persist enable_thinking: true to settings"
2139        );
2140    }
2141
2142    #[gpui::test]
2143    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2144        init_test(cx);
2145        let fs = FakeFs::new(cx.executor());
2146        fs.create_dir(paths::settings_file().parent().unwrap())
2147            .await
2148            .unwrap();
2149        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2150        let project = Project::test(fs.clone(), [], cx).await;
2151
2152        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2153        let agent = NativeAgent::new(
2154            project.clone(),
2155            thread_store,
2156            Templates::new(),
2157            None,
2158            fs.clone(),
2159            &mut cx.to_async(),
2160        )
2161        .await
2162        .unwrap();
2163        let connection = NativeAgentConnection(agent.clone());
2164
2165        let acp_thread = cx
2166            .update(|cx| {
2167                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2168            })
2169            .await
2170            .unwrap();
2171        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2172
2173        // Register a second provider with a thinking model.
2174        cx.update(|cx| {
2175            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2176                "fake-corp",
2177                "fake-thinking",
2178                "Fake Thinking",
2179                true,
2180            ));
2181            let thinking_provider = Arc::new(
2182                FakeLanguageModelProvider::new(
2183                    LanguageModelProviderId::from("fake-corp".to_string()),
2184                    LanguageModelProviderName::from("Fake Corp".to_string()),
2185                )
2186                .with_models(vec![thinking_model]),
2187            );
2188            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2189                registry.register_provider(thinking_provider, cx);
2190            });
2191        });
2192        // Refresh the agent's model list so it picks up the new provider.
2193        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2194
2195        // Thread starts with thinking_enabled = false (the default).
2196        agent.read_with(cx, |agent, _| {
2197            let session = agent.sessions.get(&session_id).unwrap();
2198            session.thread.read_with(cx, |thread, _| {
2199                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2200            });
2201        });
2202
2203        // Select the thinking model via select_model.
2204        let selector = connection.model_selector(&session_id).unwrap();
2205        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2206            .await
2207            .unwrap();
2208
2209        // select_model should have enabled thinking based on the model's supports_thinking().
2210        agent.read_with(cx, |agent, _| {
2211            let session = agent.sessions.get(&session_id).unwrap();
2212            session.thread.read_with(cx, |thread, _| {
2213                assert!(
2214                    thread.thinking_enabled(),
2215                    "select_model should enable thinking when model supports it"
2216                );
2217            });
2218        });
2219
2220        // Switch back to the non-thinking model.
2221        let selector = connection.model_selector(&session_id).unwrap();
2222        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2223            .await
2224            .unwrap();
2225
2226        // select_model should have disabled thinking.
2227        agent.read_with(cx, |agent, _| {
2228            let session = agent.sessions.get(&session_id).unwrap();
2229            session.thread.read_with(cx, |thread, _| {
2230                assert!(
2231                    !thread.thinking_enabled(),
2232                    "select_model should disable thinking when model does not support it"
2233                );
2234            });
2235        });
2236    }
2237
2238    #[gpui::test]
2239    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2240        init_test(cx);
2241        let fs = FakeFs::new(cx.executor());
2242        fs.insert_tree("/", json!({ "a": {} })).await;
2243        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2244        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2245        let agent = NativeAgent::new(
2246            project.clone(),
2247            thread_store.clone(),
2248            Templates::new(),
2249            None,
2250            fs.clone(),
2251            &mut cx.to_async(),
2252        )
2253        .await
2254        .unwrap();
2255        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2256
2257        // Register a thinking model.
2258        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2259            "fake-corp",
2260            "fake-thinking",
2261            "Fake Thinking",
2262            true,
2263        ));
2264        let thinking_provider = Arc::new(
2265            FakeLanguageModelProvider::new(
2266                LanguageModelProviderId::from("fake-corp".to_string()),
2267                LanguageModelProviderName::from("Fake Corp".to_string()),
2268            )
2269            .with_models(vec![thinking_model.clone()]),
2270        );
2271        cx.update(|cx| {
2272            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2273                registry.register_provider(thinking_provider, cx);
2274            });
2275        });
2276        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2277
2278        // Create a thread and select the thinking model.
2279        let acp_thread = cx
2280            .update(|cx| {
2281                connection
2282                    .clone()
2283                    .new_session(project.clone(), Path::new("/a"), cx)
2284            })
2285            .await
2286            .unwrap();
2287        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2288
2289        let selector = connection.model_selector(&session_id).unwrap();
2290        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2291            .await
2292            .unwrap();
2293
2294        // Verify thinking is enabled after selecting the thinking model.
2295        let thread = agent.read_with(cx, |agent, _| {
2296            agent.sessions.get(&session_id).unwrap().thread.clone()
2297        });
2298        thread.read_with(cx, |thread, _| {
2299            assert!(
2300                thread.thinking_enabled(),
2301                "thinking should be enabled after selecting thinking model"
2302            );
2303        });
2304
2305        // Send a message so the thread gets persisted.
2306        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2307        let send = cx.foreground_executor().spawn(send);
2308        cx.run_until_parked();
2309
2310        thinking_model.send_last_completion_stream_text_chunk("Response.");
2311        thinking_model.end_last_completion_stream();
2312
2313        send.await.unwrap();
2314        cx.run_until_parked();
2315
2316        // Close the session so it can be reloaded from disk.
2317        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2318            .await
2319            .unwrap();
2320        drop(thread);
2321        drop(acp_thread);
2322        agent.read_with(cx, |agent, _| {
2323            assert!(agent.sessions.is_empty());
2324        });
2325
2326        // Reload the thread and verify thinking_enabled is still true.
2327        let reloaded_acp_thread = agent
2328            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2329            .await
2330            .unwrap();
2331        let reloaded_thread = agent.read_with(cx, |agent, _| {
2332            agent.sessions.get(&session_id).unwrap().thread.clone()
2333        });
2334        reloaded_thread.read_with(cx, |thread, _| {
2335            assert!(
2336                thread.thinking_enabled(),
2337                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2338            );
2339        });
2340
2341        drop(reloaded_acp_thread);
2342    }
2343
2344    #[gpui::test]
2345    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2346        init_test(cx);
2347        let fs = FakeFs::new(cx.executor());
2348        fs.insert_tree("/", json!({ "a": {} })).await;
2349        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2350        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2351        let agent = NativeAgent::new(
2352            project.clone(),
2353            thread_store.clone(),
2354            Templates::new(),
2355            None,
2356            fs.clone(),
2357            &mut cx.to_async(),
2358        )
2359        .await
2360        .unwrap();
2361        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2362
2363        // Register a model where id() != name(), like real Anthropic models
2364        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2365        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2366            "fake-corp",
2367            "custom-model-id",
2368            "Custom Model Display Name",
2369            false,
2370        ));
2371        let provider = Arc::new(
2372            FakeLanguageModelProvider::new(
2373                LanguageModelProviderId::from("fake-corp".to_string()),
2374                LanguageModelProviderName::from("Fake Corp".to_string()),
2375            )
2376            .with_models(vec![model.clone()]),
2377        );
2378        cx.update(|cx| {
2379            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2380                registry.register_provider(provider, cx);
2381            });
2382        });
2383        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2384
2385        // Create a thread and select the model.
2386        let acp_thread = cx
2387            .update(|cx| {
2388                connection
2389                    .clone()
2390                    .new_session(project.clone(), Path::new("/a"), cx)
2391            })
2392            .await
2393            .unwrap();
2394        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2395
2396        let selector = connection.model_selector(&session_id).unwrap();
2397        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2398            .await
2399            .unwrap();
2400
2401        let thread = agent.read_with(cx, |agent, _| {
2402            agent.sessions.get(&session_id).unwrap().thread.clone()
2403        });
2404        thread.read_with(cx, |thread, _| {
2405            assert_eq!(
2406                thread.model().unwrap().id().0.as_ref(),
2407                "custom-model-id",
2408                "model should be set before persisting"
2409            );
2410        });
2411
2412        // Send a message so the thread gets persisted.
2413        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2414        let send = cx.foreground_executor().spawn(send);
2415        cx.run_until_parked();
2416
2417        model.send_last_completion_stream_text_chunk("Response.");
2418        model.end_last_completion_stream();
2419
2420        send.await.unwrap();
2421        cx.run_until_parked();
2422
2423        // Close the session so it can be reloaded from disk.
2424        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2425            .await
2426            .unwrap();
2427        drop(thread);
2428        drop(acp_thread);
2429        agent.read_with(cx, |agent, _| {
2430            assert!(agent.sessions.is_empty());
2431        });
2432
2433        // Reload the thread and verify the model was preserved.
2434        let reloaded_acp_thread = agent
2435            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2436            .await
2437            .unwrap();
2438        let reloaded_thread = agent.read_with(cx, |agent, _| {
2439            agent.sessions.get(&session_id).unwrap().thread.clone()
2440        });
2441        reloaded_thread.read_with(cx, |thread, _| {
2442            let reloaded_model = thread
2443                .model()
2444                .expect("model should be present after reload");
2445            assert_eq!(
2446                reloaded_model.id().0.as_ref(),
2447                "custom-model-id",
2448                "reloaded thread should have the same model, not fall back to the default"
2449            );
2450        });
2451
2452        drop(reloaded_acp_thread);
2453    }
2454
2455    #[gpui::test]
2456    async fn test_save_load_thread(cx: &mut TestAppContext) {
2457        init_test(cx);
2458        let fs = FakeFs::new(cx.executor());
2459        fs.insert_tree(
2460            "/",
2461            json!({
2462                "a": {
2463                    "b.md": "Lorem"
2464                }
2465            }),
2466        )
2467        .await;
2468        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2469        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2470        let agent = NativeAgent::new(
2471            project.clone(),
2472            thread_store.clone(),
2473            Templates::new(),
2474            None,
2475            fs.clone(),
2476            &mut cx.to_async(),
2477        )
2478        .await
2479        .unwrap();
2480        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2481
2482        let acp_thread = cx
2483            .update(|cx| {
2484                connection
2485                    .clone()
2486                    .new_session(project.clone(), Path::new(""), cx)
2487            })
2488            .await
2489            .unwrap();
2490        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2491        let thread = agent.read_with(cx, |agent, _| {
2492            agent.sessions.get(&session_id).unwrap().thread.clone()
2493        });
2494
2495        // Ensure empty threads are not saved, even if they get mutated.
2496        let model = Arc::new(FakeLanguageModel::default());
2497        let summary_model = Arc::new(FakeLanguageModel::default());
2498        thread.update(cx, |thread, cx| {
2499            thread.set_model(model.clone(), cx);
2500            thread.set_summarization_model(Some(summary_model.clone()), cx);
2501        });
2502        cx.run_until_parked();
2503        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2504
2505        let send = acp_thread.update(cx, |thread, cx| {
2506            thread.send(
2507                vec![
2508                    "What does ".into(),
2509                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2510                        "b.md",
2511                        MentionUri::File {
2512                            abs_path: path!("/a/b.md").into(),
2513                        }
2514                        .to_uri()
2515                        .to_string(),
2516                    )),
2517                    " mean?".into(),
2518                ],
2519                cx,
2520            )
2521        });
2522        let send = cx.foreground_executor().spawn(send);
2523        cx.run_until_parked();
2524
2525        model.send_last_completion_stream_text_chunk("Lorem.");
2526        model.end_last_completion_stream();
2527        cx.run_until_parked();
2528        summary_model
2529            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2530        summary_model.end_last_completion_stream();
2531
2532        send.await.unwrap();
2533        let uri = MentionUri::File {
2534            abs_path: path!("/a/b.md").into(),
2535        }
2536        .to_uri();
2537        acp_thread.read_with(cx, |thread, cx| {
2538            assert_eq!(
2539                thread.to_markdown(cx),
2540                formatdoc! {"
2541                    ## User
2542
2543                    What does [@b.md]({uri}) mean?
2544
2545                    ## Assistant
2546
2547                    Lorem.
2548
2549                "}
2550            )
2551        });
2552
2553        cx.run_until_parked();
2554
2555        // Close the session so it can be reloaded from disk.
2556        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2557            .await
2558            .unwrap();
2559        drop(thread);
2560        drop(acp_thread);
2561        agent.read_with(cx, |agent, _| {
2562            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2563        });
2564
2565        // Ensure the thread can be reloaded from disk.
2566        assert_eq!(
2567            thread_entries(&thread_store, cx),
2568            vec![(
2569                session_id.clone(),
2570                format!("Explaining {}", path!("/a/b.md"))
2571            )]
2572        );
2573        let acp_thread = agent
2574            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2575            .await
2576            .unwrap();
2577        acp_thread.read_with(cx, |thread, cx| {
2578            assert_eq!(
2579                thread.to_markdown(cx),
2580                formatdoc! {"
2581                    ## User
2582
2583                    What does [@b.md]({uri}) mean?
2584
2585                    ## Assistant
2586
2587                    Lorem.
2588
2589                "}
2590            )
2591        });
2592    }
2593
2594    fn thread_entries(
2595        thread_store: &Entity<ThreadStore>,
2596        cx: &mut TestAppContext,
2597    ) -> Vec<(acp::SessionId, String)> {
2598        thread_store.read_with(cx, |store, _| {
2599            store
2600                .entries()
2601                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2602                .collect::<Vec<_>>()
2603        })
2604    }
2605
2606    fn init_test(cx: &mut TestAppContext) {
2607        env_logger::try_init().ok();
2608        cx.update(|cx| {
2609            let settings_store = SettingsStore::test(cx);
2610            cx.set_global(settings_store);
2611
2612            LanguageModelRegistry::test(cx);
2613        });
2614    }
2615}
2616
2617fn mcp_message_content_to_acp_content_block(
2618    content: context_server::types::MessageContent,
2619) -> acp::ContentBlock {
2620    match content {
2621        context_server::types::MessageContent::Text {
2622            text,
2623            annotations: _,
2624        } => text.into(),
2625        context_server::types::MessageContent::Image {
2626            data,
2627            mime_type,
2628            annotations: _,
2629        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2630        context_server::types::MessageContent::Audio {
2631            data,
2632            mime_type,
2633            annotations: _,
2634        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2635        context_server::types::MessageContent::Resource {
2636            resource,
2637            annotations: _,
2638        } => {
2639            let mut link =
2640                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2641            if let Some(mime_type) = resource.mime_type {
2642                link = link.mime_type(mime_type);
2643            }
2644            acp::ContentBlock::ResourceLink(link)
2645        }
2646    }
2647}