agent.rs

   1use crate::{
   2    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
   3    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
   4    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
   5    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
   6};
   7use acp_thread::AgentModelSelector;
   8use agent_client_protocol as acp;
   9use agent_settings::AgentSettings;
  10use anyhow::{Context as _, Result, anyhow};
  11use collections::{HashSet, IndexMap};
  12use fs::Fs;
  13use futures::channel::mpsc;
  14use futures::{StreamExt, future};
  15use gpui::{
  16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  17};
  18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  19use project::{Project, ProjectItem, ProjectPath, Worktree};
  20use prompt_store::{
  21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  22};
  23use settings::update_settings_file;
  24use std::any::Any;
  25use std::collections::HashMap;
  26use std::path::Path;
  27use std::rc::Rc;
  28use std::sync::Arc;
  29use util::ResultExt;
  30
  31const RULES_FILE_NAMES: [&'static str; 9] = [
  32    ".rules",
  33    ".cursorrules",
  34    ".windsurfrules",
  35    ".clinerules",
  36    ".github/copilot-instructions.md",
  37    "CLAUDE.md",
  38    "AGENT.md",
  39    "AGENTS.md",
  40    "GEMINI.md",
  41];
  42
  43pub struct RulesLoadingError {
  44    pub message: SharedString,
  45}
  46
  47/// Holds both the internal Thread and the AcpThread for a session
  48struct Session {
  49    /// The internal thread that processes messages
  50    thread: Entity<Thread>,
  51    /// The ACP thread that handles protocol communication
  52    acp_thread: WeakEntity<acp_thread::AcpThread>,
  53    _subscription: Subscription,
  54}
  55
  56pub struct LanguageModels {
  57    /// Access language model by ID
  58    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  59    /// Cached list for returning language model information
  60    model_list: acp_thread::AgentModelList,
  61    refresh_models_rx: watch::Receiver<()>,
  62    refresh_models_tx: watch::Sender<()>,
  63}
  64
  65impl LanguageModels {
  66    fn new(cx: &App) -> Self {
  67        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  68        let mut this = Self {
  69            models: HashMap::default(),
  70            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  71            refresh_models_rx,
  72            refresh_models_tx,
  73        };
  74        this.refresh_list(cx);
  75        this
  76    }
  77
  78    fn refresh_list(&mut self, cx: &App) {
  79        let providers = LanguageModelRegistry::global(cx)
  80            .read(cx)
  81            .providers()
  82            .into_iter()
  83            .filter(|provider| provider.is_authenticated(cx))
  84            .collect::<Vec<_>>();
  85
  86        let mut language_model_list = IndexMap::default();
  87        let mut recommended_models = HashSet::default();
  88
  89        let mut recommended = Vec::new();
  90        for provider in &providers {
  91            for model in provider.recommended_models(cx) {
  92                recommended_models.insert(model.id());
  93                recommended.push(Self::map_language_model_to_info(&model, provider));
  94            }
  95        }
  96        if !recommended.is_empty() {
  97            language_model_list.insert(
  98                acp_thread::AgentModelGroupName("Recommended".into()),
  99                recommended,
 100            );
 101        }
 102
 103        let mut models = HashMap::default();
 104        for provider in providers {
 105            let mut provider_models = Vec::new();
 106            for model in provider.provided_models(cx) {
 107                let model_info = Self::map_language_model_to_info(&model, &provider);
 108                let model_id = model_info.id.clone();
 109                if !recommended_models.contains(&model.id()) {
 110                    provider_models.push(model_info);
 111                }
 112                models.insert(model_id, model);
 113            }
 114            if !provider_models.is_empty() {
 115                language_model_list.insert(
 116                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 117                    provider_models,
 118                );
 119            }
 120        }
 121
 122        self.models = models;
 123        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 124        self.refresh_models_tx.send(()).ok();
 125    }
 126
 127    fn watch(&self) -> watch::Receiver<()> {
 128        self.refresh_models_rx.clone()
 129    }
 130
 131    pub fn model_from_id(
 132        &self,
 133        model_id: &acp_thread::AgentModelId,
 134    ) -> Option<Arc<dyn LanguageModel>> {
 135        self.models.get(model_id).cloned()
 136    }
 137
 138    fn map_language_model_to_info(
 139        model: &Arc<dyn LanguageModel>,
 140        provider: &Arc<dyn LanguageModelProvider>,
 141    ) -> acp_thread::AgentModelInfo {
 142        acp_thread::AgentModelInfo {
 143            id: Self::model_id(model),
 144            name: model.name().0,
 145            icon: Some(provider.icon()),
 146        }
 147    }
 148
 149    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 150        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 151    }
 152}
 153
 154pub struct NativeAgent {
 155    /// Session ID -> Session mapping
 156    sessions: HashMap<acp::SessionId, Session>,
 157    /// Shared project context for all threads
 158    project_context: Entity<ProjectContext>,
 159    project_context_needs_refresh: watch::Sender<()>,
 160    _maintain_project_context: Task<Result<()>>,
 161    context_server_registry: Entity<ContextServerRegistry>,
 162    /// Shared templates for all threads
 163    templates: Arc<Templates>,
 164    /// Cached model information
 165    models: LanguageModels,
 166    project: Entity<Project>,
 167    prompt_store: Option<Entity<PromptStore>>,
 168    fs: Arc<dyn Fs>,
 169    _subscriptions: Vec<Subscription>,
 170}
 171
 172impl NativeAgent {
 173    pub async fn new(
 174        project: Entity<Project>,
 175        templates: Arc<Templates>,
 176        prompt_store: Option<Entity<PromptStore>>,
 177        fs: Arc<dyn Fs>,
 178        cx: &mut AsyncApp,
 179    ) -> Result<Entity<NativeAgent>> {
 180        log::info!("Creating new NativeAgent");
 181
 182        let project_context = cx
 183            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 184            .await;
 185
 186        cx.new(|cx| {
 187            let mut subscriptions = vec![
 188                cx.subscribe(&project, Self::handle_project_event),
 189                cx.subscribe(
 190                    &LanguageModelRegistry::global(cx),
 191                    Self::handle_models_updated_event,
 192                ),
 193            ];
 194            if let Some(prompt_store) = prompt_store.as_ref() {
 195                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 196            }
 197
 198            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 199                watch::channel(());
 200            Self {
 201                sessions: HashMap::new(),
 202                project_context: cx.new(|_| project_context),
 203                project_context_needs_refresh: project_context_needs_refresh_tx,
 204                _maintain_project_context: cx.spawn(async move |this, cx| {
 205                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 206                }),
 207                context_server_registry: cx.new(|cx| {
 208                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 209                }),
 210                templates,
 211                models: LanguageModels::new(cx),
 212                project,
 213                prompt_store,
 214                fs,
 215                _subscriptions: subscriptions,
 216            }
 217        })
 218    }
 219
 220    pub fn models(&self) -> &LanguageModels {
 221        &self.models
 222    }
 223
 224    async fn maintain_project_context(
 225        this: WeakEntity<Self>,
 226        mut needs_refresh: watch::Receiver<()>,
 227        cx: &mut AsyncApp,
 228    ) -> Result<()> {
 229        while needs_refresh.changed().await.is_ok() {
 230            let project_context = this
 231                .update(cx, |this, cx| {
 232                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 233                })?
 234                .await;
 235            this.update(cx, |this, cx| {
 236                this.project_context = cx.new(|_| project_context);
 237            })?;
 238        }
 239
 240        Ok(())
 241    }
 242
 243    fn build_project_context(
 244        project: &Entity<Project>,
 245        prompt_store: Option<&Entity<PromptStore>>,
 246        cx: &mut App,
 247    ) -> Task<ProjectContext> {
 248        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 249        let worktree_tasks = worktrees
 250            .into_iter()
 251            .map(|worktree| {
 252                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 253            })
 254            .collect::<Vec<_>>();
 255        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 256            prompt_store.read_with(cx, |prompt_store, cx| {
 257                let prompts = prompt_store.default_prompt_metadata();
 258                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 259                    let contents = prompt_store.load(prompt_metadata.id, cx);
 260                    async move { (contents.await, prompt_metadata) }
 261                });
 262                cx.background_spawn(future::join_all(load_tasks))
 263            })
 264        } else {
 265            Task::ready(vec![])
 266        };
 267
 268        cx.spawn(async move |_cx| {
 269            let (worktrees, default_user_rules) =
 270                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 271
 272            let worktrees = worktrees
 273                .into_iter()
 274                .map(|(worktree, _rules_error)| {
 275                    // TODO: show error message
 276                    // if let Some(rules_error) = rules_error {
 277                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 278                    // }
 279                    worktree
 280                })
 281                .collect::<Vec<_>>();
 282
 283            let default_user_rules = default_user_rules
 284                .into_iter()
 285                .flat_map(|(contents, prompt_metadata)| match contents {
 286                    Ok(contents) => Some(UserRulesContext {
 287                        uuid: match prompt_metadata.id {
 288                            PromptId::User { uuid } => uuid,
 289                            PromptId::EditWorkflow => return None,
 290                        },
 291                        title: prompt_metadata.title.map(|title| title.to_string()),
 292                        contents,
 293                    }),
 294                    Err(_err) => {
 295                        // TODO: show error message
 296                        // this.update(cx, |_, cx| {
 297                        //     cx.emit(RulesLoadingError {
 298                        //         message: format!("{err:?}").into(),
 299                        //     });
 300                        // })
 301                        // .ok();
 302                        None
 303                    }
 304                })
 305                .collect::<Vec<_>>();
 306
 307            ProjectContext::new(worktrees, default_user_rules)
 308        })
 309    }
 310
 311    fn load_worktree_info_for_system_prompt(
 312        worktree: Entity<Worktree>,
 313        project: Entity<Project>,
 314        cx: &mut App,
 315    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 316        let tree = worktree.read(cx);
 317        let root_name = tree.root_name().into();
 318        let abs_path = tree.abs_path();
 319
 320        let mut context = WorktreeContext {
 321            root_name,
 322            abs_path,
 323            rules_file: None,
 324        };
 325
 326        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 327        let Some(rules_task) = rules_task else {
 328            return Task::ready((context, None));
 329        };
 330
 331        cx.spawn(async move |_| {
 332            let (rules_file, rules_file_error) = match rules_task.await {
 333                Ok(rules_file) => (Some(rules_file), None),
 334                Err(err) => (
 335                    None,
 336                    Some(RulesLoadingError {
 337                        message: format!("{err}").into(),
 338                    }),
 339                ),
 340            };
 341            context.rules_file = rules_file;
 342            (context, rules_file_error)
 343        })
 344    }
 345
 346    fn load_worktree_rules_file(
 347        worktree: Entity<Worktree>,
 348        project: Entity<Project>,
 349        cx: &mut App,
 350    ) -> Option<Task<Result<RulesFileContext>>> {
 351        let worktree = worktree.read(cx);
 352        let worktree_id = worktree.id();
 353        let selected_rules_file = RULES_FILE_NAMES
 354            .into_iter()
 355            .filter_map(|name| {
 356                worktree
 357                    .entry_for_path(name)
 358                    .filter(|entry| entry.is_file())
 359                    .map(|entry| entry.path.clone())
 360            })
 361            .next();
 362
 363        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 364        // supported. This doesn't seem to occur often in GitHub repositories.
 365        selected_rules_file.map(|path_in_worktree| {
 366            let project_path = ProjectPath {
 367                worktree_id,
 368                path: path_in_worktree.clone(),
 369            };
 370            let buffer_task =
 371                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 372            let rope_task = cx.spawn(async move |cx| {
 373                buffer_task.await?.read_with(cx, |buffer, cx| {
 374                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 375                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 376                })?
 377            });
 378            // Build a string from the rope on a background thread.
 379            cx.background_spawn(async move {
 380                let (project_entry_id, rope) = rope_task.await?;
 381                anyhow::Ok(RulesFileContext {
 382                    path_in_worktree,
 383                    text: rope.to_string().trim().to_string(),
 384                    project_entry_id: project_entry_id.to_usize(),
 385                })
 386            })
 387        })
 388    }
 389
 390    fn handle_project_event(
 391        &mut self,
 392        _project: Entity<Project>,
 393        event: &project::Event,
 394        _cx: &mut Context<Self>,
 395    ) {
 396        match event {
 397            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 398                self.project_context_needs_refresh.send(()).ok();
 399            }
 400            project::Event::WorktreeUpdatedEntries(_, items) => {
 401                if items.iter().any(|(path, _, _)| {
 402                    RULES_FILE_NAMES
 403                        .iter()
 404                        .any(|name| path.as_ref() == Path::new(name))
 405                }) {
 406                    self.project_context_needs_refresh.send(()).ok();
 407                }
 408            }
 409            _ => {}
 410        }
 411    }
 412
 413    fn handle_prompts_updated_event(
 414        &mut self,
 415        _prompt_store: Entity<PromptStore>,
 416        _event: &prompt_store::PromptsUpdatedEvent,
 417        _cx: &mut Context<Self>,
 418    ) {
 419        self.project_context_needs_refresh.send(()).ok();
 420    }
 421
 422    fn handle_models_updated_event(
 423        &mut self,
 424        _registry: Entity<LanguageModelRegistry>,
 425        _event: &language_model::Event,
 426        cx: &mut Context<Self>,
 427    ) {
 428        self.models.refresh_list(cx);
 429
 430        let default_model = LanguageModelRegistry::read_global(cx)
 431            .default_model()
 432            .map(|m| m.model.clone());
 433
 434        for session in self.sessions.values_mut() {
 435            session.thread.update(cx, |thread, cx| {
 436                if thread.model().is_none()
 437                    && let Some(model) = default_model.clone()
 438                {
 439                    thread.set_model(model);
 440                    cx.notify();
 441                }
 442            });
 443        }
 444    }
 445}
 446
 447/// Wrapper struct that implements the AgentConnection trait
 448#[derive(Clone)]
 449pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 450
 451impl NativeAgentConnection {
 452    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 453        self.0
 454            .read(cx)
 455            .sessions
 456            .get(session_id)
 457            .map(|session| session.thread.clone())
 458    }
 459
 460    fn run_turn(
 461        &self,
 462        session_id: acp::SessionId,
 463        cx: &mut App,
 464        f: impl 'static
 465        + FnOnce(
 466            Entity<Thread>,
 467            &mut App,
 468        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
 469    ) -> Task<Result<acp::PromptResponse>> {
 470        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 471            agent
 472                .sessions
 473                .get_mut(&session_id)
 474                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 475        }) else {
 476            return Task::ready(Err(anyhow!("Session not found")));
 477        };
 478        log::debug!("Found session for: {}", session_id);
 479
 480        let mut response_stream = match f(thread, cx) {
 481            Ok(stream) => stream,
 482            Err(err) => return Task::ready(Err(err)),
 483        };
 484        cx.spawn(async move |cx| {
 485            // Handle response stream and forward to session.acp_thread
 486            while let Some(result) = response_stream.next().await {
 487                match result {
 488                    Ok(event) => {
 489                        log::trace!("Received completion event: {:?}", event);
 490
 491                        match event {
 492                            AgentResponseEvent::Text(text) => {
 493                                acp_thread.update(cx, |thread, cx| {
 494                                    thread.push_assistant_content_block(
 495                                        acp::ContentBlock::Text(acp::TextContent {
 496                                            text,
 497                                            annotations: None,
 498                                        }),
 499                                        false,
 500                                        cx,
 501                                    )
 502                                })?;
 503                            }
 504                            AgentResponseEvent::Thinking(text) => {
 505                                acp_thread.update(cx, |thread, cx| {
 506                                    thread.push_assistant_content_block(
 507                                        acp::ContentBlock::Text(acp::TextContent {
 508                                            text,
 509                                            annotations: None,
 510                                        }),
 511                                        true,
 512                                        cx,
 513                                    )
 514                                })?;
 515                            }
 516                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
 517                                tool_call,
 518                                options,
 519                                response,
 520                            }) => {
 521                                let recv = acp_thread.update(cx, |thread, cx| {
 522                                    thread.request_tool_call_authorization(tool_call, options, cx)
 523                                })?;
 524                                cx.background_spawn(async move {
 525                                    if let Some(recv) = recv.log_err()
 526                                        && let Some(option) = recv
 527                                            .await
 528                                            .context("authorization sender was dropped")
 529                                            .log_err()
 530                                    {
 531                                        response
 532                                            .send(option)
 533                                            .map(|_| anyhow!("authorization receiver was dropped"))
 534                                            .log_err();
 535                                    }
 536                                })
 537                                .detach();
 538                            }
 539                            AgentResponseEvent::ToolCall(tool_call) => {
 540                                acp_thread.update(cx, |thread, cx| {
 541                                    thread.upsert_tool_call(tool_call, cx)
 542                                })??;
 543                            }
 544                            AgentResponseEvent::ToolCallUpdate(update) => {
 545                                acp_thread.update(cx, |thread, cx| {
 546                                    thread.update_tool_call(update, cx)
 547                                })??;
 548                            }
 549                            AgentResponseEvent::Retry(status) => {
 550                                acp_thread.update(cx, |thread, cx| {
 551                                    thread.update_retry_status(status, cx)
 552                                })?;
 553                            }
 554                            AgentResponseEvent::Stop(stop_reason) => {
 555                                log::debug!("Assistant message complete: {:?}", stop_reason);
 556                                return Ok(acp::PromptResponse { stop_reason });
 557                            }
 558                        }
 559                    }
 560                    Err(e) => {
 561                        log::error!("Error in model response stream: {:?}", e);
 562                        return Err(e);
 563                    }
 564                }
 565            }
 566
 567            log::info!("Response stream completed");
 568            anyhow::Ok(acp::PromptResponse {
 569                stop_reason: acp::StopReason::EndTurn,
 570            })
 571        })
 572    }
 573}
 574
 575impl AgentModelSelector for NativeAgentConnection {
 576    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 577        log::debug!("NativeAgentConnection::list_models called");
 578        let list = self.0.read(cx).models.model_list.clone();
 579        Task::ready(if list.is_empty() {
 580            Err(anyhow::anyhow!("No models available"))
 581        } else {
 582            Ok(list)
 583        })
 584    }
 585
 586    fn select_model(
 587        &self,
 588        session_id: acp::SessionId,
 589        model_id: acp_thread::AgentModelId,
 590        cx: &mut App,
 591    ) -> Task<Result<()>> {
 592        log::info!("Setting model for session {}: {}", session_id, model_id);
 593        let Some(thread) = self
 594            .0
 595            .read(cx)
 596            .sessions
 597            .get(&session_id)
 598            .map(|session| session.thread.clone())
 599        else {
 600            return Task::ready(Err(anyhow!("Session not found")));
 601        };
 602
 603        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 604            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 605        };
 606
 607        thread.update(cx, |thread, _cx| {
 608            thread.set_model(model.clone());
 609        });
 610
 611        update_settings_file::<AgentSettings>(
 612            self.0.read(cx).fs.clone(),
 613            cx,
 614            move |settings, _cx| {
 615                settings.set_model(model);
 616            },
 617        );
 618
 619        Task::ready(Ok(()))
 620    }
 621
 622    fn selected_model(
 623        &self,
 624        session_id: &acp::SessionId,
 625        cx: &mut App,
 626    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 627        let session_id = session_id.clone();
 628
 629        let Some(thread) = self
 630            .0
 631            .read(cx)
 632            .sessions
 633            .get(&session_id)
 634            .map(|session| session.thread.clone())
 635        else {
 636            return Task::ready(Err(anyhow!("Session not found")));
 637        };
 638        let Some(model) = thread.read(cx).model() else {
 639            return Task::ready(Err(anyhow!("Model not found")));
 640        };
 641        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 642        else {
 643            return Task::ready(Err(anyhow!("Provider not found")));
 644        };
 645        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 646            model, &provider,
 647        )))
 648    }
 649
 650    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 651        self.0.read(cx).models.watch()
 652    }
 653}
 654
 655impl acp_thread::AgentConnection for NativeAgentConnection {
 656    fn new_thread(
 657        self: Rc<Self>,
 658        project: Entity<Project>,
 659        cwd: &Path,
 660        cx: &mut App,
 661    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 662        let agent = self.0.clone();
 663        log::info!("Creating new thread for project at: {:?}", cwd);
 664
 665        cx.spawn(async move |cx| {
 666            log::debug!("Starting thread creation in async context");
 667
 668            // Generate session ID
 669            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 670            log::info!("Created session with ID: {}", session_id);
 671
 672            // Create AcpThread
 673            let acp_thread = cx.update(|cx| {
 674                cx.new(|cx| {
 675                    acp_thread::AcpThread::new(
 676                        "agent2",
 677                        self.clone(),
 678                        project.clone(),
 679                        session_id.clone(),
 680                        cx,
 681                    )
 682                })
 683            })?;
 684            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 685
 686            // Create Thread
 687            let thread = agent.update(
 688                cx,
 689                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 690                    // Fetch default model from registry settings
 691                    let registry = LanguageModelRegistry::read_global(cx);
 692
 693                    // Log available models for debugging
 694                    let available_count = registry.available_models(cx).count();
 695                    log::debug!("Total available models: {}", available_count);
 696
 697                    let default_model = registry.default_model().and_then(|default_model| {
 698                        agent
 699                            .models
 700                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 701                    });
 702
 703                    let thread = cx.new(|cx| {
 704                        let mut thread = Thread::new(
 705                            project.clone(),
 706                            agent.project_context.clone(),
 707                            agent.context_server_registry.clone(),
 708                            action_log.clone(),
 709                            agent.templates.clone(),
 710                            default_model,
 711                            cx,
 712                        );
 713                        thread.add_tool(CopyPathTool::new(project.clone()));
 714                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 715                        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 716                        thread.add_tool(DiagnosticsTool::new(project.clone()));
 717                        thread.add_tool(EditFileTool::new(cx.entity()));
 718                        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 719                        thread.add_tool(FindPathTool::new(project.clone()));
 720                        thread.add_tool(GrepTool::new(project.clone()));
 721                        thread.add_tool(ListDirectoryTool::new(project.clone()));
 722                        thread.add_tool(MovePathTool::new(project.clone()));
 723                        thread.add_tool(NowTool);
 724                        thread.add_tool(OpenTool::new(project.clone()));
 725                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 726                        thread.add_tool(TerminalTool::new(project.clone(), cx));
 727                        thread.add_tool(ThinkingTool);
 728                        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 729                        thread
 730                    });
 731
 732                    Ok(thread)
 733                },
 734            )??;
 735
 736            // Store the session
 737            agent.update(cx, |agent, cx| {
 738                agent.sessions.insert(
 739                    session_id,
 740                    Session {
 741                        thread,
 742                        acp_thread: acp_thread.downgrade(),
 743                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 744                            this.sessions.remove(acp_thread.session_id());
 745                        }),
 746                    },
 747                );
 748            })?;
 749
 750            Ok(acp_thread)
 751        })
 752    }
 753
 754    fn auth_methods(&self) -> &[acp::AuthMethod] {
 755        &[] // No auth for in-process
 756    }
 757
 758    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 759        Task::ready(Ok(()))
 760    }
 761
 762    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 763        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 764    }
 765
 766    fn prompt(
 767        &self,
 768        id: Option<acp_thread::UserMessageId>,
 769        params: acp::PromptRequest,
 770        cx: &mut App,
 771    ) -> Task<Result<acp::PromptResponse>> {
 772        let id = id.expect("UserMessageId is required");
 773        let session_id = params.session_id.clone();
 774        log::info!("Received prompt request for session: {}", session_id);
 775        log::debug!("Prompt blocks count: {}", params.prompt.len());
 776
 777        self.run_turn(session_id, cx, |thread, cx| {
 778            let content: Vec<UserMessageContent> = params
 779                .prompt
 780                .into_iter()
 781                .map(Into::into)
 782                .collect::<Vec<_>>();
 783            log::info!("Converted prompt to message: {} chars", content.len());
 784            log::debug!("Message id: {:?}", id);
 785            log::debug!("Message content: {:?}", content);
 786
 787            thread.update(cx, |thread, cx| thread.send(id, content, cx))
 788        })
 789    }
 790
 791    fn resume(
 792        &self,
 793        session_id: &acp::SessionId,
 794        _cx: &mut App,
 795    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 796        Some(Rc::new(NativeAgentSessionResume {
 797            connection: self.clone(),
 798            session_id: session_id.clone(),
 799        }) as _)
 800    }
 801
 802    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 803        log::info!("Cancelling on session: {}", session_id);
 804        self.0.update(cx, |agent, cx| {
 805            if let Some(agent) = agent.sessions.get(session_id) {
 806                agent.thread.update(cx, |thread, _cx| thread.cancel());
 807            }
 808        });
 809    }
 810
 811    fn session_editor(
 812        &self,
 813        session_id: &agent_client_protocol::SessionId,
 814        cx: &mut App,
 815    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 816        self.0.update(cx, |agent, _cx| {
 817            agent
 818                .sessions
 819                .get(session_id)
 820                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 821        })
 822    }
 823
 824    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 825        self
 826    }
 827}
 828
 829struct NativeAgentSessionEditor(Entity<Thread>);
 830
 831impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 832    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 833        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 834    }
 835}
 836
 837struct NativeAgentSessionResume {
 838    connection: NativeAgentConnection,
 839    session_id: acp::SessionId,
 840}
 841
 842impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 843    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
 844        self.connection
 845            .run_turn(self.session_id.clone(), cx, |thread, cx| {
 846                thread.update(cx, |thread, cx| thread.resume(cx))
 847            })
 848    }
 849}
 850
 851#[cfg(test)]
 852mod tests {
 853    use super::*;
 854    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
 855    use fs::FakeFs;
 856    use gpui::TestAppContext;
 857    use serde_json::json;
 858    use settings::SettingsStore;
 859
 860    #[gpui::test]
 861    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
 862        init_test(cx);
 863        let fs = FakeFs::new(cx.executor());
 864        fs.insert_tree(
 865            "/",
 866            json!({
 867                "a": {}
 868            }),
 869        )
 870        .await;
 871        let project = Project::test(fs.clone(), [], cx).await;
 872        let agent = NativeAgent::new(
 873            project.clone(),
 874            Templates::new(),
 875            None,
 876            fs.clone(),
 877            &mut cx.to_async(),
 878        )
 879        .await
 880        .unwrap();
 881        agent.read_with(cx, |agent, cx| {
 882            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
 883        });
 884
 885        let worktree = project
 886            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
 887            .await
 888            .unwrap();
 889        cx.run_until_parked();
 890        agent.read_with(cx, |agent, cx| {
 891            assert_eq!(
 892                agent.project_context.read(cx).worktrees,
 893                vec![WorktreeContext {
 894                    root_name: "a".into(),
 895                    abs_path: Path::new("/a").into(),
 896                    rules_file: None
 897                }]
 898            )
 899        });
 900
 901        // Creating `/a/.rules` updates the project context.
 902        fs.insert_file("/a/.rules", Vec::new()).await;
 903        cx.run_until_parked();
 904        agent.read_with(cx, |agent, cx| {
 905            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
 906            assert_eq!(
 907                agent.project_context.read(cx).worktrees,
 908                vec![WorktreeContext {
 909                    root_name: "a".into(),
 910                    abs_path: Path::new("/a").into(),
 911                    rules_file: Some(RulesFileContext {
 912                        path_in_worktree: Path::new(".rules").into(),
 913                        text: "".into(),
 914                        project_entry_id: rules_entry.id.to_usize()
 915                    })
 916                }]
 917            )
 918        });
 919    }
 920
 921    #[gpui::test]
 922    async fn test_listing_models(cx: &mut TestAppContext) {
 923        init_test(cx);
 924        let fs = FakeFs::new(cx.executor());
 925        fs.insert_tree("/", json!({ "a": {}  })).await;
 926        let project = Project::test(fs.clone(), [], cx).await;
 927        let connection = NativeAgentConnection(
 928            NativeAgent::new(
 929                project.clone(),
 930                Templates::new(),
 931                None,
 932                fs.clone(),
 933                &mut cx.to_async(),
 934            )
 935            .await
 936            .unwrap(),
 937        );
 938
 939        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
 940
 941        let acp_thread::AgentModelList::Grouped(models) = models else {
 942            panic!("Unexpected model group");
 943        };
 944        assert_eq!(
 945            models,
 946            IndexMap::from_iter([(
 947                AgentModelGroupName("Fake".into()),
 948                vec![AgentModelInfo {
 949                    id: AgentModelId("fake/fake".into()),
 950                    name: "Fake".into(),
 951                    icon: Some(ui::IconName::ZedAssistant),
 952                }]
 953            )])
 954        );
 955    }
 956
 957    #[gpui::test]
 958    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
 959        init_test(cx);
 960        let fs = FakeFs::new(cx.executor());
 961        fs.create_dir(paths::settings_file().parent().unwrap())
 962            .await
 963            .unwrap();
 964        fs.insert_file(
 965            paths::settings_file(),
 966            json!({
 967                "agent": {
 968                    "default_model": {
 969                        "provider": "foo",
 970                        "model": "bar"
 971                    }
 972                }
 973            })
 974            .to_string()
 975            .into_bytes(),
 976        )
 977        .await;
 978        let project = Project::test(fs.clone(), [], cx).await;
 979
 980        // Create the agent and connection
 981        let agent = NativeAgent::new(
 982            project.clone(),
 983            Templates::new(),
 984            None,
 985            fs.clone(),
 986            &mut cx.to_async(),
 987        )
 988        .await
 989        .unwrap();
 990        let connection = NativeAgentConnection(agent.clone());
 991
 992        // Create a thread/session
 993        let acp_thread = cx
 994            .update(|cx| {
 995                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
 996            })
 997            .await
 998            .unwrap();
 999
