agent.rs

   1mod db;
   2mod edit_agent;
   3mod history_store;
   4mod legacy_thread;
   5mod native_agent_server;
   6pub mod outline;
   7mod templates;
   8mod thread;
   9mod tools;
  10
  11#[cfg(test)]
  12mod tests;
  13
  14pub use db::*;
  15pub use history_store::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use tools::*;
  20
  21use acp_thread::{AcpThread, AgentModelSelector};
  22use agent_client_protocol as acp;
  23use anyhow::{Context as _, Result, anyhow};
  24use chrono::{DateTime, Utc};
  25use collections::{HashSet, IndexMap};
  26use fs::Fs;
  27use futures::channel::{mpsc, oneshot};
  28use futures::future::Shared;
  29use futures::{StreamExt, future};
  30use gpui::{
  31    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  32};
  33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  34use project::{Project, ProjectItem, ProjectPath, Worktree};
  35use prompt_store::{
  36    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  37    WorktreeContext,
  38};
  39use serde::{Deserialize, Serialize};
  40use settings::{LanguageModelSelection, update_settings_file};
  41use std::any::Any;
  42use std::collections::HashMap;
  43use std::path::{Path, PathBuf};
  44use std::rc::Rc;
  45use std::sync::Arc;
  46use util::ResultExt;
  47use util::rel_path::RelPath;
  48
  49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  50pub struct ProjectSnapshot {
  51    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  52    pub timestamp: DateTime<Utc>,
  53}
  54
  55pub struct RulesLoadingError {
  56    pub message: SharedString,
  57}
  58
  59/// Holds both the internal Thread and the AcpThread for a session
  60struct Session {
  61    /// The internal thread that processes messages
  62    thread: Entity<Thread>,
  63    /// The ACP thread that handles protocol communication
  64    acp_thread: WeakEntity<acp_thread::AcpThread>,
  65    pending_save: Task<()>,
  66    _subscriptions: Vec<Subscription>,
  67}
  68
  69pub struct LanguageModels {
  70    /// Access language model by ID
  71    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  72    /// Cached list for returning language model information
  73    model_list: acp_thread::AgentModelList,
  74    refresh_models_rx: watch::Receiver<()>,
  75    refresh_models_tx: watch::Sender<()>,
  76    _authenticate_all_providers_task: Task<()>,
  77}
  78
  79impl LanguageModels {
  80    fn new(cx: &mut App) -> Self {
  81        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  82
  83        let mut this = Self {
  84            models: HashMap::default(),
  85            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  86            refresh_models_rx,
  87            refresh_models_tx,
  88            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  89        };
  90        this.refresh_list(cx);
  91        this
  92    }
  93
  94    fn refresh_list(&mut self, cx: &App) {
  95        let providers = LanguageModelRegistry::global(cx)
  96            .read(cx)
  97            .providers()
  98            .into_iter()
  99            .filter(|provider| provider.is_authenticated(cx))
 100            .collect::<Vec<_>>();
 101
 102        let mut language_model_list = IndexMap::default();
 103        let mut recommended_models = HashSet::default();
 104
 105        let mut recommended = Vec::new();
 106        for provider in &providers {
 107            for model in provider.recommended_models(cx) {
 108                recommended_models.insert((model.provider_id(), model.id()));
 109                recommended.push(Self::map_language_model_to_info(&model, provider));
 110            }
 111        }
 112        if !recommended.is_empty() {
 113            language_model_list.insert(
 114                acp_thread::AgentModelGroupName("Recommended".into()),
 115                recommended,
 116            );
 117        }
 118
 119        let mut models = HashMap::default();
 120        for provider in providers {
 121            let mut provider_models = Vec::new();
 122            for model in provider.provided_models(cx) {
 123                let model_info = Self::map_language_model_to_info(&model, &provider);
 124                let model_id = model_info.id.clone();
 125                provider_models.push(model_info);
 126                models.insert(model_id, model);
 127            }
 128            if !provider_models.is_empty() {
 129                language_model_list.insert(
 130                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 131                    provider_models,
 132                );
 133            }
 134        }
 135
 136        self.models = models;
 137        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 138        self.refresh_models_tx.send(()).ok();
 139    }
 140
 141    fn watch(&self) -> watch::Receiver<()> {
 142        self.refresh_models_rx.clone()
 143    }
 144
 145    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 146        self.models.get(model_id).cloned()
 147    }
 148
 149    fn map_language_model_to_info(
 150        model: &Arc<dyn LanguageModel>,
 151        provider: &Arc<dyn LanguageModelProvider>,
 152    ) -> acp_thread::AgentModelInfo {
 153        acp_thread::AgentModelInfo {
 154            id: Self::model_id(model),
 155            name: model.name().0,
 156            description: None,
 157            icon: Some(provider.icon()),
 158        }
 159    }
 160
 161    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 162        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 163    }
 164
 165    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 166        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 167            .read(cx)
 168            .providers()
 169            .iter()
 170            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 171            .collect::<Vec<_>>();
 172
 173        cx.background_spawn(async move {
 174            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 175                if let Err(err) = authenticate_task.await {
 176                    match err {
 177                        language_model::AuthenticateError::CredentialsNotFound => {
 178                            // Since we're authenticating these providers in the
 179                            // background for the purposes of populating the
 180                            // language selector, we don't care about providers
 181                            // where the credentials are not found.
 182                        }
 183                        language_model::AuthenticateError::ConnectionRefused => {
 184                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 185                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 186                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 187                        }
 188                        _ => {
 189                            // Some providers have noisy failure states that we
 190                            // don't want to spam the logs with every time the
 191                            // language model selector is initialized.
 192                            //
 193                            // Ideally these should have more clear failure modes
 194                            // that we know are safe to ignore here, like what we do
 195                            // with `CredentialsNotFound` above.
 196                            match provider_id.0.as_ref() {
 197                                "lmstudio" | "ollama" => {
 198                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 199                                    //
 200                                    // These fail noisily, so we don't log them.
 201                                }
 202                                "copilot_chat" => {
 203                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 204                                }
 205                                _ => {
 206                                    log::error!(
 207                                        "Failed to authenticate provider: {}: {err:#}",
 208                                        provider_name.0
 209                                    );
 210                                }
 211                            }
 212                        }
 213                    }
 214                }
 215            }
 216        })
 217    }
 218}
 219
 220pub struct NativeAgent {
 221    /// Session ID -> Session mapping
 222    sessions: HashMap<acp::SessionId, Session>,
 223    history: Entity<HistoryStore>,
 224    /// Shared project context for all threads
 225    project_context: Entity<ProjectContext>,
 226    project_context_needs_refresh: watch::Sender<()>,
 227    _maintain_project_context: Task<Result<()>>,
 228    context_server_registry: Entity<ContextServerRegistry>,
 229    /// Shared templates for all threads
 230    templates: Arc<Templates>,
 231    /// Cached model information
 232    models: LanguageModels,
 233    project: Entity<Project>,
 234    prompt_store: Option<Entity<PromptStore>>,
 235    fs: Arc<dyn Fs>,
 236    _subscriptions: Vec<Subscription>,
 237}
 238
 239impl NativeAgent {
 240    pub async fn new(
 241        project: Entity<Project>,
 242        history: Entity<HistoryStore>,
 243        templates: Arc<Templates>,
 244        prompt_store: Option<Entity<PromptStore>>,
 245        fs: Arc<dyn Fs>,
 246        cx: &mut AsyncApp,
 247    ) -> Result<Entity<NativeAgent>> {
 248        log::debug!("Creating new NativeAgent");
 249
 250        let project_context = cx
 251            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 252            .await;
 253
 254        cx.new(|cx| {
 255            let mut subscriptions = vec![
 256                cx.subscribe(&project, Self::handle_project_event),
 257                cx.subscribe(
 258                    &LanguageModelRegistry::global(cx),
 259                    Self::handle_models_updated_event,
 260                ),
 261            ];
 262            if let Some(prompt_store) = prompt_store.as_ref() {
 263                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 264            }
 265
 266            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 267                watch::channel(());
 268            Self {
 269                sessions: HashMap::new(),
 270                history,
 271                project_context: cx.new(|_| project_context),
 272                project_context_needs_refresh: project_context_needs_refresh_tx,
 273                _maintain_project_context: cx.spawn(async move |this, cx| {
 274                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 275                }),
 276                context_server_registry: cx.new(|cx| {
 277                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 278                }),
 279                templates,
 280                models: LanguageModels::new(cx),
 281                project,
 282                prompt_store,
 283                fs,
 284                _subscriptions: subscriptions,
 285            }
 286        })
 287    }
 288
 289    fn register_session(
 290        &mut self,
 291        thread_handle: Entity<Thread>,
 292        cx: &mut Context<Self>,
 293    ) -> Entity<AcpThread> {
 294        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 295
 296        let thread = thread_handle.read(cx);
 297        let session_id = thread.id().clone();
 298        let title = thread.title();
 299        let project = thread.project.clone();
 300        let action_log = thread.action_log.clone();
 301        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 302        let acp_thread = cx.new(|cx| {
 303            acp_thread::AcpThread::new(
 304                title,
 305                connection,
 306                project.clone(),
 307                action_log.clone(),
 308                session_id.clone(),
 309                prompt_capabilities_rx,
 310                cx,
 311            )
 312        });
 313
 314        let registry = LanguageModelRegistry::read_global(cx);
 315        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 316
 317        thread_handle.update(cx, |thread, cx| {
 318            thread.set_summarization_model(summarization_model, cx);
 319            thread.add_default_tools(
 320                Rc::new(AcpThreadEnvironment {
 321                    acp_thread: acp_thread.downgrade(),
 322                }) as _,
 323                cx,
 324            )
 325        });
 326
 327        let subscriptions = vec![
 328            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 329                this.sessions.remove(acp_thread.session_id());
 330            }),
 331            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 332            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 333            cx.observe(&thread_handle, move |this, thread, cx| {
 334                this.save_thread(thread, cx)
 335            }),
 336        ];
 337
 338        self.sessions.insert(
 339            session_id,
 340            Session {
 341                thread: thread_handle,
 342                acp_thread: acp_thread.downgrade(),
 343                _subscriptions: subscriptions,
 344                pending_save: Task::ready(()),
 345            },
 346        );
 347        acp_thread
 348    }
 349
 350    pub fn models(&self) -> &LanguageModels {
 351        &self.models
 352    }
 353
 354    async fn maintain_project_context(
 355        this: WeakEntity<Self>,
 356        mut needs_refresh: watch::Receiver<()>,
 357        cx: &mut AsyncApp,
 358    ) -> Result<()> {
 359        while needs_refresh.changed().await.is_ok() {
 360            let project_context = this
 361                .update(cx, |this, cx| {
 362                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 363                })?
 364                .await;
 365            this.update(cx, |this, cx| {
 366                this.project_context = cx.new(|_| project_context);
 367            })?;
 368        }
 369
 370        Ok(())
 371    }
 372
 373    fn build_project_context(
 374        project: &Entity<Project>,
 375        prompt_store: Option<&Entity<PromptStore>>,
 376        cx: &mut App,
 377    ) -> Task<ProjectContext> {
 378        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 379        let worktree_tasks = worktrees
 380            .into_iter()
 381            .map(|worktree| {
 382                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 383            })
 384            .collect::<Vec<_>>();
 385        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 386            prompt_store.read_with(cx, |prompt_store, cx| {
 387                let prompts = prompt_store.default_prompt_metadata();
 388                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 389                    let contents = prompt_store.load(prompt_metadata.id, cx);
 390                    async move { (contents.await, prompt_metadata) }
 391                });
 392                cx.background_spawn(future::join_all(load_tasks))
 393            })
 394        } else {
 395            Task::ready(vec![])
 396        };
 397
 398        cx.spawn(async move |_cx| {
 399            let (worktrees, default_user_rules) =
 400                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 401
 402            let worktrees = worktrees
 403                .into_iter()
 404                .map(|(worktree, _rules_error)| {
 405                    // TODO: show error message
 406                    // if let Some(rules_error) = rules_error {
 407                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 408                    // }
 409                    worktree
 410                })
 411                .collect::<Vec<_>>();
 412
 413            let default_user_rules = default_user_rules
 414                .into_iter()
 415                .flat_map(|(contents, prompt_metadata)| match contents {
 416                    Ok(contents) => Some(UserRulesContext {
 417                        uuid: prompt_metadata.id.user_id()?,
 418                        title: prompt_metadata.title.map(|title| title.to_string()),
 419                        contents,
 420                    }),
 421                    Err(_err) => {
 422                        // TODO: show error message
 423                        // this.update(cx, |_, cx| {
 424                        //     cx.emit(RulesLoadingError {
 425                        //         message: format!("{err:?}").into(),
 426                        //     });
 427                        // })
 428                        // .ok();
 429                        None
 430                    }
 431                })
 432                .collect::<Vec<_>>();
 433
 434            ProjectContext::new(worktrees, default_user_rules)
 435        })
 436    }
 437
 438    fn load_worktree_info_for_system_prompt(
 439        worktree: Entity<Worktree>,
 440        project: Entity<Project>,
 441        cx: &mut App,
 442    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 443        let tree = worktree.read(cx);
 444        let root_name = tree.root_name_str().into();
 445        let abs_path = tree.abs_path();
 446
 447        let mut context = WorktreeContext {
 448            root_name,
 449            abs_path,
 450            rules_file: None,
 451        };
 452
 453        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 454        let Some(rules_task) = rules_task else {
 455            return Task::ready((context, None));
 456        };
 457
 458        cx.spawn(async move |_| {
 459            let (rules_file, rules_file_error) = match rules_task.await {
 460                Ok(rules_file) => (Some(rules_file), None),
 461                Err(err) => (
 462                    None,
 463                    Some(RulesLoadingError {
 464                        message: format!("{err}").into(),
 465                    }),
 466                ),
 467            };
 468            context.rules_file = rules_file;
 469            (context, rules_file_error)
 470        })
 471    }
 472
 473    fn load_worktree_rules_file(
 474        worktree: Entity<Worktree>,
 475        project: Entity<Project>,
 476        cx: &mut App,
 477    ) -> Option<Task<Result<RulesFileContext>>> {
 478        let worktree = worktree.read(cx);
 479        let worktree_id = worktree.id();
 480        let selected_rules_file = RULES_FILE_NAMES
 481            .into_iter()
 482            .filter_map(|name| {
 483                worktree
 484                    .entry_for_path(RelPath::unix(name).unwrap())
 485                    .filter(|entry| entry.is_file())
 486                    .map(|entry| entry.path.clone())
 487            })
 488            .next();
 489
 490        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 491        // supported. This doesn't seem to occur often in GitHub repositories.
 492        selected_rules_file.map(|path_in_worktree| {
 493            let project_path = ProjectPath {
 494                worktree_id,
 495                path: path_in_worktree.clone(),
 496            };
 497            let buffer_task =
 498                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 499            let rope_task = cx.spawn(async move |cx| {
 500                buffer_task.await?.read_with(cx, |buffer, cx| {
 501                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 502                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 503                })?
 504            });
 505            // Build a string from the rope on a background thread.
 506            cx.background_spawn(async move {
 507                let (project_entry_id, rope) = rope_task.await?;
 508                anyhow::Ok(RulesFileContext {
 509                    path_in_worktree,
 510                    text: rope.to_string().trim().to_string(),
 511                    project_entry_id: project_entry_id.to_usize(),
 512                })
 513            })
 514        })
 515    }
 516
 517    fn handle_thread_title_updated(
 518        &mut self,
 519        thread: Entity<Thread>,
 520        _: &TitleUpdated,
 521        cx: &mut Context<Self>,
 522    ) {
 523        let session_id = thread.read(cx).id();
 524        let Some(session) = self.sessions.get(session_id) else {
 525            return;
 526        };
 527        let thread = thread.downgrade();
 528        let acp_thread = session.acp_thread.clone();
 529        cx.spawn(async move |_, cx| {
 530            let title = thread.read_with(cx, |thread, _| thread.title())?;
 531            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 532            task.await
 533        })
 534        .detach_and_log_err(cx);
 535    }
 536
 537    fn handle_thread_token_usage_updated(
 538        &mut self,
 539        thread: Entity<Thread>,
 540        usage: &TokenUsageUpdated,
 541        cx: &mut Context<Self>,
 542    ) {
 543        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 544            return;
 545        };
 546        session
 547            .acp_thread
 548            .update(cx, |acp_thread, cx| {
 549                acp_thread.update_token_usage(usage.0.clone(), cx);
 550            })
 551            .ok();
 552    }
 553
 554    fn handle_project_event(
 555        &mut self,
 556        _project: Entity<Project>,
 557        event: &project::Event,
 558        _cx: &mut Context<Self>,
 559    ) {
 560        match event {
 561            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 562                self.project_context_needs_refresh.send(()).ok();
 563            }
 564            project::Event::WorktreeUpdatedEntries(_, items) => {
 565                if items.iter().any(|(path, _, _)| {
 566                    RULES_FILE_NAMES
 567                        .iter()
 568                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 569                }) {
 570                    self.project_context_needs_refresh.send(()).ok();
 571                }
 572            }
 573            _ => {}
 574        }
 575    }
 576
 577    fn handle_prompts_updated_event(
 578        &mut self,
 579        _prompt_store: Entity<PromptStore>,
 580        _event: &prompt_store::PromptsUpdatedEvent,
 581        _cx: &mut Context<Self>,
 582    ) {
 583        self.project_context_needs_refresh.send(()).ok();
 584    }
 585
 586    fn handle_models_updated_event(
 587        &mut self,
 588        _registry: Entity<LanguageModelRegistry>,
 589        _event: &language_model::Event,
 590        cx: &mut Context<Self>,
 591    ) {
 592        self.models.refresh_list(cx);
 593
 594        let registry = LanguageModelRegistry::read_global(cx);
 595        let default_model = registry.default_model().map(|m| m.model);
 596        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 597
 598        for session in self.sessions.values_mut() {
 599            session.thread.update(cx, |thread, cx| {
 600                if thread.model().is_none()
 601                    && let Some(model) = default_model.clone()
 602                {
 603                    thread.set_model(model, cx);
 604                    cx.notify();
 605                }
 606                thread.set_summarization_model(summarization_model.clone(), cx);
 607            });
 608        }
 609    }
 610
 611    pub fn load_thread(
 612        &mut self,
 613        id: acp::SessionId,
 614        cx: &mut Context<Self>,
 615    ) -> Task<Result<Entity<Thread>>> {
 616        let database_future = ThreadsDatabase::connect(cx);
 617        cx.spawn(async move |this, cx| {
 618            let database = database_future.await.map_err(|err| anyhow!(err))?;
 619            let db_thread = database
 620                .load_thread(id.clone())
 621                .await?
 622                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 623
 624            this.update(cx, |this, cx| {
 625                let summarization_model = LanguageModelRegistry::read_global(cx)
 626                    .thread_summary_model()
 627                    .map(|c| c.model);
 628
 629                cx.new(|cx| {
 630                    let mut thread = Thread::from_db(
 631                        id.clone(),
 632                        db_thread,
 633                        this.project.clone(),
 634                        this.project_context.clone(),
 635                        this.context_server_registry.clone(),
 636                        this.templates.clone(),
 637                        cx,
 638                    );
 639                    thread.set_summarization_model(summarization_model, cx);
 640                    thread
 641                })
 642            })
 643        })
 644    }
 645
 646    pub fn open_thread(
 647        &mut self,
 648        id: acp::SessionId,
 649        cx: &mut Context<Self>,
 650    ) -> Task<Result<Entity<AcpThread>>> {
 651        let task = self.load_thread(id, cx);
 652        cx.spawn(async move |this, cx| {
 653            let thread = task.await?;
 654            let acp_thread =
 655                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 656            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 657            cx.update(|cx| {
 658                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 659            })?
 660            .await?;
 661            Ok(acp_thread)
 662        })
 663    }
 664
 665    pub fn thread_summary(
 666        &mut self,
 667        id: acp::SessionId,
 668        cx: &mut Context<Self>,
 669    ) -> Task<Result<SharedString>> {
 670        let thread = self.open_thread(id.clone(), cx);
 671        cx.spawn(async move |this, cx| {
 672            let acp_thread = thread.await?;
 673            let result = this
 674                .update(cx, |this, cx| {
 675                    this.sessions
 676                        .get(&id)
 677                        .unwrap()
 678                        .thread
 679                        .update(cx, |thread, cx| thread.summary(cx))
 680                })?
 681                .await
 682                .context("Failed to generate summary")?;
 683            drop(acp_thread);
 684            Ok(result)
 685        })
 686    }
 687
 688    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 689        if thread.read(cx).is_empty() {
 690            return;
 691        }
 692
 693        let database_future = ThreadsDatabase::connect(cx);
 694        let (id, db_thread) =
 695            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 696        let Some(session) = self.sessions.get_mut(&id) else {
 697            return;
 698        };
 699        let history = self.history.clone();
 700        session.pending_save = cx.spawn(async move |_, cx| {
 701            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 702                return;
 703            };
 704            let db_thread = db_thread.await;
 705            database.save_thread(id, db_thread).await.log_err();
 706            history.update(cx, |history, cx| history.reload(cx)).ok();
 707        });
 708    }
 709}
 710
 711/// Wrapper struct that implements the AgentConnection trait
 712#[derive(Clone)]
 713pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 714
 715impl NativeAgentConnection {
 716    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 717        self.0
 718            .read(cx)
 719            .sessions
 720            .get(session_id)
 721            .map(|session| session.thread.clone())
 722    }
 723
 724    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 725        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 726    }
 727
 728    fn run_turn(
 729        &self,
 730        session_id: acp::SessionId,
 731        cx: &mut App,
 732        f: impl 'static
 733        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 734    ) -> Task<Result<acp::PromptResponse>> {
 735        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 736            agent
 737                .sessions
 738                .get_mut(&session_id)
 739                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 740        }) else {
 741            return Task::ready(Err(anyhow!("Session not found")));
 742        };
 743        log::debug!("Found session for: {}", session_id);
 744
 745        let response_stream = match f(thread, cx) {
 746            Ok(stream) => stream,
 747            Err(err) => return Task::ready(Err(err)),
 748        };
 749        Self::handle_thread_events(response_stream, acp_thread, cx)
 750    }
 751
 752    fn handle_thread_events(
 753        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 754        acp_thread: WeakEntity<AcpThread>,
 755        cx: &App,
 756    ) -> Task<Result<acp::PromptResponse>> {
 757        cx.spawn(async move |cx| {
 758            // Handle response stream and forward to session.acp_thread
 759            while let Some(result) = events.next().await {
 760                match result {
 761                    Ok(event) => {
 762                        log::trace!("Received completion event: {:?}", event);
 763
 764                        match event {
 765                            ThreadEvent::UserMessage(message) => {
 766                                acp_thread.update(cx, |thread, cx| {
 767                                    for content in message.content {
 768                                        thread.push_user_content_block(
 769                                            Some(message.id.clone()),
 770                                            content.into(),
 771                                            cx,
 772                                        );
 773                                    }
 774                                })?;
 775                            }
 776                            ThreadEvent::AgentText(text) => {
 777                                acp_thread.update(cx, |thread, cx| {
 778                                    thread.push_assistant_content_block(text.into(), false, cx)
 779                                })?;
 780                            }
 781                            ThreadEvent::AgentThinking(text) => {
 782                                acp_thread.update(cx, |thread, cx| {
 783                                    thread.push_assistant_content_block(text.into(), true, cx)
 784                                })?;
 785                            }
 786                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 787                                tool_call,
 788                                options,
 789                                response,
 790                            }) => {
 791                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 792                                    thread.request_tool_call_authorization(
 793                                        tool_call, options, true, cx,
 794                                    )
 795                                })??;
 796                                cx.background_spawn(async move {
 797                                    if let acp::RequestPermissionOutcome::Selected(
 798                                        acp::SelectedPermissionOutcome { option_id, .. },
 799                                    ) = outcome_task.await
 800                                    {
 801                                        response
 802                                            .send(option_id)
 803                                            .map(|_| anyhow!("authorization receiver was dropped"))
 804                                            .log_err();
 805                                    }
 806                                })
 807                                .detach();
 808                            }
 809                            ThreadEvent::ToolCall(tool_call) => {
 810                                acp_thread.update(cx, |thread, cx| {
 811                                    thread.upsert_tool_call(tool_call, cx)
 812                                })??;
 813                            }
 814                            ThreadEvent::ToolCallUpdate(update) => {
 815                                acp_thread.update(cx, |thread, cx| {
 816                                    thread.update_tool_call(update, cx)
 817                                })??;
 818                            }
 819                            ThreadEvent::Retry(status) => {
 820                                acp_thread.update(cx, |thread, cx| {
 821                                    thread.update_retry_status(status, cx)
 822                                })?;
 823                            }
 824                            ThreadEvent::Stop(stop_reason) => {
 825                                log::debug!("Assistant message complete: {:?}", stop_reason);
 826                                return Ok(acp::PromptResponse::new(stop_reason));
 827                            }
 828                        }
 829                    }
 830                    Err(e) => {
 831                        log::error!("Error in model response stream: {:?}", e);
 832                        return Err(e);
 833                    }
 834                }
 835            }
 836
 837            log::debug!("Response stream completed");
 838            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
 839        })
 840    }
 841}
 842
 843struct NativeAgentModelSelector {
 844    session_id: acp::SessionId,
 845    connection: NativeAgentConnection,
 846}
 847
 848impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 849    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 850        log::debug!("NativeAgentConnection::list_models called");
 851        let list = self.connection.0.read(cx).models.model_list.clone();
 852        Task::ready(if list.is_empty() {
 853            Err(anyhow::anyhow!("No models available"))
 854        } else {
 855            Ok(list)
 856        })
 857    }
 858
 859    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 860        log::debug!(
 861            "Setting model for session {}: {}",
 862            self.session_id,
 863            model_id
 864        );
 865        let Some(thread) = self
 866            .connection
 867            .0
 868            .read(cx)
 869            .sessions
 870            .get(&self.session_id)
 871            .map(|session| session.thread.clone())
 872        else {
 873            return Task::ready(Err(anyhow!("Session not found")));
 874        };
 875
 876        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 877            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 878        };
 879
 880        thread.update(cx, |thread, cx| {
 881            thread.set_model(model.clone(), cx);
 882        });
 883
 884        update_settings_file(
 885            self.connection.0.read(cx).fs.clone(),
 886            cx,
 887            move |settings, _cx| {
 888                let provider = model.provider_id().0.to_string();
 889                let model = model.id().0.to_string();
 890                settings
 891                    .agent
 892                    .get_or_insert_default()
 893                    .set_model(LanguageModelSelection {
 894                        provider: provider.into(),
 895                        model,
 896                    });
 897            },
 898        );
 899
 900        Task::ready(Ok(()))
 901    }
 902
 903    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 904        let Some(thread) = self
 905            .connection
 906            .0
 907            .read(cx)
 908            .sessions
 909            .get(&self.session_id)
 910            .map(|session| session.thread.clone())
 911        else {
 912            return Task::ready(Err(anyhow!("Session not found")));
 913        };
 914        let Some(model) = thread.read(cx).model() else {
 915            return Task::ready(Err(anyhow!("Model not found")));
 916        };
 917        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 918        else {
 919            return Task::ready(Err(anyhow!("Provider not found")));
 920        };
 921        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 922            model, &provider,
 923        )))
 924    }
 925
 926    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 927        Some(self.connection.0.read(cx).models.watch())
 928    }
 929
 930    fn should_render_footer(&self) -> bool {
 931        true
 932    }
 933}
 934
 935impl acp_thread::AgentConnection for NativeAgentConnection {
 936    fn telemetry_id(&self) -> SharedString {
 937        "zed".into()
 938    }
 939
 940    fn new_thread(
 941        self: Rc<Self>,
 942        project: Entity<Project>,
 943        cwd: &Path,
 944        cx: &mut App,
 945    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 946        let agent = self.0.clone();
 947        log::debug!("Creating new thread for project at: {:?}", cwd);
 948
 949        cx.spawn(async move |cx| {
 950            log::debug!("Starting thread creation in async context");
 951
 952            // Create Thread
 953            let thread = agent.update(
 954                cx,
 955                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 956                    // Fetch default model from registry settings
 957                    let registry = LanguageModelRegistry::read_global(cx);
 958                    // Log available models for debugging
 959                    let available_count = registry.available_models(cx).count();
 960                    log::debug!("Total available models: {}", available_count);
 961
 962                    let default_model = registry.default_model().and_then(|default_model| {
 963                        agent
 964                            .models
 965                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 966                    });
 967                    Ok(cx.new(|cx| {
 968                        Thread::new(
 969                            project.clone(),
 970                            agent.project_context.clone(),
 971                            agent.context_server_registry.clone(),
 972                            agent.templates.clone(),
 973                            default_model,
 974                            cx,
 975                        )
 976                    }))
 977                },
 978            )??;
 979            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 980        })
 981    }
 982
 983    fn auth_methods(&self) -> &[acp::AuthMethod] {
 984        &[] // No auth for in-process
 985    }
 986
 987    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 988        Task::ready(Ok(()))
 989    }
 990
 991    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 992        Some(Rc::new(NativeAgentModelSelector {
 993            session_id: session_id.clone(),
 994            connection: self.clone(),
 995        }) as Rc<dyn AgentModelSelector>)
 996    }
 997
 998    fn prompt(
 999        &self,
1000        id: Option<acp_thread::UserMessageId>,
1001        params: acp::PromptRequest,
1002        cx: &mut App,
1003    ) -> Task<Result<acp::PromptResponse>> {
1004        let id = id.expect("UserMessageId is required");
1005        let session_id = params.session_id.clone();
1006        log::info!("Received prompt request for session: {}", session_id);
1007        log::debug!("Prompt blocks count: {}", params.prompt.len());
1008        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1009
1010        self.run_turn(session_id, cx, move |thread, cx| {
1011            let content: Vec<UserMessageContent> = params
1012                .prompt
1013                .into_iter()
1014                .map(|block| UserMessageContent::from_content_block(block, path_style))
1015                .collect::<Vec<_>>();
1016            log::debug!("Converted prompt to message: {} chars", content.len());
1017            log::debug!("Message id: {:?}", id);
1018            log::debug!("Message content: {:?}", content);
1019
1020            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1021        })
1022    }
1023
1024    fn resume(
1025        &self,
1026        session_id: &acp::SessionId,
1027        _cx: &App,
1028    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1029        Some(Rc::new(NativeAgentSessionResume {
1030            connection: self.clone(),
1031            session_id: session_id.clone(),
1032        }) as _)
1033    }
1034
1035    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1036        log::info!("Cancelling on session: {}", session_id);
1037        self.0.update(cx, |agent, cx| {
1038            if let Some(agent) = agent.sessions.get(session_id) {
1039                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1040            }
1041        });
1042    }
1043
1044    fn truncate(
1045        &self,
1046        session_id: &agent_client_protocol::SessionId,
1047        cx: &App,
1048    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1049        self.0.read_with(cx, |agent, _cx| {
1050            agent.sessions.get(session_id).map(|session| {
1051                Rc::new(NativeAgentSessionTruncate {
1052                    thread: session.thread.clone(),
1053                    acp_thread: session.acp_thread.clone(),
1054                }) as _
1055            })
1056        })
1057    }
1058
1059    fn set_title(
1060        &self,
1061        session_id: &acp::SessionId,
1062        _cx: &App,
1063    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1064        Some(Rc::new(NativeAgentSessionSetTitle {
1065            connection: self.clone(),
1066            session_id: session_id.clone(),
1067        }) as _)
1068    }
1069
1070    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1071        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1072    }
1073
1074    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1075        self
1076    }
1077}
1078
1079impl acp_thread::AgentTelemetry for NativeAgentConnection {
1080    fn thread_data(
1081        &self,
1082        session_id: &acp::SessionId,
1083        cx: &mut App,
1084    ) -> Task<Result<serde_json::Value>> {
1085        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1086            return Task::ready(Err(anyhow!("Session not found")));
1087        };
1088
1089        let task = session.thread.read(cx).to_db(cx);
1090        cx.background_spawn(async move {
1091            serde_json::to_value(task.await).context("Failed to serialize thread")
1092        })
1093    }
1094}
1095
1096struct NativeAgentSessionTruncate {
1097    thread: Entity<Thread>,
1098    acp_thread: WeakEntity<AcpThread>,
1099}
1100
1101impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1102    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1103        match self.thread.update(cx, |thread, cx| {
1104            thread.truncate(message_id.clone(), cx)?;
1105            Ok(thread.latest_token_usage())
1106        }) {
1107            Ok(usage) => {
1108                self.acp_thread
1109                    .update(cx, |thread, cx| {
1110                        thread.update_token_usage(usage, cx);
1111                    })
1112                    .ok();
1113                Task::ready(Ok(()))
1114            }
1115            Err(error) => Task::ready(Err(error)),
1116        }
1117    }
1118}
1119
1120struct NativeAgentSessionResume {
1121    connection: NativeAgentConnection,
1122    session_id: acp::SessionId,
1123}
1124
1125impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1126    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1127        self.connection
1128            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1129                thread.update(cx, |thread, cx| thread.resume(cx))
1130            })
1131    }
1132}
1133
1134struct NativeAgentSessionSetTitle {
1135    connection: NativeAgentConnection,
1136    session_id: acp::SessionId,
1137}
1138
1139impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1140    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1141        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1142            return Task::ready(Err(anyhow!("session not found")));
1143        };
1144        let thread = session.thread.clone();
1145        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1146        Task::ready(Ok(()))
1147    }
1148}
1149
1150pub struct AcpThreadEnvironment {
1151    acp_thread: WeakEntity<AcpThread>,
1152}
1153
1154impl ThreadEnvironment for AcpThreadEnvironment {
1155    fn create_terminal(
1156        &self,
1157        command: String,
1158        cwd: Option<PathBuf>,
1159        output_byte_limit: Option<u64>,
1160        cx: &mut AsyncApp,
1161    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1162        let task = self.acp_thread.update(cx, |thread, cx| {
1163            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1164        });
1165
1166        let acp_thread = self.acp_thread.clone();
1167        cx.spawn(async move |cx| {
1168            let terminal = task?.await?;
1169
1170            let (drop_tx, drop_rx) = oneshot::channel();
1171            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1172
1173            cx.spawn(async move |cx| {
1174                drop_rx.await.ok();
1175                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1176            })
1177            .detach();
1178
1179            let handle = AcpTerminalHandle {
1180                terminal,
1181                _drop_tx: Some(drop_tx),
1182            };
1183
1184            Ok(Rc::new(handle) as _)
1185        })
1186    }
1187}
1188
1189pub struct AcpTerminalHandle {
1190    terminal: Entity<acp_thread::Terminal>,
1191    _drop_tx: Option<oneshot::Sender<()>>,
1192}
1193
1194impl TerminalHandle for AcpTerminalHandle {
1195    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1196        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1197    }
1198
1199    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1200        self.terminal
1201            .read_with(cx, |term, _cx| term.wait_for_exit())
1202    }
1203
1204    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1205        self.terminal
1206            .read_with(cx, |term, cx| term.current_output(cx))
1207    }
1208
1209    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1210        cx.update(|cx| {
1211            self.terminal.update(cx, |terminal, cx| {
1212                terminal.kill(cx);
1213            });
1214        })?;
1215        Ok(())
1216    }
1217}
1218
1219#[cfg(test)]
1220mod internal_tests {
1221    use crate::HistoryEntryId;
1222
1223    use super::*;
1224    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1225    use fs::FakeFs;
1226    use gpui::TestAppContext;
1227    use indoc::formatdoc;
1228    use language_model::fake_provider::FakeLanguageModel;
1229    use serde_json::json;
1230    use settings::SettingsStore;
1231    use util::{path, rel_path::rel_path};
1232
1233    #[gpui::test]
1234    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1235        init_test(cx);
1236        let fs = FakeFs::new(cx.executor());
1237        fs.insert_tree(
1238            "/",
1239            json!({
1240                "a": {}
1241            }),
1242        )
1243        .await;
1244        let project = Project::test(fs.clone(), [], cx).await;
1245        let text_thread_store =
1246            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1247        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1248        let agent = NativeAgent::new(
1249            project.clone(),
1250            history_store,
1251            Templates::new(),
1252            None,
1253            fs.clone(),
1254            &mut cx.to_async(),
1255        )
1256        .await
1257        .unwrap();
1258        agent.read_with(cx, |agent, cx| {
1259            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1260        });
1261
1262        let worktree = project
1263            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1264            .await
1265            .unwrap();
1266        cx.run_until_parked();
1267        agent.read_with(cx, |agent, cx| {
1268            assert_eq!(
1269                agent.project_context.read(cx).worktrees,
1270                vec![WorktreeContext {
1271                    root_name: "a".into(),
1272                    abs_path: Path::new("/a").into(),
1273                    rules_file: None
1274                }]
1275            )
1276        });
1277
1278        // Creating `/a/.rules` updates the project context.
1279        fs.insert_file("/a/.rules", Vec::new()).await;
1280        cx.run_until_parked();
1281        agent.read_with(cx, |agent, cx| {
1282            let rules_entry = worktree
1283                .read(cx)
1284                .entry_for_path(rel_path(".rules"))
1285                .unwrap();
1286            assert_eq!(
1287                agent.project_context.read(cx).worktrees,
1288                vec![WorktreeContext {
1289                    root_name: "a".into(),
1290                    abs_path: Path::new("/a").into(),
1291                    rules_file: Some(RulesFileContext {
1292                        path_in_worktree: rel_path(".rules").into(),
1293                        text: "".into(),
1294                        project_entry_id: rules_entry.id.to_usize()
1295                    })
1296                }]
1297            )
1298        });
1299    }
1300
1301    #[gpui::test]
1302    async fn test_listing_models(cx: &mut TestAppContext) {
1303        init_test(cx);
1304        let fs = FakeFs::new(cx.executor());
1305        fs.insert_tree("/", json!({ "a": {}  })).await;
1306        let project = Project::test(fs.clone(), [], cx).await;
1307        let text_thread_store =
1308            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1309        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1310        let connection = NativeAgentConnection(
1311            NativeAgent::new(
1312                project.clone(),
1313                history_store,
1314                Templates::new(),
1315                None,
1316                fs.clone(),
1317                &mut cx.to_async(),
1318            )
1319            .await
1320            .unwrap(),
1321        );
1322
1323        // Create a thread/session
1324        let acp_thread = cx
1325            .update(|cx| {
1326                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1327            })
1328            .await
1329            .unwrap();
1330
1331        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1332
1333        let models = cx
1334            .update(|cx| {
1335                connection
1336                    .model_selector(&session_id)
1337                    .unwrap()
1338                    .list_models(cx)
1339            })
1340            .await
1341            .unwrap();
1342
1343        let acp_thread::AgentModelList::Grouped(models) = models else {
1344            panic!("Unexpected model group");
1345        };
1346        assert_eq!(
1347            models,
1348            IndexMap::from_iter([(
1349                AgentModelGroupName("Fake".into()),
1350                vec![AgentModelInfo {
1351                    id: acp::ModelId::new("fake/fake"),
1352                    name: "Fake".into(),
1353                    description: None,
1354                    icon: Some(ui::IconName::ZedAssistant),
1355                }]
1356            )])
1357        );
1358    }
1359
1360    #[gpui::test]
1361    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1362        init_test(cx);
1363        let fs = FakeFs::new(cx.executor());
1364        fs.create_dir(paths::settings_file().parent().unwrap())
1365            .await
1366            .unwrap();
1367        fs.insert_file(
1368            paths::settings_file(),
1369            json!({
1370                "agent": {
1371                    "default_model": {
1372                        "provider": "foo",
1373                        "model": "bar"
1374                    }
1375                }
1376            })
1377            .to_string()
1378            .into_bytes(),
1379        )
1380        .await;
1381        let project = Project::test(fs.clone(), [], cx).await;
1382
1383        let text_thread_store =
1384            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1385        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1386
1387        // Create the agent and connection
1388        let agent = NativeAgent::new(
1389            project.clone(),
1390            history_store,
1391            Templates::new(),
1392            None,
1393            fs.clone(),
1394            &mut cx.to_async(),
1395        )
1396        .await
1397        .unwrap();
1398        let connection = NativeAgentConnection(agent.clone());
1399
1400        // Create a thread/session
1401        let acp_thread = cx
1402            .update(|cx| {
1403                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1404            })
1405            .await
1406            .unwrap();
1407
1408        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1409
1410        // Select a model
1411        let selector = connection.model_selector(&session_id).unwrap();
1412        let model_id = acp::ModelId::new("fake/fake");
1413        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1414            .await
1415            .unwrap();
1416
1417        // Verify the thread has the selected model
1418        agent.read_with(cx, |agent, _| {
1419            let session = agent.sessions.get(&session_id).unwrap();
1420            session.thread.read_with(cx, |thread, _| {
1421                assert_eq!(thread.model().unwrap().id().0, "fake");
1422            });
1423        });
1424
1425        cx.run_until_parked();
1426
1427        // Verify settings file was updated
1428        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1429        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1430
1431        // Check that the agent settings contain the selected model
1432        assert_eq!(
1433            settings_json["agent"]["default_model"]["model"],
1434            json!("fake")
1435        );
1436        assert_eq!(
1437            settings_json["agent"]["default_model"]["provider"],
1438            json!("fake")
1439        );
1440    }
1441
1442    #[gpui::test]
1443    async fn test_save_load_thread(cx: &mut TestAppContext) {
1444        init_test(cx);
1445        let fs = FakeFs::new(cx.executor());
1446        fs.insert_tree(
1447            "/",
1448            json!({
1449                "a": {
1450                    "b.md": "Lorem"
1451                }
1452            }),
1453        )
1454        .await;
1455        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1456        let text_thread_store =
1457            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1458        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1459        let agent = NativeAgent::new(
1460            project.clone(),
1461            history_store.clone(),
1462            Templates::new(),
1463            None,
1464            fs.clone(),
1465            &mut cx.to_async(),
1466        )
1467        .await
1468        .unwrap();
1469        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1470
1471        let acp_thread = cx
1472            .update(|cx| {
1473                connection
1474                    .clone()
1475                    .new_thread(project.clone(), Path::new(""), cx)
1476            })
1477            .await
1478            .unwrap();
1479        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1480        let thread = agent.read_with(cx, |agent, _| {
1481            agent.sessions.get(&session_id).unwrap().thread.clone()
1482        });
1483
1484        // Ensure empty threads are not saved, even if they get mutated.
1485        let model = Arc::new(FakeLanguageModel::default());
1486        let summary_model = Arc::new(FakeLanguageModel::default());
1487        thread.update(cx, |thread, cx| {
1488            thread.set_model(model.clone(), cx);
1489            thread.set_summarization_model(Some(summary_model.clone()), cx);
1490        });
1491        cx.run_until_parked();
1492        assert_eq!(history_entries(&history_store, cx), vec![]);
1493
1494        let send = acp_thread.update(cx, |thread, cx| {
1495            thread.send(
1496                vec![
1497                    "What does ".into(),
1498                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1499                        "b.md",
1500                        MentionUri::File {
1501                            abs_path: path!("/a/b.md").into(),
1502                        }
1503                        .to_uri()
1504                        .to_string(),
1505                    )),
1506                    " mean?".into(),
1507                ],
1508                cx,
1509            )
1510        });
1511        let send = cx.foreground_executor().spawn(send);
1512        cx.run_until_parked();
1513
1514        model.send_last_completion_stream_text_chunk("Lorem.");
1515        model.end_last_completion_stream();
1516        cx.run_until_parked();
1517        summary_model
1518            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1519        summary_model.end_last_completion_stream();
1520
1521        send.await.unwrap();
1522        let uri = MentionUri::File {
1523            abs_path: path!("/a/b.md").into(),
1524        }
1525        .to_uri();
1526        acp_thread.read_with(cx, |thread, cx| {
1527            assert_eq!(
1528                thread.to_markdown(cx),
1529                formatdoc! {"
1530                    ## User
1531
1532                    What does [@b.md]({uri}) mean?
1533
1534                    ## Assistant
1535
1536                    Lorem.
1537
1538                "}
1539            )
1540        });
1541
1542        cx.run_until_parked();
1543
1544        // Drop the ACP thread, which should cause the session to be dropped as well.
1545        cx.update(|_| {
1546            drop(thread);
1547            drop(acp_thread);
1548        });
1549        agent.read_with(cx, |agent, _| {
1550            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1551        });
1552
1553        // Ensure the thread can be reloaded from disk.
1554        assert_eq!(
1555            history_entries(&history_store, cx),
1556            vec![(
1557                HistoryEntryId::AcpThread(session_id.clone()),
1558                format!("Explaining {}", path!("/a/b.md"))
1559            )]
1560        );
1561        let acp_thread = agent
1562            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1563            .await
1564            .unwrap();
1565        acp_thread.read_with(cx, |thread, cx| {
1566            assert_eq!(
1567                thread.to_markdown(cx),
1568                formatdoc! {"
1569                    ## User
1570
1571                    What does [@b.md]({uri}) mean?
1572
1573                    ## Assistant
1574
1575                    Lorem.
1576
1577                "}
1578            )
1579        });
1580    }
1581
1582    fn history_entries(
1583        history: &Entity<HistoryStore>,
1584        cx: &mut TestAppContext,
1585    ) -> Vec<(HistoryEntryId, String)> {
1586        history.read_with(cx, |history, _| {
1587            history
1588                .entries()
1589                .map(|e| (e.id(), e.title().to_string()))
1590                .collect::<Vec<_>>()
1591        })
1592    }
1593
1594    fn init_test(cx: &mut TestAppContext) {
1595        env_logger::try_init().ok();
1596        cx.update(|cx| {
1597            let settings_store = SettingsStore::test(cx);
1598            cx.set_global(settings_store);
1599
1600            LanguageModelRegistry::test(cx);
1601        });
1602    }
1603}