agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod templates;
   7#[cfg(test)]
   8mod tests;
   9mod thread;
  10mod thread_store;
  11mod tool_permissions;
  12mod tools;
  13
  14use context_server::ContextServerId;
  15pub use db::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use thread_store::*;
  20pub use tool_permissions::*;
  21pub use tools::*;
  22
  23use acp_thread::{
  24    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  25    AgentSessionListResponse, UserMessageId,
  26};
  27use agent_client_protocol as acp;
  28use anyhow::{Context as _, Result, anyhow};
  29use chrono::{DateTime, Utc};
  30use collections::{HashMap, HashSet, IndexMap};
  31use fs::Fs;
  32use futures::channel::{mpsc, oneshot};
  33use futures::future::Shared;
  34use futures::{StreamExt, future};
  35use gpui::{
  36    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  37};
  38use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  39use project::{Project, ProjectItem, ProjectPath, Worktree};
  40use prompt_store::{
  41    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  42    WorktreeContext,
  43};
  44use serde::{Deserialize, Serialize};
  45use settings::{LanguageModelSelection, update_settings_file};
  46use std::any::Any;
  47use std::path::{Path, PathBuf};
  48use std::rc::Rc;
  49use std::sync::Arc;
  50use util::ResultExt;
  51use util::rel_path::RelPath;
  52
  53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  54pub struct ProjectSnapshot {
  55    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  56    pub timestamp: DateTime<Utc>,
  57}
  58
  59pub struct RulesLoadingError {
  60    pub message: SharedString,
  61}
  62
  63/// Holds both the internal Thread and the AcpThread for a session
  64struct Session {
  65    /// The internal thread that processes messages
  66    thread: Entity<Thread>,
  67    /// The ACP thread that handles protocol communication
  68    acp_thread: WeakEntity<acp_thread::AcpThread>,
  69    pending_save: Task<()>,
  70    _subscriptions: Vec<Subscription>,
  71}
  72
  73pub struct LanguageModels {
  74    /// Access language model by ID
  75    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  76    /// Cached list for returning language model information
  77    model_list: acp_thread::AgentModelList,
  78    refresh_models_rx: watch::Receiver<()>,
  79    refresh_models_tx: watch::Sender<()>,
  80    _authenticate_all_providers_task: Task<()>,
  81}
  82
  83impl LanguageModels {
  84    fn new(cx: &mut App) -> Self {
  85        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  86
  87        let mut this = Self {
  88            models: HashMap::default(),
  89            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  90            refresh_models_rx,
  91            refresh_models_tx,
  92            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  93        };
  94        this.refresh_list(cx);
  95        this
  96    }
  97
  98    fn refresh_list(&mut self, cx: &App) {
  99        let providers = LanguageModelRegistry::global(cx)
 100            .read(cx)
 101            .visible_providers()
 102            .into_iter()
 103            .filter(|provider| provider.is_authenticated(cx))
 104            .collect::<Vec<_>>();
 105
 106        let mut language_model_list = IndexMap::default();
 107        let mut recommended_models = HashSet::default();
 108
 109        let mut recommended = Vec::new();
 110        for provider in &providers {
 111            for model in provider.recommended_models(cx) {
 112                recommended_models.insert((model.provider_id(), model.id()));
 113                recommended.push(Self::map_language_model_to_info(&model, provider));
 114            }
 115        }
 116        if !recommended.is_empty() {
 117            language_model_list.insert(
 118                acp_thread::AgentModelGroupName("Recommended".into()),
 119                recommended,
 120            );
 121        }
 122
 123        let mut models = HashMap::default();
 124        for provider in providers {
 125            let mut provider_models = Vec::new();
 126            for model in provider.provided_models(cx) {
 127                let model_info = Self::map_language_model_to_info(&model, &provider);
 128                let model_id = model_info.id.clone();
 129                provider_models.push(model_info);
 130                models.insert(model_id, model);
 131            }
 132            if !provider_models.is_empty() {
 133                language_model_list.insert(
 134                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 135                    provider_models,
 136                );
 137            }
 138        }
 139
 140        self.models = models;
 141        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 142        self.refresh_models_tx.send(()).ok();
 143    }
 144
 145    fn watch(&self) -> watch::Receiver<()> {
 146        self.refresh_models_rx.clone()
 147    }
 148
 149    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 150        self.models.get(model_id).cloned()
 151    }
 152
 153    fn map_language_model_to_info(
 154        model: &Arc<dyn LanguageModel>,
 155        provider: &Arc<dyn LanguageModelProvider>,
 156    ) -> acp_thread::AgentModelInfo {
 157        acp_thread::AgentModelInfo {
 158            id: Self::model_id(model),
 159            name: model.name().0,
 160            description: None,
 161            icon: Some(match provider.icon() {
 162                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 163                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 164            }),
 165        }
 166    }
 167
 168    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 169        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 170    }
 171
 172    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 173        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 174            .read(cx)
 175            .visible_providers()
 176            .iter()
 177            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 178            .collect::<Vec<_>>();
 179
 180        cx.background_spawn(async move {
 181            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 182                if let Err(err) = authenticate_task.await {
 183                    match err {
 184                        language_model::AuthenticateError::CredentialsNotFound => {
 185                            // Since we're authenticating these providers in the
 186                            // background for the purposes of populating the
 187                            // language selector, we don't care about providers
 188                            // where the credentials are not found.
 189                        }
 190                        language_model::AuthenticateError::ConnectionRefused => {
 191                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 192                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 193                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 194                        }
 195                        _ => {
 196                            // Some providers have noisy failure states that we
 197                            // don't want to spam the logs with every time the
 198                            // language model selector is initialized.
 199                            //
 200                            // Ideally these should have more clear failure modes
 201                            // that we know are safe to ignore here, like what we do
 202                            // with `CredentialsNotFound` above.
 203                            match provider_id.0.as_ref() {
 204                                "lmstudio" | "ollama" => {
 205                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 206                                    //
 207                                    // These fail noisily, so we don't log them.
 208                                }
 209                                "copilot_chat" => {
 210                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 211                                }
 212                                _ => {
 213                                    log::error!(
 214                                        "Failed to authenticate provider: {}: {err:#}",
 215                                        provider_name.0
 216                                    );
 217                                }
 218                            }
 219                        }
 220                    }
 221                }
 222            }
 223        })
 224    }
 225}
 226
 227pub struct NativeAgent {
 228    /// Session ID -> Session mapping
 229    sessions: HashMap<acp::SessionId, Session>,
 230    thread_store: Entity<ThreadStore>,
 231    /// Shared project context for all threads
 232    project_context: Entity<ProjectContext>,
 233    project_context_needs_refresh: watch::Sender<()>,
 234    _maintain_project_context: Task<Result<()>>,
 235    context_server_registry: Entity<ContextServerRegistry>,
 236    /// Shared templates for all threads
 237    templates: Arc<Templates>,
 238    /// Cached model information
 239    models: LanguageModels,
 240    project: Entity<Project>,
 241    prompt_store: Option<Entity<PromptStore>>,
 242    fs: Arc<dyn Fs>,
 243    _subscriptions: Vec<Subscription>,
 244}
 245
 246impl NativeAgent {
 247    pub async fn new(
 248        project: Entity<Project>,
 249        thread_store: Entity<ThreadStore>,
 250        templates: Arc<Templates>,
 251        prompt_store: Option<Entity<PromptStore>>,
 252        fs: Arc<dyn Fs>,
 253        cx: &mut AsyncApp,
 254    ) -> Result<Entity<NativeAgent>> {
 255        log::debug!("Creating new NativeAgent");
 256
 257        let project_context = cx
 258            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 259            .await;
 260
 261        Ok(cx.new(|cx| {
 262            let context_server_store = project.read(cx).context_server_store();
 263            let context_server_registry =
 264                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 265
 266            let mut subscriptions = vec![
 267                cx.subscribe(&project, Self::handle_project_event),
 268                cx.subscribe(
 269                    &LanguageModelRegistry::global(cx),
 270                    Self::handle_models_updated_event,
 271                ),
 272                cx.subscribe(
 273                    &context_server_store,
 274                    Self::handle_context_server_store_updated,
 275                ),
 276                cx.subscribe(
 277                    &context_server_registry,
 278                    Self::handle_context_server_registry_event,
 279                ),
 280            ];
 281            if let Some(prompt_store) = prompt_store.as_ref() {
 282                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 283            }
 284
 285            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 286                watch::channel(());
 287            Self {
 288                sessions: HashMap::default(),
 289                thread_store,
 290                project_context: cx.new(|_| project_context),
 291                project_context_needs_refresh: project_context_needs_refresh_tx,
 292                _maintain_project_context: cx.spawn(async move |this, cx| {
 293                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 294                }),
 295                context_server_registry,
 296                templates,
 297                models: LanguageModels::new(cx),
 298                project,
 299                prompt_store,
 300                fs,
 301                _subscriptions: subscriptions,
 302            }
 303        }))
 304    }
 305
 306    fn register_session(
 307        &mut self,
 308        thread_handle: Entity<Thread>,
 309        cx: &mut Context<Self>,
 310    ) -> Entity<AcpThread> {
 311        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 312
 313        let thread = thread_handle.read(cx);
 314        let session_id = thread.id().clone();
 315        let title = thread.title();
 316        let project = thread.project.clone();
 317        let action_log = thread.action_log.clone();
 318        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 319        let acp_thread = cx.new(|cx| {
 320            acp_thread::AcpThread::new(
 321                title,
 322                connection,
 323                project.clone(),
 324                action_log.clone(),
 325                session_id.clone(),
 326                prompt_capabilities_rx,
 327                cx,
 328            )
 329        });
 330
 331        let registry = LanguageModelRegistry::read_global(cx);
 332        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 333
 334        thread_handle.update(cx, |thread, cx| {
 335            thread.set_summarization_model(summarization_model, cx);
 336            thread.add_default_tools(
 337                Rc::new(AcpThreadEnvironment {
 338                    acp_thread: acp_thread.downgrade(),
 339                }) as _,
 340                cx,
 341            )
 342        });
 343
 344        let subscriptions = vec![
 345            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 346                this.sessions.remove(acp_thread.session_id());
 347            }),
 348            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 349            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 350            cx.observe(&thread_handle, move |this, thread, cx| {
 351                this.save_thread(thread, cx)
 352            }),
 353        ];
 354
 355        self.sessions.insert(
 356            session_id,
 357            Session {
 358                thread: thread_handle,
 359                acp_thread: acp_thread.downgrade(),
 360                _subscriptions: subscriptions,
 361                pending_save: Task::ready(()),
 362            },
 363        );
 364
 365        self.update_available_commands(cx);
 366
 367        acp_thread
 368    }
 369
 370    pub fn models(&self) -> &LanguageModels {
 371        &self.models
 372    }
 373
 374    async fn maintain_project_context(
 375        this: WeakEntity<Self>,
 376        mut needs_refresh: watch::Receiver<()>,
 377        cx: &mut AsyncApp,
 378    ) -> Result<()> {
 379        while needs_refresh.changed().await.is_ok() {
 380            let project_context = this
 381                .update(cx, |this, cx| {
 382                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 383                })?
 384                .await;
 385            this.update(cx, |this, cx| {
 386                this.project_context = cx.new(|_| project_context);
 387            })?;
 388        }
 389
 390        Ok(())
 391    }
 392
 393    fn build_project_context(
 394        project: &Entity<Project>,
 395        prompt_store: Option<&Entity<PromptStore>>,
 396        cx: &mut App,
 397    ) -> Task<ProjectContext> {
 398        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 399        let worktree_tasks = worktrees
 400            .into_iter()
 401            .map(|worktree| {
 402                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 403            })
 404            .collect::<Vec<_>>();
 405        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 406            prompt_store.read_with(cx, |prompt_store, cx| {
 407                let prompts = prompt_store.default_prompt_metadata();
 408                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 409                    let contents = prompt_store.load(prompt_metadata.id, cx);
 410                    async move { (contents.await, prompt_metadata) }
 411                });
 412                cx.background_spawn(future::join_all(load_tasks))
 413            })
 414        } else {
 415            Task::ready(vec![])
 416        };
 417
 418        cx.spawn(async move |_cx| {
 419            let (worktrees, default_user_rules) =
 420                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 421
 422            let worktrees = worktrees
 423                .into_iter()
 424                .map(|(worktree, _rules_error)| {
 425                    // TODO: show error message
 426                    // if let Some(rules_error) = rules_error {
 427                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 428                    // }
 429                    worktree
 430                })
 431                .collect::<Vec<_>>();
 432
 433            let default_user_rules = default_user_rules
 434                .into_iter()
 435                .flat_map(|(contents, prompt_metadata)| match contents {
 436                    Ok(contents) => Some(UserRulesContext {
 437                        uuid: prompt_metadata.id.as_user()?,
 438                        title: prompt_metadata.title.map(|title| title.to_string()),
 439                        contents,
 440                    }),
 441                    Err(_err) => {
 442                        // TODO: show error message
 443                        // this.update(cx, |_, cx| {
 444                        //     cx.emit(RulesLoadingError {
 445                        //         message: format!("{err:?}").into(),
 446                        //     });
 447                        // })
 448                        // .ok();
 449                        None
 450                    }
 451                })
 452                .collect::<Vec<_>>();
 453
 454            ProjectContext::new(worktrees, default_user_rules)
 455        })
 456    }
 457
 458    fn load_worktree_info_for_system_prompt(
 459        worktree: Entity<Worktree>,
 460        project: Entity<Project>,
 461        cx: &mut App,
 462    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 463        let tree = worktree.read(cx);
 464        let root_name = tree.root_name_str().into();
 465        let abs_path = tree.abs_path();
 466
 467        let mut context = WorktreeContext {
 468            root_name,
 469            abs_path,
 470            rules_file: None,
 471        };
 472
 473        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 474        let Some(rules_task) = rules_task else {
 475            return Task::ready((context, None));
 476        };
 477
 478        cx.spawn(async move |_| {
 479            let (rules_file, rules_file_error) = match rules_task.await {
 480                Ok(rules_file) => (Some(rules_file), None),
 481                Err(err) => (
 482                    None,
 483                    Some(RulesLoadingError {
 484                        message: format!("{err}").into(),
 485                    }),
 486                ),
 487            };
 488            context.rules_file = rules_file;
 489            (context, rules_file_error)
 490        })
 491    }
 492
 493    fn load_worktree_rules_file(
 494        worktree: Entity<Worktree>,
 495        project: Entity<Project>,
 496        cx: &mut App,
 497    ) -> Option<Task<Result<RulesFileContext>>> {
 498        let worktree = worktree.read(cx);
 499        let worktree_id = worktree.id();
 500        let selected_rules_file = RULES_FILE_NAMES
 501            .into_iter()
 502            .filter_map(|name| {
 503                worktree
 504                    .entry_for_path(RelPath::unix(name).unwrap())
 505                    .filter(|entry| entry.is_file())
 506                    .map(|entry| entry.path.clone())
 507            })
 508            .next();
 509
 510        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 511        // supported. This doesn't seem to occur often in GitHub repositories.
 512        selected_rules_file.map(|path_in_worktree| {
 513            let project_path = ProjectPath {
 514                worktree_id,
 515                path: path_in_worktree.clone(),
 516            };
 517            let buffer_task =
 518                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 519            let rope_task = cx.spawn(async move |cx| {
 520                let buffer = buffer_task.await?;
 521                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 522                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 523                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 524                })?;
 525                anyhow::Ok((project_entry_id, rope))
 526            });
 527            // Build a string from the rope on a background thread.
 528            cx.background_spawn(async move {
 529                let (project_entry_id, rope) = rope_task.await?;
 530                anyhow::Ok(RulesFileContext {
 531                    path_in_worktree,
 532                    text: rope.to_string().trim().to_string(),
 533                    project_entry_id: project_entry_id.to_usize(),
 534                })
 535            })
 536        })
 537    }
 538
 539    fn handle_thread_title_updated(
 540        &mut self,
 541        thread: Entity<Thread>,
 542        _: &TitleUpdated,
 543        cx: &mut Context<Self>,
 544    ) {
 545        let session_id = thread.read(cx).id();
 546        let Some(session) = self.sessions.get(session_id) else {
 547            return;
 548        };
 549        let thread = thread.downgrade();
 550        let acp_thread = session.acp_thread.clone();
 551        cx.spawn(async move |_, cx| {
 552            let title = thread.read_with(cx, |thread, _| thread.title())?;
 553            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 554            task.await
 555        })
 556        .detach_and_log_err(cx);
 557    }
 558
 559    fn handle_thread_token_usage_updated(
 560        &mut self,
 561        thread: Entity<Thread>,
 562        usage: &TokenUsageUpdated,
 563        cx: &mut Context<Self>,
 564    ) {
 565        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 566            return;
 567        };
 568        session
 569            .acp_thread
 570            .update(cx, |acp_thread, cx| {
 571                acp_thread.update_token_usage(usage.0.clone(), cx);
 572            })
 573            .ok();
 574    }
 575
 576    fn handle_project_event(
 577        &mut self,
 578        _project: Entity<Project>,
 579        event: &project::Event,
 580        _cx: &mut Context<Self>,
 581    ) {
 582        match event {
 583            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 584                self.project_context_needs_refresh.send(()).ok();
 585            }
 586            project::Event::WorktreeUpdatedEntries(_, items) => {
 587                if items.iter().any(|(path, _, _)| {
 588                    RULES_FILE_NAMES
 589                        .iter()
 590                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 591                }) {
 592                    self.project_context_needs_refresh.send(()).ok();
 593                }
 594            }
 595            _ => {}
 596        }
 597    }
 598
 599    fn handle_prompts_updated_event(
 600        &mut self,
 601        _prompt_store: Entity<PromptStore>,
 602        _event: &prompt_store::PromptsUpdatedEvent,
 603        _cx: &mut Context<Self>,
 604    ) {
 605        self.project_context_needs_refresh.send(()).ok();
 606    }
 607
 608    fn handle_models_updated_event(
 609        &mut self,
 610        _registry: Entity<LanguageModelRegistry>,
 611        _event: &language_model::Event,
 612        cx: &mut Context<Self>,
 613    ) {
 614        self.models.refresh_list(cx);
 615
 616        let registry = LanguageModelRegistry::read_global(cx);
 617        let default_model = registry.default_model().map(|m| m.model);
 618        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 619
 620        for session in self.sessions.values_mut() {
 621            session.thread.update(cx, |thread, cx| {
 622                if thread.model().is_none()
 623                    && let Some(model) = default_model.clone()
 624                {
 625                    thread.set_model(model, cx);
 626                    cx.notify();
 627                }
 628                thread.set_summarization_model(summarization_model.clone(), cx);
 629            });
 630        }
 631    }
 632
 633    fn handle_context_server_store_updated(
 634        &mut self,
 635        _store: Entity<project::context_server_store::ContextServerStore>,
 636        _event: &project::context_server_store::Event,
 637        cx: &mut Context<Self>,
 638    ) {
 639        self.update_available_commands(cx);
 640    }
 641
 642    fn handle_context_server_registry_event(
 643        &mut self,
 644        _registry: Entity<ContextServerRegistry>,
 645        event: &ContextServerRegistryEvent,
 646        cx: &mut Context<Self>,
 647    ) {
 648        match event {
 649            ContextServerRegistryEvent::ToolsChanged => {}
 650            ContextServerRegistryEvent::PromptsChanged => {
 651                self.update_available_commands(cx);
 652            }
 653        }
 654    }
 655
 656    fn update_available_commands(&self, cx: &mut Context<Self>) {
 657        let available_commands = self.build_available_commands(cx);
 658        for session in self.sessions.values() {
 659            if let Some(acp_thread) = session.acp_thread.upgrade() {
 660                acp_thread.update(cx, |thread, cx| {
 661                    thread
 662                        .handle_session_update(
 663                            acp::SessionUpdate::AvailableCommandsUpdate(
 664                                acp::AvailableCommandsUpdate::new(available_commands.clone()),
 665                            ),
 666                            cx,
 667                        )
 668                        .log_err();
 669                });
 670            }
 671        }
 672    }
 673
 674    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 675        let registry = self.context_server_registry.read(cx);
 676
 677        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 678        for context_server_prompt in registry.prompts() {
 679            *prompt_name_counts
 680                .entry(context_server_prompt.prompt.name.as_str())
 681                .or_insert(0) += 1;
 682        }
 683
 684        registry
 685            .prompts()
 686            .flat_map(|context_server_prompt| {
 687                let prompt = &context_server_prompt.prompt;
 688
 689                let should_prefix = prompt_name_counts
 690                    .get(prompt.name.as_str())
 691                    .copied()
 692                    .unwrap_or(0)
 693                    > 1;
 694
 695                let name = if should_prefix {
 696                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 697                } else {
 698                    prompt.name.clone()
 699                };
 700
 701                let mut command = acp::AvailableCommand::new(
 702                    name,
 703                    prompt.description.clone().unwrap_or_default(),
 704                );
 705
 706                match prompt.arguments.as_deref() {
 707                    Some([arg]) => {
 708                        let hint = format!("<{}>", arg.name);
 709
 710                        command = command.input(acp::AvailableCommandInput::Unstructured(
 711                            acp::UnstructuredCommandInput::new(hint),
 712                        ));
 713                    }
 714                    Some([]) | None => {}
 715                    Some(_) => {
 716                        // skip >1 argument commands since we don't support them yet
 717                        return None;
 718                    }
 719                }
 720
 721                Some(command)
 722            })
 723            .collect()
 724    }
 725
 726    pub fn load_thread(
 727        &mut self,
 728        id: acp::SessionId,
 729        cx: &mut Context<Self>,
 730    ) -> Task<Result<Entity<Thread>>> {
 731        let database_future = ThreadsDatabase::connect(cx);
 732        cx.spawn(async move |this, cx| {
 733            let database = database_future.await.map_err(|err| anyhow!(err))?;
 734            let db_thread = database
 735                .load_thread(id.clone())
 736                .await?
 737                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 738
 739            this.update(cx, |this, cx| {
 740                let summarization_model = LanguageModelRegistry::read_global(cx)
 741                    .thread_summary_model()
 742                    .map(|c| c.model);
 743
 744                cx.new(|cx| {
 745                    let mut thread = Thread::from_db(
 746                        id.clone(),
 747                        db_thread,
 748                        this.project.clone(),
 749                        this.project_context.clone(),
 750                        this.context_server_registry.clone(),
 751                        this.templates.clone(),
 752                        cx,
 753                    );
 754                    thread.set_summarization_model(summarization_model, cx);
 755                    thread
 756                })
 757            })
 758        })
 759    }
 760
 761    pub fn open_thread(
 762        &mut self,
 763        id: acp::SessionId,
 764        cx: &mut Context<Self>,
 765    ) -> Task<Result<Entity<AcpThread>>> {
 766        let task = self.load_thread(id, cx);
 767        cx.spawn(async move |this, cx| {
 768            let thread = task.await?;
 769            let acp_thread =
 770                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 771            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 772            cx.update(|cx| {
 773                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 774            })
 775            .await?;
 776            Ok(acp_thread)
 777        })
 778    }
 779
 780    pub fn thread_summary(
 781        &mut self,
 782        id: acp::SessionId,
 783        cx: &mut Context<Self>,
 784    ) -> Task<Result<SharedString>> {
 785        let thread = self.open_thread(id.clone(), cx);
 786        cx.spawn(async move |this, cx| {
 787            let acp_thread = thread.await?;
 788            let result = this
 789                .update(cx, |this, cx| {
 790                    this.sessions
 791                        .get(&id)
 792                        .unwrap()
 793                        .thread
 794                        .update(cx, |thread, cx| thread.summary(cx))
 795                })?
 796                .await
 797                .context("Failed to generate summary")?;
 798            drop(acp_thread);
 799            Ok(result)
 800        })
 801    }
 802
 803    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 804        if thread.read(cx).is_empty() {
 805            return;
 806        }
 807
 808        let database_future = ThreadsDatabase::connect(cx);
 809        let (id, db_thread) =
 810            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 811        let Some(session) = self.sessions.get_mut(&id) else {
 812            return;
 813        };
 814        let thread_store = self.thread_store.clone();
 815        session.pending_save = cx.spawn(async move |_, cx| {
 816            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 817                return;
 818            };
 819            let db_thread = db_thread.await;
 820            database.save_thread(id, db_thread).await.log_err();
 821            thread_store.update(cx, |store, cx| store.reload(cx));
 822        });
 823    }
 824
 825    fn send_mcp_prompt(
 826        &self,
 827        message_id: UserMessageId,
 828        session_id: agent_client_protocol::SessionId,
 829        prompt_name: String,
 830        server_id: ContextServerId,
 831        arguments: HashMap<String, String>,
 832        original_content: Vec<acp::ContentBlock>,
 833        cx: &mut Context<Self>,
 834    ) -> Task<Result<acp::PromptResponse>> {
 835        let server_store = self.context_server_registry.read(cx).server_store().clone();
 836        let path_style = self.project.read(cx).path_style(cx);
 837
 838        cx.spawn(async move |this, cx| {
 839            let prompt =
 840                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 841
 842            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 843                let session = this
 844                    .sessions
 845                    .get(&session_id)
 846                    .context("Failed to get session")?;
 847                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 848            })??;
 849
 850            let mut last_is_user = true;
 851
 852            thread.update(cx, |thread, cx| {
 853                thread.push_acp_user_block(
 854                    message_id,
 855                    original_content.into_iter().skip(1),
 856                    path_style,
 857                    cx,
 858                );
 859            });
 860
 861            for message in prompt.messages {
 862                let context_server::types::PromptMessage { role, content } = message;
 863                let block = mcp_message_content_to_acp_content_block(content);
 864
 865                match role {
 866                    context_server::types::Role::User => {
 867                        let id = acp_thread::UserMessageId::new();
 868
 869                        acp_thread.update(cx, |acp_thread, cx| {
 870                            acp_thread.push_user_content_block_with_indent(
 871                                Some(id.clone()),
 872                                block.clone(),
 873                                true,
 874                                cx,
 875                            );
 876                        })?;
 877
 878                        thread.update(cx, |thread, cx| {
 879                            thread.push_acp_user_block(id, [block], path_style, cx);
 880                        });
 881                    }
 882                    context_server::types::Role::Assistant => {
 883                        acp_thread.update(cx, |acp_thread, cx| {
 884                            acp_thread.push_assistant_content_block_with_indent(
 885                                block.clone(),
 886                                false,
 887                                true,
 888                                cx,
 889                            );
 890                        })?;
 891
 892                        thread.update(cx, |thread, cx| {
 893                            thread.push_acp_agent_block(block, cx);
 894                        });
 895                    }
 896                }
 897
 898                last_is_user = role == context_server::types::Role::User;
 899            }
 900
 901            let response_stream = thread.update(cx, |thread, cx| {
 902                if last_is_user {
 903                    thread.send_existing(cx)
 904                } else {
 905                    // Resume if MCP prompt did not end with a user message
 906                    thread.resume(cx)
 907                }
 908            })?;
 909
 910            cx.update(|cx| {
 911                NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
 912            })
 913            .await
 914        })
 915    }
 916}
 917
 918/// Wrapper struct that implements the AgentConnection trait
 919#[derive(Clone)]
 920pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 921
 922impl NativeAgentConnection {
 923    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 924        self.0
 925            .read(cx)
 926            .sessions
 927            .get(session_id)
 928            .map(|session| session.thread.clone())
 929    }
 930
 931    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 932        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 933    }
 934
 935    fn run_turn(
 936        &self,
 937        session_id: acp::SessionId,
 938        cx: &mut App,
 939        f: impl 'static
 940        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 941    ) -> Task<Result<acp::PromptResponse>> {
 942        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 943            agent
 944                .sessions
 945                .get_mut(&session_id)
 946                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 947        }) else {
 948            return Task::ready(Err(anyhow!("Session not found")));
 949        };
 950        log::debug!("Found session for: {}", session_id);
 951
 952        let response_stream = match f(thread, cx) {
 953            Ok(stream) => stream,
 954            Err(err) => return Task::ready(Err(err)),
 955        };
 956        Self::handle_thread_events(response_stream, acp_thread, cx)
 957    }
 958
 959    fn handle_thread_events(
 960        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 961        acp_thread: WeakEntity<AcpThread>,
 962        cx: &App,
 963    ) -> Task<Result<acp::PromptResponse>> {
 964        cx.spawn(async move |cx| {
 965            // Handle response stream and forward to session.acp_thread
 966            while let Some(result) = events.next().await {
 967                match result {
 968                    Ok(event) => {
 969                        log::trace!("Received completion event: {:?}", event);
 970
 971                        match event {
 972                            ThreadEvent::UserMessage(message) => {
 973                                acp_thread.update(cx, |thread, cx| {
 974                                    for content in message.content {
 975                                        thread.push_user_content_block(
 976                                            Some(message.id.clone()),
 977                                            content.into(),
 978                                            cx,
 979                                        );
 980                                    }
 981                                })?;
 982                            }
 983                            ThreadEvent::AgentText(text) => {
 984                                acp_thread.update(cx, |thread, cx| {
 985                                    thread.push_assistant_content_block(text.into(), false, cx)
 986                                })?;
 987                            }
 988                            ThreadEvent::AgentThinking(text) => {
 989                                acp_thread.update(cx, |thread, cx| {
 990                                    thread.push_assistant_content_block(text.into(), true, cx)
 991                                })?;
 992                            }
 993                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 994                                tool_call,
 995                                options,
 996                                response,
 997                            }) => {
 998                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 999                                    thread.request_tool_call_authorization(
1000                                        tool_call, options, true, cx,
1001                                    )
1002                                })??;
1003                                cx.background_spawn(async move {
1004                                    if let acp::RequestPermissionOutcome::Selected(
1005                                        acp::SelectedPermissionOutcome { option_id, .. },
1006                                    ) = outcome_task.await
1007                                    {
1008                                        response
1009                                            .send(option_id)
1010                                            .map(|_| anyhow!("authorization receiver was dropped"))
1011                                            .log_err();
1012                                    }
1013                                })
1014                                .detach();
1015                            }
1016                            ThreadEvent::ToolCall(tool_call) => {
1017                                acp_thread.update(cx, |thread, cx| {
1018                                    thread.upsert_tool_call(tool_call, cx)
1019                                })??;
1020                            }
1021                            ThreadEvent::ToolCallUpdate(update) => {
1022                                acp_thread.update(cx, |thread, cx| {
1023                                    thread.update_tool_call(update, cx)
1024                                })??;
1025                            }
1026                            ThreadEvent::Retry(status) => {
1027                                acp_thread.update(cx, |thread, cx| {
1028                                    thread.update_retry_status(status, cx)
1029                                })?;
1030                            }
1031                            ThreadEvent::Stop(stop_reason) => {
1032                                log::debug!("Assistant message complete: {:?}", stop_reason);
1033                                return Ok(acp::PromptResponse::new(stop_reason));
1034                            }
1035                        }
1036                    }
1037                    Err(e) => {
1038                        log::error!("Error in model response stream: {:?}", e);
1039                        return Err(e);
1040                    }
1041                }
1042            }
1043
1044            log::debug!("Response stream completed");
1045            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1046        })
1047    }
1048}
1049
1050struct Command<'a> {
1051    prompt_name: &'a str,
1052    arg_value: &'a str,
1053    explicit_server_id: Option<&'a str>,
1054}
1055
1056impl<'a> Command<'a> {
1057    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1058        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1059            return None;
1060        };
1061        let text = text_content.text.trim();
1062        let command = text.strip_prefix('/')?;
1063        let (command, arg_value) = command
1064            .split_once(char::is_whitespace)
1065            .unwrap_or((command, ""));
1066
1067        if let Some((server_id, prompt_name)) = command.split_once('.') {
1068            Some(Self {
1069                prompt_name,
1070                arg_value,
1071                explicit_server_id: Some(server_id),
1072            })
1073        } else {
1074            Some(Self {
1075                prompt_name: command,
1076                arg_value,
1077                explicit_server_id: None,
1078            })
1079        }
1080    }
1081}
1082
1083struct NativeAgentModelSelector {
1084    session_id: acp::SessionId,
1085    connection: NativeAgentConnection,
1086}
1087
1088impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1089    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1090        log::debug!("NativeAgentConnection::list_models called");
1091        let list = self.connection.0.read(cx).models.model_list.clone();
1092        Task::ready(if list.is_empty() {
1093            Err(anyhow::anyhow!("No models available"))
1094        } else {
1095            Ok(list)
1096        })
1097    }
1098
1099    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1100        log::debug!(
1101            "Setting model for session {}: {}",
1102            self.session_id,
1103            model_id
1104        );
1105        let Some(thread) = self
1106            .connection
1107            .0
1108            .read(cx)
1109            .sessions
1110            .get(&self.session_id)
1111            .map(|session| session.thread.clone())
1112        else {
1113            return Task::ready(Err(anyhow!("Session not found")));
1114        };
1115
1116        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1117            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1118        };
1119
1120        thread.update(cx, |thread, cx| {
1121            thread.set_model(model.clone(), cx);
1122        });
1123
1124        update_settings_file(
1125            self.connection.0.read(cx).fs.clone(),
1126            cx,
1127            move |settings, _cx| {
1128                let provider = model.provider_id().0.to_string();
1129                let model = model.id().0.to_string();
1130                settings
1131                    .agent
1132                    .get_or_insert_default()
1133                    .set_model(LanguageModelSelection {
1134                        provider: provider.into(),
1135                        model,
1136                    });
1137            },
1138        );
1139
1140        Task::ready(Ok(()))
1141    }
1142
1143    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1144        let Some(thread) = self
1145            .connection
1146            .0
1147            .read(cx)
1148            .sessions
1149            .get(&self.session_id)
1150            .map(|session| session.thread.clone())
1151        else {
1152            return Task::ready(Err(anyhow!("Session not found")));
1153        };
1154        let Some(model) = thread.read(cx).model() else {
1155            return Task::ready(Err(anyhow!("Model not found")));
1156        };
1157        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1158        else {
1159            return Task::ready(Err(anyhow!("Provider not found")));
1160        };
1161        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1162            model, &provider,
1163        )))
1164    }
1165
1166    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1167        Some(self.connection.0.read(cx).models.watch())
1168    }
1169
1170    fn should_render_footer(&self) -> bool {
1171        true
1172    }
1173}
1174
1175impl acp_thread::AgentConnection for NativeAgentConnection {
1176    fn telemetry_id(&self) -> SharedString {
1177        "zed".into()
1178    }
1179
1180    fn new_thread(
1181        self: Rc<Self>,
1182        project: Entity<Project>,
1183        cwd: &Path,
1184        cx: &mut App,
1185    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1186        let agent = self.0.clone();
1187        log::debug!("Creating new thread for project at: {:?}", cwd);
1188
1189        cx.spawn(async move |cx| {
1190            log::debug!("Starting thread creation in async context");
1191
1192            // Create Thread
1193            let thread = agent.update(cx, |agent, cx| {
1194                // Fetch default model from registry settings
1195                let registry = LanguageModelRegistry::read_global(cx);
1196                // Log available models for debugging
1197                let available_count = registry.available_models(cx).count();
1198                log::debug!("Total available models: {}", available_count);
1199
1200                let default_model = registry.default_model().and_then(|default_model| {
1201                    agent
1202                        .models
1203                        .model_from_id(&LanguageModels::model_id(&default_model.model))
1204                });
1205                cx.new(|cx| {
1206                    Thread::new(
1207                        project.clone(),
1208                        agent.project_context.clone(),
1209                        agent.context_server_registry.clone(),
1210                        agent.templates.clone(),
1211                        default_model,
1212                        cx,
1213                    )
1214                })
1215            });
1216            Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1217        })
1218    }
1219
1220    fn auth_methods(&self) -> &[acp::AuthMethod] {
1221        &[] // No auth for in-process
1222    }
1223
1224    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1225        Task::ready(Ok(()))
1226    }
1227
1228    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1229        Some(Rc::new(NativeAgentModelSelector {
1230            session_id: session_id.clone(),
1231            connection: self.clone(),
1232        }) as Rc<dyn AgentModelSelector>)
1233    }
1234
1235    fn prompt(
1236        &self,
1237        id: Option<acp_thread::UserMessageId>,
1238        params: acp::PromptRequest,
1239        cx: &mut App,
1240    ) -> Task<Result<acp::PromptResponse>> {
1241        let id = id.expect("UserMessageId is required");
1242        let session_id = params.session_id.clone();
1243        log::info!("Received prompt request for session: {}", session_id);
1244        log::debug!("Prompt blocks count: {}", params.prompt.len());
1245
1246        if let Some(parsed_command) = Command::parse(&params.prompt) {
1247            let registry = self.0.read(cx).context_server_registry.read(cx);
1248
1249            let explicit_server_id = parsed_command
1250                .explicit_server_id
1251                .map(|server_id| ContextServerId(server_id.into()));
1252
1253            if let Some(prompt) =
1254                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1255            {
1256                let arguments = if !parsed_command.arg_value.is_empty()
1257                    && let Some(arg_name) = prompt
1258                        .prompt
1259                        .arguments
1260                        .as_ref()
1261                        .and_then(|args| args.first())
1262                        .map(|arg| arg.name.clone())
1263                {
1264                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1265                } else {
1266                    Default::default()
1267                };
1268
1269                let prompt_name = prompt.prompt.name.clone();
1270                let server_id = prompt.server_id.clone();
1271
1272                return self.0.update(cx, |agent, cx| {
1273                    agent.send_mcp_prompt(
1274                        id,
1275                        session_id.clone(),
1276                        prompt_name,
1277                        server_id,
1278                        arguments,
1279                        params.prompt,
1280                        cx,
1281                    )
1282                });
1283            };
1284        };
1285
1286        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1287
1288        self.run_turn(session_id, cx, move |thread, cx| {
1289            let content: Vec<UserMessageContent> = params
1290                .prompt
1291                .into_iter()
1292                .map(|block| UserMessageContent::from_content_block(block, path_style))
1293                .collect::<Vec<_>>();
1294            log::debug!("Converted prompt to message: {} chars", content.len());
1295            log::debug!("Message id: {:?}", id);
1296            log::debug!("Message content: {:?}", content);
1297
1298            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1299        })
1300    }
1301
1302    fn resume(
1303        &self,
1304        session_id: &acp::SessionId,
1305        _cx: &App,
1306    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1307        Some(Rc::new(NativeAgentSessionResume {
1308            connection: self.clone(),
1309            session_id: session_id.clone(),
1310        }) as _)
1311    }
1312
1313    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1314        log::info!("Cancelling on session: {}", session_id);
1315        self.0.update(cx, |agent, cx| {
1316            if let Some(agent) = agent.sessions.get(session_id) {
1317                agent
1318                    .thread
1319                    .update(cx, |thread, cx| thread.cancel(cx))
1320                    .detach();
1321            }
1322        });
1323    }
1324
1325    fn truncate(
1326        &self,
1327        session_id: &agent_client_protocol::SessionId,
1328        cx: &App,
1329    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1330        self.0.read_with(cx, |agent, _cx| {
1331            agent.sessions.get(session_id).map(|session| {
1332                Rc::new(NativeAgentSessionTruncate {
1333                    thread: session.thread.clone(),
1334                    acp_thread: session.acp_thread.clone(),
1335                }) as _
1336            })
1337        })
1338    }
1339
1340    fn set_title(
1341        &self,
1342        session_id: &acp::SessionId,
1343        _cx: &App,
1344    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1345        Some(Rc::new(NativeAgentSessionSetTitle {
1346            connection: self.clone(),
1347            session_id: session_id.clone(),
1348        }) as _)
1349    }
1350
1351    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1352        let thread_store = self.0.read(cx).thread_store.clone();
1353        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1354    }
1355
1356    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1357        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1358    }
1359
1360    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1361        self
1362    }
1363}
1364
1365impl acp_thread::AgentTelemetry for NativeAgentConnection {
1366    fn thread_data(
1367        &self,
1368        session_id: &acp::SessionId,
1369        cx: &mut App,
1370    ) -> Task<Result<serde_json::Value>> {
1371        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1372            return Task::ready(Err(anyhow!("Session not found")));
1373        };
1374
1375        let task = session.thread.read(cx).to_db(cx);
1376        cx.background_spawn(async move {
1377            serde_json::to_value(task.await).context("Failed to serialize thread")
1378        })
1379    }
1380}
1381
1382pub struct NativeAgentSessionList {
1383    thread_store: Entity<ThreadStore>,
1384    updates_rx: watch::Receiver<()>,
1385    _subscription: Subscription,
1386}
1387
1388impl NativeAgentSessionList {
1389    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1390        let (mut tx, rx) = watch::channel(());
1391        let subscription = cx.observe(&thread_store, move |_, _| {
1392            tx.send(()).ok();
1393        });
1394        Self {
1395            thread_store,
1396            updates_rx: rx,
1397            _subscription: subscription,
1398        }
1399    }
1400
1401    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1402        AgentSessionInfo {
1403            session_id: entry.id,
1404            cwd: None,
1405            title: Some(entry.title),
1406            updated_at: Some(entry.updated_at),
1407            meta: None,
1408        }
1409    }
1410
1411    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1412        &self.thread_store
1413    }
1414}
1415
1416impl AgentSessionList for NativeAgentSessionList {
1417    fn list_sessions(
1418        &self,
1419        _request: AgentSessionListRequest,
1420        cx: &mut App,
1421    ) -> Task<Result<AgentSessionListResponse>> {
1422        let sessions = self
1423            .thread_store
1424            .read(cx)
1425            .entries()
1426            .map(Self::to_session_info)
1427            .collect();
1428        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1429    }
1430
1431    fn supports_delete(&self) -> bool {
1432        true
1433    }
1434
1435    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1436        self.thread_store
1437            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1438    }
1439
1440    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1441        self.thread_store
1442            .update(cx, |store, cx| store.delete_threads(cx))
1443    }
1444
1445    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
1446        Some(self.updates_rx.clone())
1447    }
1448
1449    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1450        self
1451    }
1452}
1453
1454struct NativeAgentSessionTruncate {
1455    thread: Entity<Thread>,
1456    acp_thread: WeakEntity<AcpThread>,
1457}
1458
1459impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1460    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1461        match self.thread.update(cx, |thread, cx| {
1462            thread.truncate(message_id.clone(), cx)?;
1463            Ok(thread.latest_token_usage())
1464        }) {
1465            Ok(usage) => {
1466                self.acp_thread
1467                    .update(cx, |thread, cx| {
1468                        thread.update_token_usage(usage, cx);
1469                    })
1470                    .ok();
1471                Task::ready(Ok(()))
1472            }
1473            Err(error) => Task::ready(Err(error)),
1474        }
1475    }
1476}
1477
1478struct NativeAgentSessionResume {
1479    connection: NativeAgentConnection,
1480    session_id: acp::SessionId,
1481}
1482
1483impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1484    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1485        self.connection
1486            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1487                thread.update(cx, |thread, cx| thread.resume(cx))
1488            })
1489    }
1490}
1491
1492struct NativeAgentSessionSetTitle {
1493    connection: NativeAgentConnection,
1494    session_id: acp::SessionId,
1495}
1496
1497impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1498    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1499        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1500            return Task::ready(Err(anyhow!("session not found")));
1501        };
1502        let thread = session.thread.clone();
1503        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1504        Task::ready(Ok(()))
1505    }
1506}
1507
1508pub struct AcpThreadEnvironment {
1509    acp_thread: WeakEntity<AcpThread>,
1510}
1511
1512impl ThreadEnvironment for AcpThreadEnvironment {
1513    fn create_terminal(
1514        &self,
1515        command: String,
1516        cwd: Option<PathBuf>,
1517        output_byte_limit: Option<u64>,
1518        cx: &mut AsyncApp,
1519    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1520        let task = self.acp_thread.update(cx, |thread, cx| {
1521            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1522        });
1523
1524        let acp_thread = self.acp_thread.clone();
1525        cx.spawn(async move |cx| {
1526            let terminal = task?.await?;
1527
1528            let (drop_tx, drop_rx) = oneshot::channel();
1529            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1530
1531            cx.spawn(async move |cx| {
1532                drop_rx.await.ok();
1533                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1534            })
1535            .detach();
1536
1537            let handle = AcpTerminalHandle {
1538                terminal,
1539                _drop_tx: Some(drop_tx),
1540            };
1541
1542            Ok(Rc::new(handle) as _)
1543        })
1544    }
1545}
1546
1547pub struct AcpTerminalHandle {
1548    terminal: Entity<acp_thread::Terminal>,
1549    _drop_tx: Option<oneshot::Sender<()>>,
1550}
1551
1552impl TerminalHandle for AcpTerminalHandle {
1553    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1554        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1555    }
1556
1557    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1558        Ok(self
1559            .terminal
1560            .read_with(cx, |term, _cx| term.wait_for_exit()))
1561    }
1562
1563    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1564        Ok(self
1565            .terminal
1566            .read_with(cx, |term, cx| term.current_output(cx)))
1567    }
1568
1569    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1570        cx.update(|cx| {
1571            self.terminal.update(cx, |terminal, cx| {
1572                terminal.kill(cx);
1573            });
1574        });
1575        Ok(())
1576    }
1577
1578    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1579        Ok(self
1580            .terminal
1581            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1582    }
1583}
1584
1585#[cfg(test)]
1586mod internal_tests {
1587    use super::*;
1588    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1589    use fs::FakeFs;
1590    use gpui::TestAppContext;
1591    use indoc::formatdoc;
1592    use language_model::fake_provider::FakeLanguageModel;
1593    use serde_json::json;
1594    use settings::SettingsStore;
1595    use util::{path, rel_path::rel_path};
1596
1597    #[gpui::test]
1598    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1599        init_test(cx);
1600        let fs = FakeFs::new(cx.executor());
1601        fs.insert_tree(
1602            "/",
1603            json!({
1604                "a": {}
1605            }),
1606        )
1607        .await;
1608        let project = Project::test(fs.clone(), [], cx).await;
1609        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1610        let agent = NativeAgent::new(
1611            project.clone(),
1612            thread_store,
1613            Templates::new(),
1614            None,
1615            fs.clone(),
1616            &mut cx.to_async(),
1617        )
1618        .await
1619        .unwrap();
1620        agent.read_with(cx, |agent, cx| {
1621            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1622        });
1623
1624        let worktree = project
1625            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1626            .await
1627            .unwrap();
1628        cx.run_until_parked();
1629        agent.read_with(cx, |agent, cx| {
1630            assert_eq!(
1631                agent.project_context.read(cx).worktrees,
1632                vec![WorktreeContext {
1633                    root_name: "a".into(),
1634                    abs_path: Path::new("/a").into(),
1635                    rules_file: None
1636                }]
1637            )
1638        });
1639
1640        // Creating `/a/.rules` updates the project context.
1641        fs.insert_file("/a/.rules", Vec::new()).await;
1642        cx.run_until_parked();
1643        agent.read_with(cx, |agent, cx| {
1644            let rules_entry = worktree
1645                .read(cx)
1646                .entry_for_path(rel_path(".rules"))
1647                .unwrap();
1648            assert_eq!(
1649                agent.project_context.read(cx).worktrees,
1650                vec![WorktreeContext {
1651                    root_name: "a".into(),
1652                    abs_path: Path::new("/a").into(),
1653                    rules_file: Some(RulesFileContext {
1654                        path_in_worktree: rel_path(".rules").into(),
1655                        text: "".into(),
1656                        project_entry_id: rules_entry.id.to_usize()
1657                    })
1658                }]
1659            )
1660        });
1661    }
1662
1663    #[gpui::test]
1664    async fn test_listing_models(cx: &mut TestAppContext) {
1665        init_test(cx);
1666        let fs = FakeFs::new(cx.executor());
1667        fs.insert_tree("/", json!({ "a": {}  })).await;
1668        let project = Project::test(fs.clone(), [], cx).await;
1669        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1670        let connection = NativeAgentConnection(
1671            NativeAgent::new(
1672                project.clone(),
1673                thread_store,
1674                Templates::new(),
1675                None,
1676                fs.clone(),
1677                &mut cx.to_async(),
1678            )
1679            .await
1680            .unwrap(),
1681        );
1682
1683        // Create a thread/session
1684        let acp_thread = cx
1685            .update(|cx| {
1686                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1687            })
1688            .await
1689            .unwrap();
1690
1691        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1692
1693        let models = cx
1694            .update(|cx| {
1695                connection
1696                    .model_selector(&session_id)
1697                    .unwrap()
1698                    .list_models(cx)
1699            })
1700            .await
1701            .unwrap();
1702
1703        let acp_thread::AgentModelList::Grouped(models) = models else {
1704            panic!("Unexpected model group");
1705        };
1706        assert_eq!(
1707            models,
1708            IndexMap::from_iter([(
1709                AgentModelGroupName("Fake".into()),
1710                vec![AgentModelInfo {
1711                    id: acp::ModelId::new("fake/fake"),
1712                    name: "Fake".into(),
1713                    description: None,
1714                    icon: Some(acp_thread::AgentModelIcon::Named(
1715                        ui::IconName::ZedAssistant
1716                    )),
1717                }]
1718            )])
1719        );
1720    }
1721
1722    #[gpui::test]
1723    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1724        init_test(cx);
1725        let fs = FakeFs::new(cx.executor());
1726        fs.create_dir(paths::settings_file().parent().unwrap())
1727            .await
1728            .unwrap();
1729        fs.insert_file(
1730            paths::settings_file(),
1731            json!({
1732                "agent": {
1733                    "default_model": {
1734                        "provider": "foo",
1735                        "model": "bar"
1736                    }
1737                }
1738            })
1739            .to_string()
1740            .into_bytes(),
1741        )
1742        .await;
1743        let project = Project::test(fs.clone(), [], cx).await;
1744
1745        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1746
1747        // Create the agent and connection
1748        let agent = NativeAgent::new(
1749            project.clone(),
1750            thread_store,
1751            Templates::new(),
1752            None,
1753            fs.clone(),
1754            &mut cx.to_async(),
1755        )
1756        .await
1757        .unwrap();
1758        let connection = NativeAgentConnection(agent.clone());
1759
1760        // Create a thread/session
1761        let acp_thread = cx
1762            .update(|cx| {
1763                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1764            })
1765            .await
1766            .unwrap();
1767
1768        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1769
1770        // Select a model
1771        let selector = connection.model_selector(&session_id).unwrap();
1772        let model_id = acp::ModelId::new("fake/fake");
1773        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1774            .await
1775            .unwrap();
1776
1777        // Verify the thread has the selected model
1778        agent.read_with(cx, |agent, _| {
1779            let session = agent.sessions.get(&session_id).unwrap();
1780            session.thread.read_with(cx, |thread, _| {
1781                assert_eq!(thread.model().unwrap().id().0, "fake");
1782            });
1783        });
1784
1785        cx.run_until_parked();
1786
1787        // Verify settings file was updated
1788        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1789        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1790
1791        // Check that the agent settings contain the selected model
1792        assert_eq!(
1793            settings_json["agent"]["default_model"]["model"],
1794            json!("fake")
1795        );
1796        assert_eq!(
1797            settings_json["agent"]["default_model"]["provider"],
1798            json!("fake")
1799        );
1800    }
1801
1802    #[gpui::test]
1803    async fn test_save_load_thread(cx: &mut TestAppContext) {
1804        init_test(cx);
1805        let fs = FakeFs::new(cx.executor());
1806        fs.insert_tree(
1807            "/",
1808            json!({
1809                "a": {
1810                    "b.md": "Lorem"
1811                }
1812            }),
1813        )
1814        .await;
1815        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1816        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1817        let agent = NativeAgent::new(
1818            project.clone(),
1819            thread_store.clone(),
1820            Templates::new(),
1821            None,
1822            fs.clone(),
1823            &mut cx.to_async(),
1824        )
1825        .await
1826        .unwrap();
1827        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1828
1829        let acp_thread = cx
1830            .update(|cx| {
1831                connection
1832                    .clone()
1833                    .new_thread(project.clone(), Path::new(""), cx)
1834            })
1835            .await
1836            .unwrap();
1837        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1838        let thread = agent.read_with(cx, |agent, _| {
1839            agent.sessions.get(&session_id).unwrap().thread.clone()
1840        });
1841
1842        // Ensure empty threads are not saved, even if they get mutated.
1843        let model = Arc::new(FakeLanguageModel::default());
1844        let summary_model = Arc::new(FakeLanguageModel::default());
1845        thread.update(cx, |thread, cx| {
1846            thread.set_model(model.clone(), cx);
1847            thread.set_summarization_model(Some(summary_model.clone()), cx);
1848        });
1849        cx.run_until_parked();
1850        assert_eq!(thread_entries(&thread_store, cx), vec![]);
1851
1852        let send = acp_thread.update(cx, |thread, cx| {
1853            thread.send(
1854                vec![
1855                    "What does ".into(),
1856                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1857                        "b.md",
1858                        MentionUri::File {
1859                            abs_path: path!("/a/b.md").into(),
1860                        }
1861                        .to_uri()
1862                        .to_string(),
1863                    )),
1864                    " mean?".into(),
1865                ],
1866                cx,
1867            )
1868        });
1869        let send = cx.foreground_executor().spawn(send);
1870        cx.run_until_parked();
1871
1872        model.send_last_completion_stream_text_chunk("Lorem.");
1873        model.end_last_completion_stream();
1874        cx.run_until_parked();
1875        summary_model
1876            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1877        summary_model.end_last_completion_stream();
1878
1879        send.await.unwrap();
1880        let uri = MentionUri::File {
1881            abs_path: path!("/a/b.md").into(),
1882        }
1883        .to_uri();
1884        acp_thread.read_with(cx, |thread, cx| {
1885            assert_eq!(
1886                thread.to_markdown(cx),
1887                formatdoc! {"
1888                    ## User
1889
1890                    What does [@b.md]({uri}) mean?
1891
1892                    ## Assistant
1893
1894                    Lorem.
1895
1896                "}
1897            )
1898        });
1899
1900        cx.run_until_parked();
1901
1902        // Drop the ACP thread, which should cause the session to be dropped as well.
1903        cx.update(|_| {
1904            drop(thread);
1905            drop(acp_thread);
1906        });
1907        agent.read_with(cx, |agent, _| {
1908            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1909        });
1910
1911        // Ensure the thread can be reloaded from disk.
1912        assert_eq!(
1913            thread_entries(&thread_store, cx),
1914            vec![(
1915                session_id.clone(),
1916                format!("Explaining {}", path!("/a/b.md"))
1917            )]
1918        );
1919        let acp_thread = agent
1920            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1921            .await
1922            .unwrap();
1923        acp_thread.read_with(cx, |thread, cx| {
1924            assert_eq!(
1925                thread.to_markdown(cx),
1926                formatdoc! {"
1927                    ## User
1928
1929                    What does [@b.md]({uri}) mean?
1930
1931                    ## Assistant
1932
1933                    Lorem.
1934
1935                "}
1936            )
1937        });
1938    }
1939
1940    fn thread_entries(
1941        thread_store: &Entity<ThreadStore>,
1942        cx: &mut TestAppContext,
1943    ) -> Vec<(acp::SessionId, String)> {
1944        thread_store.read_with(cx, |store, _| {
1945            store
1946                .entries()
1947                .map(|entry| (entry.id.clone(), entry.title.to_string()))
1948                .collect::<Vec<_>>()
1949        })
1950    }
1951
1952    fn init_test(cx: &mut TestAppContext) {
1953        env_logger::try_init().ok();
1954        cx.update(|cx| {
1955            let settings_store = SettingsStore::test(cx);
1956            cx.set_global(settings_store);
1957
1958            LanguageModelRegistry::test(cx);
1959        });
1960    }
1961}
1962
1963fn mcp_message_content_to_acp_content_block(
1964    content: context_server::types::MessageContent,
1965) -> acp::ContentBlock {
1966    match content {
1967        context_server::types::MessageContent::Text {
1968            text,
1969            annotations: _,
1970        } => text.into(),
1971        context_server::types::MessageContent::Image {
1972            data,
1973            mime_type,
1974            annotations: _,
1975        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1976        context_server::types::MessageContent::Audio {
1977            data,
1978            mime_type,
1979            annotations: _,
1980        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1981        context_server::types::MessageContent::Resource {
1982            resource,
1983            annotations: _,
1984        } => {
1985            let mut link =
1986                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1987            if let Some(mime_type) = resource.mime_type {
1988                link = link.mime_type(mime_type);
1989            }
1990            acp::ContentBlock::ResourceLink(link)
1991        }
1992    }
1993}