1000        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1001
1002        // Select a model
1003        let model_id = AgentModelId("fake/fake".into());
1004        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1005            .await
1006            .unwrap();
1007
1008        // Verify the thread has the selected model
1009        agent.read_with(cx, |agent, _| {
1010            let session = agent.sessions.get(&session_id).unwrap();
1011            session.thread.read_with(cx, |thread, _| {
1012                assert_eq!(thread.model().unwrap().id().0, "fake");
1013            });
1014        });
1015
1016        cx.run_until_parked();
1017
1018        // Verify settings file was updated
1019        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1020        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1021
1022        // Check that the agent settings contain the selected model
1023        assert_eq!(
1024            settings_json["agent"]["default_model"]["model"],
1025            json!("fake")
1026        );
1027        assert_eq!(
1028            settings_json["agent"]["default_model"]["provider"],
1029            json!("fake")
1030        );
1031    }
1032
1033    fn init_test(cx: &mut TestAppContext) {
1034        env_logger::try_init().ok();
1035        cx.update(|cx| {
1036            let settings_store = SettingsStore::test(cx);
1037            cx.set_global(settings_store);
1038            Project::init_settings(cx);
1039            agent_settings::init(cx);
1040            language::init(cx);
1041            LanguageModelRegistry::test(cx);
1042        });
1043    }
1044}