agent.rs

   1mod db;
   2mod edit_agent;
   3mod history_store;
   4mod legacy_thread;
   5mod native_agent_server;
   6pub mod outline;
   7mod templates;
   8mod thread;
   9mod tools;
  10
  11#[cfg(test)]
  12mod tests;
  13
  14pub use db::*;
  15pub use history_store::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use tools::*;
  20
  21use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
  22use agent_client_protocol as acp;
  23use anyhow::{Context as _, Result, anyhow};
  24use chrono::{DateTime, Utc};
  25use collections::{HashSet, IndexMap};
  26use fs::Fs;
  27use futures::channel::{mpsc, oneshot};
  28use futures::future::Shared;
  29use futures::{StreamExt, future};
  30use gpui::{
  31    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  32};
  33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  34use project::{Project, ProjectItem, ProjectPath, Worktree};
  35use prompt_store::{
  36    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  37    WorktreeContext,
  38};
  39use serde::{Deserialize, Serialize};
  40use settings::{LanguageModelSelection, update_settings_file};
  41use std::any::Any;
  42use std::collections::HashMap;
  43use std::path::{Path, PathBuf};
  44use std::rc::Rc;
  45use std::sync::Arc;
  46use util::ResultExt;
  47use util::rel_path::RelPath;
  48
  49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  50pub struct ProjectSnapshot {
  51    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  52    pub timestamp: DateTime<Utc>,
  53}
  54
  55pub struct RulesLoadingError {
  56    pub message: SharedString,
  57}
  58
  59/// Holds both the internal Thread and the AcpThread for a session
  60struct Session {
  61    /// The internal thread that processes messages
  62    thread: Entity<Thread>,
  63    /// The ACP thread that handles protocol communication
  64    acp_thread: WeakEntity<acp_thread::AcpThread>,
  65    pending_save: Task<()>,
  66    _subscriptions: Vec<Subscription>,
  67}
  68
  69pub struct LanguageModels {
  70    /// Access language model by ID
  71    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  72    /// Cached list for returning language model information
  73    model_list: acp_thread::AgentModelList,
  74    refresh_models_rx: watch::Receiver<()>,
  75    refresh_models_tx: watch::Sender<()>,
  76    _authenticate_all_providers_task: Task<()>,
  77}
  78
  79impl LanguageModels {
  80    fn new(cx: &mut App) -> Self {
  81        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  82
  83        let mut this = Self {
  84            models: HashMap::default(),
  85            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  86            refresh_models_rx,
  87            refresh_models_tx,
  88            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  89        };
  90        this.refresh_list(cx);
  91        this
  92    }
  93
  94    fn refresh_list(&mut self, cx: &App) {
  95        let providers = LanguageModelRegistry::global(cx)
  96            .read(cx)
  97            .visible_providers()
  98            .into_iter()
  99            .filter(|provider| provider.is_authenticated(cx))
 100            .collect::<Vec<_>>();
 101
 102        let mut language_model_list = IndexMap::default();
 103        let mut recommended_models = HashSet::default();
 104
 105        let mut recommended = Vec::new();
 106        for provider in &providers {
 107            for model in provider.recommended_models(cx) {
 108                recommended_models.insert((model.provider_id(), model.id()));
 109                recommended.push(Self::map_language_model_to_info(&model, provider));
 110            }
 111        }
 112        if !recommended.is_empty() {
 113            language_model_list.insert(
 114                acp_thread::AgentModelGroupName("Recommended".into()),
 115                recommended,
 116            );
 117        }
 118
 119        let mut models = HashMap::default();
 120        for provider in providers {
 121            let mut provider_models = Vec::new();
 122            for model in provider.provided_models(cx) {
 123                let model_info = Self::map_language_model_to_info(&model, &provider);
 124                let model_id = model_info.id.clone();
 125                provider_models.push(model_info);
 126                models.insert(model_id, model);
 127            }
 128            if !provider_models.is_empty() {
 129                language_model_list.insert(
 130                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 131                    provider_models,
 132                );
 133            }
 134        }
 135
 136        self.models = models;
 137        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 138        self.refresh_models_tx.send(()).ok();
 139    }
 140
 141    fn watch(&self) -> watch::Receiver<()> {
 142        self.refresh_models_rx.clone()
 143    }
 144
 145    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 146        self.models.get(model_id).cloned()
 147    }
 148
 149    fn map_language_model_to_info(
 150        model: &Arc<dyn LanguageModel>,
 151        provider: &Arc<dyn LanguageModelProvider>,
 152    ) -> acp_thread::AgentModelInfo {
 153        let icon = if let Some(path) = provider.icon_path() {
 154            Some(AgentModelIcon::Path(path))
 155        } else {
 156            Some(AgentModelIcon::Named(provider.icon()))
 157        };
 158        acp_thread::AgentModelInfo {
 159            id: Self::model_id(model),
 160            name: model.name().0,
 161            description: None,
 162            icon,
 163        }
 164    }
 165
 166    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 167        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 168    }
 169
 170    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 171        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 172            .read(cx)
 173            .providers()
 174            .iter()
 175            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 176            .collect::<Vec<_>>();
 177
 178        cx.background_spawn(async move {
 179            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 180                if let Err(err) = authenticate_task.await {
 181                    match err {
 182                        language_model::AuthenticateError::CredentialsNotFound => {
 183                            // Since we're authenticating these providers in the
 184                            // background for the purposes of populating the
 185                            // language selector, we don't care about providers
 186                            // where the credentials are not found.
 187                        }
 188                        language_model::AuthenticateError::ConnectionRefused => {
 189                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 190                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 191                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 192                        }
 193                        _ => {
 194                            // Some providers have noisy failure states that we
 195                            // don't want to spam the logs with every time the
 196                            // language model selector is initialized.
 197                            //
 198                            // Ideally these should have more clear failure modes
 199                            // that we know are safe to ignore here, like what we do
 200                            // with `CredentialsNotFound` above.
 201                            match provider_id.0.as_ref() {
 202                                "lmstudio" | "ollama" => {
 203                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 204                                    //
 205                                    // These fail noisily, so we don't log them.
 206                                }
 207                                "copilot_chat" => {
 208                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 209                                }
 210                                _ => {
 211                                    log::error!(
 212                                        "Failed to authenticate provider: {}: {err:#}",
 213                                        provider_name.0
 214                                    );
 215                                }
 216                            }
 217                        }
 218                    }
 219                }
 220            }
 221        })
 222    }
 223}
 224
 225pub struct NativeAgent {
 226    /// Session ID -> Session mapping
 227    sessions: HashMap<acp::SessionId, Session>,
 228    history: Entity<HistoryStore>,
 229    /// Shared project context for all threads
 230    project_context: Entity<ProjectContext>,
 231    project_context_needs_refresh: watch::Sender<()>,
 232    _maintain_project_context: Task<Result<()>>,
 233    context_server_registry: Entity<ContextServerRegistry>,
 234    /// Shared templates for all threads
 235    templates: Arc<Templates>,
 236    /// Cached model information
 237    models: LanguageModels,
 238    project: Entity<Project>,
 239    prompt_store: Option<Entity<PromptStore>>,
 240    fs: Arc<dyn Fs>,
 241    _subscriptions: Vec<Subscription>,
 242}
 243
 244impl NativeAgent {
 245    pub async fn new(
 246        project: Entity<Project>,
 247        history: Entity<HistoryStore>,
 248        templates: Arc<Templates>,
 249        prompt_store: Option<Entity<PromptStore>>,
 250        fs: Arc<dyn Fs>,
 251        cx: &mut AsyncApp,
 252    ) -> Result<Entity<NativeAgent>> {
 253        log::debug!("Creating new NativeAgent");
 254
 255        let project_context = cx
 256            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 257            .await;
 258
 259        cx.new(|cx| {
 260            let mut subscriptions = vec![
 261                cx.subscribe(&project, Self::handle_project_event),
 262                cx.subscribe(
 263                    &LanguageModelRegistry::global(cx),
 264                    Self::handle_models_updated_event,
 265                ),
 266            ];
 267            if let Some(prompt_store) = prompt_store.as_ref() {
 268                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 269            }
 270
 271            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 272                watch::channel(());
 273            Self {
 274                sessions: HashMap::new(),
 275                history,
 276                project_context: cx.new(|_| project_context),
 277                project_context_needs_refresh: project_context_needs_refresh_tx,
 278                _maintain_project_context: cx.spawn(async move |this, cx| {
 279                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 280                }),
 281                context_server_registry: cx.new(|cx| {
 282                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 283                }),
 284                templates,
 285                models: LanguageModels::new(cx),
 286                project,
 287                prompt_store,
 288                fs,
 289                _subscriptions: subscriptions,
 290            }
 291        })
 292    }
 293
 294    fn register_session(
 295        &mut self,
 296        thread_handle: Entity<Thread>,
 297        cx: &mut Context<Self>,
 298    ) -> Entity<AcpThread> {
 299        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 300
 301        let thread = thread_handle.read(cx);
 302        let session_id = thread.id().clone();
 303        let title = thread.title();
 304        let project = thread.project.clone();
 305        let action_log = thread.action_log.clone();
 306        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 307        let acp_thread = cx.new(|cx| {
 308            acp_thread::AcpThread::new(
 309                title,
 310                connection,
 311                project.clone(),
 312                action_log.clone(),
 313                session_id.clone(),
 314                prompt_capabilities_rx,
 315                cx,
 316            )
 317        });
 318
 319        let registry = LanguageModelRegistry::read_global(cx);
 320        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 321
 322        thread_handle.update(cx, |thread, cx| {
 323            thread.set_summarization_model(summarization_model, cx);
 324            thread.add_default_tools(
 325                Rc::new(AcpThreadEnvironment {
 326                    acp_thread: acp_thread.downgrade(),
 327                }) as _,
 328                cx,
 329            )
 330        });
 331
 332        let subscriptions = vec![
 333            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 334                this.sessions.remove(acp_thread.session_id());
 335            }),
 336            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 337            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 338            cx.observe(&thread_handle, move |this, thread, cx| {
 339                this.save_thread(thread, cx)
 340            }),
 341        ];
 342
 343        self.sessions.insert(
 344            session_id,
 345            Session {
 346                thread: thread_handle,
 347                acp_thread: acp_thread.downgrade(),
 348                _subscriptions: subscriptions,
 349                pending_save: Task::ready(()),
 350            },
 351        );
 352        acp_thread
 353    }
 354
 355    pub fn models(&self) -> &LanguageModels {
 356        &self.models
 357    }
 358
 359    async fn maintain_project_context(
 360        this: WeakEntity<Self>,
 361        mut needs_refresh: watch::Receiver<()>,
 362        cx: &mut AsyncApp,
 363    ) -> Result<()> {
 364        while needs_refresh.changed().await.is_ok() {
 365            let project_context = this
 366                .update(cx, |this, cx| {
 367                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 368                })?
 369                .await;
 370            this.update(cx, |this, cx| {
 371                this.project_context = cx.new(|_| project_context);
 372            })?;
 373        }
 374
 375        Ok(())
 376    }
 377
 378    fn build_project_context(
 379        project: &Entity<Project>,
 380        prompt_store: Option<&Entity<PromptStore>>,
 381        cx: &mut App,
 382    ) -> Task<ProjectContext> {
 383        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 384        let worktree_tasks = worktrees
 385            .into_iter()
 386            .map(|worktree| {
 387                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 388            })
 389            .collect::<Vec<_>>();
 390        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 391            prompt_store.read_with(cx, |prompt_store, cx| {
 392                let prompts = prompt_store.default_prompt_metadata();
 393                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 394                    let contents = prompt_store.load(prompt_metadata.id, cx);
 395                    async move { (contents.await, prompt_metadata) }
 396                });
 397                cx.background_spawn(future::join_all(load_tasks))
 398            })
 399        } else {
 400            Task::ready(vec![])
 401        };
 402
 403        cx.spawn(async move |_cx| {
 404            let (worktrees, default_user_rules) =
 405                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 406
 407            let worktrees = worktrees
 408                .into_iter()
 409                .map(|(worktree, _rules_error)| {
 410                    // TODO: show error message
 411                    // if let Some(rules_error) = rules_error {
 412                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 413                    // }
 414                    worktree
 415                })
 416                .collect::<Vec<_>>();
 417
 418            let default_user_rules = default_user_rules
 419                .into_iter()
 420                .flat_map(|(contents, prompt_metadata)| match contents {
 421                    Ok(contents) => Some(UserRulesContext {
 422                        uuid: match prompt_metadata.id {
 423                            prompt_store::PromptId::User { uuid } => uuid,
 424                            prompt_store::PromptId::EditWorkflow => return None,
 425                        },
 426                        title: prompt_metadata.title.map(|title| title.to_string()),
 427                        contents,
 428                    }),
 429                    Err(_err) => {
 430                        // TODO: show error message
 431                        // this.update(cx, |_, cx| {
 432                        //     cx.emit(RulesLoadingError {
 433                        //         message: format!("{err:?}").into(),
 434                        //     });
 435                        // })
 436                        // .ok();
 437                        None
 438                    }
 439                })
 440                .collect::<Vec<_>>();
 441
 442            ProjectContext::new(worktrees, default_user_rules)
 443        })
 444    }
 445
 446    fn load_worktree_info_for_system_prompt(
 447        worktree: Entity<Worktree>,
 448        project: Entity<Project>,
 449        cx: &mut App,
 450    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 451        let tree = worktree.read(cx);
 452        let root_name = tree.root_name_str().into();
 453        let abs_path = tree.abs_path();
 454
 455        let mut context = WorktreeContext {
 456            root_name,
 457            abs_path,
 458            rules_file: None,
 459        };
 460
 461        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 462        let Some(rules_task) = rules_task else {
 463            return Task::ready((context, None));
 464        };
 465
 466        cx.spawn(async move |_| {
 467            let (rules_file, rules_file_error) = match rules_task.await {
 468                Ok(rules_file) => (Some(rules_file), None),
 469                Err(err) => (
 470                    None,
 471                    Some(RulesLoadingError {
 472                        message: format!("{err}").into(),
 473                    }),
 474                ),
 475            };
 476            context.rules_file = rules_file;
 477            (context, rules_file_error)
 478        })
 479    }
 480
 481    fn load_worktree_rules_file(
 482        worktree: Entity<Worktree>,
 483        project: Entity<Project>,
 484        cx: &mut App,
 485    ) -> Option<Task<Result<RulesFileContext>>> {
 486        let worktree = worktree.read(cx);
 487        let worktree_id = worktree.id();
 488        let selected_rules_file = RULES_FILE_NAMES
 489            .into_iter()
 490            .filter_map(|name| {
 491                worktree
 492                    .entry_for_path(RelPath::unix(name).unwrap())
 493                    .filter(|entry| entry.is_file())
 494                    .map(|entry| entry.path.clone())
 495            })
 496            .next();
 497
 498        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 499        // supported. This doesn't seem to occur often in GitHub repositories.
 500        selected_rules_file.map(|path_in_worktree| {
 501            let project_path = ProjectPath {
 502                worktree_id,
 503                path: path_in_worktree.clone(),
 504            };
 505            let buffer_task =
 506                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 507            let rope_task = cx.spawn(async move |cx| {
 508                buffer_task.await?.read_with(cx, |buffer, cx| {
 509                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 510                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 511                })?
 512            });
 513            // Build a string from the rope on a background thread.
 514            cx.background_spawn(async move {
 515                let (project_entry_id, rope) = rope_task.await?;
 516                anyhow::Ok(RulesFileContext {
 517                    path_in_worktree,
 518                    text: rope.to_string().trim().to_string(),
 519                    project_entry_id: project_entry_id.to_usize(),
 520                })
 521            })
 522        })
 523    }
 524
 525    fn handle_thread_title_updated(
 526        &mut self,
 527        thread: Entity<Thread>,
 528        _: &TitleUpdated,
 529        cx: &mut Context<Self>,
 530    ) {
 531        let session_id = thread.read(cx).id();
 532        let Some(session) = self.sessions.get(session_id) else {
 533            return;
 534        };
 535        let thread = thread.downgrade();
 536        let acp_thread = session.acp_thread.clone();
 537        cx.spawn(async move |_, cx| {
 538            let title = thread.read_with(cx, |thread, _| thread.title())?;
 539            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 540            task.await
 541        })
 542        .detach_and_log_err(cx);
 543    }
 544
 545    fn handle_thread_token_usage_updated(
 546        &mut self,
 547        thread: Entity<Thread>,
 548        usage: &TokenUsageUpdated,
 549        cx: &mut Context<Self>,
 550    ) {
 551        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 552            return;
 553        };
 554        session
 555            .acp_thread
 556            .update(cx, |acp_thread, cx| {
 557                acp_thread.update_token_usage(usage.0.clone(), cx);
 558            })
 559            .ok();
 560    }
 561
 562    fn handle_project_event(
 563        &mut self,
 564        _project: Entity<Project>,
 565        event: &project::Event,
 566        _cx: &mut Context<Self>,
 567    ) {
 568        match event {
 569            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 570                self.project_context_needs_refresh.send(()).ok();
 571            }
 572            project::Event::WorktreeUpdatedEntries(_, items) => {
 573                if items.iter().any(|(path, _, _)| {
 574                    RULES_FILE_NAMES
 575                        .iter()
 576                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 577                }) {
 578                    self.project_context_needs_refresh.send(()).ok();
 579                }
 580            }
 581            _ => {}
 582        }
 583    }
 584
 585    fn handle_prompts_updated_event(
 586        &mut self,
 587        _prompt_store: Entity<PromptStore>,
 588        _event: &prompt_store::PromptsUpdatedEvent,
 589        _cx: &mut Context<Self>,
 590    ) {
 591        self.project_context_needs_refresh.send(()).ok();
 592    }
 593
 594    fn handle_models_updated_event(
 595        &mut self,
 596        _registry: Entity<LanguageModelRegistry>,
 597        _event: &language_model::Event,
 598        cx: &mut Context<Self>,
 599    ) {
 600        self.models.refresh_list(cx);
 601
 602        let registry = LanguageModelRegistry::read_global(cx);
 603        let default_model = registry.default_model().map(|m| m.model);
 604        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 605
 606        for session in self.sessions.values_mut() {
 607            session.thread.update(cx, |thread, cx| {
 608                if thread.model().is_none()
 609                    && let Some(model) = default_model.clone()
 610                {
 611                    thread.set_model(model, cx);
 612                    cx.notify();
 613                }
 614                thread.set_summarization_model(summarization_model.clone(), cx);
 615            });
 616        }
 617    }
 618
 619    pub fn load_thread(
 620        &mut self,
 621        id: acp::SessionId,
 622        cx: &mut Context<Self>,
 623    ) -> Task<Result<Entity<Thread>>> {
 624        let database_future = ThreadsDatabase::connect(cx);
 625        cx.spawn(async move |this, cx| {
 626            let database = database_future.await.map_err(|err| anyhow!(err))?;
 627            let db_thread = database
 628                .load_thread(id.clone())
 629                .await?
 630                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 631
 632            this.update(cx, |this, cx| {
 633                let summarization_model = LanguageModelRegistry::read_global(cx)
 634                    .thread_summary_model()
 635                    .map(|c| c.model);
 636
 637                cx.new(|cx| {
 638                    let mut thread = Thread::from_db(
 639                        id.clone(),
 640                        db_thread,
 641                        this.project.clone(),
 642                        this.project_context.clone(),
 643                        this.context_server_registry.clone(),
 644                        this.templates.clone(),
 645                        cx,
 646                    );
 647                    thread.set_summarization_model(summarization_model, cx);
 648                    thread
 649                })
 650            })
 651        })
 652    }
 653
 654    pub fn open_thread(
 655        &mut self,
 656        id: acp::SessionId,
 657        cx: &mut Context<Self>,
 658    ) -> Task<Result<Entity<AcpThread>>> {
 659        let task = self.load_thread(id, cx);
 660        cx.spawn(async move |this, cx| {
 661            let thread = task.await?;
 662            let acp_thread =
 663                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 664            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 665            cx.update(|cx| {
 666                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 667            })?
 668            .await?;
 669            Ok(acp_thread)
 670        })
 671    }
 672
 673    pub fn thread_summary(
 674        &mut self,
 675        id: acp::SessionId,
 676        cx: &mut Context<Self>,
 677    ) -> Task<Result<SharedString>> {
 678        let thread = self.open_thread(id.clone(), cx);
 679        cx.spawn(async move |this, cx| {
 680            let acp_thread = thread.await?;
 681            let result = this
 682                .update(cx, |this, cx| {
 683                    this.sessions
 684                        .get(&id)
 685                        .unwrap()
 686                        .thread
 687                        .update(cx, |thread, cx| thread.summary(cx))
 688                })?
 689                .await
 690                .context("Failed to generate summary")?;
 691            drop(acp_thread);
 692            Ok(result)
 693        })
 694    }
 695
 696    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 697        if thread.read(cx).is_empty() {
 698            return;
 699        }
 700
 701        let database_future = ThreadsDatabase::connect(cx);
 702        let (id, db_thread) =
 703            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 704        let Some(session) = self.sessions.get_mut(&id) else {
 705            return;
 706        };
 707        let history = self.history.clone();
 708        session.pending_save = cx.spawn(async move |_, cx| {
 709            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 710                return;
 711            };
 712            let db_thread = db_thread.await;
 713            database.save_thread(id, db_thread).await.log_err();
 714            history.update(cx, |history, cx| history.reload(cx)).ok();
 715        });
 716    }
 717}
 718
 719/// Wrapper struct that implements the AgentConnection trait
 720#[derive(Clone)]
 721pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 722
 723impl NativeAgentConnection {
 724    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 725        self.0
 726            .read(cx)
 727            .sessions
 728            .get(session_id)
 729            .map(|session| session.thread.clone())
 730    }
 731
 732    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 733        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 734    }
 735
 736    fn run_turn(
 737        &self,
 738        session_id: acp::SessionId,
 739        cx: &mut App,
 740        f: impl 'static
 741        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 742    ) -> Task<Result<acp::PromptResponse>> {
 743        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 744            agent
 745                .sessions
 746                .get_mut(&session_id)
 747                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 748        }) else {
 749            return Task::ready(Err(anyhow!("Session not found")));
 750        };
 751        log::debug!("Found session for: {}", session_id);
 752
 753        let response_stream = match f(thread, cx) {
 754            Ok(stream) => stream,
 755            Err(err) => return Task::ready(Err(err)),
 756        };
 757        Self::handle_thread_events(response_stream, acp_thread, cx)
 758    }
 759
 760    fn handle_thread_events(
 761        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 762        acp_thread: WeakEntity<AcpThread>,
 763        cx: &App,
 764    ) -> Task<Result<acp::PromptResponse>> {
 765        cx.spawn(async move |cx| {
 766            // Handle response stream and forward to session.acp_thread
 767            while let Some(result) = events.next().await {
 768                match result {
 769                    Ok(event) => {
 770                        log::trace!("Received completion event: {:?}", event);
 771
 772                        match event {
 773                            ThreadEvent::UserMessage(message) => {
 774                                acp_thread.update(cx, |thread, cx| {
 775                                    for content in message.content {
 776                                        thread.push_user_content_block(
 777                                            Some(message.id.clone()),
 778                                            content.into(),
 779                                            cx,
 780                                        );
 781                                    }
 782                                })?;
 783                            }
 784                            ThreadEvent::AgentText(text) => {
 785                                acp_thread.update(cx, |thread, cx| {
 786                                    thread.push_assistant_content_block(text.into(), false, cx)
 787                                })?;
 788                            }
 789                            ThreadEvent::AgentThinking(text) => {
 790                                acp_thread.update(cx, |thread, cx| {
 791                                    thread.push_assistant_content_block(text.into(), true, cx)
 792                                })?;
 793                            }
 794                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 795                                tool_call,
 796                                options,
 797                                response,
 798                            }) => {
 799                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 800                                    thread.request_tool_call_authorization(
 801                                        tool_call, options, true, cx,
 802                                    )
 803                                })??;
 804                                cx.background_spawn(async move {
 805                                    if let acp::RequestPermissionOutcome::Selected(
 806                                        acp::SelectedPermissionOutcome { option_id, .. },
 807                                    ) = outcome_task.await
 808                                    {
 809                                        response
 810                                            .send(option_id)
 811                                            .map(|_| anyhow!("authorization receiver was dropped"))
 812                                            .log_err();
 813                                    }
 814                                })
 815                                .detach();
 816                            }
 817                            ThreadEvent::ToolCall(tool_call) => {
 818                                acp_thread.update(cx, |thread, cx| {
 819                                    thread.upsert_tool_call(tool_call, cx)
 820                                })??;
 821                            }
 822                            ThreadEvent::ToolCallUpdate(update) => {
 823                                acp_thread.update(cx, |thread, cx| {
 824                                    thread.update_tool_call(update, cx)
 825                                })??;
 826                            }
 827                            ThreadEvent::Retry(status) => {
 828                                acp_thread.update(cx, |thread, cx| {
 829                                    thread.update_retry_status(status, cx)
 830                                })?;
 831                            }
 832                            ThreadEvent::Stop(stop_reason) => {
 833                                log::debug!("Assistant message complete: {:?}", stop_reason);
 834                                return Ok(acp::PromptResponse::new(stop_reason));
 835                            }
 836                        }
 837                    }
 838                    Err(e) => {
 839                        log::error!("Error in model response stream: {:?}", e);
 840                        return Err(e);
 841                    }
 842                }
 843            }
 844
 845            log::debug!("Response stream completed");
 846            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
 847        })
 848    }
 849}
 850
 851struct NativeAgentModelSelector {
 852    session_id: acp::SessionId,
 853    connection: NativeAgentConnection,
 854}
 855
 856impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 857    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 858        log::debug!("NativeAgentConnection::list_models called");
 859        let list = self.connection.0.read(cx).models.model_list.clone();
 860        Task::ready(if list.is_empty() {
 861            Err(anyhow::anyhow!("No models available"))
 862        } else {
 863            Ok(list)
 864        })
 865    }
 866
 867    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 868        log::debug!(
 869            "Setting model for session {}: {}",
 870            self.session_id,
 871            model_id
 872        );
 873        let Some(thread) = self
 874            .connection
 875            .0
 876            .read(cx)
 877            .sessions
 878            .get(&self.session_id)
 879            .map(|session| session.thread.clone())
 880        else {
 881            return Task::ready(Err(anyhow!("Session not found")));
 882        };
 883
 884        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 885            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 886        };
 887
 888        thread.update(cx, |thread, cx| {
 889            thread.set_model(model.clone(), cx);
 890        });
 891
 892        update_settings_file(
 893            self.connection.0.read(cx).fs.clone(),
 894            cx,
 895            move |settings, _cx| {
 896                let provider = model.provider_id().0.to_string();
 897                let model = model.id().0.to_string();
 898                settings
 899                    .agent
 900                    .get_or_insert_default()
 901                    .set_model(LanguageModelSelection {
 902                        provider: provider.into(),
 903                        model,
 904                    });
 905            },
 906        );
 907
 908        Task::ready(Ok(()))
 909    }
 910
 911    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 912        let Some(thread) = self
 913            .connection
 914            .0
 915            .read(cx)
 916            .sessions
 917            .get(&self.session_id)
 918            .map(|session| session.thread.clone())
 919        else {
 920            return Task::ready(Err(anyhow!("Session not found")));
 921        };
 922        let Some(model) = thread.read(cx).model() else {
 923            return Task::ready(Err(anyhow!("Model not found")));
 924        };
 925        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 926        else {
 927            return Task::ready(Err(anyhow!("Provider not found")));
 928        };
 929        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 930            model, &provider,
 931        )))
 932    }
 933
 934    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 935        Some(self.connection.0.read(cx).models.watch())
 936    }
 937
 938    fn should_render_footer(&self) -> bool {
 939        true
 940    }
 941}
 942
 943impl acp_thread::AgentConnection for NativeAgentConnection {
 944    fn telemetry_id(&self) -> SharedString {
 945        "zed".into()
 946    }
 947
 948    fn new_thread(
 949        self: Rc<Self>,
 950        project: Entity<Project>,
 951        cwd: &Path,
 952        cx: &mut App,
 953    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 954        let agent = self.0.clone();
 955        log::debug!("Creating new thread for project at: {:?}", cwd);
 956
 957        cx.spawn(async move |cx| {
 958            log::debug!("Starting thread creation in async context");
 959
 960            // Create Thread
 961            let thread = agent.update(
 962                cx,
 963                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 964                    // Fetch default model from registry settings
 965                    let registry = LanguageModelRegistry::read_global(cx);
 966                    // Log available models for debugging
 967                    let available_count = registry.available_models(cx).count();
 968                    log::debug!("Total available models: {}", available_count);
 969
 970                    let default_model = registry.default_model().and_then(|default_model| {
 971                        agent
 972                            .models
 973                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 974                    });
 975                    Ok(cx.new(|cx| {
 976                        Thread::new(
 977                            project.clone(),
 978                            agent.project_context.clone(),
 979                            agent.context_server_registry.clone(),
 980                            agent.templates.clone(),
 981                            default_model,
 982                            cx,
 983                        )
 984                    }))
 985                },
 986            )??;
 987            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 988        })
 989    }
 990
 991    fn auth_methods(&self) -> &[acp::AuthMethod] {
 992        &[] // No auth for in-process
 993    }
 994
 995    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 996        Task::ready(Ok(()))
 997    }
 998
 999    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1000        Some(Rc::new(NativeAgentModelSelector {
1001            session_id: session_id.clone(),
1002            connection: self.clone(),
1003        }) as Rc<dyn AgentModelSelector>)
1004    }
1005
1006    fn prompt(
1007        &self,
1008        id: Option<acp_thread::UserMessageId>,
1009        params: acp::PromptRequest,
1010        cx: &mut App,
1011    ) -> Task<Result<acp::PromptResponse>> {
1012        let id = id.expect("UserMessageId is required");
1013        let session_id = params.session_id.clone();
1014        log::info!("Received prompt request for session: {}", session_id);
1015        log::debug!("Prompt blocks count: {}", params.prompt.len());
1016        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1017
1018        self.run_turn(session_id, cx, move |thread, cx| {
1019            let content: Vec<UserMessageContent> = params
1020                .prompt
1021                .into_iter()
1022                .map(|block| UserMessageContent::from_content_block(block, path_style))
1023                .collect::<Vec<_>>();
1024            log::debug!("Converted prompt to message: {} chars", content.len());
1025            log::debug!("Message id: {:?}", id);
1026            log::debug!("Message content: {:?}", content);
1027
1028            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1029        })
1030    }
1031
1032    fn resume(
1033        &self,
1034        session_id: &acp::SessionId,
1035        _cx: &App,
1036    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1037        Some(Rc::new(NativeAgentSessionResume {
1038            connection: self.clone(),
1039            session_id: session_id.clone(),
1040        }) as _)
1041    }
1042
1043    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1044        log::info!("Cancelling on session: {}", session_id);
1045        self.0.update(cx, |agent, cx| {
1046            if let Some(agent) = agent.sessions.get(session_id) {
1047                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1048            }
1049        });
1050    }
1051
1052    fn truncate(
1053        &self,
1054        session_id: &agent_client_protocol::SessionId,
1055        cx: &App,
1056    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1057        self.0.read_with(cx, |agent, _cx| {
1058            agent.sessions.get(session_id).map(|session| {
1059                Rc::new(NativeAgentSessionTruncate {
1060                    thread: session.thread.clone(),
1061                    acp_thread: session.acp_thread.clone(),
1062                }) as _
1063            })
1064        })
1065    }
1066
1067    fn set_title(
1068        &self,
1069        session_id: &acp::SessionId,
1070        _cx: &App,
1071    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1072        Some(Rc::new(NativeAgentSessionSetTitle {
1073            connection: self.clone(),
1074            session_id: session_id.clone(),
1075        }) as _)
1076    }
1077
1078    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1079        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1080    }
1081
1082    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1083        self
1084    }
1085}
1086
1087impl acp_thread::AgentTelemetry for NativeAgentConnection {
1088    fn thread_data(
1089        &self,
1090        session_id: &acp::SessionId,
1091        cx: &mut App,
1092    ) -> Task<Result<serde_json::Value>> {
1093        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1094            return Task::ready(Err(anyhow!("Session not found")));
1095        };
1096
1097        let task = session.thread.read(cx).to_db(cx);
1098        cx.background_spawn(async move {
1099            serde_json::to_value(task.await).context("Failed to serialize thread")
1100        })
1101    }
1102}
1103
1104struct NativeAgentSessionTruncate {
1105    thread: Entity<Thread>,
1106    acp_thread: WeakEntity<AcpThread>,
1107}
1108
1109impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1110    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1111        match self.thread.update(cx, |thread, cx| {
1112            thread.truncate(message_id.clone(), cx)?;
1113            Ok(thread.latest_token_usage())
1114        }) {
1115            Ok(usage) => {
1116                self.acp_thread
1117                    .update(cx, |thread, cx| {
1118                        thread.update_token_usage(usage, cx);
1119                    })
1120                    .ok();
1121                Task::ready(Ok(()))
1122            }
1123            Err(error) => Task::ready(Err(error)),
1124        }
1125    }
1126}
1127
1128struct NativeAgentSessionResume {
1129    connection: NativeAgentConnection,
1130    session_id: acp::SessionId,
1131}
1132
1133impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1134    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1135        self.connection
1136            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1137                thread.update(cx, |thread, cx| thread.resume(cx))
1138            })
1139    }
1140}
1141
1142struct NativeAgentSessionSetTitle {
1143    connection: NativeAgentConnection,
1144    session_id: acp::SessionId,
1145}
1146
1147impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1148    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1149        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1150            return Task::ready(Err(anyhow!("session not found")));
1151        };
1152        let thread = session.thread.clone();
1153        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1154        Task::ready(Ok(()))
1155    }
1156}
1157
1158pub struct AcpThreadEnvironment {
1159    acp_thread: WeakEntity<AcpThread>,
1160}
1161
1162impl ThreadEnvironment for AcpThreadEnvironment {
1163    fn create_terminal(
1164        &self,
1165        command: String,
1166        cwd: Option<PathBuf>,
1167        output_byte_limit: Option<u64>,
1168        cx: &mut AsyncApp,
1169    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1170        let task = self.acp_thread.update(cx, |thread, cx| {
1171            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1172        });
1173
1174        let acp_thread = self.acp_thread.clone();
1175        cx.spawn(async move |cx| {
1176            let terminal = task?.await?;
1177
1178            let (drop_tx, drop_rx) = oneshot::channel();
1179            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1180
1181            cx.spawn(async move |cx| {
1182                drop_rx.await.ok();
1183                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1184            })
1185            .detach();
1186
1187            let handle = AcpTerminalHandle {
1188                terminal,
1189                _drop_tx: Some(drop_tx),
1190            };
1191
1192            Ok(Rc::new(handle) as _)
1193        })
1194    }
1195}
1196
1197pub struct AcpTerminalHandle {
1198    terminal: Entity<acp_thread::Terminal>,
1199    _drop_tx: Option<oneshot::Sender<()>>,
1200}
1201
1202impl TerminalHandle for AcpTerminalHandle {
1203    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1204        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1205    }
1206
1207    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1208        self.terminal
1209            .read_with(cx, |term, _cx| term.wait_for_exit())
1210    }
1211
1212    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1213        self.terminal
1214            .read_with(cx, |term, cx| term.current_output(cx))
1215    }
1216
1217    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1218        cx.update(|cx| {
1219            self.terminal.update(cx, |terminal, cx| {
1220                terminal.kill(cx);
1221            });
1222        })?;
1223        Ok(())
1224    }
1225}
1226
1227#[cfg(test)]
1228mod internal_tests {
1229    use crate::HistoryEntryId;
1230
1231    use super::*;
1232    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1233    use fs::FakeFs;
1234    use gpui::TestAppContext;
1235    use indoc::formatdoc;
1236    use language_model::fake_provider::FakeLanguageModel;
1237    use serde_json::json;
1238    use settings::SettingsStore;
1239    use util::{path, rel_path::rel_path};
1240
1241    #[gpui::test]
1242    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1243        init_test(cx);
1244        let fs = FakeFs::new(cx.executor());
1245        fs.insert_tree(
1246            "/",
1247            json!({
1248                "a": {}
1249            }),
1250        )
1251        .await;
1252        let project = Project::test(fs.clone(), [], cx).await;
1253        let text_thread_store =
1254            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1255        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1256        let agent = NativeAgent::new(
1257            project.clone(),
1258            history_store,
1259            Templates::new(),
1260            None,
1261            fs.clone(),
1262            &mut cx.to_async(),
1263        )
1264        .await
1265        .unwrap();
1266        agent.read_with(cx, |agent, cx| {
1267            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1268        });
1269
1270        let worktree = project
1271            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1272            .await
1273            .unwrap();
1274        cx.run_until_parked();
1275        agent.read_with(cx, |agent, cx| {
1276            assert_eq!(
1277                agent.project_context.read(cx).worktrees,
1278                vec![WorktreeContext {
1279                    root_name: "a".into(),
1280                    abs_path: Path::new("/a").into(),
1281                    rules_file: None
1282                }]
1283            )
1284        });
1285
1286        // Creating `/a/.rules` updates the project context.
1287        fs.insert_file("/a/.rules", Vec::new()).await;
1288        cx.run_until_parked();
1289        agent.read_with(cx, |agent, cx| {
1290            let rules_entry = worktree
1291                .read(cx)
1292                .entry_for_path(rel_path(".rules"))
1293                .unwrap();
1294            assert_eq!(
1295                agent.project_context.read(cx).worktrees,
1296                vec![WorktreeContext {
1297                    root_name: "a".into(),
1298                    abs_path: Path::new("/a").into(),
1299                    rules_file: Some(RulesFileContext {
1300                        path_in_worktree: rel_path(".rules").into(),
1301                        text: "".into(),
1302                        project_entry_id: rules_entry.id.to_usize()
1303                    })
1304                }]
1305            )
1306        });
1307    }
1308
1309    #[gpui::test]
1310    async fn test_listing_models(cx: &mut TestAppContext) {
1311        init_test(cx);
1312        let fs = FakeFs::new(cx.executor());
1313        fs.insert_tree("/", json!({ "a": {}  })).await;
1314        let project = Project::test(fs.clone(), [], cx).await;
1315        let text_thread_store =
1316            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1317        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1318        let connection = NativeAgentConnection(
1319            NativeAgent::new(
1320                project.clone(),
1321                history_store,
1322                Templates::new(),
1323                None,
1324                fs.clone(),
1325                &mut cx.to_async(),
1326            )
1327            .await
1328            .unwrap(),
1329        );
1330
1331        // Create a thread/session
1332        let acp_thread = cx
1333            .update(|cx| {
1334                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1335            })
1336            .await
1337            .unwrap();
1338
1339        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1340
1341        let models = cx
1342            .update(|cx| {
1343                connection
1344                    .model_selector(&session_id)
1345                    .unwrap()
1346                    .list_models(cx)
1347            })
1348            .await
1349            .unwrap();
1350
1351        let acp_thread::AgentModelList::Grouped(models) = models else {
1352            panic!("Unexpected model group");
1353        };
1354        assert_eq!(
1355            models,
1356            IndexMap::from_iter([(
1357                AgentModelGroupName("Fake".into()),
1358                vec![AgentModelInfo {
1359                    id: acp::ModelId::new("fake/fake"),
1360                    name: "Fake".into(),
1361                    description: None,
1362                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
1363                }]
1364            )])
1365        );
1366    }
1367
1368    #[gpui::test]
1369    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1370        init_test(cx);
1371        let fs = FakeFs::new(cx.executor());
1372        fs.create_dir(paths::settings_file().parent().unwrap())
1373            .await
1374            .unwrap();
1375        fs.insert_file(
1376            paths::settings_file(),
1377            json!({
1378                "agent": {
1379                    "default_model": {
1380                        "provider": "foo",
1381                        "model": "bar"
1382                    }
1383                }
1384            })
1385            .to_string()
1386            .into_bytes(),
1387        )
1388        .await;
1389        let project = Project::test(fs.clone(), [], cx).await;
1390
1391        let text_thread_store =
1392            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1393        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1394
1395        // Create the agent and connection
1396        let agent = NativeAgent::new(
1397            project.clone(),
1398            history_store,
1399            Templates::new(),
1400            None,
1401            fs.clone(),
1402            &mut cx.to_async(),
1403        )
1404        .await
1405        .unwrap();
1406        let connection = NativeAgentConnection(agent.clone());
1407
1408        // Create a thread/session
1409        let acp_thread = cx
1410            .update(|cx| {
1411                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1412            })
1413            .await
1414            .unwrap();
1415
1416        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1417
1418        // Select a model
1419        let selector = connection.model_selector(&session_id).unwrap();
1420        let model_id = acp::ModelId::new("fake/fake");
1421        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1422            .await
1423            .unwrap();
1424
1425        // Verify the thread has the selected model
1426        agent.read_with(cx, |agent, _| {
1427            let session = agent.sessions.get(&session_id).unwrap();
1428            session.thread.read_with(cx, |thread, _| {
1429                assert_eq!(thread.model().unwrap().id().0, "fake");
1430            });
1431        });
1432
1433        cx.run_until_parked();
1434
1435        // Verify settings file was updated
1436        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1437        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1438
1439        // Check that the agent settings contain the selected model
1440        assert_eq!(
1441            settings_json["agent"]["default_model"]["model"],
1442            json!("fake")
1443        );
1444        assert_eq!(
1445            settings_json["agent"]["default_model"]["provider"],
1446            json!("fake")
1447        );
1448    }
1449
1450    #[gpui::test]
1451    async fn test_save_load_thread(cx: &mut TestAppContext) {
1452        init_test(cx);
1453        let fs = FakeFs::new(cx.executor());
1454        fs.insert_tree(
1455            "/",
1456            json!({
1457                "a": {
1458                    "b.md": "Lorem"
1459                }
1460            }),
1461        )
1462        .await;
1463        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1464        let text_thread_store =
1465            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1466        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1467        let agent = NativeAgent::new(
1468            project.clone(),
1469            history_store.clone(),
1470            Templates::new(),
1471            None,
1472            fs.clone(),
1473            &mut cx.to_async(),
1474        )
1475        .await
1476        .unwrap();
1477        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1478
1479        let acp_thread = cx
1480            .update(|cx| {
1481                connection
1482                    .clone()
1483                    .new_thread(project.clone(), Path::new(""), cx)
1484            })
1485            .await
1486            .unwrap();
1487        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1488        let thread = agent.read_with(cx, |agent, _| {
1489            agent.sessions.get(&session_id).unwrap().thread.clone()
1490        });
1491
1492        // Ensure empty threads are not saved, even if they get mutated.
1493        let model = Arc::new(FakeLanguageModel::default());
1494        let summary_model = Arc::new(FakeLanguageModel::default());
1495        thread.update(cx, |thread, cx| {
1496            thread.set_model(model.clone(), cx);
1497            thread.set_summarization_model(Some(summary_model.clone()), cx);
1498        });
1499        cx.run_until_parked();
1500        assert_eq!(history_entries(&history_store, cx), vec![]);
1501
1502        let send = acp_thread.update(cx, |thread, cx| {
1503            thread.send(
1504                vec![
1505                    "What does ".into(),
1506                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1507                        "b.md",
1508                        MentionUri::File {
1509                            abs_path: path!("/a/b.md").into(),
1510                        }
1511                        .to_uri()
1512                        .to_string(),
1513                    )),
1514                    " mean?".into(),
1515                ],
1516                cx,
1517            )
1518        });
1519        let send = cx.foreground_executor().spawn(send);
1520        cx.run_until_parked();
1521
1522        model.send_last_completion_stream_text_chunk("Lorem.");
1523        model.end_last_completion_stream();
1524        cx.run_until_parked();
1525        summary_model
1526            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1527        summary_model.end_last_completion_stream();
1528
1529        send.await.unwrap();
1530        let uri = MentionUri::File {
1531            abs_path: path!("/a/b.md").into(),
1532        }
1533        .to_uri();
1534        acp_thread.read_with(cx, |thread, cx| {
1535            assert_eq!(
1536                thread.to_markdown(cx),
1537                formatdoc! {"
1538                    ## User
1539
1540                    What does [@b.md]({uri}) mean?
1541
1542                    ## Assistant
1543
1544                    Lorem.
1545
1546                "}
1547            )
1548        });
1549
1550        cx.run_until_parked();
1551
1552        // Drop the ACP thread, which should cause the session to be dropped as well.
1553        cx.update(|_| {
1554            drop(thread);
1555            drop(acp_thread);
1556        });
1557        agent.read_with(cx, |agent, _| {
1558            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1559        });
1560
1561        // Ensure the thread can be reloaded from disk.
1562        assert_eq!(
1563            history_entries(&history_store, cx),
1564            vec![(
1565                HistoryEntryId::AcpThread(session_id.clone()),
1566                format!("Explaining {}", path!("/a/b.md"))
1567            )]
1568        );
1569        let acp_thread = agent
1570            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1571            .await
1572            .unwrap();
1573        acp_thread.read_with(cx, |thread, cx| {
1574            assert_eq!(
1575                thread.to_markdown(cx),
1576                formatdoc! {"
1577                    ## User
1578
1579                    What does [@b.md]({uri}) mean?
1580
1581                    ## Assistant
1582
1583                    Lorem.
1584
1585                "}
1586            )
1587        });
1588    }
1589
1590    fn history_entries(
1591        history: &Entity<HistoryStore>,
1592        cx: &mut TestAppContext,
1593    ) -> Vec<(HistoryEntryId, String)> {
1594        history.read_with(cx, |history, _| {
1595            history
1596                .entries()
1597                .map(|e| (e.id(), e.title().to_string()))
1598                .collect::<Vec<_>>()
1599        })
1600    }
1601
1602    fn init_test(cx: &mut TestAppContext) {
1603        env_logger::try_init().ok();
1604        cx.update(|cx| {
1605            let settings_store = SettingsStore::test(cx);
1606            cx.set_global(settings_store);
1607
1608            LanguageModelRegistry::test(cx);
1609        });
1610    }
1611}