agent.rs

   1mod db;
   2mod edit_agent;
   3mod history_store;
   4mod legacy_thread;
   5mod native_agent_server;
   6pub mod outline;
   7mod templates;
   8mod thread;
   9mod tool_schema;
  10mod tools;
  11
  12#[cfg(test)]
  13mod tests;
  14
  15pub use db::*;
  16pub use history_store::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use templates::*;
  19pub use thread::*;
  20pub use tools::*;
  21
  22use acp_thread::{AcpThread, AgentModelSelector};
  23use agent_client_protocol as acp;
  24use anyhow::{Context as _, Result, anyhow};
  25use chrono::{DateTime, Utc};
  26use collections::{HashSet, IndexMap};
  27use fs::Fs;
  28use futures::channel::{mpsc, oneshot};
  29use futures::future::Shared;
  30use futures::{StreamExt, future};
  31use gpui::{
  32    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  33};
  34use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  35use project::{Project, ProjectItem, ProjectPath, Worktree};
  36use prompt_store::{
  37    ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  38};
  39use serde::{Deserialize, Serialize};
  40use settings::{LanguageModelSelection, update_settings_file};
  41use std::any::Any;
  42use std::collections::HashMap;
  43use std::path::{Path, PathBuf};
  44use std::rc::Rc;
  45use std::sync::Arc;
  46use util::ResultExt;
  47use util::rel_path::RelPath;
  48
  49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  50pub struct ProjectSnapshot {
  51    pub worktree_snapshots: Vec<WorktreeSnapshot>,
  52    pub timestamp: DateTime<Utc>,
  53}
  54
  55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  56pub struct WorktreeSnapshot {
  57    pub worktree_path: String,
  58    pub git_state: Option<GitState>,
  59}
  60
  61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  62pub struct GitState {
  63    pub remote_url: Option<String>,
  64    pub head_sha: Option<String>,
  65    pub current_branch: Option<String>,
  66    pub diff: Option<String>,
  67}
  68
  69const RULES_FILE_NAMES: [&str; 9] = [
  70    ".rules",
  71    ".cursorrules",
  72    ".windsurfrules",
  73    ".clinerules",
  74    ".github/copilot-instructions.md",
  75    "CLAUDE.md",
  76    "AGENT.md",
  77    "AGENTS.md",
  78    "GEMINI.md",
  79];
  80
  81pub struct RulesLoadingError {
  82    pub message: SharedString,
  83}
  84
  85/// Holds both the internal Thread and the AcpThread for a session
  86struct Session {
  87    /// The internal thread that processes messages
  88    thread: Entity<Thread>,
  89    /// The ACP thread that handles protocol communication
  90    acp_thread: WeakEntity<acp_thread::AcpThread>,
  91    pending_save: Task<()>,
  92    _subscriptions: Vec<Subscription>,
  93}
  94
  95pub struct LanguageModels {
  96    /// Access language model by ID
  97    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  98    /// Cached list for returning language model information
  99    model_list: acp_thread::AgentModelList,
 100    refresh_models_rx: watch::Receiver<()>,
 101    refresh_models_tx: watch::Sender<()>,
 102    _authenticate_all_providers_task: Task<()>,
 103}
 104
 105impl LanguageModels {
 106    fn new(cx: &mut App) -> Self {
 107        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 108
 109        let mut this = Self {
 110            models: HashMap::default(),
 111            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 112            refresh_models_rx,
 113            refresh_models_tx,
 114            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 115        };
 116        this.refresh_list(cx);
 117        this
 118    }
 119
 120    fn refresh_list(&mut self, cx: &App) {
 121        let providers = LanguageModelRegistry::global(cx)
 122            .read(cx)
 123            .providers()
 124            .into_iter()
 125            .filter(|provider| provider.is_authenticated(cx))
 126            .collect::<Vec<_>>();
 127
 128        let mut language_model_list = IndexMap::default();
 129        let mut recommended_models = HashSet::default();
 130
 131        let mut recommended = Vec::new();
 132        for provider in &providers {
 133            for model in provider.recommended_models(cx) {
 134                recommended_models.insert((model.provider_id(), model.id()));
 135                recommended.push(Self::map_language_model_to_info(&model, provider));
 136            }
 137        }
 138        if !recommended.is_empty() {
 139            language_model_list.insert(
 140                acp_thread::AgentModelGroupName("Recommended".into()),
 141                recommended,
 142            );
 143        }
 144
 145        let mut models = HashMap::default();
 146        for provider in providers {
 147            let mut provider_models = Vec::new();
 148            for model in provider.provided_models(cx) {
 149                let model_info = Self::map_language_model_to_info(&model, &provider);
 150                let model_id = model_info.id.clone();
 151                if !recommended_models.contains(&(model.provider_id(), model.id())) {
 152                    provider_models.push(model_info);
 153                }
 154                models.insert(model_id, model);
 155            }
 156            if !provider_models.is_empty() {
 157                language_model_list.insert(
 158                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 159                    provider_models,
 160                );
 161            }
 162        }
 163
 164        self.models = models;
 165        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 166        self.refresh_models_tx.send(()).ok();
 167    }
 168
 169    fn watch(&self) -> watch::Receiver<()> {
 170        self.refresh_models_rx.clone()
 171    }
 172
 173    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 174        self.models.get(model_id).cloned()
 175    }
 176
 177    fn map_language_model_to_info(
 178        model: &Arc<dyn LanguageModel>,
 179        provider: &Arc<dyn LanguageModelProvider>,
 180    ) -> acp_thread::AgentModelInfo {
 181        acp_thread::AgentModelInfo {
 182            id: Self::model_id(model),
 183            name: model.name().0,
 184            description: None,
 185            icon: Some(provider.icon()),
 186        }
 187    }
 188
 189    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 190        acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 191    }
 192
 193    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 194        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 195            .read(cx)
 196            .providers()
 197            .iter()
 198            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 199            .collect::<Vec<_>>();
 200
 201        cx.background_spawn(async move {
 202            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 203                if let Err(err) = authenticate_task.await {
 204                    match err {
 205                        language_model::AuthenticateError::CredentialsNotFound => {
 206                            // Since we're authenticating these providers in the
 207                            // background for the purposes of populating the
 208                            // language selector, we don't care about providers
 209                            // where the credentials are not found.
 210                        }
 211                        language_model::AuthenticateError::ConnectionRefused => {
 212                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 213                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 214                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 215                        }
 216                        _ => {
 217                            // Some providers have noisy failure states that we
 218                            // don't want to spam the logs with every time the
 219                            // language model selector is initialized.
 220                            //
 221                            // Ideally these should have more clear failure modes
 222                            // that we know are safe to ignore here, like what we do
 223                            // with `CredentialsNotFound` above.
 224                            match provider_id.0.as_ref() {
 225                                "lmstudio" | "ollama" => {
 226                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 227                                    //
 228                                    // These fail noisily, so we don't log them.
 229                                }
 230                                "copilot_chat" => {
 231                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 232                                }
 233                                _ => {
 234                                    log::error!(
 235                                        "Failed to authenticate provider: {}: {err}",
 236                                        provider_name.0
 237                                    );
 238                                }
 239                            }
 240                        }
 241                    }
 242                }
 243            }
 244        })
 245    }
 246}
 247
 248pub struct NativeAgent {
 249    /// Session ID -> Session mapping
 250    sessions: HashMap<acp::SessionId, Session>,
 251    history: Entity<HistoryStore>,
 252    /// Shared project context for all threads
 253    project_context: Entity<ProjectContext>,
 254    project_context_needs_refresh: watch::Sender<()>,
 255    _maintain_project_context: Task<Result<()>>,
 256    context_server_registry: Entity<ContextServerRegistry>,
 257    /// Shared templates for all threads
 258    templates: Arc<Templates>,
 259    /// Cached model information
 260    models: LanguageModels,
 261    project: Entity<Project>,
 262    prompt_store: Option<Entity<PromptStore>>,
 263    fs: Arc<dyn Fs>,
 264    _subscriptions: Vec<Subscription>,
 265}
 266
 267impl NativeAgent {
 268    pub async fn new(
 269        project: Entity<Project>,
 270        history: Entity<HistoryStore>,
 271        templates: Arc<Templates>,
 272        prompt_store: Option<Entity<PromptStore>>,
 273        fs: Arc<dyn Fs>,
 274        cx: &mut AsyncApp,
 275    ) -> Result<Entity<NativeAgent>> {
 276        log::debug!("Creating new NativeAgent");
 277
 278        let project_context = cx
 279            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 280            .await;
 281
 282        cx.new(|cx| {
 283            let mut subscriptions = vec![
 284                cx.subscribe(&project, Self::handle_project_event),
 285                cx.subscribe(
 286                    &LanguageModelRegistry::global(cx),
 287                    Self::handle_models_updated_event,
 288                ),
 289            ];
 290            if let Some(prompt_store) = prompt_store.as_ref() {
 291                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 292            }
 293
 294            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 295                watch::channel(());
 296            Self {
 297                sessions: HashMap::new(),
 298                history,
 299                project_context: cx.new(|_| project_context),
 300                project_context_needs_refresh: project_context_needs_refresh_tx,
 301                _maintain_project_context: cx.spawn(async move |this, cx| {
 302                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 303                }),
 304                context_server_registry: cx.new(|cx| {
 305                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 306                }),
 307                templates,
 308                models: LanguageModels::new(cx),
 309                project,
 310                prompt_store,
 311                fs,
 312                _subscriptions: subscriptions,
 313            }
 314        })
 315    }
 316
 317    fn register_session(
 318        &mut self,
 319        thread_handle: Entity<Thread>,
 320        cx: &mut Context<Self>,
 321    ) -> Entity<AcpThread> {
 322        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 323
 324        let thread = thread_handle.read(cx);
 325        let session_id = thread.id().clone();
 326        let title = thread.title();
 327        let project = thread.project.clone();
 328        let action_log = thread.action_log.clone();
 329        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 330        let acp_thread = cx.new(|cx| {
 331            acp_thread::AcpThread::new(
 332                title,
 333                connection,
 334                project.clone(),
 335                action_log.clone(),
 336                session_id.clone(),
 337                prompt_capabilities_rx,
 338                cx,
 339            )
 340        });
 341
 342        let registry = LanguageModelRegistry::read_global(cx);
 343        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 344
 345        thread_handle.update(cx, |thread, cx| {
 346            thread.set_summarization_model(summarization_model, cx);
 347            thread.add_default_tools(
 348                Rc::new(AcpThreadEnvironment {
 349                    acp_thread: acp_thread.downgrade(),
 350                }) as _,
 351                cx,
 352            )
 353        });
 354
 355        let subscriptions = vec![
 356            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 357                this.sessions.remove(acp_thread.session_id());
 358            }),
 359            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 360            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 361            cx.observe(&thread_handle, move |this, thread, cx| {
 362                this.save_thread(thread, cx)
 363            }),
 364        ];
 365
 366        self.sessions.insert(
 367            session_id,
 368            Session {
 369                thread: thread_handle,
 370                acp_thread: acp_thread.downgrade(),
 371                _subscriptions: subscriptions,
 372                pending_save: Task::ready(()),
 373            },
 374        );
 375        acp_thread
 376    }
 377
 378    pub const fn models(&self) -> &LanguageModels {
 379        &self.models
 380    }
 381
 382    async fn maintain_project_context(
 383        this: WeakEntity<Self>,
 384        mut needs_refresh: watch::Receiver<()>,
 385        cx: &mut AsyncApp,
 386    ) -> Result<()> {
 387        while needs_refresh.changed().await.is_ok() {
 388            let project_context = this
 389                .update(cx, |this, cx| {
 390                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 391                })?
 392                .await;
 393            this.update(cx, |this, cx| {
 394                this.project_context = cx.new(|_| project_context);
 395            })?;
 396        }
 397
 398        Ok(())
 399    }
 400
 401    fn build_project_context(
 402        project: &Entity<Project>,
 403        prompt_store: Option<&Entity<PromptStore>>,
 404        cx: &mut App,
 405    ) -> Task<ProjectContext> {
 406        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 407        let worktree_tasks = worktrees
 408            .into_iter()
 409            .map(|worktree| {
 410                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 411            })
 412            .collect::<Vec<_>>();
 413        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 414            prompt_store.read_with(cx, |prompt_store, cx| {
 415                let prompts = prompt_store.default_prompt_metadata();
 416                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 417                    let contents = prompt_store.load(prompt_metadata.id, cx);
 418                    async move { (contents.await, prompt_metadata) }
 419                });
 420                cx.background_spawn(future::join_all(load_tasks))
 421            })
 422        } else {
 423            Task::ready(vec![])
 424        };
 425
 426        cx.spawn(async move |_cx| {
 427            let (worktrees, default_user_rules) =
 428                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 429
 430            let worktrees = worktrees
 431                .into_iter()
 432                .map(|(worktree, _rules_error)| {
 433                    // TODO: show error message
 434                    // if let Some(rules_error) = rules_error {
 435                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 436                    // }
 437                    worktree
 438                })
 439                .collect::<Vec<_>>();
 440
 441            let default_user_rules = default_user_rules
 442                .into_iter()
 443                .flat_map(|(contents, prompt_metadata)| match contents {
 444                    Ok(contents) => Some(UserRulesContext {
 445                        uuid: match prompt_metadata.id {
 446                            prompt_store::PromptId::User { uuid } => uuid,
 447                            prompt_store::PromptId::EditWorkflow => return None,
 448                        },
 449                        title: prompt_metadata.title.map(|title| title.to_string()),
 450                        contents,
 451                    }),
 452                    Err(_err) => {
 453                        // TODO: show error message
 454                        // this.update(cx, |_, cx| {
 455                        //     cx.emit(RulesLoadingError {
 456                        //         message: format!("{err:?}").into(),
 457                        //     });
 458                        // })
 459                        // .ok();
 460                        None
 461                    }
 462                })
 463                .collect::<Vec<_>>();
 464
 465            ProjectContext::new(worktrees, default_user_rules)
 466        })
 467    }
 468
 469    fn load_worktree_info_for_system_prompt(
 470        worktree: Entity<Worktree>,
 471        project: Entity<Project>,
 472        cx: &mut App,
 473    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 474        let tree = worktree.read(cx);
 475        let root_name = tree.root_name_str().into();
 476        let abs_path = tree.abs_path();
 477
 478        let mut context = WorktreeContext {
 479            root_name,
 480            abs_path,
 481            rules_file: None,
 482        };
 483
 484        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 485        let Some(rules_task) = rules_task else {
 486            return Task::ready((context, None));
 487        };
 488
 489        cx.spawn(async move |_| {
 490            let (rules_file, rules_file_error) = match rules_task.await {
 491                Ok(rules_file) => (Some(rules_file), None),
 492                Err(err) => (
 493                    None,
 494                    Some(RulesLoadingError {
 495                        message: format!("{err}").into(),
 496                    }),
 497                ),
 498            };
 499            context.rules_file = rules_file;
 500            (context, rules_file_error)
 501        })
 502    }
 503
 504    fn load_worktree_rules_file(
 505        worktree: Entity<Worktree>,
 506        project: Entity<Project>,
 507        cx: &mut App,
 508    ) -> Option<Task<Result<RulesFileContext>>> {
 509        let worktree = worktree.read(cx);
 510        let worktree_id = worktree.id();
 511        let selected_rules_file = RULES_FILE_NAMES
 512            .into_iter()
 513            .filter_map(|name| {
 514                worktree
 515                    .entry_for_path(RelPath::unix(name).unwrap())
 516                    .filter(|entry| entry.is_file())
 517                    .map(|entry| entry.path.clone())
 518            })
 519            .next();
 520
 521        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 522        // supported. This doesn't seem to occur often in GitHub repositories.
 523        selected_rules_file.map(|path_in_worktree| {
 524            let project_path = ProjectPath {
 525                worktree_id,
 526                path: path_in_worktree.clone(),
 527            };
 528            let buffer_task =
 529                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 530            let rope_task = cx.spawn(async move |cx| {
 531                buffer_task.await?.read_with(cx, |buffer, cx| {
 532                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 533                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 534                })?
 535            });
 536            // Build a string from the rope on a background thread.
 537            cx.background_spawn(async move {
 538                let (project_entry_id, rope) = rope_task.await?;
 539                anyhow::Ok(RulesFileContext {
 540                    path_in_worktree,
 541                    text: rope.to_string().trim().to_string(),
 542                    project_entry_id: project_entry_id.to_usize(),
 543                })
 544            })
 545        })
 546    }
 547
 548    fn handle_thread_title_updated(
 549        &mut self,
 550        thread: Entity<Thread>,
 551        _: &TitleUpdated,
 552        cx: &mut Context<Self>,
 553    ) {
 554        let session_id = thread.read(cx).id();
 555        let Some(session) = self.sessions.get(session_id) else {
 556            return;
 557        };
 558        let thread = thread.downgrade();
 559        let acp_thread = session.acp_thread.clone();
 560        cx.spawn(async move |_, cx| {
 561            let title = thread.read_with(cx, |thread, _| thread.title())?;
 562            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 563            task.await
 564        })
 565        .detach_and_log_err(cx);
 566    }
 567
 568    fn handle_thread_token_usage_updated(
 569        &mut self,
 570        thread: Entity<Thread>,
 571        usage: &TokenUsageUpdated,
 572        cx: &mut Context<Self>,
 573    ) {
 574        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 575            return;
 576        };
 577        session
 578            .acp_thread
 579            .update(cx, |acp_thread, cx| {
 580                acp_thread.update_token_usage(usage.0.clone(), cx);
 581            })
 582            .ok();
 583    }
 584
 585    fn handle_project_event(
 586        &mut self,
 587        _project: Entity<Project>,
 588        event: &project::Event,
 589        _cx: &mut Context<Self>,
 590    ) {
 591        match event {
 592            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 593                self.project_context_needs_refresh.send(()).ok();
 594            }
 595            project::Event::WorktreeUpdatedEntries(_, items) => {
 596                if items.iter().any(|(path, _, _)| {
 597                    RULES_FILE_NAMES
 598                        .iter()
 599                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 600                }) {
 601                    self.project_context_needs_refresh.send(()).ok();
 602                }
 603            }
 604            _ => {}
 605        }
 606    }
 607
 608    fn handle_prompts_updated_event(
 609        &mut self,
 610        _prompt_store: Entity<PromptStore>,
 611        _event: &prompt_store::PromptsUpdatedEvent,
 612        _cx: &mut Context<Self>,
 613    ) {
 614        self.project_context_needs_refresh.send(()).ok();
 615    }
 616
 617    fn handle_models_updated_event(
 618        &mut self,
 619        _registry: Entity<LanguageModelRegistry>,
 620        _event: &language_model::Event,
 621        cx: &mut Context<Self>,
 622    ) {
 623        self.models.refresh_list(cx);
 624
 625        let registry = LanguageModelRegistry::read_global(cx);
 626        let default_model = registry.default_model().map(|m| m.model);
 627        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 628
 629        for session in self.sessions.values_mut() {
 630            session.thread.update(cx, |thread, cx| {
 631                if thread.model().is_none()
 632                    && let Some(model) = default_model.clone()
 633                {
 634                    thread.set_model(model, cx);
 635                    cx.notify();
 636                }
 637                thread.set_summarization_model(summarization_model.clone(), cx);
 638            });
 639        }
 640    }
 641
 642    pub fn load_thread(
 643        &mut self,
 644        id: acp::SessionId,
 645        cx: &mut Context<Self>,
 646    ) -> Task<Result<Entity<Thread>>> {
 647        let database_future = ThreadsDatabase::connect(cx);
 648        cx.spawn(async move |this, cx| {
 649            let database = database_future.await.map_err(|err| anyhow!(err))?;
 650            let db_thread = database
 651                .load_thread(id.clone())
 652                .await?
 653                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 654
 655            this.update(cx, |this, cx| {
 656                let summarization_model = LanguageModelRegistry::read_global(cx)
 657                    .thread_summary_model()
 658                    .map(|c| c.model);
 659
 660                cx.new(|cx| {
 661                    let mut thread = Thread::from_db(
 662                        id.clone(),
 663                        db_thread,
 664                        this.project.clone(),
 665                        this.project_context.clone(),
 666                        this.context_server_registry.clone(),
 667                        this.templates.clone(),
 668                        cx,
 669                    );
 670                    thread.set_summarization_model(summarization_model, cx);
 671                    thread
 672                })
 673            })
 674        })
 675    }
 676
 677    pub fn open_thread(
 678        &mut self,
 679        id: acp::SessionId,
 680        cx: &mut Context<Self>,
 681    ) -> Task<Result<Entity<AcpThread>>> {
 682        let task = self.load_thread(id, cx);
 683        cx.spawn(async move |this, cx| {
 684            let thread = task.await?;
 685            let acp_thread =
 686                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 687            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 688            cx.update(|cx| {
 689                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 690            })?
 691            .await?;
 692            Ok(acp_thread)
 693        })
 694    }
 695
 696    pub fn thread_summary(
 697        &mut self,
 698        id: acp::SessionId,
 699        cx: &mut Context<Self>,
 700    ) -> Task<Result<SharedString>> {
 701        let thread = self.open_thread(id.clone(), cx);
 702        cx.spawn(async move |this, cx| {
 703            let acp_thread = thread.await?;
 704            let result = this
 705                .update(cx, |this, cx| {
 706                    this.sessions
 707                        .get(&id)
 708                        .unwrap()
 709                        .thread
 710                        .update(cx, |thread, cx| thread.summary(cx))
 711                })?
 712                .await
 713                .context("Failed to generate summary")?;
 714            drop(acp_thread);
 715            Ok(result)
 716        })
 717    }
 718
 719    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 720        if thread.read(cx).is_empty() {
 721            return;
 722        }
 723
 724        let database_future = ThreadsDatabase::connect(cx);
 725        let (id, db_thread) =
 726            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 727        let Some(session) = self.sessions.get_mut(&id) else {
 728            return;
 729        };
 730        let history = self.history.clone();
 731        session.pending_save = cx.spawn(async move |_, cx| {
 732            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 733                return;
 734            };
 735            let db_thread = db_thread.await;
 736            database.save_thread(id, db_thread).await.log_err();
 737            history.update(cx, |history, cx| history.reload(cx)).ok();
 738        });
 739    }
 740}
 741
 742/// Wrapper struct that implements the AgentConnection trait
 743#[derive(Clone)]
 744pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 745
 746impl NativeAgentConnection {
 747    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 748        self.0
 749            .read(cx)
 750            .sessions
 751            .get(session_id)
 752            .map(|session| session.thread.clone())
 753    }
 754
 755    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 756        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 757    }
 758
 759    fn run_turn(
 760        &self,
 761        session_id: acp::SessionId,
 762        cx: &mut App,
 763        f: impl 'static
 764        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 765    ) -> Task<Result<acp::PromptResponse>> {
 766        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 767            agent
 768                .sessions
 769                .get_mut(&session_id)
 770                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 771        }) else {
 772            return Task::ready(Err(anyhow!("Session not found")));
 773        };
 774        log::debug!("Found session for: {}", session_id);
 775
 776        let response_stream = match f(thread, cx) {
 777            Ok(stream) => stream,
 778            Err(err) => return Task::ready(Err(err)),
 779        };
 780        Self::handle_thread_events(response_stream, acp_thread, cx)
 781    }
 782
 783    fn handle_thread_events(
 784        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 785        acp_thread: WeakEntity<AcpThread>,
 786        cx: &App,
 787    ) -> Task<Result<acp::PromptResponse>> {
 788        cx.spawn(async move |cx| {
 789            // Handle response stream and forward to session.acp_thread
 790            while let Some(result) = events.next().await {
 791                match result {
 792                    Ok(event) => {
 793                        log::trace!("Received completion event: {:?}", event);
 794
 795                        match event {
 796                            ThreadEvent::UserMessage(message) => {
 797                                acp_thread.update(cx, |thread, cx| {
 798                                    for content in message.content {
 799                                        thread.push_user_content_block(
 800                                            Some(message.id.clone()),
 801                                            content.into(),
 802                                            cx,
 803                                        );
 804                                    }
 805                                })?;
 806                            }
 807                            ThreadEvent::AgentText(text) => {
 808                                acp_thread.update(cx, |thread, cx| {
 809                                    thread.push_assistant_content_block(
 810                                        acp::ContentBlock::Text(acp::TextContent {
 811                                            text,
 812                                            annotations: None,
 813                                            meta: None,
 814                                        }),
 815                                        false,
 816                                        cx,
 817                                    )
 818                                })?;
 819                            }
 820                            ThreadEvent::AgentThinking(text) => {
 821                                acp_thread.update(cx, |thread, cx| {
 822                                    thread.push_assistant_content_block(
 823                                        acp::ContentBlock::Text(acp::TextContent {
 824                                            text,
 825                                            annotations: None,
 826                                            meta: None,
 827                                        }),
 828                                        true,
 829                                        cx,
 830                                    )
 831                                })?;
 832                            }
 833                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 834                                tool_call,
 835                                options,
 836                                response,
 837                            }) => {
 838                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 839                                    thread.request_tool_call_authorization(
 840                                        tool_call, options, true, cx,
 841                                    )
 842                                })??;
 843                                cx.background_spawn(async move {
 844                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 845                                        outcome_task.await
 846                                    {
 847                                        response
 848                                            .send(option_id)
 849                                            .map(|_| anyhow!("authorization receiver was dropped"))
 850                                            .log_err();
 851                                    }
 852                                })
 853                                .detach();
 854                            }
 855                            ThreadEvent::ToolCall(tool_call) => {
 856                                acp_thread.update(cx, |thread, cx| {
 857                                    thread.upsert_tool_call(tool_call, cx)
 858                                })??;
 859                            }
 860                            ThreadEvent::ToolCallUpdate(update) => {
 861                                acp_thread.update(cx, |thread, cx| {
 862                                    thread.update_tool_call(update, cx)
 863                                })??;
 864                            }
 865                            ThreadEvent::Retry(status) => {
 866                                acp_thread.update(cx, |thread, cx| {
 867                                    thread.update_retry_status(status, cx)
 868                                })?;
 869                            }
 870                            ThreadEvent::Stop(stop_reason) => {
 871                                log::debug!("Assistant message complete: {:?}", stop_reason);
 872                                return Ok(acp::PromptResponse {
 873                                    stop_reason,
 874                                    meta: None,
 875                                });
 876                            }
 877                        }
 878                    }
 879                    Err(e) => {
 880                        log::error!("Error in model response stream: {:?}", e);
 881                        return Err(e);
 882                    }
 883                }
 884            }
 885
 886            log::debug!("Response stream completed");
 887            anyhow::Ok(acp::PromptResponse {
 888                stop_reason: acp::StopReason::EndTurn,
 889                meta: None,
 890            })
 891        })
 892    }
 893}
 894
 895struct NativeAgentModelSelector {
 896    session_id: acp::SessionId,
 897    connection: NativeAgentConnection,
 898}
 899
 900impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 901    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 902        log::debug!("NativeAgentConnection::list_models called");
 903        let list = self.connection.0.read(cx).models.model_list.clone();
 904        Task::ready(if list.is_empty() {
 905            Err(anyhow::anyhow!("No models available"))
 906        } else {
 907            Ok(list)
 908        })
 909    }
 910
 911    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 912        log::debug!(
 913            "Setting model for session {}: {}",
 914            self.session_id,
 915            model_id
 916        );
 917        let Some(thread) = self
 918            .connection
 919            .0
 920            .read(cx)
 921            .sessions
 922            .get(&self.session_id)
 923            .map(|session| session.thread.clone())
 924        else {
 925            return Task::ready(Err(anyhow!("Session not found")));
 926        };
 927
 928        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 929            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 930        };
 931
 932        thread.update(cx, |thread, cx| {
 933            thread.set_model(model.clone(), cx);
 934        });
 935
 936        update_settings_file(
 937            self.connection.0.read(cx).fs.clone(),
 938            cx,
 939            move |settings, _cx| {
 940                let provider = model.provider_id().0.to_string();
 941                let model = model.id().0.to_string();
 942                settings
 943                    .agent
 944                    .get_or_insert_default()
 945                    .set_model(LanguageModelSelection {
 946                        provider: provider.into(),
 947                        model,
 948                    });
 949            },
 950        );
 951
 952        Task::ready(Ok(()))
 953    }
 954
 955    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 956        let Some(thread) = self
 957            .connection
 958            .0
 959            .read(cx)
 960            .sessions
 961            .get(&self.session_id)
 962            .map(|session| session.thread.clone())
 963        else {
 964            return Task::ready(Err(anyhow!("Session not found")));
 965        };
 966        let Some(model) = thread.read(cx).model() else {
 967            return Task::ready(Err(anyhow!("Model not found")));
 968        };
 969        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 970        else {
 971            return Task::ready(Err(anyhow!("Provider not found")));
 972        };
 973        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 974            model, &provider,
 975        )))
 976    }
 977
 978    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 979        Some(self.connection.0.read(cx).models.watch())
 980    }
 981}
 982
 983impl acp_thread::AgentConnection for NativeAgentConnection {
 984    fn new_thread(
 985        self: Rc<Self>,
 986        project: Entity<Project>,
 987        cwd: &Path,
 988        cx: &mut App,
 989    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 990        let agent = self.0.clone();
 991        log::debug!("Creating new thread for project at: {:?}", cwd);
 992
 993        cx.spawn(async move |cx| {
 994            log::debug!("Starting thread creation in async context");
 995
 996            // Create Thread
 997            let thread = agent.update(
 998                cx,
 999                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1000                    // Fetch default model from registry settings
1001                    let registry = LanguageModelRegistry::read_global(cx);
1002                    // Log available models for debugging
1003                    let available_count = registry.available_models(cx).count();
1004                    log::debug!("Total available models: {}", available_count);
1005
1006                    let default_model = registry.default_model().and_then(|default_model| {
1007                        agent
1008                            .models
1009                            .model_from_id(&LanguageModels::model_id(&default_model.model))
1010                    });
1011                    Ok(cx.new(|cx| {
1012                        Thread::new(
1013                            project.clone(),
1014                            agent.project_context.clone(),
1015                            agent.context_server_registry.clone(),
1016                            agent.templates.clone(),
1017                            default_model,
1018                            cx,
1019                        )
1020                    }))
1021                },
1022            )??;
1023            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1024        })
1025    }
1026
1027    fn auth_methods(&self) -> &[acp::AuthMethod] {
1028        &[] // No auth for in-process
1029    }
1030
1031    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1032        Task::ready(Ok(()))
1033    }
1034
1035    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1036        Some(Rc::new(NativeAgentModelSelector {
1037            session_id: session_id.clone(),
1038            connection: self.clone(),
1039        }) as Rc<dyn AgentModelSelector>)
1040    }
1041
1042    fn prompt(
1043        &self,
1044        id: Option<acp_thread::UserMessageId>,
1045        params: acp::PromptRequest,
1046        cx: &mut App,
1047    ) -> Task<Result<acp::PromptResponse>> {
1048        let id = id.expect("UserMessageId is required");
1049        let session_id = params.session_id.clone();
1050        log::info!("Received prompt request for session: {}", session_id);
1051        log::debug!("Prompt blocks count: {}", params.prompt.len());
1052
1053        self.run_turn(session_id, cx, |thread, cx| {
1054            let content: Vec<UserMessageContent> = params
1055                .prompt
1056                .into_iter()
1057                .map(Into::into)
1058                .collect::<Vec<_>>();
1059            log::debug!("Converted prompt to message: {} chars", content.len());
1060            log::debug!("Message id: {:?}", id);
1061            log::debug!("Message content: {:?}", content);
1062
1063            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1064        })
1065    }
1066
1067    fn resume(
1068        &self,
1069        session_id: &acp::SessionId,
1070        _cx: &App,
1071    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1072        Some(Rc::new(NativeAgentSessionResume {
1073            connection: self.clone(),
1074            session_id: session_id.clone(),
1075        }) as _)
1076    }
1077
1078    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1079        log::info!("Cancelling on session: {}", session_id);
1080        self.0.update(cx, |agent, cx| {
1081            if let Some(agent) = agent.sessions.get(session_id) {
1082                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1083            }
1084        });
1085    }
1086
1087    fn truncate(
1088        &self,
1089        session_id: &agent_client_protocol::SessionId,
1090        cx: &App,
1091    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1092        self.0.read_with(cx, |agent, _cx| {
1093            agent.sessions.get(session_id).map(|session| {
1094                Rc::new(NativeAgentSessionTruncate {
1095                    thread: session.thread.clone(),
1096                    acp_thread: session.acp_thread.clone(),
1097                }) as _
1098            })
1099        })
1100    }
1101
1102    fn set_title(
1103        &self,
1104        session_id: &acp::SessionId,
1105        _cx: &App,
1106    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1107        Some(Rc::new(NativeAgentSessionSetTitle {
1108            connection: self.clone(),
1109            session_id: session_id.clone(),
1110        }) as _)
1111    }
1112
1113    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1114        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1115    }
1116
1117    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1118        self
1119    }
1120}
1121
1122impl acp_thread::AgentTelemetry for NativeAgentConnection {
1123    fn agent_name(&self) -> String {
1124        "Zed".into()
1125    }
1126
1127    fn thread_data(
1128        &self,
1129        session_id: &acp::SessionId,
1130        cx: &mut App,
1131    ) -> Task<Result<serde_json::Value>> {
1132        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1133            return Task::ready(Err(anyhow!("Session not found")));
1134        };
1135
1136        let task = session.thread.read(cx).to_db(cx);
1137        cx.background_spawn(async move {
1138            serde_json::to_value(task.await).context("Failed to serialize thread")
1139        })
1140    }
1141}
1142
1143struct NativeAgentSessionTruncate {
1144    thread: Entity<Thread>,
1145    acp_thread: WeakEntity<AcpThread>,
1146}
1147
1148impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1149    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1150        match self.thread.update(cx, |thread, cx| {
1151            thread.truncate(message_id.clone(), cx)?;
1152            Ok(thread.latest_token_usage())
1153        }) {
1154            Ok(usage) => {
1155                self.acp_thread
1156                    .update(cx, |thread, cx| {
1157                        thread.update_token_usage(usage, cx);
1158                    })
1159                    .ok();
1160                Task::ready(Ok(()))
1161            }
1162            Err(error) => Task::ready(Err(error)),
1163        }
1164    }
1165}
1166
1167struct NativeAgentSessionResume {
1168    connection: NativeAgentConnection,
1169    session_id: acp::SessionId,
1170}
1171
1172impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1173    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1174        self.connection
1175            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1176                thread.update(cx, |thread, cx| thread.resume(cx))
1177            })
1178    }
1179}
1180
1181struct NativeAgentSessionSetTitle {
1182    connection: NativeAgentConnection,
1183    session_id: acp::SessionId,
1184}
1185
1186impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1187    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1188        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1189            return Task::ready(Err(anyhow!("session not found")));
1190        };
1191        let thread = session.thread.clone();
1192        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1193        Task::ready(Ok(()))
1194    }
1195}
1196
1197pub struct AcpThreadEnvironment {
1198    acp_thread: WeakEntity<AcpThread>,
1199}
1200
1201impl ThreadEnvironment for AcpThreadEnvironment {
1202    fn create_terminal(
1203        &self,
1204        command: String,
1205        cwd: Option<PathBuf>,
1206        output_byte_limit: Option<u64>,
1207        cx: &mut AsyncApp,
1208    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1209        let task = self.acp_thread.update(cx, |thread, cx| {
1210            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1211        });
1212
1213        let acp_thread = self.acp_thread.clone();
1214        cx.spawn(async move |cx| {
1215            let terminal = task?.await?;
1216
1217            let (drop_tx, drop_rx) = oneshot::channel();
1218            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1219
1220            cx.spawn(async move |cx| {
1221                drop_rx.await.ok();
1222                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1223            })
1224            .detach();
1225
1226            let handle = AcpTerminalHandle {
1227                terminal,
1228                _drop_tx: Some(drop_tx),
1229            };
1230
1231            Ok(Rc::new(handle) as _)
1232        })
1233    }
1234}
1235
1236pub struct AcpTerminalHandle {
1237    terminal: Entity<acp_thread::Terminal>,
1238    _drop_tx: Option<oneshot::Sender<()>>,
1239}
1240
1241impl TerminalHandle for AcpTerminalHandle {
1242    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1243        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1244    }
1245
1246    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1247        self.terminal
1248            .read_with(cx, |term, _cx| term.wait_for_exit())
1249    }
1250
1251    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1252        self.terminal
1253            .read_with(cx, |term, cx| term.current_output(cx))
1254    }
1255}
1256
1257#[cfg(test)]
1258mod internal_tests {
1259    use crate::HistoryEntryId;
1260
1261    use super::*;
1262    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1263    use fs::FakeFs;
1264    use gpui::TestAppContext;
1265    use indoc::formatdoc;
1266    use language_model::fake_provider::FakeLanguageModel;
1267    use serde_json::json;
1268    use settings::SettingsStore;
1269    use util::{path, rel_path::rel_path};
1270
1271    #[gpui::test]
1272    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1273        init_test(cx);
1274        let fs = FakeFs::new(cx.executor());
1275        fs.insert_tree(
1276            "/",
1277            json!({
1278                "a": {}
1279            }),
1280        )
1281        .await;
1282        let project = Project::test(fs.clone(), [], cx).await;
1283        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1284        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1285        let agent = NativeAgent::new(
1286            project.clone(),
1287            history_store,
1288            Templates::new(),
1289            None,
1290            fs.clone(),
1291            &mut cx.to_async(),
1292        )
1293        .await
1294        .unwrap();
1295        agent.read_with(cx, |agent, cx| {
1296            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1297        });
1298
1299        let worktree = project
1300            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1301            .await
1302            .unwrap();
1303        cx.run_until_parked();
1304        agent.read_with(cx, |agent, cx| {
1305            assert_eq!(
1306                agent.project_context.read(cx).worktrees,
1307                vec![WorktreeContext {
1308                    root_name: "a".into(),
1309                    abs_path: Path::new("/a").into(),
1310                    rules_file: None
1311                }]
1312            )
1313        });
1314
1315        // Creating `/a/.rules` updates the project context.
1316        fs.insert_file("/a/.rules", Vec::new()).await;
1317        cx.run_until_parked();
1318        agent.read_with(cx, |agent, cx| {
1319            let rules_entry = worktree
1320                .read(cx)
1321                .entry_for_path(rel_path(".rules"))
1322                .unwrap();
1323            assert_eq!(
1324                agent.project_context.read(cx).worktrees,
1325                vec![WorktreeContext {
1326                    root_name: "a".into(),
1327                    abs_path: Path::new("/a").into(),
1328                    rules_file: Some(RulesFileContext {
1329                        path_in_worktree: rel_path(".rules").into(),
1330                        text: "".into(),
1331                        project_entry_id: rules_entry.id.to_usize()
1332                    })
1333                }]
1334            )
1335        });
1336    }
1337
1338    #[gpui::test]
1339    async fn test_listing_models(cx: &mut TestAppContext) {
1340        init_test(cx);
1341        let fs = FakeFs::new(cx.executor());
1342        fs.insert_tree("/", json!({ "a": {}  })).await;
1343        let project = Project::test(fs.clone(), [], cx).await;
1344        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1345        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1346        let connection = NativeAgentConnection(
1347            NativeAgent::new(
1348                project.clone(),
1349                history_store,
1350                Templates::new(),
1351                None,
1352                fs.clone(),
1353                &mut cx.to_async(),
1354            )
1355            .await
1356            .unwrap(),
1357        );
1358
1359        // Create a thread/session
1360        let acp_thread = cx
1361            .update(|cx| {
1362                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1363            })
1364            .await
1365            .unwrap();
1366
1367        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1368
1369        let models = cx
1370            .update(|cx| {
1371                connection
1372                    .model_selector(&session_id)
1373                    .unwrap()
1374                    .list_models(cx)
1375            })
1376            .await
1377            .unwrap();
1378
1379        let acp_thread::AgentModelList::Grouped(models) = models else {
1380            panic!("Unexpected model group");
1381        };
1382        assert_eq!(
1383            models,
1384            IndexMap::from_iter([(
1385                AgentModelGroupName("Fake".into()),
1386                vec![AgentModelInfo {
1387                    id: acp::ModelId("fake/fake".into()),
1388                    name: "Fake".into(),
1389                    description: None,
1390                    icon: Some(ui::IconName::ZedAssistant),
1391                }]
1392            )])
1393        );
1394    }
1395
1396    #[gpui::test]
1397    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1398        init_test(cx);
1399        let fs = FakeFs::new(cx.executor());
1400        fs.create_dir(paths::settings_file().parent().unwrap())
1401            .await
1402            .unwrap();
1403        fs.insert_file(
1404            paths::settings_file(),
1405            json!({
1406                "agent": {
1407                    "default_model": {
1408                        "provider": "foo",
1409                        "model": "bar"
1410                    }
1411                }
1412            })
1413            .to_string()
1414            .into_bytes(),
1415        )
1416        .await;
1417        let project = Project::test(fs.clone(), [], cx).await;
1418
1419        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1420        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1421
1422        // Create the agent and connection
1423        let agent = NativeAgent::new(
1424            project.clone(),
1425            history_store,
1426            Templates::new(),
1427            None,
1428            fs.clone(),
1429            &mut cx.to_async(),
1430        )
1431        .await
1432        .unwrap();
1433        let connection = NativeAgentConnection(agent.clone());
1434
1435        // Create a thread/session
1436        let acp_thread = cx
1437            .update(|cx| {
1438                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1439            })
1440            .await
1441            .unwrap();
1442
1443        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1444
1445        // Select a model
1446        let selector = connection.model_selector(&session_id).unwrap();
1447        let model_id = acp::ModelId("fake/fake".into());
1448        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1449            .await
1450            .unwrap();
1451
1452        // Verify the thread has the selected model
1453        agent.read_with(cx, |agent, _| {
1454            let session = agent.sessions.get(&session_id).unwrap();
1455            session.thread.read_with(cx, |thread, _| {
1456                assert_eq!(thread.model().unwrap().id().0, "fake");
1457            });
1458        });
1459
1460        cx.run_until_parked();
1461
1462        // Verify settings file was updated
1463        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1464        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1465
1466        // Check that the agent settings contain the selected model
1467        assert_eq!(
1468            settings_json["agent"]["default_model"]["model"],
1469            json!("fake")
1470        );
1471        assert_eq!(
1472            settings_json["agent"]["default_model"]["provider"],
1473            json!("fake")
1474        );
1475    }
1476
1477    #[gpui::test]
1478    async fn test_save_load_thread(cx: &mut TestAppContext) {
1479        init_test(cx);
1480        let fs = FakeFs::new(cx.executor());
1481        fs.insert_tree(
1482            "/",
1483            json!({
1484                "a": {
1485                    "b.md": "Lorem"
1486                }
1487            }),
1488        )
1489        .await;
1490        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1491        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1492        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1493        let agent = NativeAgent::new(
1494            project.clone(),
1495            history_store.clone(),
1496            Templates::new(),
1497            None,
1498            fs.clone(),
1499            &mut cx.to_async(),
1500        )
1501        .await
1502        .unwrap();
1503        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1504
1505        let acp_thread = cx
1506            .update(|cx| {
1507                connection
1508                    .clone()
1509                    .new_thread(project.clone(), Path::new(""), cx)
1510            })
1511            .await
1512            .unwrap();
1513        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1514        let thread = agent.read_with(cx, |agent, _| {
1515            agent.sessions.get(&session_id).unwrap().thread.clone()
1516        });
1517
1518        // Ensure empty threads are not saved, even if they get mutated.
1519        let model = Arc::new(FakeLanguageModel::default());
1520        let summary_model = Arc::new(FakeLanguageModel::default());
1521        thread.update(cx, |thread, cx| {
1522            thread.set_model(model.clone(), cx);
1523            thread.set_summarization_model(Some(summary_model.clone()), cx);
1524        });
1525        cx.run_until_parked();
1526        assert_eq!(history_entries(&history_store, cx), vec![]);
1527
1528        let send = acp_thread.update(cx, |thread, cx| {
1529            thread.send(
1530                vec![
1531                    "What does ".into(),
1532                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1533                        name: "b.md".into(),
1534                        uri: MentionUri::File {
1535                            abs_path: path!("/a/b.md").into(),
1536                        }
1537                        .to_uri()
1538                        .to_string(),
1539                        annotations: None,
1540                        description: None,
1541                        mime_type: None,
1542                        size: None,
1543                        title: None,
1544                        meta: None,
1545                    }),
1546                    " mean?".into(),
1547                ],
1548                cx,
1549            )
1550        });
1551        let send = cx.foreground_executor().spawn(send);
1552        cx.run_until_parked();
1553
1554        model.send_last_completion_stream_text_chunk("Lorem.");
1555        model.end_last_completion_stream();
1556        cx.run_until_parked();
1557        summary_model
1558            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1559        summary_model.end_last_completion_stream();
1560
1561        send.await.unwrap();
1562        let uri = MentionUri::File {
1563            abs_path: path!("/a/b.md").into(),
1564        }
1565        .to_uri();
1566        acp_thread.read_with(cx, |thread, cx| {
1567            assert_eq!(
1568                thread.to_markdown(cx),
1569                formatdoc! {"
1570                    ## User
1571
1572                    What does [@b.md]({uri}) mean?
1573
1574                    ## Assistant
1575
1576                    Lorem.
1577
1578                "}
1579            )
1580        });
1581
1582        cx.run_until_parked();
1583
1584        // Drop the ACP thread, which should cause the session to be dropped as well.
1585        cx.update(|_| {
1586            drop(thread);
1587            drop(acp_thread);
1588        });
1589        agent.read_with(cx, |agent, _| {
1590            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1591        });
1592
1593        // Ensure the thread can be reloaded from disk.
1594        assert_eq!(
1595            history_entries(&history_store, cx),
1596            vec![(
1597                HistoryEntryId::AcpThread(session_id.clone()),
1598                format!("Explaining {}", path!("/a/b.md"))
1599            )]
1600        );
1601        let acp_thread = agent
1602            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1603            .await
1604            .unwrap();
1605        acp_thread.read_with(cx, |thread, cx| {
1606            assert_eq!(
1607                thread.to_markdown(cx),
1608                formatdoc! {"
1609                    ## User
1610
1611                    What does [@b.md]({uri}) mean?
1612
1613                    ## Assistant
1614
1615                    Lorem.
1616
1617                "}
1618            )
1619        });
1620    }
1621
1622    fn history_entries(
1623        history: &Entity<HistoryStore>,
1624        cx: &mut TestAppContext,
1625    ) -> Vec<(HistoryEntryId, String)> {
1626        history.read_with(cx, |history, _| {
1627            history
1628                .entries()
1629                .map(|e| (e.id(), e.title().to_string()))
1630                .collect::<Vec<_>>()
1631        })
1632    }
1633
1634    fn init_test(cx: &mut TestAppContext) {
1635        env_logger::try_init().ok();
1636        cx.update(|cx| {
1637            let settings_store = SettingsStore::test(cx);
1638            cx.set_global(settings_store);
1639            Project::init_settings(cx);
1640            agent_settings::init(cx);
1641            language::init(cx);
1642            LanguageModelRegistry::test(cx);
1643        });
1644    }
1645}