agent.rs

   1mod db;
   2mod edit_agent;
   3mod history_store;
   4mod legacy_thread;
   5mod native_agent_server;
   6pub mod outline;
   7mod templates;
   8mod thread;
   9mod tools;
  10
  11#[cfg(test)]
  12mod tests;
  13
  14pub use db::*;
  15pub use history_store::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use tools::*;
  20
  21use acp_thread::{AcpThread, AgentModelSelector};
  22use agent_client_protocol as acp;
  23use anyhow::{Context as _, Result, anyhow};
  24use chrono::{DateTime, Utc};
  25use collections::{HashSet, IndexMap};
  26use fs::Fs;
  27use futures::channel::{mpsc, oneshot};
  28use futures::future::Shared;
  29use futures::{StreamExt, future};
  30use gpui::{
  31    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  32};
  33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  34use project::{Project, ProjectItem, ProjectPath, Worktree};
  35use prompt_store::{
  36    ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  37};
  38use serde::{Deserialize, Serialize};
  39use settings::{LanguageModelSelection, update_settings_file};
  40use std::any::Any;
  41use std::collections::HashMap;
  42use std::path::{Path, PathBuf};
  43use std::rc::Rc;
  44use std::sync::Arc;
  45use util::ResultExt;
  46use util::rel_path::RelPath;
  47
  48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  49pub struct ProjectSnapshot {
  50    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  51    pub timestamp: DateTime<Utc>,
  52}
  53
  54const RULES_FILE_NAMES: [&str; 9] = [
  55    ".rules",
  56    ".cursorrules",
  57    ".windsurfrules",
  58    ".clinerules",
  59    ".github/copilot-instructions.md",
  60    "CLAUDE.md",
  61    "AGENT.md",
  62    "AGENTS.md",
  63    "GEMINI.md",
  64];
  65
  66pub struct RulesLoadingError {
  67    pub message: SharedString,
  68}
  69
  70/// Holds both the internal Thread and the AcpThread for a session
  71struct Session {
  72    /// The internal thread that processes messages
  73    thread: Entity<Thread>,
  74    /// The ACP thread that handles protocol communication
  75    acp_thread: WeakEntity<acp_thread::AcpThread>,
  76    pending_save: Task<()>,
  77    _subscriptions: Vec<Subscription>,
  78}
  79
  80pub struct LanguageModels {
  81    /// Access language model by ID
  82    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  83    /// Cached list for returning language model information
  84    model_list: acp_thread::AgentModelList,
  85    refresh_models_rx: watch::Receiver<()>,
  86    refresh_models_tx: watch::Sender<()>,
  87    _authenticate_all_providers_task: Task<()>,
  88}
  89
  90impl LanguageModels {
  91    fn new(cx: &mut App) -> Self {
  92        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  93
  94        let mut this = Self {
  95            models: HashMap::default(),
  96            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  97            refresh_models_rx,
  98            refresh_models_tx,
  99            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
 100        };
 101        this.refresh_list(cx);
 102        this
 103    }
 104
 105    fn refresh_list(&mut self, cx: &App) {
 106        let now = std::time::SystemTime::now()
 107            .duration_since(std::time::UNIX_EPOCH)
 108            .unwrap_or_default()
 109            .as_millis();
 110        eprintln!("[{}ms] LanguageModels::refresh_list called", now);
 111        let providers = LanguageModelRegistry::global(cx)
 112            .read(cx)
 113            .providers()
 114            .into_iter()
 115            .filter(|provider| provider.is_authenticated(cx))
 116            .collect::<Vec<_>>();
 117        eprintln!(
 118            "[{}ms] LanguageModels::refresh_list got {} authenticated providers",
 119            now,
 120            providers.len()
 121        );
 122
 123        let mut language_model_list = IndexMap::default();
 124        let mut recommended_models = HashSet::default();
 125
 126        let mut recommended = Vec::new();
 127        for provider in &providers {
 128            for model in provider.recommended_models(cx) {
 129                recommended_models.insert((model.provider_id(), model.id()));
 130                recommended.push(Self::map_language_model_to_info(&model, provider));
 131            }
 132        }
 133        if !recommended.is_empty() {
 134            language_model_list.insert(
 135                acp_thread::AgentModelGroupName("Recommended".into()),
 136                recommended,
 137            );
 138        }
 139
 140        let mut models = HashMap::default();
 141        for provider in providers {
 142            let mut provider_models = Vec::new();
 143            for model in provider.provided_models(cx) {
 144                let model_info = Self::map_language_model_to_info(&model, &provider);
 145                let model_id = model_info.id.clone();
 146                provider_models.push(model_info);
 147                models.insert(model_id, model);
 148            }
 149            if !provider_models.is_empty() {
 150                language_model_list.insert(
 151                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 152                    provider_models,
 153                );
 154            }
 155        }
 156
 157        self.models = models;
 158        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 159        let now = std::time::SystemTime::now()
 160            .duration_since(std::time::UNIX_EPOCH)
 161            .unwrap_or_default()
 162            .as_millis();
 163        eprintln!(
 164            "[{}ms] LanguageModels::refresh_list completed with {} models in list",
 165            now,
 166            self.models.len()
 167        );
 168        self.refresh_models_tx.send(()).ok();
 169    }
 170
 171    fn watch(&self) -> watch::Receiver<()> {
 172        self.refresh_models_rx.clone()
 173    }
 174
 175    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 176        self.models.get(model_id).cloned()
 177    }
 178
 179    fn map_language_model_to_info(
 180        model: &Arc<dyn LanguageModel>,
 181        provider: &Arc<dyn LanguageModelProvider>,
 182    ) -> acp_thread::AgentModelInfo {
 183        acp_thread::AgentModelInfo {
 184            id: Self::model_id(model),
 185            name: model.name().0,
 186            description: None,
 187            icon: Some(provider.icon()),
 188        }
 189    }
 190
 191    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 192        acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 193    }
 194
 195    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 196        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 197            .read(cx)
 198            .providers()
 199            .iter()
 200            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 201            .collect::<Vec<_>>();
 202
 203        cx.background_spawn(async move {
 204            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 205                if let Err(err) = authenticate_task.await {
 206                    match err {
 207                        language_model::AuthenticateError::CredentialsNotFound => {
 208                            // Since we're authenticating these providers in the
 209                            // background for the purposes of populating the
 210                            // language selector, we don't care about providers
 211                            // where the credentials are not found.
 212                        }
 213                        language_model::AuthenticateError::ConnectionRefused => {
 214                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 215                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 216                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 217                        }
 218                        _ => {
 219                            // Some providers have noisy failure states that we
 220                            // don't want to spam the logs with every time the
 221                            // language model selector is initialized.
 222                            //
 223                            // Ideally these should have more clear failure modes
 224                            // that we know are safe to ignore here, like what we do
 225                            // with `CredentialsNotFound` above.
 226                            match provider_id.0.as_ref() {
 227                                "lmstudio" | "ollama" => {
 228                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 229                                    //
 230                                    // These fail noisily, so we don't log them.
 231                                }
 232                                "copilot_chat" => {
 233                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 234                                }
 235                                _ => {
 236                                    log::error!(
 237                                        "Failed to authenticate provider: {}: {err:#}",
 238                                        provider_name.0
 239                                    );
 240                                }
 241                            }
 242                        }
 243                    }
 244                }
 245            }
 246        })
 247    }
 248}
 249
 250pub struct NativeAgent {
 251    /// Session ID -> Session mapping
 252    sessions: HashMap<acp::SessionId, Session>,
 253    history: Entity<HistoryStore>,
 254    /// Shared project context for all threads
 255    project_context: Entity<ProjectContext>,
 256    project_context_needs_refresh: watch::Sender<()>,
 257    _maintain_project_context: Task<Result<()>>,
 258    context_server_registry: Entity<ContextServerRegistry>,
 259    /// Shared templates for all threads
 260    templates: Arc<Templates>,
 261    /// Cached model information
 262    models: LanguageModels,
 263    project: Entity<Project>,
 264    prompt_store: Option<Entity<PromptStore>>,
 265    fs: Arc<dyn Fs>,
 266    _subscriptions: Vec<Subscription>,
 267}
 268
 269impl NativeAgent {
 270    pub async fn new(
 271        project: Entity<Project>,
 272        history: Entity<HistoryStore>,
 273        templates: Arc<Templates>,
 274        prompt_store: Option<Entity<PromptStore>>,
 275        fs: Arc<dyn Fs>,
 276        cx: &mut AsyncApp,
 277    ) -> Result<Entity<NativeAgent>> {
 278        log::debug!("Creating new NativeAgent");
 279
 280        let project_context = cx
 281            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 282            .await;
 283
 284        cx.new(|cx| {
 285            let mut subscriptions = vec![
 286                cx.subscribe(&project, Self::handle_project_event),
 287                cx.subscribe(
 288                    &LanguageModelRegistry::global(cx),
 289                    Self::handle_models_updated_event,
 290                ),
 291            ];
 292            if let Some(prompt_store) = prompt_store.as_ref() {
 293                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 294            }
 295
 296            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 297                watch::channel(());
 298            Self {
 299                sessions: HashMap::new(),
 300                history,
 301                project_context: cx.new(|_| project_context),
 302                project_context_needs_refresh: project_context_needs_refresh_tx,
 303                _maintain_project_context: cx.spawn(async move |this, cx| {
 304                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 305                }),
 306                context_server_registry: cx.new(|cx| {
 307                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 308                }),
 309                templates,
 310                models: LanguageModels::new(cx),
 311                project,
 312                prompt_store,
 313                fs,
 314                _subscriptions: subscriptions,
 315            }
 316        })
 317    }
 318
 319    fn register_session(
 320        &mut self,
 321        thread_handle: Entity<Thread>,
 322        cx: &mut Context<Self>,
 323    ) -> Entity<AcpThread> {
 324        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 325
 326        let thread = thread_handle.read(cx);
 327        let session_id = thread.id().clone();
 328        let title = thread.title();
 329        let project = thread.project.clone();
 330        let action_log = thread.action_log.clone();
 331        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 332        let acp_thread = cx.new(|cx| {
 333            acp_thread::AcpThread::new(
 334                title,
 335                connection,
 336                project.clone(),
 337                action_log.clone(),
 338                session_id.clone(),
 339                prompt_capabilities_rx,
 340                cx,
 341            )
 342        });
 343
 344        let registry = LanguageModelRegistry::read_global(cx);
 345        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 346
 347        thread_handle.update(cx, |thread, cx| {
 348            thread.set_summarization_model(summarization_model, cx);
 349            thread.add_default_tools(
 350                Rc::new(AcpThreadEnvironment {
 351                    acp_thread: acp_thread.downgrade(),
 352                }) as _,
 353                cx,
 354            )
 355        });
 356
 357        let subscriptions = vec![
 358            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 359                this.sessions.remove(acp_thread.session_id());
 360            }),
 361            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 362            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 363            cx.observe(&thread_handle, move |this, thread, cx| {
 364                this.save_thread(thread, cx)
 365            }),
 366        ];
 367
 368        self.sessions.insert(
 369            session_id,
 370            Session {
 371                thread: thread_handle,
 372                acp_thread: acp_thread.downgrade(),
 373                _subscriptions: subscriptions,
 374                pending_save: Task::ready(()),
 375            },
 376        );
 377        acp_thread
 378    }
 379
 380    pub fn models(&self) -> &LanguageModels {
 381        &self.models
 382    }
 383
 384    async fn maintain_project_context(
 385        this: WeakEntity<Self>,
 386        mut needs_refresh: watch::Receiver<()>,
 387        cx: &mut AsyncApp,
 388    ) -> Result<()> {
 389        while needs_refresh.changed().await.is_ok() {
 390            let project_context = this
 391                .update(cx, |this, cx| {
 392                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 393                })?
 394                .await;
 395            this.update(cx, |this, cx| {
 396                this.project_context = cx.new(|_| project_context);
 397            })?;
 398        }
 399
 400        Ok(())
 401    }
 402
 403    fn build_project_context(
 404        project: &Entity<Project>,
 405        prompt_store: Option<&Entity<PromptStore>>,
 406        cx: &mut App,
 407    ) -> Task<ProjectContext> {
 408        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 409        let worktree_tasks = worktrees
 410            .into_iter()
 411            .map(|worktree| {
 412                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 413            })
 414            .collect::<Vec<_>>();
 415        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 416            prompt_store.read_with(cx, |prompt_store, cx| {
 417                let prompts = prompt_store.default_prompt_metadata();
 418                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 419                    let contents = prompt_store.load(prompt_metadata.id, cx);
 420                    async move { (contents.await, prompt_metadata) }
 421                });
 422                cx.background_spawn(future::join_all(load_tasks))
 423            })
 424        } else {
 425            Task::ready(vec![])
 426        };
 427
 428        cx.spawn(async move |_cx| {
 429            let (worktrees, default_user_rules) =
 430                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 431
 432            let worktrees = worktrees
 433                .into_iter()
 434                .map(|(worktree, _rules_error)| {
 435                    // TODO: show error message
 436                    // if let Some(rules_error) = rules_error {
 437                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 438                    // }
 439                    worktree
 440                })
 441                .collect::<Vec<_>>();
 442
 443            let default_user_rules = default_user_rules
 444                .into_iter()
 445                .flat_map(|(contents, prompt_metadata)| match contents {
 446                    Ok(contents) => Some(UserRulesContext {
 447                        uuid: match prompt_metadata.id {
 448                            prompt_store::PromptId::User { uuid } => uuid,
 449                            prompt_store::PromptId::EditWorkflow => return None,
 450                        },
 451                        title: prompt_metadata.title.map(|title| title.to_string()),
 452                        contents,
 453                    }),
 454                    Err(_err) => {
 455                        // TODO: show error message
 456                        // this.update(cx, |_, cx| {
 457                        //     cx.emit(RulesLoadingError {
 458                        //         message: format!("{err:?}").into(),
 459                        //     });
 460                        // })
 461                        // .ok();
 462                        None
 463                    }
 464                })
 465                .collect::<Vec<_>>();
 466
 467            ProjectContext::new(worktrees, default_user_rules)
 468        })
 469    }
 470
 471    fn load_worktree_info_for_system_prompt(
 472        worktree: Entity<Worktree>,
 473        project: Entity<Project>,
 474        cx: &mut App,
 475    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 476        let tree = worktree.read(cx);
 477        let root_name = tree.root_name_str().into();
 478        let abs_path = tree.abs_path();
 479
 480        let mut context = WorktreeContext {
 481            root_name,
 482            abs_path,
 483            rules_file: None,
 484        };
 485
 486        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 487        let Some(rules_task) = rules_task else {
 488            return Task::ready((context, None));
 489        };
 490
 491        cx.spawn(async move |_| {
 492            let (rules_file, rules_file_error) = match rules_task.await {
 493                Ok(rules_file) => (Some(rules_file), None),
 494                Err(err) => (
 495                    None,
 496                    Some(RulesLoadingError {
 497                        message: format!("{err}").into(),
 498                    }),
 499                ),
 500            };
 501            context.rules_file = rules_file;
 502            (context, rules_file_error)
 503        })
 504    }
 505
 506    fn load_worktree_rules_file(
 507        worktree: Entity<Worktree>,
 508        project: Entity<Project>,
 509        cx: &mut App,
 510    ) -> Option<Task<Result<RulesFileContext>>> {
 511        let worktree = worktree.read(cx);
 512        let worktree_id = worktree.id();
 513        let selected_rules_file = RULES_FILE_NAMES
 514            .into_iter()
 515            .filter_map(|name| {
 516                worktree
 517                    .entry_for_path(RelPath::unix(name).unwrap())
 518                    .filter(|entry| entry.is_file())
 519                    .map(|entry| entry.path.clone())
 520            })
 521            .next();
 522
 523        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 524        // supported. This doesn't seem to occur often in GitHub repositories.
 525        selected_rules_file.map(|path_in_worktree| {
 526            let project_path = ProjectPath {
 527                worktree_id,
 528                path: path_in_worktree.clone(),
 529            };
 530            let buffer_task =
 531                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 532            let rope_task = cx.spawn(async move |cx| {
 533                buffer_task.await?.read_with(cx, |buffer, cx| {
 534                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 535                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 536                })?
 537            });
 538            // Build a string from the rope on a background thread.
 539            cx.background_spawn(async move {
 540                let (project_entry_id, rope) = rope_task.await?;
 541                anyhow::Ok(RulesFileContext {
 542                    path_in_worktree,
 543                    text: rope.to_string().trim().to_string(),
 544                    project_entry_id: project_entry_id.to_usize(),
 545                })
 546            })
 547        })
 548    }
 549
 550    fn handle_thread_title_updated(
 551        &mut self,
 552        thread: Entity<Thread>,
 553        _: &TitleUpdated,
 554        cx: &mut Context<Self>,
 555    ) {
 556        let session_id = thread.read(cx).id();
 557        let Some(session) = self.sessions.get(session_id) else {
 558            return;
 559        };
 560        let thread = thread.downgrade();
 561        let acp_thread = session.acp_thread.clone();
 562        cx.spawn(async move |_, cx| {
 563            let title = thread.read_with(cx, |thread, _| thread.title())?;
 564            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 565            task.await
 566        })
 567        .detach_and_log_err(cx);
 568    }
 569
 570    fn handle_thread_token_usage_updated(
 571        &mut self,
 572        thread: Entity<Thread>,
 573        usage: &TokenUsageUpdated,
 574        cx: &mut Context<Self>,
 575    ) {
 576        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 577            return;
 578        };
 579        session
 580            .acp_thread
 581            .update(cx, |acp_thread, cx| {
 582                acp_thread.update_token_usage(usage.0.clone(), cx);
 583            })
 584            .ok();
 585    }
 586
 587    fn handle_project_event(
 588        &mut self,
 589        _project: Entity<Project>,
 590        event: &project::Event,
 591        _cx: &mut Context<Self>,
 592    ) {
 593        match event {
 594            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 595                self.project_context_needs_refresh.send(()).ok();
 596            }
 597            project::Event::WorktreeUpdatedEntries(_, items) => {
 598                if items.iter().any(|(path, _, _)| {
 599                    RULES_FILE_NAMES
 600                        .iter()
 601                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 602                }) {
 603                    self.project_context_needs_refresh.send(()).ok();
 604                }
 605            }
 606            _ => {}
 607        }
 608    }
 609
 610    fn handle_prompts_updated_event(
 611        &mut self,
 612        _prompt_store: Entity<PromptStore>,
 613        _event: &prompt_store::PromptsUpdatedEvent,
 614        _cx: &mut Context<Self>,
 615    ) {
 616        self.project_context_needs_refresh.send(()).ok();
 617    }
 618
 619    fn handle_models_updated_event(
 620        &mut self,
 621        _registry: Entity<LanguageModelRegistry>,
 622        _event: &language_model::Event,
 623        cx: &mut Context<Self>,
 624    ) {
 625        let now = std::time::SystemTime::now()
 626            .duration_since(std::time::UNIX_EPOCH)
 627            .unwrap_or_default()
 628            .as_millis();
 629        eprintln!(
 630            "[{}ms] NativeAgent::handle_models_updated_event called",
 631            now
 632        );
 633        self.models.refresh_list(cx);
 634
 635        let registry = LanguageModelRegistry::read_global(cx);
 636        let default_model = registry.default_model().map(|m| m.model);
 637        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 638
 639        for session in self.sessions.values_mut() {
 640            session.thread.update(cx, |thread, cx| {
 641                if thread.model().is_none()
 642                    && let Some(model) = default_model.clone()
 643                {
 644                    thread.set_model(model, cx);
 645                    cx.notify();
 646                }
 647                thread.set_summarization_model(summarization_model.clone(), cx);
 648            });
 649        }
 650    }
 651
 652    pub fn load_thread(
 653        &mut self,
 654        id: acp::SessionId,
 655        cx: &mut Context<Self>,
 656    ) -> Task<Result<Entity<Thread>>> {
 657        let database_future = ThreadsDatabase::connect(cx);
 658        cx.spawn(async move |this, cx| {
 659            let database = database_future.await.map_err(|err| anyhow!(err))?;
 660            let db_thread = database
 661                .load_thread(id.clone())
 662                .await?
 663                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 664
 665            this.update(cx, |this, cx| {
 666                let summarization_model = LanguageModelRegistry::read_global(cx)
 667                    .thread_summary_model()
 668                    .map(|c| c.model);
 669
 670                cx.new(|cx| {
 671                    let mut thread = Thread::from_db(
 672                        id.clone(),
 673                        db_thread,
 674                        this.project.clone(),
 675                        this.project_context.clone(),
 676                        this.context_server_registry.clone(),
 677                        this.templates.clone(),
 678                        cx,
 679                    );
 680                    thread.set_summarization_model(summarization_model, cx);
 681                    thread
 682                })
 683            })
 684        })
 685    }
 686
 687    pub fn open_thread(
 688        &mut self,
 689        id: acp::SessionId,
 690        cx: &mut Context<Self>,
 691    ) -> Task<Result<Entity<AcpThread>>> {
 692        let task = self.load_thread(id, cx);
 693        cx.spawn(async move |this, cx| {
 694            let thread = task.await?;
 695            let acp_thread =
 696                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 697            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 698            cx.update(|cx| {
 699                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 700            })?
 701            .await?;
 702            Ok(acp_thread)
 703        })
 704    }
 705
 706    pub fn thread_summary(
 707        &mut self,
 708        id: acp::SessionId,
 709        cx: &mut Context<Self>,
 710    ) -> Task<Result<SharedString>> {
 711        let thread = self.open_thread(id.clone(), cx);
 712        cx.spawn(async move |this, cx| {
 713            let acp_thread = thread.await?;
 714            let result = this
 715                .update(cx, |this, cx| {
 716                    this.sessions
 717                        .get(&id)
 718                        .unwrap()
 719                        .thread
 720                        .update(cx, |thread, cx| thread.summary(cx))
 721                })?
 722                .await
 723                .context("Failed to generate summary")?;
 724            drop(acp_thread);
 725            Ok(result)
 726        })
 727    }
 728
 729    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 730        if thread.read(cx).is_empty() {
 731            return;
 732        }
 733
 734        let database_future = ThreadsDatabase::connect(cx);
 735        let (id, db_thread) =
 736            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 737        let Some(session) = self.sessions.get_mut(&id) else {
 738            return;
 739        };
 740        let history = self.history.clone();
 741        session.pending_save = cx.spawn(async move |_, cx| {
 742            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 743                return;
 744            };
 745            let db_thread = db_thread.await;
 746            database.save_thread(id, db_thread).await.log_err();
 747            history.update(cx, |history, cx| history.reload(cx)).ok();
 748        });
 749    }
 750}
 751
 752/// Wrapper struct that implements the AgentConnection trait
 753#[derive(Clone)]
 754pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 755
 756impl NativeAgentConnection {
 757    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 758        self.0
 759            .read(cx)
 760            .sessions
 761            .get(session_id)
 762            .map(|session| session.thread.clone())
 763    }
 764
 765    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 766        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 767    }
 768
 769    fn run_turn(
 770        &self,
 771        session_id: acp::SessionId,
 772        cx: &mut App,
 773        f: impl 'static
 774        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 775    ) -> Task<Result<acp::PromptResponse>> {
 776        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 777            agent
 778                .sessions
 779                .get_mut(&session_id)
 780                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 781        }) else {
 782            return Task::ready(Err(anyhow!("Session not found")));
 783        };
 784        log::debug!("Found session for: {}", session_id);
 785
 786        let response_stream = match f(thread, cx) {
 787            Ok(stream) => stream,
 788            Err(err) => return Task::ready(Err(err)),
 789        };
 790        Self::handle_thread_events(response_stream, acp_thread, cx)
 791    }
 792
 793    fn handle_thread_events(
 794        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 795        acp_thread: WeakEntity<AcpThread>,
 796        cx: &App,
 797    ) -> Task<Result<acp::PromptResponse>> {
 798        cx.spawn(async move |cx| {
 799            // Handle response stream and forward to session.acp_thread
 800            while let Some(result) = events.next().await {
 801                match result {
 802                    Ok(event) => {
 803                        log::trace!("Received completion event: {:?}", event);
 804
 805                        match event {
 806                            ThreadEvent::UserMessage(message) => {
 807                                acp_thread.update(cx, |thread, cx| {
 808                                    for content in message.content {
 809                                        thread.push_user_content_block(
 810                                            Some(message.id.clone()),
 811                                            content.into(),
 812                                            cx,
 813                                        );
 814                                    }
 815                                })?;
 816                            }
 817                            ThreadEvent::AgentText(text) => {
 818                                acp_thread.update(cx, |thread, cx| {
 819                                    thread.push_assistant_content_block(
 820                                        acp::ContentBlock::Text(acp::TextContent {
 821                                            text,
 822                                            annotations: None,
 823                                            meta: None,
 824                                        }),
 825                                        false,
 826                                        cx,
 827                                    )
 828                                })?;
 829                            }
 830                            ThreadEvent::AgentThinking(text) => {
 831                                acp_thread.update(cx, |thread, cx| {
 832                                    thread.push_assistant_content_block(
 833                                        acp::ContentBlock::Text(acp::TextContent {
 834                                            text,
 835                                            annotations: None,
 836                                            meta: None,
 837                                        }),
 838                                        true,
 839                                        cx,
 840                                    )
 841                                })?;
 842                            }
 843                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 844                                tool_call,
 845                                options,
 846                                response,
 847                            }) => {
 848                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 849                                    thread.request_tool_call_authorization(
 850                                        tool_call, options, true, cx,
 851                                    )
 852                                })??;
 853                                cx.background_spawn(async move {
 854                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 855                                        outcome_task.await
 856                                    {
 857                                        response
 858                                            .send(option_id)
 859                                            .map(|_| anyhow!("authorization receiver was dropped"))
 860                                            .log_err();
 861                                    }
 862                                })
 863                                .detach();
 864                            }
 865                            ThreadEvent::ToolCall(tool_call) => {
 866                                acp_thread.update(cx, |thread, cx| {
 867                                    thread.upsert_tool_call(tool_call, cx)
 868                                })??;
 869                            }
 870                            ThreadEvent::ToolCallUpdate(update) => {
 871                                acp_thread.update(cx, |thread, cx| {
 872                                    thread.update_tool_call(update, cx)
 873                                })??;
 874                            }
 875                            ThreadEvent::Retry(status) => {
 876                                acp_thread.update(cx, |thread, cx| {
 877                                    thread.update_retry_status(status, cx)
 878                                })?;
 879                            }
 880                            ThreadEvent::Stop(stop_reason) => {
 881                                log::debug!("Assistant message complete: {:?}", stop_reason);
 882                                return Ok(acp::PromptResponse {
 883                                    stop_reason,
 884                                    meta: None,
 885                                });
 886                            }
 887                        }
 888                    }
 889                    Err(e) => {
 890                        log::error!("Error in model response stream: {:?}", e);
 891                        return Err(e);
 892                    }
 893                }
 894            }
 895
 896            log::debug!("Response stream completed");
 897            anyhow::Ok(acp::PromptResponse {
 898                stop_reason: acp::StopReason::EndTurn,
 899                meta: None,
 900            })
 901        })
 902    }
 903}
 904
 905struct NativeAgentModelSelector {
 906    session_id: acp::SessionId,
 907    connection: NativeAgentConnection,
 908}
 909
 910impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 911    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 912        log::debug!("NativeAgentConnection::list_models called");
 913        let list = self.connection.0.read(cx).models.model_list.clone();
 914        Task::ready(if list.is_empty() {
 915            Err(anyhow::anyhow!("No models available"))
 916        } else {
 917            Ok(list)
 918        })
 919    }
 920
 921    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 922        log::debug!(
 923            "Setting model for session {}: {}",
 924            self.session_id,
 925            model_id
 926        );
 927        let Some(thread) = self
 928            .connection
 929            .0
 930            .read(cx)
 931            .sessions
 932            .get(&self.session_id)
 933            .map(|session| session.thread.clone())
 934        else {
 935            return Task::ready(Err(anyhow!("Session not found")));
 936        };
 937
 938        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 939            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 940        };
 941
 942        thread.update(cx, |thread, cx| {
 943            thread.set_model(model.clone(), cx);
 944        });
 945
 946        update_settings_file(
 947            self.connection.0.read(cx).fs.clone(),
 948            cx,
 949            move |settings, _cx| {
 950                let provider = model.provider_id().0.to_string();
 951                let model = model.id().0.to_string();
 952                settings
 953                    .agent
 954                    .get_or_insert_default()
 955                    .set_model(LanguageModelSelection {
 956                        provider: provider.into(),
 957                        model,
 958                    });
 959            },
 960        );
 961
 962        Task::ready(Ok(()))
 963    }
 964
 965    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 966        let Some(thread) = self
 967            .connection
 968            .0
 969            .read(cx)
 970            .sessions
 971            .get(&self.session_id)
 972            .map(|session| session.thread.clone())
 973        else {
 974            return Task::ready(Err(anyhow!("Session not found")));
 975        };
 976        let Some(model) = thread.read(cx).model() else {
 977            return Task::ready(Err(anyhow!("Model not found")));
 978        };
 979        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 980        else {
 981            return Task::ready(Err(anyhow!("Provider not found")));
 982        };
 983        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 984            model, &provider,
 985        )))
 986    }
 987
 988    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 989        Some(self.connection.0.read(cx).models.watch())
 990    }
 991
 992    fn should_render_footer(&self) -> bool {
 993        true
 994    }
 995}
 996
 997impl acp_thread::AgentConnection for NativeAgentConnection {
 998    fn telemetry_id(&self) -> &'static str {
 999        "zed"
1000    }
1001
1002    fn new_thread(
1003        self: Rc<Self>,
1004        project: Entity<Project>,
1005        cwd: &Path,
1006        cx: &mut App,
1007    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1008        let agent = self.0.clone();
1009        log::debug!("Creating new thread for project at: {:?}", cwd);
1010
1011        cx.spawn(async move |cx| {
1012            log::debug!("Starting thread creation in async context");
1013
1014            // Create Thread
1015            let thread = agent.update(
1016                cx,
1017                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1018                    // Fetch default model from registry settings
1019                    let registry = LanguageModelRegistry::read_global(cx);
1020                    // Log available models for debugging
1021                    let available_count = registry.available_models(cx).count();
1022                    log::debug!("Total available models: {}", available_count);
1023
1024                    let default_model = registry.default_model().and_then(|default_model| {
1025                        agent
1026                            .models
1027                            .model_from_id(&LanguageModels::model_id(&default_model.model))
1028                    });
1029                    Ok(cx.new(|cx| {
1030                        Thread::new(
1031                            project.clone(),
1032                            agent.project_context.clone(),
1033                            agent.context_server_registry.clone(),
1034                            agent.templates.clone(),
1035                            default_model,
1036                            cx,
1037                        )
1038                    }))
1039                },
1040            )??;
1041            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1042        })
1043    }
1044
1045    fn auth_methods(&self) -> &[acp::AuthMethod] {
1046        &[] // No auth for in-process
1047    }
1048
1049    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1050        Task::ready(Ok(()))
1051    }
1052
1053    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1054        Some(Rc::new(NativeAgentModelSelector {
1055            session_id: session_id.clone(),
1056            connection: self.clone(),
1057        }) as Rc<dyn AgentModelSelector>)
1058    }
1059
1060    fn prompt(
1061        &self,
1062        id: Option<acp_thread::UserMessageId>,
1063        params: acp::PromptRequest,
1064        cx: &mut App,
1065    ) -> Task<Result<acp::PromptResponse>> {
1066        let id = id.expect("UserMessageId is required");
1067        let session_id = params.session_id.clone();
1068        log::info!("Received prompt request for session: {}", session_id);
1069        log::debug!("Prompt blocks count: {}", params.prompt.len());
1070        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1071
1072        self.run_turn(session_id, cx, move |thread, cx| {
1073            let content: Vec<UserMessageContent> = params
1074                .prompt
1075                .into_iter()
1076                .map(|block| UserMessageContent::from_content_block(block, path_style))
1077                .collect::<Vec<_>>();
1078            log::debug!("Converted prompt to message: {} chars", content.len());
1079            log::debug!("Message id: {:?}", id);
1080            log::debug!("Message content: {:?}", content);
1081
1082            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1083        })
1084    }
1085
1086    fn resume(
1087        &self,
1088        session_id: &acp::SessionId,
1089        _cx: &App,
1090    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1091        Some(Rc::new(NativeAgentSessionResume {
1092            connection: self.clone(),
1093            session_id: session_id.clone(),
1094        }) as _)
1095    }
1096
1097    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1098        log::info!("Cancelling on session: {}", session_id);
1099        self.0.update(cx, |agent, cx| {
1100            if let Some(agent) = agent.sessions.get(session_id) {
1101                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1102            }
1103        });
1104    }
1105
1106    fn truncate(
1107        &self,
1108        session_id: &agent_client_protocol::SessionId,
1109        cx: &App,
1110    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1111        self.0.read_with(cx, |agent, _cx| {
1112            agent.sessions.get(session_id).map(|session| {
1113                Rc::new(NativeAgentSessionTruncate {
1114                    thread: session.thread.clone(),
1115                    acp_thread: session.acp_thread.clone(),
1116                }) as _
1117            })
1118        })
1119    }
1120
1121    fn set_title(
1122        &self,
1123        session_id: &acp::SessionId,
1124        _cx: &App,
1125    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1126        Some(Rc::new(NativeAgentSessionSetTitle {
1127            connection: self.clone(),
1128            session_id: session_id.clone(),
1129        }) as _)
1130    }
1131
1132    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1133        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1134    }
1135
1136    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1137        self
1138    }
1139}
1140
1141impl acp_thread::AgentTelemetry for NativeAgentConnection {
1142    fn thread_data(
1143        &self,
1144        session_id: &acp::SessionId,
1145        cx: &mut App,
1146    ) -> Task<Result<serde_json::Value>> {
1147        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1148            return Task::ready(Err(anyhow!("Session not found")));
1149        };
1150
1151        let task = session.thread.read(cx).to_db(cx);
1152        cx.background_spawn(async move {
1153            serde_json::to_value(task.await).context("Failed to serialize thread")
1154        })
1155    }
1156}
1157
1158struct NativeAgentSessionTruncate {
1159    thread: Entity<Thread>,
1160    acp_thread: WeakEntity<AcpThread>,
1161}
1162
1163impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1164    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1165        match self.thread.update(cx, |thread, cx| {
1166            thread.truncate(message_id.clone(), cx)?;
1167            Ok(thread.latest_token_usage())
1168        }) {
1169            Ok(usage) => {
1170                self.acp_thread
1171                    .update(cx, |thread, cx| {
1172                        thread.update_token_usage(usage, cx);
1173                    })
1174                    .ok();
1175                Task::ready(Ok(()))
1176            }
1177            Err(error) => Task::ready(Err(error)),
1178        }
1179    }
1180}
1181
1182struct NativeAgentSessionResume {
1183    connection: NativeAgentConnection,
1184    session_id: acp::SessionId,
1185}
1186
1187impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1188    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1189        self.connection
1190            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1191                thread.update(cx, |thread, cx| thread.resume(cx))
1192            })
1193    }
1194}
1195
1196struct NativeAgentSessionSetTitle {
1197    connection: NativeAgentConnection,
1198    session_id: acp::SessionId,
1199}
1200
1201impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1202    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1203        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1204            return Task::ready(Err(anyhow!("session not found")));
1205        };
1206        let thread = session.thread.clone();
1207        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1208        Task::ready(Ok(()))
1209    }
1210}
1211
1212pub struct AcpThreadEnvironment {
1213    acp_thread: WeakEntity<AcpThread>,
1214}
1215
1216impl ThreadEnvironment for AcpThreadEnvironment {
1217    fn create_terminal(
1218        &self,
1219        command: String,
1220        cwd: Option<PathBuf>,
1221        output_byte_limit: Option<u64>,
1222        cx: &mut AsyncApp,
1223    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1224        let task = self.acp_thread.update(cx, |thread, cx| {
1225            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1226        });
1227
1228        let acp_thread = self.acp_thread.clone();
1229        cx.spawn(async move |cx| {
1230            let terminal = task?.await?;
1231
1232            let (drop_tx, drop_rx) = oneshot::channel();
1233            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1234
1235            cx.spawn(async move |cx| {
1236                drop_rx.await.ok();
1237                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1238            })
1239            .detach();
1240
1241            let handle = AcpTerminalHandle {
1242                terminal,
1243                _drop_tx: Some(drop_tx),
1244            };
1245
1246            Ok(Rc::new(handle) as _)
1247        })
1248    }
1249}
1250
1251pub struct AcpTerminalHandle {
1252    terminal: Entity<acp_thread::Terminal>,
1253    _drop_tx: Option<oneshot::Sender<()>>,
1254}
1255
1256impl TerminalHandle for AcpTerminalHandle {
1257    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1258        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1259    }
1260
1261    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1262        self.terminal
1263            .read_with(cx, |term, _cx| term.wait_for_exit())
1264    }
1265
1266    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1267        self.terminal
1268            .read_with(cx, |term, cx| term.current_output(cx))
1269    }
1270}
1271
1272#[cfg(test)]
1273mod internal_tests {
1274    use crate::HistoryEntryId;
1275
1276    use super::*;
1277    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1278    use fs::FakeFs;
1279    use gpui::TestAppContext;
1280    use indoc::formatdoc;
1281    use language_model::fake_provider::FakeLanguageModel;
1282    use serde_json::json;
1283    use settings::SettingsStore;
1284    use util::{path, rel_path::rel_path};
1285
1286    #[gpui::test]
1287    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1288        init_test(cx);
1289        let fs = FakeFs::new(cx.executor());
1290        fs.insert_tree(
1291            "/",
1292            json!({
1293                "a": {}
1294            }),
1295        )
1296        .await;
1297        let project = Project::test(fs.clone(), [], cx).await;
1298        let text_thread_store =
1299            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1300        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1301        let agent = NativeAgent::new(
1302            project.clone(),
1303            history_store,
1304            Templates::new(),
1305            None,
1306            fs.clone(),
1307            &mut cx.to_async(),
1308        )
1309        .await
1310        .unwrap();
1311        agent.read_with(cx, |agent, cx| {
1312            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1313        });
1314
1315        let worktree = project
1316            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1317            .await
1318            .unwrap();
1319        cx.run_until_parked();
1320        agent.read_with(cx, |agent, cx| {
1321            assert_eq!(
1322                agent.project_context.read(cx).worktrees,
1323                vec![WorktreeContext {
1324                    root_name: "a".into(),
1325                    abs_path: Path::new("/a").into(),
1326                    rules_file: None
1327                }]
1328            )
1329        });
1330
1331        // Creating `/a/.rules` updates the project context.
1332        fs.insert_file("/a/.rules", Vec::new()).await;
1333        cx.run_until_parked();
1334        agent.read_with(cx, |agent, cx| {
1335            let rules_entry = worktree
1336                .read(cx)
1337                .entry_for_path(rel_path(".rules"))
1338                .unwrap();
1339            assert_eq!(
1340                agent.project_context.read(cx).worktrees,
1341                vec![WorktreeContext {
1342                    root_name: "a".into(),
1343                    abs_path: Path::new("/a").into(),
1344                    rules_file: Some(RulesFileContext {
1345                        path_in_worktree: rel_path(".rules").into(),
1346                        text: "".into(),
1347                        project_entry_id: rules_entry.id.to_usize()
1348                    })
1349                }]
1350            )
1351        });
1352    }
1353
1354    #[gpui::test]
1355    async fn test_listing_models(cx: &mut TestAppContext) {
1356        init_test(cx);
1357        let fs = FakeFs::new(cx.executor());
1358        fs.insert_tree("/", json!({ "a": {}  })).await;
1359        let project = Project::test(fs.clone(), [], cx).await;
1360        let text_thread_store =
1361            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1362        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1363        let connection = NativeAgentConnection(
1364            NativeAgent::new(
1365                project.clone(),
1366                history_store,
1367                Templates::new(),
1368                None,
1369                fs.clone(),
1370                &mut cx.to_async(),
1371            )
1372            .await
1373            .unwrap(),
1374        );
1375
1376        // Create a thread/session
1377        let acp_thread = cx
1378            .update(|cx| {
1379                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1380            })
1381            .await
1382            .unwrap();
1383
1384        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1385
1386        let models = cx
1387            .update(|cx| {
1388                connection
1389                    .model_selector(&session_id)
1390                    .unwrap()
1391                    .list_models(cx)
1392            })
1393            .await
1394            .unwrap();
1395
1396        let acp_thread::AgentModelList::Grouped(models) = models else {
1397            panic!("Unexpected model group");
1398        };
1399        assert_eq!(
1400            models,
1401            IndexMap::from_iter([(
1402                AgentModelGroupName("Fake".into()),
1403                vec![AgentModelInfo {
1404                    id: acp::ModelId("fake/fake".into()),
1405                    name: "Fake".into(),
1406                    description: None,
1407                    icon: Some(ui::IconName::ZedAssistant),
1408                }]
1409            )])
1410        );
1411    }
1412
1413    #[gpui::test]
1414    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1415        init_test(cx);
1416        let fs = FakeFs::new(cx.executor());
1417        fs.create_dir(paths::settings_file().parent().unwrap())
1418            .await
1419            .unwrap();
1420        fs.insert_file(
1421            paths::settings_file(),
1422            json!({
1423                "agent": {
1424                    "default_model": {
1425                        "provider": "foo",
1426                        "model": "bar"
1427                    }
1428                }
1429            })
1430            .to_string()
1431            .into_bytes(),
1432        )
1433        .await;
1434        let project = Project::test(fs.clone(), [], cx).await;
1435
1436        let text_thread_store =
1437            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1438        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1439
1440        // Create the agent and connection
1441        let agent = NativeAgent::new(
1442            project.clone(),
1443            history_store,
1444            Templates::new(),
1445            None,
1446            fs.clone(),
1447            &mut cx.to_async(),
1448        )
1449        .await
1450        .unwrap();
1451        let connection = NativeAgentConnection(agent.clone());
1452
1453        // Create a thread/session
1454        let acp_thread = cx
1455            .update(|cx| {
1456                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1457            })
1458            .await
1459            .unwrap();
1460
1461        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1462
1463        // Select a model
1464        let selector = connection.model_selector(&session_id).unwrap();
1465        let model_id = acp::ModelId("fake/fake".into());
1466        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1467            .await
1468            .unwrap();
1469
1470        // Verify the thread has the selected model
1471        agent.read_with(cx, |agent, _| {
1472            let session = agent.sessions.get(&session_id).unwrap();
1473            session.thread.read_with(cx, |thread, _| {
1474                assert_eq!(thread.model().unwrap().id().0, "fake");
1475            });
1476        });
1477
1478        cx.run_until_parked();
1479
1480        // Verify settings file was updated
1481        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1482        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1483
1484        // Check that the agent settings contain the selected model
1485        assert_eq!(
1486            settings_json["agent"]["default_model"]["model"],
1487            json!("fake")
1488        );
1489        assert_eq!(
1490            settings_json["agent"]["default_model"]["provider"],
1491            json!("fake")
1492        );
1493    }
1494
1495    #[gpui::test]
1496    async fn test_save_load_thread(cx: &mut TestAppContext) {
1497        init_test(cx);
1498        let fs = FakeFs::new(cx.executor());
1499        fs.insert_tree(
1500            "/",
1501            json!({
1502                "a": {
1503                    "b.md": "Lorem"
1504                }
1505            }),
1506        )
1507        .await;
1508        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1509        let text_thread_store =
1510            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1511        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1512        let agent = NativeAgent::new(
1513            project.clone(),
1514            history_store.clone(),
1515            Templates::new(),
1516            None,
1517            fs.clone(),
1518            &mut cx.to_async(),
1519        )
1520        .await
1521        .unwrap();
1522        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1523
1524        let acp_thread = cx
1525            .update(|cx| {
1526                connection
1527                    .clone()
1528                    .new_thread(project.clone(), Path::new(""), cx)
1529            })
1530            .await
1531            .unwrap();
1532        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1533        let thread = agent.read_with(cx, |agent, _| {
1534            agent.sessions.get(&session_id).unwrap().thread.clone()
1535        });
1536
1537        // Ensure empty threads are not saved, even if they get mutated.
1538        let model = Arc::new(FakeLanguageModel::default());
1539        let summary_model = Arc::new(FakeLanguageModel::default());
1540        thread.update(cx, |thread, cx| {
1541            thread.set_model(model.clone(), cx);
1542            thread.set_summarization_model(Some(summary_model.clone()), cx);
1543        });
1544        cx.run_until_parked();
1545        assert_eq!(history_entries(&history_store, cx), vec![]);
1546
1547        let send = acp_thread.update(cx, |thread, cx| {
1548            thread.send(
1549                vec![
1550                    "What does ".into(),
1551                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1552                        name: "b.md".into(),
1553                        uri: MentionUri::File {
1554                            abs_path: path!("/a/b.md").into(),
1555                        }
1556                        .to_uri()
1557                        .to_string(),
1558                        annotations: None,
1559                        description: None,
1560                        mime_type: None,
1561                        size: None,
1562                        title: None,
1563                        meta: None,
1564                    }),
1565                    " mean?".into(),
1566                ],
1567                cx,
1568            )
1569        });
1570        let send = cx.foreground_executor().spawn(send);
1571        cx.run_until_parked();
1572
1573        model.send_last_completion_stream_text_chunk("Lorem.");
1574        model.end_last_completion_stream();
1575        cx.run_until_parked();
1576        summary_model
1577            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1578        summary_model.end_last_completion_stream();
1579
1580        send.await.unwrap();
1581        let uri = MentionUri::File {
1582            abs_path: path!("/a/b.md").into(),
1583        }
1584        .to_uri();
1585        acp_thread.read_with(cx, |thread, cx| {
1586            assert_eq!(
1587                thread.to_markdown(cx),
1588                formatdoc! {"
1589                    ## User
1590
1591                    What does [@b.md]({uri}) mean?
1592
1593                    ## Assistant
1594
1595                    Lorem.
1596
1597                "}
1598            )
1599        });
1600
1601        cx.run_until_parked();
1602
1603        // Drop the ACP thread, which should cause the session to be dropped as well.
1604        cx.update(|_| {
1605            drop(thread);
1606            drop(acp_thread);
1607        });
1608        agent.read_with(cx, |agent, _| {
1609            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1610        });
1611
1612        // Ensure the thread can be reloaded from disk.
1613        assert_eq!(
1614            history_entries(&history_store, cx),
1615            vec![(
1616                HistoryEntryId::AcpThread(session_id.clone()),
1617                format!("Explaining {}", path!("/a/b.md"))
1618            )]
1619        );
1620        let acp_thread = agent
1621            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1622            .await
1623            .unwrap();
1624        acp_thread.read_with(cx, |thread, cx| {
1625            assert_eq!(
1626                thread.to_markdown(cx),
1627                formatdoc! {"
1628                    ## User
1629
1630                    What does [@b.md]({uri}) mean?
1631
1632                    ## Assistant
1633
1634                    Lorem.
1635
1636                "}
1637            )
1638        });
1639    }
1640
1641    fn history_entries(
1642        history: &Entity<HistoryStore>,
1643        cx: &mut TestAppContext,
1644    ) -> Vec<(HistoryEntryId, String)> {
1645        history.read_with(cx, |history, _| {
1646            history
1647                .entries()
1648                .map(|e| (e.id(), e.title().to_string()))
1649                .collect::<Vec<_>>()
1650        })
1651    }
1652
1653    fn init_test(cx: &mut TestAppContext) {
1654        env_logger::try_init().ok();
1655        cx.update(|cx| {
1656            let settings_store = SettingsStore::test(cx);
1657            cx.set_global(settings_store);
1658
1659            LanguageModelRegistry::test(cx);
1660        });
1661    }
1662}